前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >CIFAR10数据集实战-ResNet网络构建(中)

CIFAR10数据集实战-ResNet网络构建(中)

作者头像
用户6719124
发布2020-02-24 18:06:55
6210
发布2020-02-24 18:06:55
举报

再定义一个ResNet网络

我们本次准备构建ResNet-18层结构

代码语言:javascript
复制
class ResNet(nn.Module):

    def __init__(self):
        super(ResNet, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64)
        )
        # 紧跟着要进行四次这样的单元
        # 构建辅助函数,使[b, 64, h, w] => [b, 128, h, w]
        self.blk1 = ResBlk(64, 128)
        # 构建辅助函数,使[b, 128, h, w] = > [b, 256, h, w]
        self.blk2 = ResBlk(128, 256)
        # 构建辅助函数,使[b, 256, h, w] = > [b, 512, h, w]
        self.blk3 = ResBlk(256, 512)
        # 构建辅助函数,使[b, 512, h, w] = > [b, 1024, h, w]
        self.blk4 = ResBlk(512, 1024)

接下来构建ResNet-18的forward函数

代码语言:javascript
复制
def forward(self, x):
    x = F.relu(self.conv1(x))
    # [b, 64, h, w] => [b, 1024, h, w]
    x = self.blk1(x)
    x = self.blk2(x)
    x = self.blk3(x)
    x = self.blk4(x)

由于我们要进行10分类问题,要将添加代码

代码语言:javascript
复制
self.outlayer = nn.Linear(1024, 10)

代码语言:javascript
复制
x = self.outlayer(x)
return x

为确定具体维度大小,我们先构建假数据

代码语言:javascript
复制
def main():
    blk = ResBlk(64, 128)
    tmp = torch.randn(2, 3, 32, 32)
    out = blk(tmp)
    print(out.shape)

if __name__ == "__main__":
    main()

此时代码为

代码语言:javascript
复制
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResBlk(nn.Module):
    # 与上节一样,同样resnet的block单元,继承nn模块
    def __init__(self, ch_in, ch_out):
        super(ResBlk, self).__init__()
        # 完成初始化

        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        # 进行正则化处理,以使train过程更快更稳定
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        self.extra = nn.Sequential()

        if ch_out != ch_in:
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1),
                nn.BatchNorm2d(ch_out),
            )



    def forward(self, x):
        # 这里输入的是[b, ch, h, w]
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))


        out = self.extra(x) + out
        # 这便是element.wise add,实现了[b, ch_in, h, w] 和 [b, ch_out, h, w]两个的相加

        return out


class ResNet(nn.Module):

    def __init__(self):
        super(ResNet, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64)
        )
        # 紧跟着要进行四次这样的单元
        # 构建辅助函数,使[b, 64, h, w] => [b, 128, h, w]
        self.blk1 = ResBlk(64, 128)
        # 构建辅助函数,使[b, 128, h, w] = > [b, 256, h, w]
        self.blk2 = ResBlk(128, 256)
        # 构建辅助函数,使[b, 256, h, w] = > [b, 512, h, w]
        self.blk3 = ResBlk(256, 512)
        # 构建辅助函数,使[b, 512, h, w] = > [b, 1024, h, w]
        self.blk4 = ResBlk(512, 1024)

        self.outlayer = nn.Linear(1024, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        # [b, 64, h, w] => [b, 1024, h, w]
        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)

        x = self.outlayer(x)
        return x

def main():
    blk = ResBlk(64, 128)
    tmp = torch.randn(2, 3, 32, 32)
    out = blk(tmp)
    print(out.shape)

if __name__ == "__main__":
    main()
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-01-21,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档