pytorch学习笔记(十七):python 端扩展 pytorch

pytorch 虽然提供了很多的 op 使得我们很容易的使用。但是当已有的 op 无法满足我们的要求的时候,那就需要自己动手来扩展。 pytorch 提供了两种方式来扩展 pytorch 的基础功能。

  • 通过继承 autograd.Function
  • 通过 C 来扩展

本篇博客主要介绍 继承 autograd.Function 来扩展 pytorch

官方文档链接

继承 autograd.Function 的 子类 只需要 实现两个 静态方法:

  • forward : 计算 op 的前向过程.
    • 在执行 forward 之前,Variable 参数已经被转换成了 Tensor
    • forward 的形参可以有默认参数,默认参数可以是任意 python 对象。
    • 可以返回任意多个 Tensor
    • 里面可以使用任何 python 操作,但是 return 的值必须是 Tensor !!!
  • backward : 计算 梯度,
    • forward 返回几个 值, 这里就需要几个 形参,还得外加一个 ctx
    • forward 有几个 形参(不包含 ctx) ,backward 就得返回几个值。
    • bacward 实参也是 Variable
    • backward 返回的得是 Variable

一个 Demo(来自官网)

class LinearFunction(Function):
    # forward 和 backward 都得是 静态方法!!!!!
    @staticmethod
    # bias 是个可选参数,有个 默认值 None
    def forward(ctx, input, weight, bias=None):
        # input,weight 都已经变成了 Tensor
        # 用 ctx 把该存的存起来,留着 backward 的时候用
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    # 由于 forward 只有一个 返回值,所以 backward 只需要一个参数 接收 梯度。
    @staticmethod
    def backward(ctx, grad_output):
        # grad_output 是 Variable 类型。
        # 在开头的地方将保存的 tensor 给 unpack 了
        # 然后 给 所有应该返回的 梯度 以 None 初始化。
        # saved_variables 返回的是 Variable!!! 不是 Tensor 了。
        input, weight, bias = ctx.saved_variables
        grad_input = grad_weight = grad_bias = None

        # needs_input_grad 检查是可选的。如果想使得 代码更简单的话,可以忽略。
        # 给不需要梯度的 参数返回梯度 不是一个错误。
        # 返回值 的个数 需要和 forward 形参的个数(不包含 ctx)一致
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)
        # 梯度的顺序和 forward 形参的顺序要对应。
        return grad_input, grad_weight, grad_bias

关于 ctx

  • save_for_backward 只能存 tensor, None, 其余都不能存。
  • save_for_backward 只保存 forward 的实参,或者 forward 的返回值。

上面就是继承 Function 的全过程,然后该怎么使用呢?

# input, weight, 是 Variable
def linear(input, weight, bias=None):
    # 一定是要 通过调用 apply 来用的。 Function.apply 中估计做了不少事情。
    return LinearFunction.apply(input, weight, bias)

也可以将 LinearFunction 封装到 nn.Module 里面,以便更简单的使用。

检查梯度计算是否正确

pytorch 提供了一个简单的 接口用来检查 定义的 梯度计算是否正确

from torch.autograd import gradcheck
# Check gradients computed via small finite differences against analytical gradients

# 检查的是 inputs 中 requires_grad=True 的梯度,
# 一定要记得 double() 一下!!!!!!
input = (Variable(torch.randn(20, 20).double(), requires_grad=True),
             Variable(torch.randn(30, 20).double(), requires_grad=True),)
test = gradcheck(LinearFunction.apply, input, eps=1e-6, atol=1e-4)
# 如果通过,最后会打印一个 True
print(test)

总结

  • forward 的形参是 Tensorreturn 的也是 Tensor
  • backward 的形参是 Variablereturn 也需要是 Variable
  • gradcheck 的时候,记得将 Tensor 的类型转成 double, 使用 float 会导致检查失败。

GlobalMaxPool例子

class GlobalMaxPool(Function):
    @staticmethod
    def forward(ctx, inputs):
        bs, c, h, w = inputs.size()
        flatten_hw = inputs.view(bs, c, -1)
        max_val, indices = torch.max(flatten_hw, dim=-1, keepdim=True)
        max_val = max_val.view(bs, c, 1, 1)
        ctx.save_for_backward(inputs, indices)
        # 只有返回 indices, 才让 save_for_backward。。。 迫不得已。
        return max_val, indices

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_max_val, grad_indices):
        inputs, indices = ctx.saved_variables

        bs, c, h, w = inputs.size()
        grad_inputs = inputs.data.new().resize_as_(inputs.data).zero_().view(bs, c, -1)
        grad_inputs.scatter_(-1, indices.data,
                             torch.squeeze(grad_max_val.data).contiguous().view(bs, c, 1))
        grad_inputs = grad_inputs.view_as(inputs.data)

        return Variable(grad_inputs, volatile=grad_max_val.volatile)


def global_max_pool(input):
    return GlobalMaxPool.apply(input)


if __name__ == '__main__':
    in_ = Variable(torch.randn(2, 1, 3, 3).double(), requires_grad=True)
    res, _ = global_max_pool(in_)
    # print(res)

    res.sum().backward()
    res = gradcheck(GlobalMaxPool.apply, (in_,))
    print(res)

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏一个会写诗的程序员的博客

使用 Kotlin 实现 Y 组合子(Y-Combinator)使用 Kotlin 实现 Y 组合子(Y-Combinator)《Kotlin极简教程》正式上架:

我们可以使用 Kotlin FP (Lambda, function) 写一个 Y-combinator 函数吗?

432
来自专栏深度学习那些事儿

pytorch中autograd以及hook函数详解

有些公式为图片,如果这个页面加载不出来,请看这里:https://oldpan.me/archives/pytorch-autograd-hook

95610
来自专栏数据结构与算法

极值问题

背景 小铭的数学之旅2。 描述 已知m、n为整数,且满足下列两个条件: ① m、n∈1,2,…,K ② (n^ 2-mn-m^2)^2=1 编一程序,对给定K,...

3245
来自专栏码生

Python socket 解析客户端数据

773
来自专栏算法channel

玩转Pandas,让数据处理更easy系列2

上一篇总结了Pandas中最重要的两个数据结构:Series和DataFrame,前者相当于更加强大的一维数组,是数组和字典的组合,因为既可以按照位置,也能通过...

833
来自专栏CDA数据分析师

提升R代码运算效率的11个实用方法

众所周知,当我们利用R语言处理大型数据集时,for循环语句的运算效率非常低。有许多种方法可以提升你的代码运算效率,但或许你更想了解运算效率能得到多大的提升。本文...

1888
来自专栏AI研习社

TensorFlow全新的数据读取方式:Dataset API入门教程

Dataset API是TensorFlow 1.3版本中引入的一个新的模块,主要服务于数据读取,构建输入数据的pipeline。 此前,在TensorFlow...

3623
来自专栏深度学习那些事儿

pytorch中autograd以及hook函数详解

pytorch中的Autograd mechanics(自动求梯度机制)是实现前向以及后向反馈运算极为重要的一环,pytorch官方专门针对这个机制进行了一个版...

2315
来自专栏数据结构与算法

HDU4352 XHXJ's LIS(LIS 状压)

刚开始的思路是$f[i][j]$表示到第$i$位,LIS长度为$j$的方案。 然而发现根本不能转移,除非知道了之前的状态然后重新dp一遍。。

653
来自专栏C语言及其他语言

【每日一题】问题1075: 台球碰撞(此次问题较难)

题目描述 在平面直角坐标系下,台球桌是一个左下角在(0,0),右上角在(L,W)的矩形。有一个球心在(x,y),半径为R的圆形母球放在台球桌上(整个球都在台球...

2595

扫码关注云+社区