# 小白学PyTorch | 4 构建模型三要素与权重初始化

• 1 模型三要素
• 2 参数初始化
• 3 完整运行代码
• 4 尺寸计算与参数计算

## 1 模型三要素

1. 必须要继承nn.Module这个类，要让PyTorch知道这个类是一个Module
2. 在__init__(self)中设置好需要的组件，比如conv，pooling，Linear，BatchNorm等等
3. 最后在forward(self,x)中用定义好的组件进行组装，就像搭积木，把网络结构搭建出来，这样一个模型就定义好了

def __init__(self):
super(Net,self).__init__()
self.conv1 = nn.Conv2d(3,6,5)
self.pool1 = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6,16,5)
self.pool2 = nn.MaxPool2d(2,2)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)


def forward(self,x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = x.view(-1,16*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x


x为模型的输入，第一行表示x经过conv1，然后经过激活函数relu，然后经过pool1操作 第三行表示对x进行reshape，为后面的全连接层做准备

net = Net()
outputs = net(inputs)


## 2 参数初始化

# 定义权值初始化
def initialize_weights(self):
for m in self.modules():
if isinstance(m,nn.Conv2d):
torch.nn.init.xavier_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m,nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m,nn.Linear):
torch.nn.init.normal_(m.weight.data,0,0.01)
# m.weight.data.normal_(0,0.01)
m.bias.data.zero_()


# self.modules的源码
def modules(self):
for name,module in self.named_modules():
yield module


## 3 完整运行代码

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.xavier_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight.data, 0, 0.01)
# m.weight.data.normal_(0,0.01)
m.bias.data.zero_()

net = Net()
net.initialize_weights()
print(net.modules())
for m in net.modules():
print(m)


# 这个是print(net.modules())的输出
<generator object Module.modules at 0x0000023BDCA23258>
# 这个是第一次从net.modules()取出来的东西，是整个网络的结构
Net(
(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
# 从net.modules()第二次开始取得东西就是每一层了
Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
Linear(in_features=400, out_features=120, bias=True)
Linear(in_features=120, out_features=84, bias=True)
Linear(in_features=84, out_features=10, bias=True)


torch.nn.init.xavier_normal(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()


## 4 尺寸计算与参数计算

net = Net()
net.initialize_weights()
layers = {}
for m in net.modules():
if isinstance(m,nn.Conv2d):
print(m)
break


Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))


【问题1：输入特征图和输出特征图的尺寸计算】

net = Net()
net.initialize_weights()
input = torch.ones((16,3,10,10))
output = net.conv1(input)
print(input.shape)
print(output.shape)


torch.Size([16, 3, 10, 10])
torch.Size([16, 6, 6, 6])


【问题2：这个卷积层中有多少的参数？】输入通道是3通道的，输出是6通道的，卷积核是

,考虑bais的话，就每一个卷积核再增加一个偏置值。（这是一个一般人会忽略的知识点欸）

net = Net()
net.initialize_weights()
for m in net.modules():
if isinstance(m,nn.Conv2d):
print(m)
print(m.weight.shape)
print(m.bias.shape)
break


Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
torch.Size([6, 3, 5, 5])
torch.Size([6])


- END -

0 条评论

• ### 小白学PyTorch | 6 模型的构建访问遍历存储（附代码）

torch.nn.Module是所有网络的基类，在PyTorch实现模型的类中都要继承这个类（这个在之前的课程中已经提到）。在构建Module中，Module是...

• ### 小白学PyTorch | 12 SENet详解及PyTorch实现

上一节课讲解了MobileNet的一个DSC深度可分离卷积的概念，希望大家可以在实际的任务中使用这种方法，现在再来介绍EfficientNet的另外一个基础知识...

• ### Github项目推荐 | PyTorch代码规范最佳实践和样式指南

AI 科技评论按，本文不是 Python 的官方风格指南。本文总结了使用 PyTorch 框架进行深入学习的一年多经验中的最佳实践。本文分享的知识主要是以研究的...

• ### 【动手学深度学习笔记】之构造MLP模型的几种方法

Module类是nn模块里提供的一个模型构造类，通过继承Module实现MLP的程序如下

• ### PyTorch最佳实践，怎样才能写出一手风格优美的代码

虽然这是一个非官方的 PyTorch 指南，但本文总结了一年多使用 PyTorch 框架的经验，尤其是用它开发深度学习相关工作的最优解决方案。请注意，我们分享的...

• ### 私有云架构建设, 你做好准备了吗？

私有云基础架构的构成要素 随着越来越多的企业设定了构建内部云服务的目标，规划和构建企业内部云服务平台就成为IT部门的职责。每个企业都有自己特有的环境和具体的目标...

• ### 盘点与警示：2017年加密货币市场的重大事故

在2017年的最后一天，细数2017年发生在加密货币市场的一些重大事故，这或在2018年加密货币的交易中予以我们一些警示。 CoinDash的ICO黑客事件 支...