torch.nn.Module

nn.Module基类的构造函数:

def __init__(self):
    self._parameters = OrderedDict()
    self._modules = OrderedDict()
    self._buffers = OrderedDict()
    self._backward_hooks = OrderedDict()
    self._forward_hooks = OrderedDict()
    self.training = True

其中每个属性的解释如下:

  • _parameters:字典,保存用户直接设置的parameter,self.param1 = nn.Parameter(t.randn(3, 3))会被检测到,在字典中加入一个key为'param',value为对应parameter的item。而self.submodule = nn.Linear(3, 4)中的parameter则不会存于此。
  • _modules:子module,通过self.submodel = nn.Linear(3, 4)指定的子module会保存于此。
  • _buffers:缓存。如batchnorm使用momentum机制,每次前向传播需用到上一次前向传播的结果。
  • _backward_hooks_forward_hooks:钩子技术,用来提取中间变量,类似variable的hook。
  • training:BatchNorm与Dropout层在训练阶段和测试阶段中采取的策略不同,通过判断training值来决定前向传播策略。

上述几个属性中,_parameters_modules_buffers这三个字典中的键值,都可以通过self.key方式获得,效果等价于self._parameters['key'].

定义一个Module,这个Module即包含自己的Parameters有包含子Module及其Parameters,

import torch as t

from torch import nn

from torch.autograd import Variable as V


class Net(nn.Module):

    def __init__(self):

        super(Net, self).__init__()

        # 等价与self.register_parameter('param1' ,nn.Parameter(t.randn(3, 3)))

        self.param1 = nn.Parameter(t.rand(3, 3))

        self.submodel1 = nn.Linear(3, 4)

    def forward(self, input):

        x = self.param1.mm(input)

        x = self.submodel11(x)

        return x

net = Net()

一、modules

# 打印网络对象的话会输出子module结构
print(net)

Net(
  (submodel1): Linear(in_features=3, out_features=4)
)

# ._modules输出的也是子module结构,不过数据结构和上面的有所不同
print(net.submodel1)
print(net._modules) # 字典子类

Linear(in_features=3, out_features=4)
OrderedDict([('submodel1', Linear(in_features=3, out_features=4))])

for name, submodel in net.named_modules():
    print(name, submodel)

 Net(
  (submodel1): Linear(in_features=3, out_features=4)
)
submodel1 Linear(in_features=3, out_features=4)

print(list(net.named_modules())) # named_modules其实是包含了本层的module集合

[('', Net(
  (submodel1): Linear(in_features=3, out_features=4)
)), ('submodel1', Linear(in_features=3, out_features=4))]

二、_parameters

# ._parameters存储的也是这个结构
print(net.param1)
print(net._parameters) # 字典子类,仅仅包含直接定义的nn.Parameters参数

Parameter containing:
 0.6135  0.8082  0.4519
 0.9052  0.5929  0.2810
 0.6825  0.4437  0.3874
[torch.FloatTensor of size 3x3]

OrderedDict([('param1', Parameter containing:
 0.6135  0.8082  0.4519
 0.9052  0.5929  0.2810
 0.6825  0.4437  0.3874
[torch.FloatTensor of size 3x3]
)])


for name, param in net.named_parameters():
    print(name, param.size())

param1 torch.Size([3, 3])
submodel1.weight torch.Size([4, 3])
submodel1.bias torch.Size([4])

三、_buffers

bn = nn.BatchNorm1d(2)

input = V(t.rand(3, 2), requires_grad=True)

output = bn(input)

bn._buffers



Output:
--------------------------------------------------------------------
OrderedDict([('running_mean', 
              1.00000e-02 *
                9.1559
                1.9914
              [torch.FloatTensor of size 2]), ('running_var', 
               0.9003
               0.9019
              [torch.FloatTensor of size 2])])
--------------------------------------------------------------------

四、training

input = V(t.arange(0, 12).view(3, 4))

model = nn.Dropout()

# 在训练阶段,会有一半左右的数被随机置为0

model(input)


Output:
-----------------------------------
Variable containing:
  0   2   4   0
  8  10   0   0
  0  18   0  22
[torch.FloatTensor of size 3x4]
-----------------------------------
model.training  = False

# 在测试阶段,dropout什么都不做

model(input)



Output:
--------------------------------------
Variable containing:
  0   1   2   3
  4   5   6   7
  8   9  10  11
[torch.FloatTensor of size 3x4]
--------------------------------------

 Module.train()、Module.eval() 方法和 Module.training属性的关系

print(net.training, net.submodel1.training)

net.train() # 将本层及子层的training设定为True

net.eval() # 将本层及子层的training设定为False

net.training = True # 注意,对module的设置仅仅影响本层,子module不受影响

net.training, net.submodel1.training



Output:
----------------
True True
(True, False)
----------------

承接Matlab、Python和C++的编程,机器学习、计算机视觉的理论实现及辅导,本科和硕士的均可,咸鱼交易,专业回答请走知乎,详谈请联系QQ号757160542,非诚勿扰。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • Python 类特殊方法__getitem__

    凡是在类中定义了这个__getitem__ 方法,那么它的实例对象(假定为p),可以像这样

    于小勇
  • python使用moviepy模块对视频进行操作

    前段时间需要对多个视频进行合并,还需要对一个视频按需求进行截切成多个视频,然而网上虽然有现成的工具。

    于小勇
  • Pytorch的Sampler详解

    其原理是首先在初始化的时候拿到数据集data_source,之后在__iter__方法中首先得到一个和data_source一样长度的range可迭代器。每次只...

    于小勇
  • 接口应用小玩具-博客园积分排名变动监控工具

    小玩具-博客园积分排名变动监控工具 一个简单的在线服务监控和提醒工具 1   概述 前段时间自己准备重新开启自己的博客园,然后还和一些圈子里面的朋友夸下海口,自...

    用户1170933
  • python面向对象基础

    面向过程的程序设计的核心是过程,过程即解决问题的步骤,面向过程的设计就好比精心设计好一条流水线,考虑周全什么时候处理什么东西。

    菲宇
  • 隐马尔科夫-维特比算法

    概念介绍:   继上篇贝叶斯(https://cloud.tencent.com/developer/article/1056640)后,一直想完成隐马尔科夫这...

    知然
  • Jupyter Notebook折叠输出的内容实例

    当Jupyter Notebook的输出内容很多时,为了屏幕可以显示更多的代码行,我需要将输出的内容进行折叠。

    砸漏
  • python实现汽车管理系统

    本文实例为大家分享了python实现汽车管理系统的具体代码,供大家参考,具体内容如下

    砸漏
  • 全面深入理解Python面向对象编程

    面向过程编程最易被初学者接受,其往往用一长段代码来实现指定功能,开发过程中最常见的操作就是粘贴复制,即:将之前实现的代码块复制到现需功能处。

    顶级程序员
  • 一个Python3和Python2的range差异

    Python 3 中执行100000000 in range(100000001)会比Python 2快的非常多。

    用户1416054

扫码关注云+社区

领取腾讯云代金券