Pytorch-nn.Module

本节介绍在pytorch中十分重要的“类”:nn.Module。

在实现自己设计的层结构功能时,必须要使用自己继承的类。

类的书写如下

import torch
import torch.nn as nn
import torch.nn.functional as F


class MyLinear(nn.Module):
    # 先定义自己的类
    def __init__(self, inp, outp):
        super(MyLinear, self).__init__()
        # 初始化自己定义的类

        self.w = nn.Parameter(torch.randn(outp, inp))
        self.b = nn.Parameter(torch.randn(outp))

    def forward(self, x):
        # 定义前向
        x = x @ self.w.t() + self.b
        return x

那么nn.Module类到底是什么?

(1)nn.Module在pytorch中是基本的复类,继承它后会很方便的使用nn.linear、nn.normalize等。

(2)还可以进行嵌套,便于书写树形结构

(3)nn.Module提供了很多已经编写好的功能,如Linear、ReLU、Sigmoid、Conv2d、ConvTransposed2d、Dropout等。

最主要的功能是书写代码方便

self.net = nn.Sequential(
    # .Sequential()相当于设定了一个容器,
    # 将需要进行forward的函数代入其中,
    # 但不用每一个步骤都写上,
    # 直接放在容器中,后面再定义一个forward代码即可
    nn.Conv2d(1, 32, 5, 1, 1),
    nn.MaxPool2d(2, 2),
    ...
    
)

使用nn.Module的第三个好处是可以对网络中的参数进行有效的管理

通过.parameters()即可很方便的对参数进行查看

net = nn.Sequential(nn.Linear(4, 2), nn.Linear(2, 2))
print(list(net.parameters()))[0].shape
# 输出查看第0层的参数

也可用.named_parameters()来输出网络结构编好名字的参数

print(list(net.named_parameters()))[0].shape

后续再加上.item(),来对各种属性进行查看

print(list(net.named_parameters()))[0].item()

另外nn.Module还可以自己定义类的顺序。

也可以很方便的将所有的运算都转入到GPU上去。使用.device函数,

device = torch.device('cuda')
net = Net()
net.to(device)

还可以很方便的进行save和load,以防止突然发生的断点和系统崩溃的现象

net.load_state_dict(torch.load('ckpt.mdl'))
torch.save(net.state_dict(), 'ckpt.mdl')

nn.Modele还可以很方便的切换状态

# 切换到train状态
net.train()
# 切换到test

本文分享自微信公众号 - python pytorch AI机器学习实践(gh_a7878fd5de90),作者:王某某搞AI

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-12-02

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • CIFAR10数据集实战-LeNet5神经网络(上)

    上次课我们讲解了对于CIFAR10数据读取部分代码的编写,本节讲解如何编写经典的LeNet5神经网络。

    用户6719124
  • pytorch基础知识-Batch Norm(下)

    上图是对前节课所讲的小结,通过Normalize将[6, 3, 784]分为3个通道的[6, 784]数据。使得数据结果整体分布于(0~正负1)区间内。

    用户6719124
  • Pytorch-ResNet(残差网络)-下

    在左图(准确率)的比较中,从AlexNet到GoogleNet再到ResNet,准确率逐渐提高。20层结构是很多网络结构性能提升的分水岭,在20层之前,模型性能...

    用户6719124
  • 模型层

    torch.nn中内置了非常丰富的各种模型层。它们都属于nn.Module的子类,具备参数管理功能。

    lyhue1991
  • PyTorch nn.Module

    本节将介绍在pytorch中非常重要的类:nn.Module。在实现自己设计的网络时,必须要继承这个类,示例写法如下

    mathor
  • 【Pytorch 】笔记四:Module 与 Containers 的源码解析

    疫情在家的这段时间,想系统的学习一遍 Pytorch 基础知识,因为我发现虽然直接 Pytorch 实战上手比较快,但是关于一些内部的原理知识其实并不是太懂,这...

    阿泽 Crz
  • LSTM实现详解

    前言 在很长一段时间里,我一直忙于寻找一个实现LSTM网络的好教程。它们似乎很复杂,而且在此之前我从来没有使用它们做过任何东西。在互联网上快速搜索并没有什么帮助...

    CSDN技术头条
  • LSTM实现详解

    用户1737318
  • 卷积神经网络之-NiN 网络(Network In Network)

    Network In Network 是发表于 2014 年 ICLR 的一篇 paper。当前被引了 3298 次。这篇文章采用较少参数就取得了 Alexne...

    机器视觉CV
  • 最完整的PyTorch数据科学家指南(1)

    PyTorch 已经成为现在创建神经网络的事实上的标准之一,我喜欢它的界面。但是,对于初学者来说,要获得它有些困难。

    计算机与AI

扫码关注云+社区

领取腾讯云代金券