我看到了以下用于扩展nn.Mudule
的代码片段。我不理解的是forward
函数中的input_ @ self.weight
。我可以理解,这是在尝试使用input_
的权重信息。但是@
总是被用作装饰器,为什么它可以这样使用呢?
class Linear(nn.Module):
def __init__(self, in_size, out_size):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_size, out_size))
self.bias = nn.Parameter(torch.randn(out_size))
def forward(self, input_):
return self.bias + input_ @ self.weight
linear = Linear(5, 2)
assert isinstance(linear, nn.Module)
assert not isinstance(linear, PyroModule)
example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)
发布于 2021-10-05 06:10:46
@
是__matmul__
函数的简写:矩阵乘法运算符。
https://stackoverflow.com/questions/69451996
复制