pytorch中retain_graph参数的作用

前言

在pytorch神经网络迁移的官方教程中有这样一个损失层函数(具体看这里提供0.3.0版中文链接:https://oldpan.me/archives/pytorch-neural-transfer)。

class ContentLoss(nn.Module):

    def __init__(self, target, weight):
        super(ContentLoss, self).__init__()
        self.target = target.detach() * weight
        # 因为这里只是需要target这个数值,这个数值是一种状态,不计入计算树中。
        # 这里单纯将其当做常量对待,因此用了detach则在backward中计算梯度时不对target之前所在的计算图存在任何影响。
        self.weight = weight
        self.criterion = nn.MSELoss()

    def forward(self, input):
        self.loss = self.criterion(input * self.weight, self.target)
        self.output = input
        return self.output

    def backward(self, retain_graph=True):
        self.loss.backward(retain_graph=retain_graph)
        return self.loss

看到上面的代码,我们在内容损失层中定义了一个backward()反向反馈函数。这个函数在整个神经网络在反向循环的时候会执行loss的backward从而实现对loss的更新。

但是在这个代码中,我们设置了retain_graph=True,这个参数的作用是什么,官方定义为:

retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.

大意是如果设置为False,计算图中的中间变量在计算完后就会被释放。但是在平时的使用中这个参数默认都为False从而提高效率,和creat_graph的值一样。

正文

其实retain_graph这个参数在平常中我们是用不到的,但是在特殊的情况下我们会用到它:

假设一个我们有一个输入x,y = x **2, z = y*4,然后我们有两个输出,一个output_1 = z.mean(),另一个output_2 = z.sum()。然后我们对两个output执行backward。

In[3]: import torch
In[5]: x = torch.randn((1,4),dtype=torch.float32,requires_grad=True)
In[6]: y = x ** 2
In[7]: z = y * 4
In[8]: output1 = z.mean()
In[9]: output2 = z.sum()
In[10]: output1.backward()    # 这个代码执行正常,但是执行完中间变量都free了,所以下一个出现了问题
In[11]: output2.backward()    # 这时会引发错误
Traceback (most recent call last):
  File "/home/prototype/anaconda3/envs/pytorch-env/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2963, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-11-32d5139229de>", line 1, in <module>
    output2.backward()
  File "/home/prototype/anaconda3/envs/pytorch-env/lib/python3.6/site-packages/torch/tensor.py", line 93, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/prototype/anaconda3/envs/pytorch-env/lib/python3.6/site-packages/torch/autograd/__init__.py", line 89, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

如果我们这样写:

In[3]: import torch
  ...: from torch.autograd import Variable
  ...: x = torch.randn((1,4),dtype=torch.float32,requires_grad=True)
  ...: y = x ** 2
  ...: z = y * 4
  ...: output1 = z.mean()
  ...: output2 = z.sum()
  ...: output1.backward(retain_graph=True)   # 这里参数表明保留backward后的中间参数。
  ...: output2.backward()

有两个输出的时候就需要用到这个参数,这就和之前提到的风格迁移中Content Loss层为什么使用这个参数有了联系,因为在风格迁移中不只有Content Loss层还有Style Loss层,两个层都公用一个神经网络的参数但是有两个loss的输出,因此需要retain_graph参数为True去保留中间参数从而两个loss的backward()不会相互影响。

也就相当于,假如你有两个Loss:

# 假如你有两个Loss,先执行第一个的backward,再执行第二个backward
loss1.backward(retain_graph=True)
loss2.backward() # 执行完这个后,所有中间变量都会被释放,以便下一次的循环
optimizer.step() # 更新参数

这样就比较容易理解了。

此文由腾讯云爬虫爬取,文章来源于Oldpan博客

欢迎关注Oldpan博客公众号,持续酝酿深度学习质量文:

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

编辑于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏数据结构与算法

洛谷 P1313 计算系数

题目描述 给定一个多项式 ,请求出多项式展开后 项的系数。 输入输出格式 输入格式: 输入文件名为factor.in。 共一行,包含5 个整数,分别为 a...

3023
来自专栏李智的专栏

Python中对文件夹下的特定格式图像全部读取并转化为数组保存(也可转化为txt文件)

python下对图像进行批处理少不了读取文件夹下的全部图像,下面就以具体实例分享下对文件夹下的特定格式图像全部读取并转化为数组保存的代码,代码详解请见注释 ...

792
来自专栏数据结构与算法

1230 元素查找

1230 元素查找 时间限制: 1 s 空间限制: 128000 KB 题目等级 : 钻石 Diamond 题目描述 Description 给出...

1853
来自专栏CNN

从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)

Tensorflow官方提供的Tensorboard可以可视化神经网络结构图,但是说实话,我几乎从来不用。主要是因为Tensorboard中查看到的图结构太混乱...

3165
来自专栏null的专栏

数据结构和算法——动态规划

一、动态规划的思想     动态规划(dynamic programming)是一种算法设计的思想,主要是将一个问题划分成几个更小的问题,并对这样更小的问题进行...

3044
来自专栏瓜大三哥

视频压缩编码技术(H.264) 之哈夫曼编码

第二步,将两个最小概率组成一组,划成2 个分支域,并标以0 和1;再把2 个分支域合并成1个支域,标以两个概率之和;

772
来自专栏漫漫深度学习路

tensorflow学习笔记(三十三):ExponentialMovingAverage

ExponentialMovingAverage Some training algorithms, such as GradientDescent and M...

3776
来自专栏Petrichor的专栏

tensorflow: tf.assign探究

  Defined in tensorflow/python/ops/state_ops.py.

723
来自专栏有趣的django

第一章 魔法函数

1960
来自专栏漫漫深度学习路

pytorch: 常见bug与debug

本博文用来记录自己的 pytorch 踩坑过程,虽然 pytorch 命令式编程,声称容易 debug,可是 代码给出的错误提示可是令人相当头疼,所以在本文中记...

3458

扫码关注云+社区