首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >用梯度将PyTorch张量调整为小尺寸

用梯度将PyTorch张量调整为小尺寸
EN

Stack Overflow用户
提问于 2020-03-13 03:33:14
回答 2查看 1.5K关注 0票数 4

我试图将张量从(3,3)缩小到(1, 1),但我想保留原来的张量:

代码语言:javascript
运行
复制
import torch

a = torch.rand(3, 3)
a_copy = a.clone()
a_copy.resize_(1, 1)

我需要初始张量中的requires_grad=True,但是PyTorch禁止我试图调整副本的大小:

代码语言:javascript
运行
复制
a = torch.rand(3, 3, requires_grad=True)
a_copy = a.clone()
a_copy.resize_(1, 1)

抛出错误:

代码语言:javascript
运行
复制
Traceback (most recent call last):
  File "pytorch_test.py", line 7, in <module>
    a_copy.resize_(1, 1)
RuntimeError: cannot resize variables that require grad

克隆与分离

我也尝试过.clone().detach()

代码语言:javascript
运行
复制
a = torch.rand(3, 3, requires_grad=True)
a_copy = a.clone().detach()

with torch.no_grad():
    a_copy.resize_(1, 1)

这将产生此错误:

代码语言:javascript
运行
复制
Traceback (most recent call last):
  File "pytorch_test.py", line 14, in <module>
    a_copy.resize_(1, 1)
RuntimeError: set_sizes_contiguous is not allowed on a Tensor created from .data or .detach().
If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
For example, change:
    x.data.set_(y)
to:
    with torch.no_grad():
        x.set_(y)

这一行为已在医生们#15070中作了说明。

no_grad()

因此,按照他们在错误消息中所说的,我删除了.detach()并使用了no_grad()

代码语言:javascript
运行
复制
a = torch.rand(3, 3, requires_grad=True)
a_copy = a.clone()

with torch.no_grad():
    a_copy.resize_(1, 1)

但这还是给了我一个关于毕业生的错误:

代码语言:javascript
运行
复制
Traceback (most recent call last):
  File "pytorch_test.py", line 21, in <module>
    a_copy.resize_(1, 1)
RuntimeError: cannot resize variables that require grad

类似问题

我已经看过调整PyTorch张量了,但是在这个例子中,张量保留了所有原始值。我也看过比托尔更喜欢复制张量,这是我用来复制张量的方法。

我使用的是PyTorch版本1.4.0

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2020-03-13 05:41:03

我认为你应该先分离,然后克隆:

代码语言:javascript
运行
复制
a = torch.rand(3, 3, requires_grad=True)
a_copy = a.detach().clone()
a_copy.resize_(1, 1)

注意:a.detach()返回一个从当前图中分离出来的新张量(它不像a.detach_()那样将a本身从图中分离出来)。但是因为它与a共享存储,所以您也应该克隆它。这样,无论您对a_copy做什么都不会影响a。然而,我不知道为什么a.detach().clone()会工作,但是a.clone().detach()会出错。

编辑

以下代码也可以工作(这可能是一个更好的解决方案):

代码语言:javascript
运行
复制
a = torch.rand(3, 3, requires_grad=True)

with torch.no_grad():
    a_copy = a.clone()
    a_copy.resize_(1, 1)
票数 -1
EN

Stack Overflow用户

发布于 2020-09-23 10:40:45

有一个narrow()函数:

代码语言:javascript
运行
复制
def samestorage(x,y):
    if x.storage().data_ptr()==y.storage().data_ptr():
        print("same storage")
    else:
        print("different storage")
def contiguous(y):
    if True==y.is_contiguous():
        print("contiguous")
    else:
        print("non contiguous")
# narrow => same storage contiguous tensors
import torch
x = torch.randn(3, 3, requires_grad=True)
y = x.narrow(0, 1, 2) #dim, start, len  
print(x)
print(y)
contiguous(y)
samestorage(x,y)

退出:

代码语言:javascript
运行
复制
tensor([[ 1.1383, -1.2937,  0.8451],
        [ 0.0151,  0.8608,  1.4623],
        [ 0.8490, -0.0870, -0.0254]], requires_grad=True)
tensor([[ 0.0151,  0.8608,  1.4623],
        [ 0.8490, -0.0870, -0.0254]], grad_fn=<SliceBackward>)
contiguous
same storage
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60664524

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档