首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >PyTorch模块如何进行后部支撑

PyTorch模块如何进行后部支撑
EN

Stack Overflow用户
提问于 2018-04-01 12:39:38
回答 1查看 3.5K关注 0票数 6

在遵循extending PyTorch - adding a module的说明时,我注意到在扩展Module时,我们实际上并不需要实现向后的函数。我们唯一需要做的就是在forward函数中应用Function实例,当做back prop时,PyTorch可以自动调用function实例中的backward函数。这在我看来就像是魔术,因为我们甚至没有注册我们使用的函数实例。我查看了源代码,但没有发现任何相关内容。有没有人能给我指出所有这些事情真正发生的地方?

EN

回答 1

Stack Overflow用户

发布于 2018-04-01 14:13:30

也许我是错的,但我有不同的观点。

反向函数被定义,并由正向函数调用。

例如:

代码语言:javascript
复制
#!/usr/bin/env python
# encoding: utf-8

###############################################################
# Parametrized example
# --------------------
#
# This implements a layer with learnable weights.
#
# It implements the Cross-correlation with a learnable kernel.
#
# In deep learning literature, it’s confusingly referred to as
# Convolution.
#
# The backward computes the gradients wrt the input and gradients wrt the
# filter.
#
# **Implementation:**
#
# *Please Note that the implementation serves as an illustration, and we
# did not verify it’s correctness*

import torch
from torch.autograd import Function
from torch.autograd import Variable

from scipy.signal import convolve2d, correlate2d
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter


class ScipyConv2dFunction(Function):
    @staticmethod
    def forward(ctx, input, filter):
        result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
        ctx.save_for_backward(input, filter)
        return input.new(result)

    @staticmethod
    def backward(ctx, grad_output):
        input, filter = ctx.saved_tensors
        grad_output = grad_output.data
        grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')
        grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid')

        return Variable(grad_output.new(grad_input)), \
            Variable(grad_output.new(grad_filter))


class ScipyConv2d(Module):

    def __init__(self, kh, kw):
        super(ScipyConv2d, self).__init__()
        self.filter = Parameter(torch.randn(kh, kw))

    def forward(self, input):
        return ScipyConv2dFunction.apply(input, self.filter)

###############################################################
# **Example usage:**

module = ScipyConv2d(3, 3)
print(list(module.parameters()))
input = Variable(torch.randn(10, 10), requires_grad=True)
output = module(input)
print(output)
output.backward(torch.randn(8, 8))
print(input.grad)

在本例中,反向函数由ScipyConv2dFunction函数定义。

ScipyConv2dFunction由转发函数调用。

我说的对吗?

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/49594858

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档