前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >pytorch学习笔记(十七):python 端扩展 pytorch

pytorch学习笔记(十七):python 端扩展 pytorch

作者头像
ke1th
发布2018-01-02 11:14:15
1.2K0
发布2018-01-02 11:14:15
举报

pytorch 虽然提供了很多的 op 使得我们很容易的使用。但是当已有的 op 无法满足我们的要求的时候,那就需要自己动手来扩展。 pytorch 提供了两种方式来扩展 pytorch 的基础功能。

  • 通过继承 autograd.Function
  • 通过 C 来扩展

本篇博客主要介绍 继承 autograd.Function 来扩展 pytorch

官方文档链接

继承 autograd.Function 的 子类 只需要 实现两个 静态方法:

  • forward : 计算 op 的前向过程.
    • 在执行 forward 之前,Variable 参数已经被转换成了 Tensor
    • forward 的形参可以有默认参数,默认参数可以是任意 python 对象。
    • 可以返回任意多个 Tensor
    • 里面可以使用任何 python 操作,但是 return 的值必须是 Tensor !!!
  • backward : 计算 梯度,
    • forward 返回几个 值, 这里就需要几个 形参,还得外加一个 ctx
    • forward 有几个 形参(不包含 ctx) ,backward 就得返回几个值。
    • bacward 实参也是 Variable
    • backward 返回的得是 Variable

一个 Demo(来自官网)

代码语言:javascript
复制
class LinearFunction(Function):
    # forward 和 backward 都得是 静态方法!!!!!
    @staticmethod
    # bias 是个可选参数,有个 默认值 None
    def forward(ctx, input, weight, bias=None):
        # input,weight 都已经变成了 Tensor
        # 用 ctx 把该存的存起来,留着 backward 的时候用
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    # 由于 forward 只有一个 返回值,所以 backward 只需要一个参数 接收 梯度。
    @staticmethod
    def backward(ctx, grad_output):
        # grad_output 是 Variable 类型。
        # 在开头的地方将保存的 tensor 给 unpack 了
        # 然后 给 所有应该返回的 梯度 以 None 初始化。
        # saved_variables 返回的是 Variable!!! 不是 Tensor 了。
        input, weight, bias = ctx.saved_variables
        grad_input = grad_weight = grad_bias = None

        # needs_input_grad 检查是可选的。如果想使得 代码更简单的话,可以忽略。
        # 给不需要梯度的 参数返回梯度 不是一个错误。
        # 返回值 的个数 需要和 forward 形参的个数(不包含 ctx)一致
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)
        # 梯度的顺序和 forward 形参的顺序要对应。
        return grad_input, grad_weight, grad_bias

关于 ctx

  • save_for_backward 只能存 tensor, None, 其余都不能存。
  • save_for_backward 只保存 forward 的实参,或者 forward 的返回值。

上面就是继承 Function 的全过程,然后该怎么使用呢?

代码语言:javascript
复制
# input, weight, 是 Variable
def linear(input, weight, bias=None):
    # 一定是要 通过调用 apply 来用的。 Function.apply 中估计做了不少事情。
    return LinearFunction.apply(input, weight, bias)

也可以将 LinearFunction 封装到 nn.Module 里面,以便更简单的使用。

检查梯度计算是否正确

pytorch 提供了一个简单的 接口用来检查 定义的 梯度计算是否正确

代码语言:javascript
复制
from torch.autograd import gradcheck
# Check gradients computed via small finite differences against analytical gradients

# 检查的是 inputs 中 requires_grad=True 的梯度,
# 一定要记得 double() 一下!!!!!!
input = (Variable(torch.randn(20, 20).double(), requires_grad=True),
             Variable(torch.randn(30, 20).double(), requires_grad=True),)
test = gradcheck(LinearFunction.apply, input, eps=1e-6, atol=1e-4)
# 如果通过,最后会打印一个 True
print(test)

总结

  • forward 的形参是 Tensorreturn 的也是 Tensor
  • backward 的形参是 Variablereturn 也需要是 Variable
  • gradcheck 的时候,记得将 Tensor 的类型转成 double, 使用 float 会导致检查失败。

GlobalMaxPool例子

代码语言:javascript
复制
class GlobalMaxPool(Function):
    @staticmethod
    def forward(ctx, inputs):
        bs, c, h, w = inputs.size()
        flatten_hw = inputs.view(bs, c, -1)
        max_val, indices = torch.max(flatten_hw, dim=-1, keepdim=True)
        max_val = max_val.view(bs, c, 1, 1)
        ctx.save_for_backward(inputs, indices)
        # 只有返回 indices, 才让 save_for_backward。。。 迫不得已。
        return max_val, indices

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_max_val, grad_indices):
        inputs, indices = ctx.saved_variables

        bs, c, h, w = inputs.size()
        grad_inputs = inputs.data.new().resize_as_(inputs.data).zero_().view(bs, c, -1)
        grad_inputs.scatter_(-1, indices.data,
                             torch.squeeze(grad_max_val.data).contiguous().view(bs, c, 1))
        grad_inputs = grad_inputs.view_as(inputs.data)

        return Variable(grad_inputs, volatile=grad_max_val.volatile)


def global_max_pool(input):
    return GlobalMaxPool.apply(input)


if __name__ == '__main__':
    in_ = Variable(torch.randn(2, 1, 3, 3).double(), requires_grad=True)
    res, _ = global_max_pool(in_)
    # print(res)

    res.sum().backward()
    res = gradcheck(GlobalMaxPool.apply, (in_,))
    print(res)
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 检查梯度计算是否正确
  • 总结
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档