前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【实战项目代码分享】计算机视觉入门教程&实战项目代码

【实战项目代码分享】计算机视觉入门教程&实战项目代码

作者头像
机器视觉CV
发布2020-07-23 11:21:57
7160
发布2020-07-23 11:21:57
举报
文章被收录于专栏:机器视觉CV机器视觉CV

对理论知识有了了解后,这里介绍两个实战项目,分别是基于keras的多标签图像分类以及基于 Pytorch 的迁移学习教程。

基于 Keras 的多标签图像分类教程

首先是采用的多标签图像数据集,如下所示,6 个类别的一个衣服图像数据集:

整个多标签分类的项目结构如下所示:

代码语言:javascript
复制
├── classify.py
├── dataset
│   ├── black_jeans [344 entries
│   ├── blue_dress [386 entries]
│   ├── blue_jeans [356 entries]
│   ├── blue_shirt [369 entries]
│   ├── red_dress [380 entries]
│   └── red_shirt [332 entries]
├── examples
│   ├── example_01.jpg
│   ├── example_02.jpg
│   ├── example_03.jpg
│   ├── example_04.jpg
│   ├── example_05.jpg
│   ├── example_06.jpg
│   └── example_07.jpg
├── fashion.model
├── mlb.pickle
├── plot.png
├── pyimagesearch
│   ├── __init__.py
│   └── smallervggnet.py
├── search_bing_api.py
└── train.py

准备好数据后,就是开始选择使用的网络结构,这里采用 Keras 搭建一个简化版本的 VGGNet,然后就是训练模型和测试模型的代码,这里需要提前安装好的库是:

代码语言:javascript
复制
pip install keras, scikit-learn, matplotlib, imutils, opencv-python

训练过程的实验图如下所示:

部分测试结果:

代码语言:javascript
复制
Using TensorFlow backend.
[INFO] loading network...
[INFO] classifying image...
black: 0.00%
blue: 3.58%
dress: 95.14%
jeans: 0.00%
red: 100.00%
shirt: 64.02%

具体代码和详细教程可以扫下方二维码关注【算法猿的成长】,后台回复:多标签,即可获取

长按上方二维码 2 秒

基于 Pytorch 的迁移学习教程

第二份实战教程就是使用 Pytorch 实现迁移学习,迁移学习也是计算机视觉里非常常用的一个做法,也就是利用在 ImageNet 上预训练好的模型,在我们自定义的数据集上重新训练得到在自定义数据集上性能很好的模型。

首先是展示我们自定义的一个二分类数据集的图片,分别是蚂蚁和蜜蜂两个类别:

接下来就是加载数据集、训练模型代码的实现,其中最核心的就是迁移学习部分,对网络的微调训练:

代码语言:javascript
复制
# 加载 resnet18 网络模型,并且设置加载预训练模型
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
# 修改输出层的输出数量,本次采用的数据集类别为 2
model_ft.fc = nn.Linear(num_ftrs, 2)

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# 对所有网络层参数进行更新
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# 学习率策略,每 7 个 epochs 乘以 0.1
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

最终模型的分类结果如下所示:

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-05-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 机器视觉CV 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 基于 Keras 的多标签图像分类教程
  • 基于 Pytorch 的迁移学习教程
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档