前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >@once_differentiable有什么用?

@once_differentiable有什么用?

作者头像
狼啸风云
修改2022-09-02 13:31:25
2.1K0
修改2022-09-02 13:31:25
举报
代码语言:javascript
复制
from torch.autograd.function import once_differentiable
class GOF_Function(Function):
    @staticmethod #一般来说,要使用某个类的方法,需要先实例化一个对象再调用方法。 而使用@staticmethod或@classmethod,就可以不需要实例化,直接类名.方法名()来调用。
    def forward(ctx, weight, gaborFilterBank):# 在forward中,需要定义GOF_Function这个运算的forward计算过程
        ctx.save_for_backward(weight, gaborFilterBank)  # 将输入保存起来,在backward时使用
        output = _C.gof_forward(weight, gaborFilterBank)
        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        weight, gaborFilterBank = ctx.saved_tensors
        grad_weight = _C.gof_backward(grad_output, gaborFilterBank)
        return grad_weight, None

Pytorch提供了包torch.autograd用于自动求导。在前向过程中PyTorch会构建计算图,每个节点用Variable表示,边表示由输入节点到输出节点的函数(torch.autograd.Function对象)。Function对象不仅负责执行前向计算,在反向过程中,每个Function对象会调用.backward()函数计算输出对输入的梯度,然后将梯度传递给下一个Function对象。但是一些操作是不可导的,当你自定义的函数不可导时,在写backward函数时,就需要使用@once_differentiable。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2021-03-28 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档