专栏首页python pytorch AI机器学习实践CIFAR10数据集实战-LeNet5神经网络(上)

CIFAR10数据集实战-LeNet5神经网络(上)

上次课我们讲解了对于CIFAR10数据读取部分代码的编写,本节讲解如何编写经典的LeNet5神经网络。

首先创建python文件,命名LeNet5。

开始写代码

先引入相关工具包、完成类的初始化

import torch
from torch import nn


class LeNet5(nn.Module):
    # 将所有的类都继承给nn.module
    def __init__(self):
        super(LeNet5, self).__init__()
        # 调用类的初始化

        self.conv_unit = nn.Sequential(
           # 把网络结构放在Sequential中十分方便
            
        )

下面把网络结构放在Sequential中

注意因为CIFAR10均为彩色图片,即为RGB三个通道

这里的输入部分的size为x: [batch, 3, 32, 32]

按照其提示书写

这里的input_channel为3,看图片发现out_channel为6

而kernel_size并没有标明,这里先假设为5。stride也没有标明,先假设为1。同样padding假设为0。

第一层为

nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, stride=1, padding=0)
# 第一层完成了[b, 3, 32, 32] => [b, 6, size_h, size_w]
# 这里由于使用了stride,暂时不确定size

下一层为pooling层

nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
# 由于没有相关参数资料,部分只能靠猜测

这种pooling后,长宽各减半

第二个卷积层和pooling层

nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0)
# 图中显示原来的6个channel变为16个channel
nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
# 假设第二层的pooling部分与第一层的相同

由原图发现下一层为16个channel通道的全连接输出,因此这里进行打平操作,并接入全连接层

self.fc_unit = nn.Sequential(
        # fc即指代full connect全连接
    nn.Linear(2, 120),
    # 这里只知道由一个维度变到了120,再由120变到64,
    # 由于不清楚是什么维度变到了120,这里暂时先写2
    nn.ReLU(),
    nn.Linear(120, 64),
    nn.ReLU(),
    nn.Linear(84, 10),
)

为能知道具体的维度信息,这里可以先构建一个随机的假数据,代入其中先行进行计算。

tmp = torch.randn(2, 3, 32, 32)
out = self.conv_unit(tmp)
# 将假数据带入到out中
print('conv:', out.shape)

再加入相关代码,使之能单独运行

def main():

    net = LeNet5()


if __name__ == '__main__':
    main()

加入后整体代码为

import torch
from torch import nn


class LeNet5(nn.Module):
    # 将所有的类都继承给nn.module
    def __init__(self):
        super(LeNet5, self).__init__()
        # 调用类的初始化

        self.conv_unit = nn.Sequential(
           # 把网络结构放在Sequential中十分方便
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, stride=1, padding=0),
            # 第一层完成了[b, 3, 32, 32] => [b, 6, size_h, size_w]
            # 这里由于使用了stride,暂时不确定size

            nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
            # 由于没有相关参数资料,部分只能靠猜测

            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0),
            # 图中显示原来的6个channel变为16个channel
            nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
            # 假设第二层的pooling部分与第一层的相同
        )

        self.fc_unit = nn.Sequential(
                # fc即指代full connect全连接
            nn.Linear(2, 120),
            # 这里只知道由一个维度变到了120,再由120变到84,
            # 由于不清楚是什么维度变到了120,这里暂时先写2
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 10),
        )

        tmp = torch.randn(2, 3, 32, 32)
        out = self.conv_unit(tmp)
        # 将假数据带入到out中
        print('conv:', out.shape)


def main():

    net = LeNet5()


if __name__ == '__main__':
    main()

开始运行

输出结果为

conv: torch.Size([2, 16, 5, 5])

由此得知打平操作后的数据变为16*5*5。

对原代码进行更改

nn.Linear(16*5*5, 120),

本文分享自微信公众号 - python pytorch AI机器学习实践(gh_a7878fd5de90),作者:王某某搞AI

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-12-11

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • Pytorch-nn.Module

    (1)nn.Module在pytorch中是基本的复类,继承它后会很方便的使用nn.linear、nn.normalize等。

    用户6719124
  • pytorch基础知识-Batch Norm(下)

    上图是对前节课所讲的小结,通过Normalize将[6, 3, 784]分为3个通道的[6, 784]数据。使得数据结果整体分布于(0~正负1)区间内。

    用户6719124
  • pytorch基础知识-维度变换-(上)

    维度变换是pytorch中的重要操作,尤其是在图片处理中。本文对pytorch中的维度变换进行讲解。

    用户6719124
  • 无监督学习神经网络——自编码

    自编码是一种无监督学习的神经网络,主要应用在特征提取,对象识别,降维等。自编码器将神经网络的隐含层看成是一个编码器和解码器,输入数据经过隐含层的编码和解码,到达...

    企鹅号小编
  • 卷积神经网络之-NiN 网络(Network In Network)

    Network In Network 是发表于 2014 年 ICLR 的一篇 paper。当前被引了 3298 次。这篇文章采用较少参数就取得了 Alexne...

    机器视觉CV
  • 模型层

    torch.nn中内置了非常丰富的各种模型层。它们都属于nn.Module的子类,具备参数管理功能。

    lyhue1991
  • LSTM实现详解

    前言 在很长一段时间里,我一直忙于寻找一个实现LSTM网络的好教程。它们似乎很复杂,而且在此之前我从来没有使用它们做过任何东西。在互联网上快速搜索并没有什么帮助...

    CSDN技术头条
  • LSTM实现详解

    用户1737318
  • torch.nn.Parameter

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。 ...

    于小勇
  • pytorch

    pip3 install https://download.pytorch.org/whl/cpu/torch-1.0.1-cp35-cp35m-win_amd...

    sofu456

扫码关注云+社区

领取腾讯云代金券