前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >CIFAR10数据集实战-ResNet网络构建(下)

CIFAR10数据集实战-ResNet网络构建(下)

作者头像
用户6719124
发布2020-02-24 18:07:15
8770
发布2020-02-24 18:07:15
举报

试运行一下:

发现报错

RuntimeError: Given groups=1, weight of size 128 64 3 3, expected input[2, 3, 32, 32] to have 64 channels, but got 3 channels instead

报错原因为维度输入错误。进行更改

tmp = torch.randn(2, 64, 32, 32)

再次运行

输出为

torch.Size([2, 128, 32, 32])

这里注意到由[2, 64, 32, 32]到[2, 128, 32, 32],channel数量翻倍,而长和宽没有变化。这样势必会导致x的维度会越来越大。

因此为实现让长和宽能进行减小后,再运算。我们在卷积层中加入了stride设置。

代码改为

blk = ResBlk(64, 128, stride=2)

同样在定义时加上stride

def __init__(self, ch_in, ch_out, stride=1):

然后在卷积层中设置

self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)

另外

Short cut中的stride也要与第一层卷积层保持一致

改代码为

if ch_out != ch_in:
    self.extra = nn.Sequential(
        nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),

这时运行的输出为

torch.Size([2, 128, 16, 16])

这时为检测整个数据是否match,我们对x进行定义

x = torch.randn(2, 3, 32, 32)
model = ResNet()
out = model(x)
print('resnet:', out.shape)

运行后报错:

RuntimeError: size mismatch, m1: [65536 x 32], m2: [1024 x 10] at ..\aten\src\TH/generic/THTensorMath.cpp:961

由报错结果上看,out.layer上出现错误

在out.layer上添加代码

print('after conv:', x.shape)
x = self.outlayer(x)

该段输出

after conv: torch.Size([2, 1024, 32, 32])

首先为减小数据量 我们在4个ResNet单元层添加stride

self.blk1 = ResBlk(64, 128, stride=2)

另外我闷在输出层前加入pooling层

x = F.adaptive_avg_pool2d(x, [1, 1])
print('after pool:', x.shape)
x = x.view(x.size(0), -1)

此时输出为

after conv: torch.Size([2, 1024, 2, 2])
after pool: torch.Size([2, 1024, 1, 1])
resnet: torch.Size([2, 10])

最后输出为10个,对应于十分类问题。

将没用的输出信息注释掉,继续完善代码

回到main.py文件中去

引入工具包处改为

# from LeNet5 import LeNet5
from resnet import ResNet

并将

model = LeNet5().to(device)

改为

model = ResNet().to(device)

其余地方不需要改

运行main.py文件

结果显示为

acc: 0.6112
acc: 0.7071
acc: 0.7558
acc: 0.7831
acc: 0.7832
acc: 0.795
acc: 0.7958
acc: 0.7879
acc: 0.7995

正确率随着epoch的运行持续增加

至此,ResNet和LeNet5都已介绍完毕。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-02-08,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档