pytorch学习笔记(十七):python 端扩展 pytorch

pytorch 虽然提供了很多的 op 使得我们很容易的使用。但是当已有的 op 无法满足我们的要求的时候,那就需要自己动手来扩展。 pytorch 提供了两种方式来扩展 pytorch 的基础功能。

  • 通过继承 autograd.Function
  • 通过 C 来扩展

本篇博客主要介绍 继承 autograd.Function 来扩展 pytorch

官方文档链接

继承 autograd.Function 的 子类 只需要 实现两个 静态方法:

  • forward : 计算 op 的前向过程.
    • 在执行 forward 之前,Variable 参数已经被转换成了 Tensor
    • forward 的形参可以有默认参数,默认参数可以是任意 python 对象。
    • 可以返回任意多个 Tensor
    • 里面可以使用任何 python 操作,但是 return 的值必须是 Tensor !!!
  • backward : 计算 梯度,
    • forward 返回几个 值, 这里就需要几个 形参,还得外加一个 ctx
    • forward 有几个 形参(不包含 ctx) ,backward 就得返回几个值。
    • bacward 实参也是 Variable
    • backward 返回的得是 Variable

一个 Demo(来自官网)

class LinearFunction(Function):
    # forward 和 backward 都得是 静态方法!!!!!
    @staticmethod
    # bias 是个可选参数,有个 默认值 None
    def forward(ctx, input, weight, bias=None):
        # input,weight 都已经变成了 Tensor
        # 用 ctx 把该存的存起来,留着 backward 的时候用
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    # 由于 forward 只有一个 返回值,所以 backward 只需要一个参数 接收 梯度。
    @staticmethod
    def backward(ctx, grad_output):
        # grad_output 是 Variable 类型。
        # 在开头的地方将保存的 tensor 给 unpack 了
        # 然后 给 所有应该返回的 梯度 以 None 初始化。
        # saved_variables 返回的是 Variable!!! 不是 Tensor 了。
        input, weight, bias = ctx.saved_variables
        grad_input = grad_weight = grad_bias = None

        # needs_input_grad 检查是可选的。如果想使得 代码更简单的话,可以忽略。
        # 给不需要梯度的 参数返回梯度 不是一个错误。
        # 返回值 的个数 需要和 forward 形参的个数(不包含 ctx)一致
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)
        # 梯度的顺序和 forward 形参的顺序要对应。
        return grad_input, grad_weight, grad_bias

关于 ctx

  • save_for_backward 只能存 tensor, None, 其余都不能存。
  • save_for_backward 只保存 forward 的实参,或者 forward 的返回值。

上面就是继承 Function 的全过程,然后该怎么使用呢?

# input, weight, 是 Variable
def linear(input, weight, bias=None):
    # 一定是要 通过调用 apply 来用的。 Function.apply 中估计做了不少事情。
    return LinearFunction.apply(input, weight, bias)

也可以将 LinearFunction 封装到 nn.Module 里面,以便更简单的使用。

检查梯度计算是否正确

pytorch 提供了一个简单的 接口用来检查 定义的 梯度计算是否正确

from torch.autograd import gradcheck
# Check gradients computed via small finite differences against analytical gradients

# 检查的是 inputs 中 requires_grad=True 的梯度,
# 一定要记得 double() 一下!!!!!!
input = (Variable(torch.randn(20, 20).double(), requires_grad=True),
             Variable(torch.randn(30, 20).double(), requires_grad=True),)
test = gradcheck(LinearFunction.apply, input, eps=1e-6, atol=1e-4)
# 如果通过,最后会打印一个 True
print(test)

总结

  • forward 的形参是 Tensorreturn 的也是 Tensor
  • backward 的形参是 Variablereturn 也需要是 Variable
  • gradcheck 的时候,记得将 Tensor 的类型转成 double, 使用 float 会导致检查失败。

GlobalMaxPool例子

class GlobalMaxPool(Function):
    @staticmethod
    def forward(ctx, inputs):
        bs, c, h, w = inputs.size()
        flatten_hw = inputs.view(bs, c, -1)
        max_val, indices = torch.max(flatten_hw, dim=-1, keepdim=True)
        max_val = max_val.view(bs, c, 1, 1)
        ctx.save_for_backward(inputs, indices)
        # 只有返回 indices, 才让 save_for_backward。。。 迫不得已。
        return max_val, indices

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_max_val, grad_indices):
        inputs, indices = ctx.saved_variables

        bs, c, h, w = inputs.size()
        grad_inputs = inputs.data.new().resize_as_(inputs.data).zero_().view(bs, c, -1)
        grad_inputs.scatter_(-1, indices.data,
                             torch.squeeze(grad_max_val.data).contiguous().view(bs, c, 1))
        grad_inputs = grad_inputs.view_as(inputs.data)

        return Variable(grad_inputs, volatile=grad_max_val.volatile)


def global_max_pool(input):
    return GlobalMaxPool.apply(input)


if __name__ == '__main__':
    in_ = Variable(torch.randn(2, 1, 3, 3).double(), requires_grad=True)
    res, _ = global_max_pool(in_)
    # print(res)

    res.sum().backward()
    res = gradcheck(GlobalMaxPool.apply, (in_,))
    print(res)

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏计算机视觉与深度学习基础

Leetcode 220. Contains Duplicate III

Given an array of integers, find out whether there are two distinct indices i an...

1759
来自专栏和蔼的张星的图像处理专栏

397. 最长上升连续子序列

给定一个整数数组(下标从 0 到 n-1, n 表示整个数组的规模),请找出该数组中的最长上升连续子序列。(最长上升连续子序列可以定义为从右到左或从左到右的序列...

742
来自专栏mathor

枚举+优化(7)——前缀和1

1457
来自专栏书山有路勤为径

递归函数基础

函数代码中调用自己时称为递归,该函数被称为递归函数。递归函数是一个很高效的 开发技巧,可以极大的简化代码提高开发效率。递归函数与循环类似,循环可以完成的 事情,...

703
来自专栏程序员互动联盟

【编程基础】C语言复合赋值运算符

在C语言的赋值中有一种特殊的赋值运算符,就是复合赋值运算符。复合赋值运算符就是在赋值符“=”之前加上其它二目运算符可构成。比如大家可能最常看到这样的语句: n ...

2706
来自专栏Java技术栈

神奇,教你用随机数打印hello world

下面是一段随机数程序。 public static void main(String[] args) { System.out.println(rand...

3355
来自专栏计算机视觉与深度学习基础

Leetcode 282. Expression Add Operators

Given a string that contains only digits 0-9 and a target value, return all pos...

1658
来自专栏程序生活

最大连续子序列和

https://blog.csdn.net/bitcarmanlee/article/details/51526010

572
来自专栏ACM算法日常

字符串的距离(动态规划) - leetcode 72

,因为在刷leetcode的动态规划专题。动态规划虽然定义很简单,但是对于复杂的动态规划题目,很多时候还是很棘手的。

762
来自专栏AI派

Numpy 修炼之道 (8)—— 常用函数

上一篇:Numpy 修炼之道 (7)—— 形状操作 在了解了 Numpy 的基本运算操作,下面来看下 Numpy常用的函数。 数学运算函数 按元素添加...

3125

扫码关注云+社区