前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Pytorch 自动微分

Pytorch 自动微分

作者头像
Michael阿明
发布2021-02-19 15:03:05
3840
发布2021-02-19 15:03:05
举报
文章被收录于专栏:Michael阿明学习之路

参考 http://pytorch123.com/

  • Tensor.requires_grad = True 记录对Tensor的所有操作,后序.backward() 自动计算所有梯度到 .grad 属性
代码语言:javascript
复制
import torch
x = torch.ones(2,2, requires_grad=True) # 默认是False
print(x)

tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
  • 停止记录调用.detach()
代码语言:javascript
复制
x.detach_()
print(x.requires_grad) # False
  • .grad_fn 保存了创建张量的 Function 的引用
代码语言:javascript
复制
x = torch.ones(2,2, requires_grad=True)
y = x + 2
print(y)
print(y.grad_fn)

tensor([[3., 3.],
        [3., 3.]], grad_fn=<AddBackward0>)
<AddBackward0 object at 0x0000015716529D68>
代码语言:javascript
复制
z = y*y*3
out = z.mean()
print(z, out)

tensor([[27., 27.],
        [27., 27.]], grad_fn=<MulBackward0>) 

tensor(27., grad_fn=<MeanBackward0>)
代码语言:javascript
复制
# requires_grad 默认为 False
a = torch.randn(2, 2)
a = ((a*3)/(a-1))
print(a.requires_grad)  # False
b = (a*a).sum()
print(b.grad_fn)  # None

a.requires_grad_(True)  # 设置为 True
print(a.requires_grad)  # True
b = (a*a).sum()
print(b.grad_fn)
# <SumBackward0 object at 0x0000015717DC69E8>
  • backward() 后向传播
代码语言:javascript
复制
z = y*y*3
y = x+2
计算 d(out)/dx

o u t = 1 4 ( ∑ 3 ( x i + 2 ) 2 ) → d o u t d x i = 3 2 ( x i + 2 ) out = \frac{1}{4}(\sum3(x_i+2)^2) \rightarrow \frac{d_{out}}{dx_i} = \frac{3}{2}(x_i+2) out=41​(∑3(xi​+2)2)→dxi​dout​​=23​(xi​+2) x i = 1 , d o u t / d x i = 4.5 x_i = 1, d_{out}/dx_i = 4.5 xi​=1,dout​/dxi​=4.5

代码语言:javascript
复制
out.backward()
print(y.grad) # None, 为什么?是 None
print(x.grad)
tensor([[4.5000, 4.5000],
        [4.5000, 4.5000]])

J = ( ∂ y 1 ∂ x 1 ⋯ ∂ y m ∂ x 1 ⋮ ⋱ ⋮ ∂ y 1 ∂ x n ⋯ ∂ y m ∂ x n ) J=\left(\begin{array}{ccc}\frac{\partial y_{1}}{\partial x_{1}} & \cdots & \frac{\partial y_{m}}{\partial x_{1}} \\ \vdots & \ddots & \vdots \\ \frac{\partial y_{1}}{\partial x_{n}} & \cdots & \frac{\partial y_{m}}{\partial x_{n}}\end{array}\right) J=⎝⎜⎛​∂x1​∂y1​​⋮∂xn​∂y1​​​⋯⋱⋯​∂x1​∂ym​​⋮∂xn​∂ym​​​⎠⎟⎞​

  • 当又使用了一个函数 l = g ( y ) l = g(y) l=g(y),v 是 l l l 对 y y y 的导数,链式求导相乘,得到 l l l 对 x x x 的导数 J ⋅ v = ( ∂ y 1 ∂ x 1 ⋯ ∂ y m ∂ x 1 ⋮ ⋱ ⋮ ∂ y 1 ∂ x n ⋯ ∂ y m ∂ x n ) ( ∂ l ∂ y 1 ⋮ ∂ l ∂ y m ) = ( ∂ l ∂ x 1 ⋮ ∂ l ∂ x n ) J \cdot v=\left(\begin{array}{ccc}\frac{\partial y_{1}}{\partial x_{1}} & \cdots & \frac{\partial y_{m}}{\partial x_{1}} \\ \vdots & \ddots & \vdots \\ \frac{\partial y_{1}}{\partial x_{n}} & \cdots & \frac{\partial y_{m}}{\partial x_{n}}\end{array}\right)\left(\begin{array}{c}\frac{\partial l}{\partial y_{1}} \\ \vdots \\ \frac{\partial l}{\partial y_{m}}\end{array}\right)=\left(\begin{array}{c}\frac{\partial l}{\partial x_{1}} \\ \vdots \\ \frac{\partial l}{\partial x_{n}}\end{array}\right) J⋅v=⎝⎜⎛​∂x1​∂y1​​⋮∂xn​∂y1​​​⋯⋱⋯​∂x1​∂ym​​⋮∂xn​∂ym​​​⎠⎟⎞​⎝⎜⎛​∂y1​∂l​⋮∂ym​∂l​​⎠⎟⎞​=⎝⎜⎛​∂x1​∂l​⋮∂xn​∂l​​⎠⎟⎞​

上面代码改为:

代码语言:javascript
复制
v = torch.tensor(2, dtype=torch.float)
out.backward(v)
print(x.grad)

# 梯度乘以了 2
tensor([[9., 9.],
        [9., 9.]])
  • 评估阶段可以使用 with torch.no_grad(): 不需要梯度计算和更新
代码语言:javascript
复制
print(x.requires_grad) # True
print((x ** 2).requires_grad) # True

# 取消梯度记录
with torch.no_grad():
    print((x ** 2).requires_grad) # False
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2020/12/24 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档