专栏首页中科院渣渣博肆僧一枚Pytorch张量(Tensor)复制

Pytorch张量(Tensor)复制

tensor复制可以使用clone()函数和detach()函数即可实现各种需求。

clone

clone()函数可以返回一个完全相同的tensor,新的tensor开辟新的内存,但是仍然留在计算图中。

detach

detach()函数可以返回一个完全相同的tensor,新的tensor开辟与旧的tensor共享内存,新的tensor会脱离计算图,不会牵扯梯度计算。此外,一些原地操作(in-place, such as resize_ / resize_as_ / set_ / transpose_) 在两者任意一个执行都会引发错误。

使用分析

# Operation

New/Shared memory

Still in computation graph

tensor.clone()

New

Yes

tensor.detach()

Shared

No

如下执行一些实例: 首先导入包并固定随机种子

import torch
torch.manual_seed(0)

1.clone()之后的tensor requires_grad=True,detach()之后的tensor requires_grad=False,但是梯度并不会流向clone()之后的tensor

x= torch.tensor([1., 2., 3.], requires_grad=True)
clone_x = x.clone()
detach_x = x.detach()
clone_detach_x = x.clone().detach()

f = torch.nn.Linear(3, 1)
y = f(x)
y.backward()

print(x.grad)
print(clone_x.requires_grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)

Output:
--------------------------------------------
tensor([-0.0043,  0.3097, -0.4752])
True
None
False
False
--------------------------------------------

2.将计算图中参与运算tensor变为clone()后的tensor。此时梯度仍然只流向了原始的tensor。

x= torch.tensor([1., 2., 3.], requires_grad=True)
clone_x = x.clone()
detach_x = x.detach()
clone_detach_x = x.detach().clone()

f = torch.nn.Linear(3, 1)
y = f(clone_x)
y.backward()

print(x.grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)

Output:
------------------------------------
tensor([-0.0043,  0.3097, -0.4752])
None
False
False
------------------------------------

3.将原始tensor设为requires_grad=False,clone()后的梯度设为.requires_grad_(),clone()后的tensor参与计算图的运算,则梯度穿向clone()后的tensor。

x= torch.tensor([1., 2., 3.], requires_grad=False)
clone_x = x.clone().requires_grad_()
detach_x = x.detach()
clone_detach_x = x.detach().clone()

f = torch.nn.Linear(3, 1)
y = f(clone_x)
y.backward()

print(x.grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)


Output:
--------------------------------------
None
tensor([-0.0043,  0.3097, -0.4752])
False
False
--------------------------------------

4.detach()后的tensor由于与原始tensor共享内存,所以原始tensor在计算图中数值反向传播更新之后,detach()的tensor值也发生了改变。

x = torch.tensor([1., 2., 3.], requires_grad=True)
f = torch.nn.Linear(3, 1)
w = f.weight.detach()
print(f.weight)
print(w)

y = f(x)
y.backward()

optimizer = torch.optim.SGD(f.parameters(), 0.1)
optimizer.step()

print(f.weight)
print(w)


Output:
----------------------------------------------------------
Parameter containing:
tensor([[-0.0043,  0.3097, -0.4752]], requires_grad=True)
tensor([[-0.0043,  0.3097, -0.4752]])
Parameter containing:
tensor([[-0.1043,  0.1097, -0.7752]], requires_grad=True)
tensor([[-0.1043,  0.1097, -0.7752]])
----------------------------------------------------------

承接Matlab、Python和C++的编程,机器学习、计算机视觉的理论实现及辅导,本科和硕士的均可,咸鱼交易,专业回答请走知乎,详谈请联系QQ号757160542,非诚勿扰。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • torch.no_grad

    不能进行梯度计算的上下文管理器。当你确定你不调用Tensor.backward()时,不能计算梯度对测试来讲非常有用。对计算它将减少内存消耗,否则require...

    于小勇
  • numpy.ndarray()

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。 ...

    于小勇
  • 集成学习

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。 ...

    于小勇
  • Python实现图片手绘效果

    在家闲来无事看了会MOOC上的数据分析相关视频,以下部分内容来自于北京理工大学MOOC上的《Python数据分析与展示》。

    啤酒单恋小龙虾
  • Python实现手绘图效果实例分享

    到此这篇关于Python实现手绘图效果实例分享的文章就介绍到这了,更多相关Python实现手绘图效果内容请搜索ZaLou.Cn以前的文章或继续浏览下面的相关文章...

    砸漏
  • DARPA拟解开机器学习的黑匣子

    美国《MIT技术评论》杂志发表了Will Knight的文章,名为:美国军方要求自动化机器能够解释自身机制。文章表示:最新的机器学习技术本质上是一个黑匣子。美国...

    人工智能快报
  • 如何用 3 个月零基础入门机器学习?

    写这篇文章的初衷是大部分私信我的朋友都想了解如何入门 / 转行机器学习,搭上人工智能这列二十一世纪的快车。再加上这个问题每隔一阵子就会在知乎时间线上出现一次,因...

    AI研习社
  • 【收藏】这个时候才是最好的自学时间!深度学习-机器学习-GNN-NLP等AI课程超级大列表汇总,拿走不谢

    https://deep-learning-drizzle.github.io/index.html#contents

    zenRRan
  • 如何用3个月零基础入门机器学习?

    用户1737318
  • 什么,这些人你还不认识?!一文带你有姿势地侃深度学习大佬

    大数据文摘

扫码关注云+社区

领取腾讯云代金券