前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >CIFAR10数据集实战-ResNet网络构建(上)

CIFAR10数据集实战-ResNet网络构建(上)

作者头像
用户6719124
发布2020-01-14 10:51:06
9590
发布2020-01-14 10:51:06
举报

本部分介绍如何采用ResNet解决CIFAR10分类问题。

之前讲到过,ResNet包含了短接模块(short cut)。本节主要介绍如何实现这个模块。

先建立resnet.py文件。

如图

先引入相关包

import torch
import torch.nn as nn

准备构建resnet单元

class ResBlk(nn.Module):
    # 与上节一样,同样resnet的block单元,继承nn模块
    def __init__(self):
        super(ResBlk, self).__init__()
        # 完成初始化

由ResNet特点可知,需要传入channel_in和channel_out才能进行运算,因此在定义中需要加入两个变量。

def __init__(self, ch_in, ch_out):

接下来像之前一样,写入其原先的卷积层。

self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
# 进行正则化处理,以使train过程更快更稳定
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)

Resnet 模块的左侧的部分写好了,

先不急着写右侧,先写左侧的forward代码

先引入工具包

import torch.nn.functional as F

书写代码

def forward(self, x):
    # 这里输入的是[b, ch, h, w]
    out = F.relu(self.bn1(self.conv1(x)))
    out = F.relu(self.bn2(self.conv2(out)))

下面开始写short cut代码

out = x + out
# 这便是element.wise add,实现了[b, ch_in, h, w] 和 [b, ch_out, h, w]两个的相加

同时要考虑,若两元素中的ch_in和ch_out不匹配,则运行时会报错。因此需要在前面指定添加if函数

if ch_out != ch_in:
    self.extra = nn.Sequential(
        nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1),
        nn.BatchNorm2d(ch_out),
    )

这段代码的意思即为实现[b, ch_in, h, w] => [b, ch_out, h, w]的转化

写好后,将element.wise add部分的x替换

out = self.extra(x) + out

这里也要考虑若ch_in和ch_out原先就相匹配的情况,则需要先进行定义。

self.extra = nn.Sequential()

最后在定义后,返回结果out

至此resnet block模块构建完毕

现代码为

import torch
import torch.nn as nn
import torch.nn.functional as F

class ResBlk(nn.Module):
    # 与上节一样,同样resnet的block单元,继承nn模块
    def __init__(self, ch_in, ch_out):
        super(ResBlk, self).__init__()
        # 完成初始化

        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        # 进行正则化处理,以使train过程更快更稳定
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        self.extra = nn.Sequential()

        if ch_out != ch_in:
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1),
                nn.BatchNorm2d(ch_out),
            )



    def forward(self, x):
        # 这里输入的是[b, ch, h, w]
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))


        out = self.extra(x) + out
        # 这便是element.wise add,实现了[b, ch_in, h, w] 和 [b, ch_out, h, w]两个的相加

        return out
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-01-05,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档