前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch nn.Module

PyTorch nn.Module

作者头像
mathor
发布2020-02-14 20:10:00
1.1K0
发布2020-02-14 20:10:00
举报
文章被收录于专栏:mathor

本节将介绍在pytorch中非常重要的类:nn.Module。在实现自己设计的网络时,必须要继承这个类,示例写法如下

代码语言:javascript
复制
import torch
import torch.nn as nn
import torch.nn.functional as F

# 先定义自己的类
class MyNN(nn.Module):
    def __init__(self, inp, outp):
        # 初始化自己定义的类
        super(MyNN, self).__init__()

        self.w = nn.Parameter(torch.randn(outp, inp))
        self.b = nn.Parameter(torch.randn(outp))
        
    # 定义前向传播
    def forward(self, x):
        x = x @ self.w.t() + self.b
        return x

那么nn.Module这个类有哪些功能?

  • nn.Module提供了很多已经编写好的功能,如LinearReLUSigmoidConv2dConvTransposed2dDropout...
  • 书写代码方便。例如我们要定义一个基本的CNN结构,代码如下
代码语言:javascript
复制
self.net = nn.Sequential(
    # .Sequential()相当于设置了一个容器(Container)
    # 将需要进行forward的函数写在其中
    
    nn.Conv2d(1, 32, 5, 1, 1),
    nn.MaxPool2d(2, 2),
    nn.ReLU(True),
    nn.BatchNorm2d(32),
    
    nn.Conv2d(32, 64, 3, 1, 1),
    nn.ReLU(True),
    nn.BatchNorm2d(64),
    
    nn.Conv2d(64, 64, 3, 1, 1),
    nn.MaxPool2d(2, 2),
    nn.ReLU(True),
    nn.BatchNorm2d(64),
    
    nn.Conv2d(64, 128, 3, 1, 1),
    nn.ReLU(True),
    nn.BatchNorm2d(128)
)

或者需要将自己设计的层连接在一起的情况

代码语言:javascript
复制
class Faltten(nn.Module):
    def __init__(self):
        super(Faltten, self).__init__()
    
    def forward(self, input):
        return input.view(inputt.size(0), -1)

class TestNet(nn.Module):
    def __init__(self):
        super(TestNet, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, stride=1, padding=1),
            nn.MaxPool2d(2, 2),
            Flatten(),
            nn.Linear(1*14*14, 10)
        )
    
    def forward(self, x):
        return self.net(x)
  • 使用nn.Module可以对网络中的参数进行有效的管理
代码语言:javascript
复制
net = nn.Sequential(
    nn.Linear(in_features=4, out_features=2),
    nn.Linear(in_features=2, out_features=2)
)

# 隐藏层的编号是从0开始的
list(net.parameters())[0] # [0]是layer0的w
list(net.parameters())[3].shape # [3]是layer1的b
dict(net.named_parameters()).items() # 返回所有层的参数
 
optimizer = optim.SGD(net.parameters(), lr=1e-3)

输出

代码语言:javascript
复制
torch.Size([2, 4])
torch.Size([2])
dict_items([('0.weight', Parameter containing:
tensor([[ 0.0195,  0.4698, -0.4913, -0.3336],
        [ 0.1422,  0.2908, -0.2469,  0.0583]], requires_grad=True)), ('0.bias', Parameter containing:
tensor([-0.4704, -0.1133], requires_grad=True)), ('1.weight', Parameter containing:
tensor([[-0.6511,  0.2442],
        [ 0.5658,  0.4419]], requires_grad=True)), ('1.bias', Parameter containing:
tensor([ 0.0114, -0.5664], requires_grad=True))])
  • 可以很方便的将所有运算都转入到GPU上去,使用.device()函数
代码语言:javascript
复制
device = torch.device('cuda')
net = Net()
net.to(device)
  • 可以很方便的进行save和load,以防止突然发生的断点和系统崩溃现象
代码语言:javascript
复制
torch.save(net.state_dict(), 'ckpt.mdl')
net.load_state_dict(torch.load('ckpt.mdl'))
  • 还可以很方便的切换train和test的状态
代码语言:javascript
复制
# train
net.train()

# test
net.eval()
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
容器服务
腾讯云容器服务(Tencent Kubernetes Engine, TKE)基于原生 kubernetes 提供以容器为核心的、高度可扩展的高性能容器管理服务,覆盖 Serverless、边缘计算、分布式云等多种业务部署场景,业内首创单个集群兼容多种计算节点的容器资源管理模式。同时产品作为云原生 Finops 领先布道者,主导开源项目Crane,全面助力客户实现资源优化、成本控制。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档