前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Pytorch Autograd 基础(三)

Pytorch Autograd 基础(三)

作者头像
用户6021899
发布2022-04-15 08:28:37
2060
发布2022-04-15 08:28:37
举报
文章被收录于专栏:Python编程 pyqt matplotlib

本篇介绍如何关闭和打开Autograd。

  • 关闭和打开Autograd的最简单的方法是更改tensor的requires_grad 属性。
代码语言:javascript
复制
import torch

a = torch.ones(2, 3, requires_grad=True)
print(a)

b1 = 2 * a  # b1 由 a 计算得来,继承了 a 当前额 requires_grad属性
print(b1)

a.requires_grad = False # 关闭 Autograd,不再追踪计算历史
b2 = 2 * a  # b2 由 a 计算得来,继承了 a 当前额 requires_grad属性
print(b2) # b2 也 关闭了 Autograd
print(b2.requires_grad) 
代码语言:javascript
复制
tensor([[1., 1., 1.],
        [1., 1., 1.]], requires_grad=True)
tensor([[2., 2., 2.],
        [2., 2., 2.]], grad_fn=<MulBackward0>)
tensor([[2., 2., 2.],
        [2., 2., 2.]])
False

再次打开a的Autograd,并不影响b2。

代码语言:javascript
复制
a.requires_grad = True
print(b2)
代码语言:javascript
复制
tensor([[2., 2., 2.],
        [2., 2., 2.]])

还是可以将b2的requires_grad 属性设为True

代码语言:javascript
复制
b2.requires_grad= True
print(b2)
代码语言:javascript
复制
tensor([[2., 2., 2.],
        [2., 2., 2.]], requires_grad=True)
  • 如果只是想临时的关闭Augograd,最好的方式是用torch.no_grad()。
代码语言:javascript
复制
a = torch.ones(2, 3, requires_grad=True) * 2
b = torch.ones(2, 3, requires_grad=True) * 3
c1 = a + b  # Autograd 自动打开
print(c1)

with torch.no_grad(): # 在这个上下文中临时关闭 Autograd
    c2 = a + b
    
print(c2)
c3 = a * b  # Autograd 任然自动打开
print(c3)
代码语言:javascript
复制
tensor([[5., 5., 5.],
        [5., 5., 5.]], grad_fn=<AddBackward0>)
tensor([[5., 5., 5.],
        [5., 5., 5.]])
tensor([[6., 6., 6.],
        [6., 6., 6.]], grad_fn=<MulBackward0>)
  • torch.no_grad() 可以用做函数或者方法的装饰器,来关闭Autograd
代码语言:javascript
复制
def add_tensors1(x, y):
    return x + y


@torch.no_grad() # 关闭 Augograd
def add_tensors2(x, y):
    return x + y


a = torch.ones(2, 3, requires_grad=True) * 2
b = torch.ones(2, 3, requires_grad=True) * 3
c1 = add_tensors1(a, b) # c1由a和b计算而来,跟随a和b,打开Autograd
print(c1)

c2 = add_tensors2(a, b) # 由于有@torch.no_grad(),c2关闭了Autograd
print(c2)
代码语言:javascript
复制
tensor([[5., 5., 5.],
        [5., 5., 5.]], grad_fn=<AddBackward0>)
tensor([[5., 5., 5.],
        [5., 5., 5.]])
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-04-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Python可视化编程机器学习OpenCV 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档