专栏首页深度学习那些事儿pytorch中retain_graph参数的作用

pytorch中retain_graph参数的作用

前言

在pytorch神经网络迁移的官方教程中有这样一个损失层函数(具体看这里提供0.3.0版中文链接:https://oldpan.me/archives/pytorch-neural-transfer)。

class ContentLoss(nn.Module):

    def __init__(self, target, weight):
        super(ContentLoss, self).__init__()
        self.target = target.detach() * weight
        # 因为这里只是需要target这个数值,这个数值是一种状态,不计入计算树中。
        # 这里单纯将其当做常量对待,因此用了detach则在backward中计算梯度时不对target之前所在的计算图存在任何影响。
        self.weight = weight
        self.criterion = nn.MSELoss()

    def forward(self, input):
        self.loss = self.criterion(input * self.weight, self.target)
        self.output = input
        return self.output

    def backward(self, retain_graph=True):
        self.loss.backward(retain_graph=retain_graph)
        return self.loss

看到上面的代码,我们在内容损失层中定义了一个backward()反向反馈函数。这个函数在整个神经网络在反向循环的时候会执行loss的backward从而实现对loss的更新。

但是在这个代码中,我们设置了retain_graph=True,这个参数的作用是什么,官方定义为:

retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.

大意是如果设置为False,计算图中的中间变量在计算完后就会被释放。但是在平时的使用中这个参数默认都为False从而提高效率,和creat_graph的值一样。

正文

其实retain_graph这个参数在平常中我们是用不到的,但是在特殊的情况下我们会用到它:

假设一个我们有一个输入x,y = x **2, z = y*4,然后我们有两个输出,一个output_1 = z.mean(),另一个output_2 = z.sum()。然后我们对两个output执行backward。

In[3]: import torch
In[5]: x = torch.randn((1,4),dtype=torch.float32,requires_grad=True)
In[6]: y = x ** 2
In[7]: z = y * 4
In[8]: output1 = z.mean()
In[9]: output2 = z.sum()
In[10]: output1.backward()    # 这个代码执行正常,但是执行完中间变量都free了,所以下一个出现了问题
In[11]: output2.backward()    # 这时会引发错误
Traceback (most recent call last):
  File "/home/prototype/anaconda3/envs/pytorch-env/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2963, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-11-32d5139229de>", line 1, in <module>
    output2.backward()
  File "/home/prototype/anaconda3/envs/pytorch-env/lib/python3.6/site-packages/torch/tensor.py", line 93, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/prototype/anaconda3/envs/pytorch-env/lib/python3.6/site-packages/torch/autograd/__init__.py", line 89, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

如果我们这样写:

In[3]: import torch
  ...: from torch.autograd import Variable
  ...: x = torch.randn((1,4),dtype=torch.float32,requires_grad=True)
  ...: y = x ** 2
  ...: z = y * 4
  ...: output1 = z.mean()
  ...: output2 = z.sum()
  ...: output1.backward(retain_graph=True)   # 这里参数表明保留backward后的中间参数。
  ...: output2.backward()

有两个输出的时候就需要用到这个参数,这就和之前提到的风格迁移中Content Loss层为什么使用这个参数有了联系,因为在风格迁移中不只有Content Loss层还有Style Loss层,两个层都公用一个神经网络的参数但是有两个loss的输出,因此需要retain_graph参数为True去保留中间参数从而两个loss的backward()不会相互影响。

也就相当于,假如你有两个Loss:

# 假如你有两个Loss,先执行第一个的backward,再执行第二个backward
loss1.backward(retain_graph=True)
loss2.backward() # 执行完这个后,所有中间变量都会被释放,以便下一次的循环
optimizer.step() # 更新参数

这样就比较容易理解了。

此文由腾讯云爬虫爬取,文章来源于Oldpan博客

欢迎关注Oldpan博客公众号,持续酝酿深度学习质量文:

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 探讨pytorch中nn.Module与nn.autograd.Function的backward()函数

    本文讲解基于pytorch0.4.0版本,如不清楚版本信息请看这里。backward()在pytorch中是一个经常出现的函数,我们一般会在更新loss的时候使...

    OLDPAN
  • 探讨pytorch中nn.Module与nn.autograd.Function的backward()函数

    本文讲解基于pytorch0.4.0版本,如不清楚版本信息请看这里。backward()在pytorch中是一个经常出现的函数,我们一般会在更新loss的时候使...

    OLDPAN
  • 浅谈Pytorch与Torch的关系

    Pytorch发布已经有一段时间了,我们在使用中也发现了其独特的动态图设计,让我们可以高效地进行神经网络的构造、实现我们的想法。那么Pytorch是怎么来的,追...

    OLDPAN
  • 【致敬周杰伦】基于TensorFlow让机器生成周董的歌词(附源码)

    ? 周杰伦 深深地 影响了我们 一代人 这句话 不足为过 前言 今日推文将介绍如何使用TensorFlow一步步来搭建一个序列建模的应用——机器创作歌词,训练...

    量化投资与机器学习微信公众号
  • MJRefresh 源码阅读

    用户2215591
  • Python在windows CMD中输出中文

    happy123.me
  • iOS 仿系统指南针

    且行且珍惜_iOS
  • python面向对象小练习

    流川疯
  • 小Q-百度定位及递增递减的Label数字未完待续,持续更新中

    GuangdongQi
  • 自动机器学习新进展!性能超过人类调参师6个点,AutoGluon 低调开源

    近日,亚马逊正式推出了开源代码库 AutoGluon。开发者依靠仅仅几行代码,就可以编写出 AI 嵌入应用程序。

    Datawhale

扫码关注云+社区

领取腾讯云代金券