前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系

从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系

作者头像
狼啸风云
修改2022-09-02 22:29:12
3.5K0
修改2022-09-02 22:29:12
举报

relu多种实现之间的关系:

relu 函数在 pytorch 中总共有 3 次出现:

  1. torch.nn.ReLU()
  2. torch.nn.functional.relu_() torch.nn.functional.relu_()
  3. torch.relu() torch.relu_()

而这3种不同的实现其实是有固定的包装关系,由上至下是由表及里的过程。其中最后一个实际上并不被 pytorch 的官方文档包含,同时也找不到对应的 python 代码,只是在 __init__.pyi 中存在,因为他们来自于通过C++编写的THNN库。

下面通过分析源码来进行具体分析:

1、torch.nn.ReLU()

torch.nn 中的类代表的是神经网络层,这里我们看到作为类出现的 ReLU() 实际上只是调用了 torch.nn.functional 中的 relu relu_ 实现。

代码语言:javascript
复制
class ReLU(Module):
    r"""Applies the rectified linear unit function element-wise:

    :math:`\text{ReLU}(x)= \max(0, x)`

    Args:
        inplace: can optionally do the operation in-place. Default: ``False``

    Shape:
        - Input: :math:`(N, *)` where `*` means, any number of additional
          dimensions
        - Output: :math:`(N, *)`, same shape as the input

    .. image:: scripts/activation_images/ReLU.png

    Examples::

        >>> m = nn.ReLU()
        >>> input = torch.randn(2)
        >>> output = m(input)


      An implementation of CReLU - https://arxiv.org/abs/1603.05201

        >>> m = nn.ReLU()
        >>> input = torch.randn(2).unsqueeze(0)
        >>> output = torch.cat((m(input),m(-input)))
    """
    __constants__ = ['inplace']

    def __init__(self, inplace=False):
        super(ReLU, self).__init__()
        self.inplace = inplace

    @weak_script_method
    def forward(self, input):
      # F 来自于 import nn.functional as F
        return F.relu(input, inplace=self.inplace)

    def extra_repr(self):
        inplace_str = 'inplace' if self.inplace else ''
        return inplace_str

2、torch.nn.functional.relu() torch.nn.functional.relu_()

其实这两个函数也是调用了 torch.relu() and torch.relu_()

代码语言:javascript
复制
def relu(input, inplace=False):
    # type: (Tensor, bool) -> Tensor
    r"""relu(input, inplace=False) -> Tensor

    Applies the rectified linear unit function element-wise. See
    :class:`~torch.nn.ReLU` for more details.
    """
    if inplace:
        result = torch.relu_(input)
    else:
        result = torch.relu(input)
    return result


relu_ = _add_docstr(torch.relu_, r"""
relu_(input) -> Tensor

In-place version of :func:`~relu`.
""")

至此我们对 RELU 函数在 torch 中的出现有了一个深入的认识。实际上作为基础的两个包,torch.nntorch.nn.functional 的关系是引用与包装的关系。

3、torch.nn 与 torch.nn.functional 的区别与联系

结合上述对 relu 的分析,我们能够更清晰的认识到两个库之间的联系。

通常来说 torch.nn.functional 调用了 THNN库,实现核心计算,但是不对 learnable_parameters 例如 weight bias ,进行管理,为模型的使用带来不便。而 torch.nn 中实现的模型则对 torch.nn.functional,本质上是官方给出的对 torch.nn.functional的使用范例,我们通过直接调用这些范例能够快速方便的使用 pytorch ,但是范例可能不能够照顾到所有人的使用需求,因此保留 torch.nn.functional 来为这些用户提供灵活性,他们可以自己组装需要的模型。因此 pytorch 能够在灵活性与易用性上取得平衡。

特别注意的是,torch.nn不全都是对torch.nn.functional的范例,有一些调用了来自其他库的函数,例如常用的RNN型神经网络族即没有在torch.nn.functional中出现。

我们带着这样的思考再来看下一个例子作为结束:

对于Linear请注意对比两个库下实现的不同:

  1. learnable parameters的管理
  2. 相互之间的调用关系
  3. 初始化过程
代码语言:javascript
复制
class Linear(Module):
    r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`

    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to ``False``, the layer will not learn an additive bias.
            Default: ``True``

    Shape:
        - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
          additional dimensions and :math:`H_{in} = \text{in\_features}`
        - Output: :math:`(N, *, H_{out})` where all but the last dimension
          are the same shape as the input and :math:`H_{out} = \text{out\_features}`.

    Attributes:
        weight: the learnable weights of the module of shape
            :math:`(\text{out\_features}, \text{in\_features})`. The values are
            initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
            :math:`k = \frac{1}{\text{in\_features}}`
        bias:   the learnable bias of the module of shape :math:`(\text{out\_features})`.
                If :attr:`bias` is ``True``, the values are initialized from
                :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                :math:`k = \frac{1}{\text{in\_features}}`

    Examples::

        >>> m = nn.Linear(20, 30)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([128, 30])
    """
    __constants__ = ['bias']

    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    @weak_script_method
    def forward(self, input):
        return F.linear(input, self.weight, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )
def linear(input, weight, bias=None):
    # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
    r"""
    Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.

    Shape:

        - Input: :math:`(N, *, in\_features)` where `*` means any number of
          additional dimensions
        - Weight: :math:`(out\_features, in\_features)`
        - Bias: :math:`(out\_features)`
        - Output: :math:`(N, *, out\_features)`
    """
    if input.dim() == 2 and bias is not None:
        # fused op is marginally faster
        ret = torch.addmm(bias, input, weight.t())
    else:
        output = input.matmul(weight.t())
        if bias is not None:
            output += bias
        ret = output
    return ret
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2020-05-03 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1、torch.nn.ReLU()
  • 2、torch.nn.functional.relu() torch.nn.functional.relu_()
  • 3、torch.nn 与 torch.nn.functional 的区别与联系
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档