专栏首页AutoML(自动机器学习)Pytorch中Module,Parameter和Buffer的区别

Pytorch中Module,Parameter和Buffer的区别

下文都将torch.nn简写成nn

  • Module: 就是我们常用的torch.nn.Module类,你定义的所有网络结构都必须继承这个类。
  • Buffer: buffer和parameter相对,就是指那些不需要参与反向传播的参数 示例如下:
class MyModel(nn.Module):
	def __init__(self):
		super(MyModel, self).__init__()
		self.my_tensor = torch.randn(1) # 参数直接作为模型类成员变量
		self.register_buffer('my_buffer', torch.randn(1)) # 参数注册为 buffer
		self.my_param = nn.Parameter(torch.randn(1))
	def forward(self, x):
		return x	

model = MyModel()
print(model.state_dict())
>>>OrderedDict([('my_param', tensor([1.2357])), ('my_buffer', tensor([-0.9982]))])
  • Parameter: 是nn.parameter.Paramter,也就是组成Module的参数。例如一个nn.Linear通常由weightbias参数组成。它的特点是默认requires_grad=True,也就是说训练过程中需要反向传播的,就需要使用这个
import torch.nn as nn
fc = nn.Linear(2,2)

# 读取参数的方式一
fc._parameters
>>> OrderedDict([('weight', Parameter containing:
              tensor([[0.4142, 0.0424],
                      [0.3940, 0.0796]], requires_grad=True)),
             ('bias', Parameter containing:
              tensor([-0.2885,  0.5825], requires_grad=True))])
			  
# 读取参数的方式二(推荐这种)
for n, p in fc.named_parameters():
	print(n,p)
>>>weight Parameter containing:
tensor([[0.4142, 0.0424],
        [0.3940, 0.0796]], requires_grad=True)
bias Parameter containing:
tensor([-0.2885,  0.5825], requires_grad=True)

# 读取参数的方式三
for p in fc.parameters():
	print(p)
>>>Parameter containing:
tensor([[0.4142, 0.0424],
        [0.3940, 0.0796]], requires_grad=True)
Parameter containing:
tensor([-0.2885,  0.5825], requires_grad=True)

通过上面的例子可以看到,nn.parameter.Paramterrequires_grad属性值默认为True。另外上面例子给出了三种读取parameter的方法,推荐使用后面两种(这两种的区别可参阅Pytorch: parameters(),children(),modules(),named_*区别),因为是以迭代生成器的方式来读取,第一种方式是一股脑的把参数全丢给你,要是模型很大,估计你的电脑会吃不消。

另外需要介绍的是_parametersnn.Module__init__()函数中就定义了的一个OrderDict类,这个可以通过看下面给出的部分源码看到,可以看到还初始化了很多其他东西,其实原理都大同小异,你理解了这个之后,其他的也是同样的道理。

class Module(object):
	...
    def __init__(self):
        self._backend = thnn_backend
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()
        self.training = True

每当我们给一个成员变量定义一个nn.parameter.Paramter的时候,都会自动注册到_parameters,具体的步骤如下:

import torch.nn as nn
class MyModel(nn.Module):
	def __init__(self):
		super(MyModel, self).__init__()
		# 下面两种定义方式均可
		self.p1 = nn.paramter.Paramter(torch.tensor(1.0))
		print(self._parameters)
		self.p2 = nn.Paramter(torch.tensor(2.0))
		print(self._parameters)
  • 首先运行super(MyModel, self).__init__(),这样MyModel就初始化了_paramters等一系列的OrderDict,此时所有变量还都是空的。
  • self.p1 = nn.paramter.Paramter(torch.tensor(1.0)): 这行代码会触发nn.Module预定义好的__setattr__函数,该函数部分源码如下,:
def __setattr__(self, name, value):
	...
	params = self.__dict__.get('_parameters')
	if isinstance(value, Parameter):
		if params is None:
			raise AttributeError(
				"cannot assign parameters before Module.__init__() call")
		remove_from(self.__dict__, self._buffers, self._modules)
		self.register_parameter(name, value)
	...

__setattr__函数作用简单理解就是判断你定义的参数是否正确,如果正确就继续调用register_parameter函数进行注册,这个函数简单概括就是做了下面这件事

def register_parameter(self,name,param):
	...
	self._parameters[name]=param

下面我们实例化这个模型看结果怎样

model = MyModel()
>>>OrderedDict([('p1', Parameter containing:
tensor(1., requires_grad=True))])
OrderedDict([('p1', Parameter containing:
tensor(1., requires_grad=True)), ('p2', Parameter containing:
tensor(2., requires_grad=True))])

结果和上面分析的一致。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 理清Pytorch基本概念

    nn.ModuleList的作用就是wrap pthon list,这样其中的参数会被注册,因此可以返回可训练参数(ParameterList)。

    marsggbo
  • Pytorch Sampler详解

    其原理是首先在初始化的时候拿到数据集data_source,之后在__iter__方法中首先得到一个和data_source一样长度的range可迭代器。每次只...

    marsggbo
  • pyqt4实现tab界面切换

    de ># -*- coding: utf-8 -*- from PyQt4.QtGui import * from PyQt4.QtCore import...

    marsggbo
  • Github项目推荐 | PyTorch代码规范最佳实践和样式指南

    AI 科技评论按,本文不是 Python 的官方风格指南。本文总结了使用 PyTorch 框架进行深入学习的一年多经验中的最佳实践。本文分享的知识主要是以研究的...

    AI科技评论
  • python ftp测试

    #!/usr/bin/env python import time from ftplib import FTP local_dir_update="**...

    py3study
  • 小白学PyTorch | 4 构建模型三要素与权重初始化

    第一行是初始化,往后定义了一系列组件。nn.Conv2d就是一般图片处理的卷积模块,然后池化层,全连接层等等。

    机器学习炼丹术
  • 孟德尔随机化系列1

    孟德尔随机化(Mendelian Randomization, MR)是近几年流行起来的用来进行因果推断的有效方法,它以遗传变异为工具变量来推导结局和暴露的因果...

    生信与临床
  • Linux备份工具简介

    备份涵盖的范围很广,我们可以备份出一个重要文件的副本,也可以备份出一个完整的磁盘的快照。许多桌面应用程序和操作系统会自动进行数据备份。相比之下,腾讯云是一个灵活...

    风研雨墨
  • Mozilla推出新功能Hubs,网络浏览器秒变VR社交体验

    VRPinea
  • 掠夺性开放获取版权(Open Access)期刊文章被引用的频率

    原文题目: How Frequently are Articles in Predatory Open Access Journals Cited

    吴亚芳

扫码关注云+社区

领取腾讯云代金券