pytorch学习笔记(十二):详解 Module 类

Modulepytorch 提供的一个基类,每次我们要 搭建 自己的神经网络的时候都要继承这个类,继承这个类会使得我们 搭建网络的过程变得异常简单。

本文主要关注 Module 类的内部是怎么样的。

初始化方法中做了什么

def __init__(self):
    self._backend = thnn_backend
    self._parameters = OrderedDict()
    self._buffers = OrderedDict()
    self._backward_hooks = OrderedDict()
    self._forward_hooks = OrderedDict()
    self._forward_pre_hooks = OrderedDict()
    self._modules = OrderedDict()
    self.training = True

这是 Module 的初始化方法:

  • self._parameters 用来存放注册的 Parameter 对象
  • self._buffers 用来存放注册的 Buffer 对象。(pytorch 中 buffer 的概念就是 不需要反向传导更新的值)
  • self._modules 用来保存注册的 Module 对象。
  • self.training 标志位,用来表示是不是在 training 状态下
  • ...hooks 用来保存 注册的 hook

__setattr____getattr__

__setattr__ 每次给属性赋值的时候,都会调用这个方法。

__setattr__ 的代码比较多,我们一点一点看。

  • remove_from :工具函数, 用来从 self.__dict__, self._buffers, self._modules 中删除对象。

第一种情况: value 的类型是 Paramter

  • 从 三大 字典中将 同名的 对象删掉
  • 然后,注册 paramter

第二种情况: value不是 Parameter对象, nameself._parameter

  • self._parameters[name] = None

已经考虑了 valueParameter对象,剩下的就是考虑 valuebufferModule

第三种情况:value不是 Parameter对象, valueModule 对象

  • 从三大字典里面移除同名 对象
  • 然后直接向 self._modules 字典里添加 value

第四种情况:value不是Parameter对象, value不为 Module对象, 但是 nameself._modules

  • self._modules[name]=None

第五种情况:value不是Parameter对象, value不为 Module对象, name 存在 self._buffers

  • self._buffers[name]=None

最后一种情况: 就是 普通的属性了。

def __setattr__(self, name, value):
    def remove_from(*dicts):
        for d in dicts:
            if name in d:
                del d[name]

    params = self.__dict__.get('_parameters')

    if isinstance(value, Parameter):
        if params is None:
            raise AttributeError(
                "cannot assign parameters before Module.__init__() call")
        remove_from(self.__dict__, self._buffers, self._modules)
        self.register_parameter(name, value)
    elif params is not None and name in params:
        if value is not None:
            raise TypeError("cannot assign '{}' as parameter '{}' "
                            "(torch.nn.Parameter or None expected)"
                            .format(torch.typename(value), name))
        self.register_parameter(name, value)
    else:
        modules = self.__dict__.get('_modules')
        if isinstance(value, Module):
            if modules is None:
                raise AttributeError(
                    "cannot assign module before Module.__init__() call")
            remove_from(self.__dict__, self._parameters, self._buffers)
            modules[name] = value
        elif modules is not None and name in modules:
            if value is not None:
                raise TypeError("cannot assign '{}' as child module '{}' "
                                "(torch.nn.Module or None expected)"
                                .format(torch.typename(value), name))
            modules[name] = value
        else:
            buffers = self.__dict__.get('_buffers')
            if buffers is not None and name in buffers:
                if value is not None and not torch.is_tensor(value):
                    raise TypeError("cannot assign '{}' as buffer '{}' "
                                    "(torch.Tensor or None expected)"
                                    .format(torch.typename(value), name))
                buffers[name] = value
            else:
                object.__setattr__(self, name, value)

__getattr__ : 当获取 self.__dict__ 中没有的键所对应的值的时候,就会调用这个方法 因为 parameter, module, buffer 的键值对存在与 self._parameters, self._modules, self.buffer 中,所以,当想获取这些 值时, 就会调用这个方法。

def __getattr__(self, name):
    if '_parameters' in self.__dict__:
        _parameters = self.__dict__['_parameters']
        if name in _parameters:
            return _parameters[name]
    if '_buffers' in self.__dict__:
        _buffers = self.__dict__['_buffers']
        if name in _buffers:
            return _buffers[name]
    if '_modules' in self.__dict__:
        modules = self.__dict__['_modules']
        if name in modules:
            return modules[name]
    raise AttributeError("'{}' object has no attribute '{}'".format(
        type(self).__name__, name))

register_parameter

向模型中注册 Parameter

def register_parameter(self, name, param):
    """Adds a parameter to the module.

    The parameter can be accessed as an attribute using given name.
    """
    if '_parameters' not in self.__dict__:
        raise AttributeError(
            "cannot assign parameter before Module.__init__() call")
    if param is None:
        self._parameters[name] = None
    elif not isinstance(param, Parameter):
        raise TypeError("cannot assign '{}' object to parameter '{}' "
                        "(torch.nn.Parameter or None required)"
                        .format(torch.typename(param), name))
    elif param.grad_fn:
        raise ValueError(
            "Cannot assign non-leaf Variable to parameter '{0}'. Model "
            "parameters must be created explicitly. To express '{0}' "
            "as a function of another variable, compute the value in "
            "the forward() method.".format(name))
    else:
        self._parameters[name] = param

Module.training 标志 如何影响 前向过程

nn.Dropout 来看 Module.training

class Dropout(Module):
    def __init__(self, p=0.5, inplace=False):
        super(Dropout, self).__init__()
        if p < 0 or p > 1:
            raise ValueError("dropout probability has to be between 0 and 1, "
                             "but got {}".format(p))
        self.p = p
        self.inplace = inplace

    def forward(self, input):
        return F.dropout(input, self.p, self.training, self.inplace)

可以看出,在forward 过程中,直接获取,父类的training的值。

我们 通常通过 module.train()module.eval() 来切换模型的 训练测试阶段。

def train(self, mode=True):
    """Sets the module in training mode.
    This has any effect only on modules such as Dropout or BatchNorm.
    """
    self.training = mode

    for module in self.children():
        # 递归调用子模块 train 函数, 来设定所有 module 的 training 值。
        module.train(mode)
        return self

需要注意的是:module.eval() 仅仅设置 moduletraining 属性,如果我们想获得最快的推断速度, 还需要 设置 输入 Variablevolatile 属性为 True

参考资料

https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏机器学习算法全栈工程师

史上最透彻的KMP算法讲解

作 者:柳行刚 编 辑:李文臣 1 字符串匹配是经典的KMP算法。下面以字符串"BBC ABCDAB ABCDABCDABDE"为例,查找是否包含串"ABCDA...

33711
来自专栏数据科学与人工智能

【Python环境】Python面试题汇总(一)

拿网络上关于Python的面试题汇总了,给出了自认为合理的答案,有些题目不错,可以从中学到点什么,答案如不妥,请指正...... +++++++++++++++...

2346
来自专栏北京马哥教育

Python 中被忽略的 else

1294
来自专栏奇点大数据

Scala语言学习笔记二

在开始今天的内容前,先回复下在上一篇笔记的热心读者的问题: 1 既然是读书笔记,是读的哪本书? 这本书的名字叫《快学scala》,虽然是本比较久远的书,但是也...

2838
来自专栏Spark学习技巧

理解Spark里的闭包

1132
来自专栏全栈之路

android 实现本地定时推送(兼容)

首先写几点感悟: - 做兼容真的很累很费劲~ - android 8.0 广播部分不再支持动态注册,所以应该用service来实现定时推送功能 - 无论是闹钟还...

1602
来自专栏程序员叨叨叨

7.4 输入\输出修辞符(in\out\inout)

参数传递是指:函数调用实参值初始化函数形参的过程。在 C\C++中,根据形参值的改变是否会导致实参值的改变,参数传递分为“值传递(pass-by-value) ...

581
来自专栏数据结构与算法

洛谷P3224 [HNOI2012]永无乡

题目描述 永无乡包含 n 座岛,编号从 1 到 n,每座岛都有自己的独一无二的重要度,按照重要度可 以将这 n 座岛排名,名次用 1 到 n 来表示。某些岛之间...

2655
来自专栏Python小屋

Python花式编程案例锦集(1)

首先解答上一篇文章详解Python中的序列解包(2)中最后的习题,该题答案为5,表达式功能为迭代求解序列中元素的最大值。 -----------------分割...

3245
来自专栏AI星球

值得玩味儿的14个Python编程小技巧

最近的工作中经常使用Python编写一些基本解决一些NLP的小问题,在自己的工作日记里面也记录了不少的python编程中常用的小"Trick",看到最近大...

942

扫码关注云+社区