前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【他山之石】几行代码让你搞懂torch.no_grad

【他山之石】几行代码让你搞懂torch.no_grad

作者头像
马上科普尚尚
发布2021-09-16 09:55:12
4890
发布2021-09-16 09:55:12
举报

“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。

来源:知乎—小白

地址:https://zhuanlan.zhihu.com/p/406823590

https://blog.csdn.net/weixin_43178406/article/details/89517008

首先明确一点,no_grad与detach有异曲同工之妙,都是逃避autograd的追踪。

接下来我们做个实验:

代码语言:javascript
复制
代码语言:javascript
复制
a = torch.ones(2,requires_grad=True)
b = a*2
print(a, a.grad, a.requires_grad )
b.sum().backward(retain_graph = True )
print(a, a.grad, a.requires_grad )
with torch.no_grad():
    a = a + a.grad
    print(a, a.grad, a.requires_grad )
    # a.grad.zero_()
b.sum().backward(retain_graph = True )
print(a, a.grad ,a.requires_grad )
-------------------------
tensor([1., 1.], requires_grad=True) None True
tensor([1., 1.], requires_grad=True) tensor([2., 2.]) True
tensor([3., 3.]) None False
tensor([3., 3.]) None False
代码语言:javascript
复制

我们在with torch.nograd()下使用了 =+的操作,这实际上生成了一个新的变量a,因为torch.no_grad的作用下使得a变量没法求梯度。

如果使用-+的操作:

代码语言:javascript
复制
代码语言:javascript
复制
a = torch.ones(2,requires_grad=True)
b = a*2
print(a, a.grad, a.requires_grad )
b.sum().backward(retain_graph = True )
print(a, a.grad, a.requires_grad )
with torch.no_grad():
    a += a.grad
    print(a, a.grad, a.requires_grad )
    # a.grad.zero_()
b.sum().backward(retain_graph = True )
print(a, a.grad ,a.requires_grad )
---------------------------------
tensor([1., 1.], requires_grad=True) None True
tensor([1., 1.], requires_grad=True) tensor([2., 2.]) True
tensor([3., 3.], requires_grad=True) tensor([2., 2.]) True
tensor([3., 3.], requires_grad=True) tensor([4., 4.]) True
代码语言:javascript
复制

可以发现,-+的原地修改本来是不行的,因为autograd会检测你这个值是否变化,但是如果加上torch.no_grad()后就逃避了autograd的检测,在上下文管理器中只修改了tensor的data,属性没有修改,这样的话就可以对a进行求梯度的了,但是我们发现这个梯度被累加了,本来想要第二次反向传播的时候,最后a的输出不包含上一次的梯度。假定我在做一个梯度的更新操作,这个梯度累计越来越大,更新的步长越来越大,loss直接跑飞。所以得加一个梯度清零的操作。

代码语言:javascript
复制
代码语言:javascript
复制
a = torch.ones(2,requires_grad=True)
b = a*2
print(a, a.grad, a.requires_grad )
b.sum().backward(retain_graph = True )
print(a, a.grad, a.requires_grad )
with torch.no_grad():
    a += a.grad
    print(a, a.grad, a.requires_grad )
    a.grad.zero_()
b.sum().backward(retain_graph = True )
print(a, a.grad ,a.requires_grad )
-----------------------------------
tensor([1., 1.], requires_grad=True) None True
tensor([1., 1.], requires_grad=True) tensor([2., 2.]) True
tensor([3., 3.], requires_grad=True) tensor([2., 2.]) True
tensor([3., 3.], requires_grad=True) tensor([2., 2.]) True
代码语言:javascript
复制

这种过程,就相当于梯度的更新了,在完成原地修改的时候能不被autograd检测到,就是torch.no_grad的一种使用场景。

接下来,就是no_grad的其它作用了,这种在本文的首页链接中可以仔细体会。

代码语言:javascript
复制
本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。
“他山之石”历史文章
TensorRT部署深度学习模型
GPU 显存不足怎么办?
Pytorch Debug指南:15条重要建议
开源GPU显存虚拟化项目,你的2080Ti还能救一下
盘点来自工业界的GPU共享方案
Visualizer!简化你的Vision Transformer可视化!
实战:基于深度学习的道路损坏检测
深度学习Pytorch框架Tensor张量
人脸属性pytorch示例
快速掌握BERT源代码(pytorch)
PyTorch 51.BatchNorm和Dropout层的不协调现象
高效语义分割模型Fast-SCNN实操代码解析分享
图神经网络:PyG的MessagePassing机制实现
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-09-11,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 人工智能前沿讲习 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档