前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >专栏 >探讨pytorch中nn.Module与nn.autograd.Function的backward()函数

探讨pytorch中nn.Module与nn.autograd.Function的backward()函数

原创
作者头像
老潘
发布于 2018-05-23 10:04:37
发布于 2018-05-23 10:04:37
5.2K00
代码可运行
举报
运行总次数:0
代码可运行

前言

本文讲解基于pytorch0.4.0版本,如不清楚版本信息请看这里。backward()在pytorch中是一个经常出现的函数,我们一般会在更新loss的时候使用它,比如loss.backward()。通过对loss进行backward来实现从输出到输入的自动求梯度运算。但是这里的backward()如果追根溯源一下,或者说Go to definition一下,我们会发现,其实这个backward是来源于torch.autograd.backward

上面是官方的截图信息。但是这个函数我们可能不常见,那么这个函数在哪儿呢,就在Tensor这个类中(之前是在Variable类中,现在Variable和tensor合并)。而Tensor这个类中有一个函数:

backward()函数,这个函数返回的就是torch.autograd.backward()。也就是说,我们在训练中输入我们数据,然后经过一系列神经网络运算,最后计算loss,然后loss.backward()。这里的backward()归根绝地就是,上面说的这个函数。

正文

本文要说明的两个backward,一个nn.Module中的backward()和torch.autograd.Function中的backward(),其实有一个是假的backward()。

Fake Backward

很容易发现,我们在自己定义一个全新的网络层的时候会继承nn.Module,但是我们只需要实现__init__和forward()即可,不需要实现也没必要实现backward()函数,即使你实现了,你继承了nn.Module并且编写了一个backward()函数:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class ContentLoss(nn.Module):
   def __init__(self, target, weight):
        super(ContentLoss, self).__init__()
        # we 'detach' the target content from the tree used
        self.target = target.detach() * weight
        # to dynamically compute the gradient: this is a stated value,
        # not a variable. Otherwise the forward method of the criterion
       # will throw an error.
        self.weight = weight
        self.criterion = nn.MSELoss()
  def forward(self, input):
        self.loss = self.criterion(input * self.weight, self.target)
        self.output = input
        return self.output
  def backward(self, retain_graph=True):
        print('ContentLoss Backward works')
        self.loss.backward(retain_graph=retain_graph)
        return self.loss
    ...
    # 执行backward语句,具体代码请看下方的连接。
for sl in style_losses:
    style_score += sl.backward()
for cl in content_losses:
    content_score += cl.backward()

上面这段代码是利用pytorch实现风格迁移的自定义内容损失层,如果不懂看这里:传送门。如果正常操作,在实际运行中上面的backward函数并不会执行也不回打印执行信息。上面定义的backward称为fake backward函数,也就是假的backward函数,不会在pytorch的自动求梯度图中执行。但是为什么这么写,在pytorch官方0.3.0的教程中,可以在loss更新的时候,不使用loss.backward(),而是直接使用类中的.backward()方法然后返回loss即可。

但是在官方的0.4.0的风格迁移示例代码中,上面的代码发生了变化:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class ContentLoss(nn.Module):
     def __init__(self, target, ):
         super(ContentLoss, self).__init__()
         # we 'detach' the target content from the tree used
         # to dynamically compute the gradient: this is a stated value,
         # not a variable. Otherwise the forward method of the criterion
         # will throw an error.
         self.target = target.detach()
     def forward(self, input):
        self.loss = F.mse_loss(input, self.target)
        return input
...
# 执行代码,具体看官网的最新0.4.0风格迁移教程
for sl in style_losses:
    style_score += sl.loss
for cl in content_losses:
    content_score += cl.loss
loss = style_score + content_score
loss.backward()

我们发现没有backward函数了,而且使用的loss function发生了变化,从nn.MSELoss() ==> F.mse_loss()。

上面的这段代码没有定义backward函数,也没有执行retain_grad操作。为什么两个版本的不一样,其实第一个版本(0.3.0)完全没必要写backward函数,也没必要再单独执行backward()函数,因为最终目的都是一样的,都是要实现对loss的backward,在forward中进行操作的时候,其实我们已经对torch.autograd.Function的subclass进行了操作。也就是说在我们对tensor进行每一步操作运算的时候都会生成一个Function类的子类,里面定了好了forward和backward操作,最后连成计算图,所以没有必要多此一举。

说了这么多,既然不建议在nn.Module中定义backward。那我们能不能自己定义backward函数。

Real Backward

可以的。

通过继承torch.autograd.Function来定义。这一方面官方有教程,这里就不赘述。(下方是官方示例程序)

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class MyReLU(torch.autograd.Function):
 """
 We can implement our own custom autograd Functions by subclassing
  torch.autograd.Function and implementing the forward and backward passes
  which operate on Tensors.
 """
  @staticmethod
  def forward(ctx, x):
    """
 In the forward pass we receive a context object and a Tensor containing the
 input; we must return a Tensor containing the output, and we can use the
    context object to cache objects for use in the backward pass.
 """
    ctx.save_for_backward(x)
    return x.clamp(min=0)
  def backward(ctx, grad_output):
    """
 In the backward pass we receive the context object and a Tensor containing
    the gradient of the loss with respect to the output produced during the
    forward pass. We can retrieve cached data from the context object, and must
    compute and return the gradient of the loss with respect to the input to the
    forward function.
 """
    x, = ctx.saved_tensors
    grad_x = grad_output.clone()
    grad_x[x < 0] = 0
 return grad_x

这里讲一下我们在什么情况下需要自己定义:

我们平常使用的nn.Module其实说白了就是一层包装(Contain),比如nn.Conv2继承了nn.Module,但是里面的核心函数是torch.nn.function.conv2d,为什么要包装下,原因很简单,为了方便,因为我们使用的卷积层是有参数的,这些参数是可以学习的(learnable parameters)。在这个包装类中我们通过torch.nn.parameter的Parameter类把参数进行包装然后传递给torch.nn.function中的函数进行计算,这样也就简化了我们的操作。

那么什么时候需要使用torch.autograd.Function去定义自己的层,在有些操作通过组合pytorch中已有的层实现不了的时候,比如你要实现一个新的梯度下降算法,那么就可以尝试着写这些东西。但是要注意,因为这个涉及到了底层,你需要forward和backward一起写,然后自己写对中间变量的操作,比如gradinput以及gradoutput。

比如这样写:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class my_function(Function):
 def forward(self, input, parameters):
        self.saved_for_backward = [input, parameters]
        # output = [对输入和参数进行的操作,这里省略]
        return output
 def backward(self, grad_output):
        input, parameters = self.saved_for_backward
        # grad_input = [forward(input)关于 parameters 的导数] * grad_output
        return grad_input
# 然后通过定义一个Module来包装一下
class my_module(nn.Module):
      def __init__(self, ...):
         super(my_module, self).__init__()
         self.parameters = # 初始化一些参数
      def backward(self, input):
          output = my_function(input, self.parameters) # 在这里执行你之前定义的function!
          return output

这样你就可以通过自定义层然后包装,然后来使用了。

后记

对于这个包装,其实包不包装对于执行效率的影响几乎可以不计,对于没有学习参数的层,比如Relu(nn.ReLU vs F.relu),其实包不包装对于使用起来没什么区别。

文章来自Oldpan博客:https://cloud.tencent.com/developer/article/1149993

参考链接

https://discuss.pytorch.org/t/defining-backward-function-in-nn-module/5047

https://discuss.pytorch.org/t/whats-the-difference-between-torch-nn-functional-and-torch-nn/681

https://discuss.pytorch.org/t/difference-of-methods-between-torch-nn-and-functional/1076

https://discuss.pytorch.org/t/whats-the-difference-between-torch-nn-functional-and-torch-nn/681/4

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
Kafka和RocketMQ底层存储之那些你不知道的事
我们都知道 RocketMQ 和 Kafka 消息都是存在磁盘中的,那为什么消息存磁盘读写还可以这么快?有没有做了什么优化?都是存磁盘它们两者的实现之间有什么区别么?各自有什么优缺点? 今天我们就来一
程序员小猿
2021/01/19
7080
Kafka和RocketMQ底层存储之那些你不知道的事
RocketMQ(三):面对高并发请求,如何高效持久化消息?
上篇文章我们分析完RocketMQ发送消息的原理,得到结果客户端会通过RPC组件向Broker进行通信
菜菜的后端私房菜
2024/09/18
1.1K0
【详解文件IO系列】讲讲 MQ 消息中间件(Kafka,RocketMQ等与 MMAP、PageCache 的故事
首先来看一下一般的IO调用。在传统的文件IO操作中,我们都是调用操作系统提供的底层标准IO系统调用函数 read()、write() ,此时调用此函数的进程(在JAVA中即java进程)由当前的用户态切换到内核态,然后OS的内核代码负责将相应的文件数据读取到内核的IO缓冲区,然后再把数据从内核IO缓冲区拷贝到进程的私有地址空间中去,这样便完成了一次IO操作。如下图所示。
一个会写诗的程序员
2021/03/24
1.3K0
【详解文件IO系列】讲讲 MQ 消息中间件(Kafka,RocketMQ等与 MMAP、PageCache 的故事
3分钟白话RocketMQ系列—— 如何存储消息
RocketMQ使用了一种基于日志的存储方式,将消息以顺序写入的方式追加到文件中,从而实现高性能的消息存储和读取。
阿丸笔记
2023/10/22
5490
3分钟白话RocketMQ系列—— 如何存储消息
深度解读 RocketMQ 存储机制
RocketMQ 实现了灵活的多分区和多副本机制,有效的避免了集群内单点故障对于整体服务可用性的影响。存储机制和高可用策略是 RocketMQ 稳定性的核心,社区上关于 RocketMQ 目前存储实现的分析与讨论一直是一个热议的话题。近期我一直在负责 RocketMQ 消息多副本和高可用能力的建设,和大家分享下一些有趣的想法。
从大数据到人工智能
2022/09/08
7470
?【Alibaba中间件技术系列】「RocketMQ技术专题」服务底层高性能存储设计分析
消息中间件的本身定义来考虑,应该尽量减少对于外部第三方中间件的依赖。一般来说依赖的外部系统越多,也会使得本身的设计越复杂,采用文件系统作为消息存储的方式。
码界西柚
2022/01/23
7380
?【Alibaba中间件技术系列】「RocketMQ技术专题」服务底层高性能存储设计分析
深入剖析 RocketMQ 源码 - 消息存储模块
RocketMQ 是阿里巴巴开源的分布式消息中间件,它借鉴了 Kafka 实现,支持消息订阅与发布、顺序消息、事务消息、定时消息、消息回溯、死信队列等功能。RocketMQ 架构上主要分为四部分,如下图所示:
2020labs小助手
2021/11/09
1.5K0
消息的存储-RocketMQ知识体系3
上一篇了解了RocketMQ消息发送,本文开始聊聊消息发送到Broker端后,消息存储相关的逻辑。
DougWang
2021/07/21
5460
RocketMQ存储--消息追加【源码笔记】
commitLog内存(ByteBuffer)写入位点,标记消息写到哪了,下次从该位置开始写。
瓜农老梁
2019/08/20
9790
RocketMQ存储--消息追加【源码笔记】
万字聊一聊RocketMQ一条消息短暂而又精彩的一生
我们都知道,消息是由业务系统在运行过程产生的,当我们的业务系统产生了消息,我们就可以调用RocketMQ提供的API向RocketMQ发送消息,就像下面这样
三友的java日记
2024/06/07
1560
万字聊一聊RocketMQ一条消息短暂而又精彩的一生
面试系列之-rocketmq文件数据存储
Broker上的Topic上的消息都会顺序的写入到commitlog文件下,然后再异步转存到consumequeue以及indexFile文件;该消息的元信息存储着消息所在的Topic与Queue,当消费者要进行消费时,会通过ConsumerQueue文件来找到自己想要消费的队列;该队列不存储具体的消息,而是存储消息的基本信息与偏移量。消费者通过偏移量去CommitLog中找到自己需要消费的信息然后取出,就可以进行消费;并且Broker还可以对CommitLog来建立Hash索引文件IndexFile,这样就可以通过消息的key来找到消息;
用户4283147
2022/12/29
6910
面试系列之-rocketmq文件数据存储
消息中间件—RocketMQ消息存储(一)一、MQ消息队列的一般存储方式二、RocketMQ消息存储整体架构三、RocketMQ文件存储模型层次结构四、总结
文章摘要:MQ分布式消息队列大致流程在于消息的一发一收一存,本篇将为大家主要介绍下RocketMQ存储部分的架构 消息存储是MQ消息队列中最为复杂和最为重要的一部分,所以小编也就放在RocketMQ系列篇幅中最后一部分来进行阐述和介绍。本文先从目前几种比较常用的MQ消息队列存储方式出发,为大家介绍RocketMQ选择磁盘文件存储的原因。然后,本文分别从RocketMQ的消息存储整体架构和RocketMQ文件存储模型层次结构两方面进行深入分析介绍。使得大家读完本文后对RocketMQ消息存储部分有一个大致的了解和认识。 这里先回顾往期RocketMQ技术分享的篇幅(如果有童鞋没有读过之前的文章,建议先好好读下之前小编写的篇幅或者其他网上相关的博客,把RocketMQ消息发送和消费部分的流程先大致搞明白): (1)消息中间件—RocketMQ的RPC通信(一) (2)消息中间件—RocketMQ的RPC通信(二) (3)消息中间件—RocketMQ消息发送 (4)消息中间件—RocketMQ消息消费(一) (5)消息中间件—RocketMQ消息消费(二)(push模式实现) (6)消息中间件—RocketMQ消息消费(三)(消息消费重试)
用户2991389
2018/10/10
3.1K0
消息中间件—RocketMQ消息存储(一)一、MQ消息队列的一般存储方式二、RocketMQ消息存储整体架构三、RocketMQ文件存储模型层次结构四、总结
RocketMQ MappedFile 预热原理解析
从代码中可以看出,只有 MappedFile 的大小等于或大于 CommitLog 的大小并且开启文件预热功能才会预加载文件。 CommitLog 文件的大小默认为 1 G。
java404
2019/01/03
2.9K0
10 张图告诉你 RocketMQ 是怎样保存消息的
首先,在 RocketMQ 集群中创建一个 Topic,叫做 MyTestTopic,配置如下图:
jinjunzhu
2022/12/20
8810
10 张图告诉你 RocketMQ 是怎样保存消息的
RocketMQ原理—4.消息读写的性能优化
Producer发送消息时需要指定一个Topic,需要知道Topic里有哪些Queue,以及这些Queue分别分布在哪些Broker上。因此,Producer发送消息到Broker的流程如下:
东阳马生架构
2025/04/03
1150
RocketMQ-消息存储设计
消息存储是RocketMQ中最为复杂和最为重要的一部分,将分别从RocketMQ的消息存储整体架构、PageCache与Mmap内存映射以及RocketMQ中两种不同的刷盘方式三方面来分别展开叙述。
潇洒
2023/10/20
5820
RocketMQ-消息存储设计
RocketMQ原理—1.RocketMQ整体运行原理
这个介绍顺序就基本涵盖了RocketMQ的整体运行流程,接下来首先分析RocketMQ生产者的工作原理。
东阳马生架构
2025/04/01
890
天池中间件大赛百万队列存储设计总结【复赛】
维持了 20 天的复赛终于告一段落了,国际惯例先说结果,复赛结果不太理想,一度从第 10 名掉到了最后的第 36 名,主要是写入的优化卡了 5 天,一直没有进展,最终排名也是定格在了排行榜的第二页。痛定思痛,这篇文章将自己复赛中学习的知识,成功的优化,未成功的优化都罗列一下。
kirito-moe
2018/07/25
1K0
天池中间件大赛百万队列存储设计总结【复赛】
RocketMQ存储设计精髓
对于一款消息中间件来说,优良的数据存储设计,是实现高性能消息吞吐以及消息查询的关键所在。因为消息中间件对于外部来说就是发消息消费消息的一个平台基础设施,但是从其本身来说,需要将海量消息数据信息持久化在RocketMQ节点所在的服务器上,这样即便是服务器断电,重启等情况下,也不至于丢失消息数据。另外在进行消息消费的时候,RocketMQ如何能借助自身的存储设计快速检索到对应的消息也是非常重要的,因此本文主要对RocketMQ存储设计进行了设计分析。
慕枫技术笔记
2023/03/20
4050
RocketMQ存储设计精髓
RocketMQ为什么这么快?我从源码中扒出了10大原因!
如果你对RocketMQ还不了解,可以从公众号后台菜单栏中查看我之前写的关于RocketMQ的几篇文章
三友的java日记
2024/03/18
4520
RocketMQ为什么这么快?我从源码中扒出了10大原因!
推荐阅读
相关推荐
Kafka和RocketMQ底层存储之那些你不知道的事
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验