前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch +ResNet34实现 图像分类

PyTorch +ResNet34实现 图像分类

作者头像
大数据技术与机器学习
发布2022-03-29 19:33:33
3.8K0
发布2022-03-29 19:33:33
举报

1、 RestNet网络

1.1、 RestNet网络结构

ResNet在2015年被提出,在ImageNet比赛classification任务上获得第一名,因为它“简单与实用”并存,之后很多方法都建立在ResNet50或者ResNet101的基础上完成的,检测,分割,识别等领域里得到广泛的应用。它使用了一种连接方式叫做“shortcut connection”,顾名思义,shortcut就是“抄近道”的意思,下面是这个resnet的网络结构:

它对每层的输入做一个reference(X), 学习形成残差函数, 而不是学习一些没有reference(X)的函数。这种残差函数更容易优化,能使网络层数大大加深。在上图的残差块中它有二层,如下表达式, 其中σ代表非线性函数ReLU。

然而实验证明x已经足够了,不需要再搞个维度变换,除非需求是某个特定维度的输出,如是将通道数翻倍,如下图所示:

由上图,我们可以清楚的看到“实线”和“虚线”两种连接方式, 实线的Connection部分 (第一个粉色矩形和第三个粉色矩形) 都是执行3x3x64的卷积,他们的channel个数一致,所以采用计算方式:

Y = F(x) + x,虚线的Connection部分 (第一个绿色矩形和第三个绿色矩形) 分别是3x3x64和3x3x128的卷积操作,他们的channel个数不同(64和128),所以采用计算方式:y=F(x)+Wx 。其中W是卷积操作,用来调整x的channel维度。

在计算机视觉里,网络的深度是实现网络好的效果的重要因素,输入特征的“等级”随增网络深度的加深而变高。然而在网络深度不断加深的情况下,梯度弥散/爆炸成为训练深层次的网络的障碍,导致导致网络无法收敛。虽然,归一初始化,各层输入归一化,使得可以收敛的网络的深度提升为原来的十倍。虽然网络收敛了,但网络却开始退化 (增加网络层数却导致更大的误差), 如下图所示:

由上图可知,在一个浅层网络的基础上叠加y=x的层(称identity mappings,恒等映射),可以让网络随深度增加而不退化。这反映了多层非线性网络无法逼近恒等映射网络。

但是,在深度学习中我们希望有更好性能的网络,而网络不退化则不是我们的目的。在 RestNet网络中学习的残差函数是F(x) = H(x) - x, 这里如果F(x) = 0, 那么就是上面提到的恒等映射(H(x) = x)。事实上,RestNet是“shortcut connections”的在connections是在恒等映射下的特殊情况,它没有引入额外的参数和计算的复杂度。假如优化目标函数是逼近一个恒等映射, 而不是0映射(F(x) = 0)或者说恒等映射,那么学习找到对恒等映射的扰动会比重新学习一个映射函数要容易。

1.2、残差块的两种结构

这是文章里面的图,我们可以看到一个“弯弯的弧线“这个就是所谓的”shortcut connection“,也是文中提到identity mapping,这张图也诠释了ResNet的真谛,当然大家可以放心,真正在使用的ResNet模块并不是这么单一,文章中就提出了两种方式:

这两种结构分别针对ResNet34(左图)和ResNet50/101/152(右图),一般称整个结构为一个“building block” 。其中右图又称为“bottleneck design”,目的就是为了降低参数的数目,实际中,考虑计算的成本,对残差块做了计算优化,即将两个3x3的卷积层替换为1x1 + 3x3 + 1x1,如右图所示。新结构中的中间3x3的卷积层首先在一个降维1x1卷积层下减少了计算,然后在另一个1x1的卷积层下做了还原,既保持了精度又减少了计算量。第一个1x1的卷积把256维channel降到64维,然后在最后通过1x1卷积恢复,整体上用的参数数目:1x1x256x64 + 3x3x64x64 + 1x1x64x256 = 69632,而不使用bottleneck的话就是两个3x3x256的卷积,参数数目: 3x3x256x256x2 = 1179648,差了16.94倍。

对于常规ResNet,可以用于34层或者更少的网络中,对于Bottleneck Design的ResNet通常用于更深的如101这样的网络中,目的是减少计算和参数量。

1.3、ResNet50和ResNet101简单讲解

这里把ResNet50和ResNet101特别提出,主要因为它们的使用率很高,所以需要做特别的说明。给出了它们具体的结构:

上表是Resnet不同的结构,上表一共提出了5中深度的ResNet,分别是18,34,50,101和152,首先看表的最左侧,我们发现所有的网络都分成5部分,分别是:conv1,conv2_x,conv3_x,conv4_x,conv5_x,之后的其他论文也会专门用这个称呼指代ResNet50或者101的每部分。例如:101-layer那列,101-layer指的是101层网络,首先有个输入7x7x64的卷积,然后经过3 + 4 + 23 + 3 = 33个building block,每个block为3层,所以有33 x 3 = 99层,最后有个fc层(用于分类),所以1 + 99 + 1 = 101层,确实有101层网络;注:101层网络仅仅指卷积或者全连接层,而激活层或者Pooling层并没有计算在内;我们关注50-layer和101-layer这两列,可以发现,它们唯一的不同在于conv4_x,ResNet50有6个block,而ResNet101有23个block,两者之间差了17个block,也就是17 x 3 = 51层。

本文使用 PyTorch 构建卫星图像分类任务。使用 ResNet34 模型。

本文不做细粒度的分类。使用 Kaggle 的一个数据集,只有四个类(四种类型的卫星图像)。

本文在这里介绍:

首先,看看 Kaggle 卫星图像分类。

使用预训练的 PyTorch ResNet34 模型进行卫星图像分类。

在训练保存训练好的模型后,对来自互联网的图像进行推理。

卫星图像分类数据集

卫星图像分类数据集Satellite Image Classification包含来自传感器和谷歌地图快照的大约 5600 张图像。

它有属于 4 个不同类别的卫星图像。

cloudy:从卫星拍摄的 1500 张云图像。

desert:从卫星拍摄的 1131 张沙漠图像。

green_area:主要是森林覆盖的卫星图像。1500 张图片。

water:1500张湖泊和其他水体的卫星图像。

以下是数据集的目录结构。

有四个目录,每个目录都与类名匹配,这些目录包含 .jpg 格式的相应图像。

看一下数据集中的一些图像。

这里要注意的一件事是沙漠和多云类图像是 256×256 的彩色图像。但是 green_area 和 water class 图像只是 64×64 维图像,它们也是彩色图像。但是如果在增强它们的同时增加它们的图像大小,它们的特征可能不如其他两个类那么清晰。

目录结构

看看这个项目的目录结构。

在父项目目录中有:

包含数据子目录的输入目录,该目录又包含数据集类文件夹。test_data 子目录包含互联网的图像,在训练模型后用于推理。这些是全新的图像,在经过训练的 PyTorch ResNet34 模型中是看不到的。

输出目录包含训练和推理流程生成的图像、图和训练模型。

5 个 Python 文件。稍后介绍这些内容。

PyTorch版本 1.9.0

使用 PyTorch ResNet34 的卫星图像分类

从这里开始编码部分。

有五个 Python 文件。按以下顺序处理它们

  • utils.py
  • datasets.py
  • model.py
  • train.py
  • inference.py 训练完成后有了 PyTorch ResNet34 模型。

辅助函数

两个辅助函数,一个用于保存训练好的模型,另一个用于保存损失和准确度图。

这些函数封装在utils.py文件中

以下代码块包含导入语句和 save_model() 函数。

保存训练的 epoch 、模型状态字典、优化器状态字典, model.pth 中的损失函数。

接下来保存损失和精度图。

save_plots() 函数接受用于训练和验证的相应损失和准确度列表。保存在输出文件夹中。

目前这两个辅助函数足以满足需求。

准备数据集

在准备数据集在datasets.py 文件编写代码。

导入所需的 PyTorch 模块定义一些常量。

使用 20% 的数据进行验证。批大小为 64。如果本地机器上训练面临 GPU 的 OOM(内存不足)问题,那么降低批大小 32 或 16。

训练与验证转换

下一个代码块包含训练和验证转换。

对于训练,除了变换之外增加了图像数量以防止过拟合。在没有增强的情况下,训练准确率很快达到 99% 以上,而验证准确率仍然很低。这些增强主要来自实验,以及最适合该数据集的方法。

此外可以看到正在应用 ImageNet 统计数据进行标准化。因为使用预训练的 ResNet34 模型。

为了验证,调整图像的大小转换为张量,进行标准化。

数据加载器

以上代码为全部datasets.py文件

下一步 准备模型

ResNet34模型

使用 PyTorch ResNet34 模型进行卫星图像分类。

PyTorch 已经为 ResNet34 提供了 ImageNet 预训练模型。只需要使用正确数量的类来更改最后一层。

以下代码在model.py文件

通过 build_model() 函数的参数控制:

是否想要预训练模型。

是否要对中间层进行微调。

类的数量,即 num_classes。

训练脚本

现在准备写训练脚本在train.py文件

上面的代码块导入了所有库模块以及上面编写的模块。还有参数解析器,它控制 --epochs

学习参数,模型与优化器

下一个代码块定义了学习率、计算设备。构建了 ResNet34 模型并定义了优化器和损失函数。

调用 build_model() 函数:

  • pretrained=True
  • fine_tune=False
  • num_classes=len(dataset.classes)

优化器是Adam,学习率为0.001,损失函数是Cross Entropy。

训练与验证函数

训练函数将是 PyTorch 中的标准图像分类训练函数。进行前向传递,计算损失,反向传播梯度,并更新参数。

在每个 epoch 之后,该函数返回该 epoch 的损失和准确度。

接下来是验证函数。

训练循环

接下来编写循环代码

执行train.py开始训练

打开命令行输入

代码语言:javascript
复制
python train.py --epochs 100

训练100个epoch然后打印输出

在每个 epoch 之后,都会打印类别精度。这是验证准确性。到 100 个 epoch 结束时,green_area 和 water 类的准确度低于其他两个类。

图 2. 训练 ResNet34 模型进行卫星图像分类后的准确率。

图 3. 训练 ResNet34 模型后的损失图。

准确率和损失图似乎都有很大的波动。

现在编写执行推理的脚本。

推理脚本

代码在inference.py文件

所有的推理都在 CPU 上。对于图像分类推理,使用 GPU 设备不是强制性的,CPU 就可以了。

加载模型处理转换

下一个代码块定义类名、加载训练好的模型,并定义预处理转换。在inference.py中

对于预处理只需要将图像转换为 PIL 图像格式,调整其大小,将其转换为张量,然后应用归一化。

读取图像与前馈

执行推理脚本

有四个测试图像。运行并查看结果。

执行

代码语言:javascript
复制
python inference.py --input input/test_data/cloudy.jpeg

继续测试下一个

代码语言:javascript
复制
python inference.py --input input/test_data/desert.jpeg

总结

本文构建了一个小型图像分类项目。使用 ResNet34 模型进行卫星图像分类。对新图像进行了推理。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2022-03-06,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 机器学习入门与实战 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 卫星图像分类数据集
  • 目录结构
  • PyTorch版本 1.9.0
  • 使用 PyTorch ResNet34 的卫星图像分类
  • 辅助函数
    • 准备数据集
      • 训练与验证转换
        • 数据加载器
          • ResNet34模型
            • 训练脚本
              • 学习参数,模型与优化器
                • 训练与验证函数
                  • 训练循环
                    • 执行train.py开始训练
                      • 推理脚本
                        • 加载模型处理转换
                          • 读取图像与前馈
                            • 执行推理脚本
                            • 总结
                            领券
                            问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档