PyTorch模块如何做后备道具

内容来源于 Stack Overflow,并遵循CC BY-SA 3.0许可协议进行翻译与使用

  • 回答 (2)
  • 关注 (0)
  • 查看 (63)

按照扩展PyTorch的指示- 添加模块,我注意到在扩展Module时,我们实际上不必实现向后功能。我们唯一需要的是在forward函数中应用Function实例,PyTorch在执行后向prop时可以自动调用Function实例中的后向实例。这对我来说似乎很神奇,因为我们甚至没有注册我们使用的Function实例。我查看了源代码,但没有发现任何相关内容。有人可以指出我所有实际发生过的地方吗?

提问于
用户回答回答于

不必实现backward()PyTorch或任何其他DL框架如此有价值的原因。实际上,backward()只有在需要弄乱网络梯度的特定情况下(或者当您创建无法使用PyTorch的内置函数表示的自定义函数时)才能实现。

PyTorch使用计算图计算后向梯度,计算图跟踪正向传递期间已完成的操作。任何操作都是在Variable这里隐式获得注册的。然后,这是从调用它的变量向后遍历图形,并应用衍生链规则来计算梯度的问题。

PyTorch的“ 关于”页面具有很好的图形可视化以及它通常如何工作。如果您需要更多详细信息,我还建议您在Google上查找计算图和自动编排机制。

编辑:所有这些发生的源代码将在PyTorch的代码库的C部分,其中实现了实际的图形。经过一番挖掘,我发现了这个

/// Evaluates the function on the given inputs and returns the result of the
/// function call.
variable_list operator()(const variable_list& inputs) {
    profiler::RecordFunction rec(this);
    if (jit::tracer::isTracingVar(inputs)) {
        return traced_apply(inputs);
    }
    return apply(inputs);
}

因此,在每个Function中,PyTorch首先检查其输入是否需要跟踪,并执行此处实现的trace_apply()。您可以看到正在创建的节点并将其附加到图表中:

// Insert a CppOp in the trace.
auto& graph = state->graph;
std::vector<VariableFlags> var_flags;
for(auto & input: inputs) {
    var_flags.push_back(VariableFlags::of(input));
}
auto* this_node = graph->createCppOp(get_shared_ptr(), std::move(var_flags));
// ...
for (auto& input: inputs) {
    this_node->addInput(tracer::getValueTrace(state, input));
}
graph->appendNode(this_node);

我最好的猜测是每个Function对象在执行时注册自己及其输入(如果需要)。每个非函数调用(例如,variable.dot())只是按照相应的函数,所以这仍然适用。

注意:我不参与PyTorch的开发,也不是其架构方面的专家。任何更正或补充都将受到欢迎。

用户回答回答于

也许我不对,但我有不同的看法。

后向功能由前向功能定义并调用。

例如:

#!/usr/bin/env python
# encoding: utf-8

###############################################################
# Parametrized example
# --------------------
#
# This implements a layer with learnable weights.
#
# It implements the Cross-correlation with a learnable kernel.
#
# In deep learning literature, it’s confusingly referred to as
# Convolution.
#
# The backward computes the gradients wrt the input and gradients wrt the
# filter.
#
# **Implementation:**
#
# *Please Note that the implementation serves as an illustration, and we
# did not verify it’s correctness*

import torch
from torch.autograd import Function
from torch.autograd import Variable

from scipy.signal import convolve2d, correlate2d
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter


class ScipyConv2dFunction(Function):
    @staticmethod
    def forward(ctx, input, filter):
        result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
        ctx.save_for_backward(input, filter)
        return input.new(result)

    @staticmethod
    def backward(ctx, grad_output):
        input, filter = ctx.saved_tensors
        grad_output = grad_output.data
        grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')
        grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid')

        return Variable(grad_output.new(grad_input)), \
            Variable(grad_output.new(grad_filter))


class ScipyConv2d(Module):

    def __init__(self, kh, kw):
        super(ScipyConv2d, self).__init__()
        self.filter = Parameter(torch.randn(kh, kw))

    def forward(self, input):
        return ScipyConv2dFunction.apply(input, self.filter)

###############################################################
# **Example usage:**

module = ScipyConv2d(3, 3)
print(list(module.parameters()))
input = Variable(torch.randn(10, 10), requires_grad=True)
output = module(input)
print(output)
output.backward(torch.randn(8, 8))
print(input.grad)

在此示例中,向后函数由ScipyConv2dFunction函数定义。

ScipyConv2dFunction由forward函数调用。

我对么?

扫码关注云+社区

领取腾讯云代金券