pytorch中retain_graph参数的作用

前言

在pytorch神经网络迁移的官方教程中有这样一个损失层函数(具体看这里提供0.3.0版中文链接:https://oldpan.me/archives/pytorch-neural-transfer)。

class ContentLoss(nn.Module):

    def __init__(self, target, weight):
        super(ContentLoss, self).__init__()
        self.target = target.detach() * weight
        # 因为这里只是需要target这个数值,这个数值是一种状态,不计入计算树中。
        # 这里单纯将其当做常量对待,因此用了detach则在backward中计算梯度时不对target之前所在的计算图存在任何影响。
        self.weight = weight
        self.criterion = nn.MSELoss()

    def forward(self, input):
        self.loss = self.criterion(input * self.weight, self.target)
        self.output = input
        return self.output

    def backward(self, retain_graph=True):
        self.loss.backward(retain_graph=retain_graph)
        return self.loss

看到上面的代码,我们在内容损失层中定义了一个backward()反向反馈函数。这个函数在整个神经网络在反向循环的时候会执行loss的backward从而实现对loss的更新。

但是在这个代码中,我们设置了retain_graph=True,这个参数的作用是什么,官方定义为:

retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.

大意是如果设置为False,计算图中的中间变量在计算完后就会被释放。但是在平时的使用中这个参数默认都为False从而提高效率,和creat_graph的值一样。

正文

其实retain_graph这个参数在平常中我们是用不到的,但是在特殊的情况下我们会用到它:

假设一个我们有一个输入x,y = x **2, z = y*4,然后我们有两个输出,一个output_1 = z.mean(),另一个output_2 = z.sum()。然后我们对两个output执行backward。

In[3]: import torch
In[5]: x = torch.randn((1,4),dtype=torch.float32,requires_grad=True)
In[6]: y = x ** 2
In[7]: z = y * 4
In[8]: output1 = z.mean()
In[9]: output2 = z.sum()
In[10]: output1.backward()    # 这个代码执行正常,但是执行完中间变量都free了,所以下一个出现了问题
In[11]: output2.backward()    # 这时会引发错误
Traceback (most recent call last):
  File "/home/prototype/anaconda3/envs/pytorch-env/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2963, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-11-32d5139229de>", line 1, in <module>
    output2.backward()
  File "/home/prototype/anaconda3/envs/pytorch-env/lib/python3.6/site-packages/torch/tensor.py", line 93, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/prototype/anaconda3/envs/pytorch-env/lib/python3.6/site-packages/torch/autograd/__init__.py", line 89, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

如果我们这样写:

In[3]: import torch
  ...: from torch.autograd import Variable
  ...: x = torch.randn((1,4),dtype=torch.float32,requires_grad=True)
  ...: y = x ** 2
  ...: z = y * 4
  ...: output1 = z.mean()
  ...: output2 = z.sum()
  ...: output1.backward(retain_graph=True)   # 这里参数表明保留backward后的中间参数。
  ...: output2.backward()

有两个输出的时候就需要用到这个参数,这就和之前提到的风格迁移中Content Loss层为什么使用这个参数有了联系,因为在风格迁移中不只有Content Loss层还有Style Loss层,两个层都公用一个神经网络的参数但是有两个loss的输出,因此需要retain_graph参数为True去保留中间参数从而两个loss的backward()不会相互影响。

也就相当于,假如你有两个Loss:

# 假如你有两个Loss,先执行第一个的backward,再执行第二个backward
loss1.backward(retain_graph=True)
loss2.backward() # 执行完这个后,所有中间变量都会被释放,以便下一次的循环
optimizer.step() # 更新参数

这样就比较容易理解了。

此文由腾讯云爬虫爬取,文章来源于Oldpan博客

欢迎关注Oldpan博客公众号,持续酝酿深度学习质量文:

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

编辑于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏落影的专栏

程序员进阶之算法练习(十四)

前言 坚持做算法练习对开发的好处是抽象能力变强,拿到一个需求能很快对其进行抽象,然后再用学过的设计模式相关知识进行整理,最后用代码实现。 最大的好处在于:对...

36670
来自专栏码云1024

Tensorflow 搭建神经网络 (一)

本文为中国大学MOOC课程《人工智能实践:Tensorflow笔记》的笔记中搭建神经网络,总结搭建八股的部分

680150
来自专栏翻译

路径查找器AI

问题源于我想建立一个游戏AI,它要能够定义一条从起点到终点的路径,同时避开路上的墙壁障碍物。为此,我写了一个C#库(path.dll),它允许定义一个二维空间(...

24470
来自专栏码匠的流水账

HashedWheelTimer算法详解

George Varghese 和 Tony Lauck 1996 年的论文:Hashed and Hierarchical Timing Wheels: da...

24210
来自专栏贾志刚-OpenCV学堂

TensorFlow中的feed与fetch

TensorFlow中的feed与fetch 一:占位符(placeholder)与feed 当我们构建一个模型的时候,有时候我们需要在运行时候输入一些初始数...

55770
来自专栏王小雷

R语言基础命令与安装

1. R的安装过程 1.1.首先附上清华线路的下载链接Windows版3.3.1 1.2. 选择安装路径 ? 1.3. 注意根据自己的计算机位数选择,如我的是6...

25450
来自专栏深度学习那些事儿

在pytorch中实现与TensorFlow类似的"same"方式padding

文章来自Oldpan博客:https://oldpan.me/archives/pytorch-same-padding-tflike

3.7K70
来自专栏转载gongluck的CSDN博客

[C++]:A*——A Star算法简介

A*算法 求最优解 算法一直维护两个表: Open和Close 将起点S加入Open中 将所有S可到达的点(障碍物以及位于Close表中的点均看成不可达...

49760
来自专栏雨过天晴

原 Hash一致性算法实现

17730
来自专栏吉浦迅科技

DAY44:阅读Texture Reference API

fetches from the region of linear memory bound to the one-dimensional texture re...

11420

扫码关注云+社区

领取腾讯云代金券