前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >tochvision轻松支持十种图像分类模型迁移学习

tochvision轻松支持十种图像分类模型迁移学习

作者头像
OpenCV学堂
发布2022-10-09 10:18:10
4890
发布2022-10-09 10:18:10
举报

点击上方蓝字关注我们

微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识

torchvision分类介绍

Torchvision高版本支持各种SOTA的图像分类模型,同时还支持不同数据集分类模型的预训练模型的切换。使用起来十分方便快捷,Pytroch中支持两种迁移学习方式,分别是:

代码语言:javascript
复制
- Finetune模式基于预训练模型,全链路调优参数- 冻结特征层模式这种方式只修改输出层的参数,CNN部分的参数冻结

上述两种迁移方式,分别适合大量数据跟少量数据,前一种方式计算跟训练时间会比第二种方式要长点,但是针对大量自定义分类数据效果会比较好。

自定义分类模型修改与训练

加载模型之后,feature_extracting 为true表示冻结模式,否则为finetune模式,相关的代码如下:

代码语言:javascript
复制
def set_parameter_requires_grad(model, feature_extracting):     if feature_extracting:         for param in model.parameters():             param.requires_grad = False

以resnet18为例,修改之后的自定义训练代码如下:

代码语言:javascript
复制
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 5.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
model_ft.fc = nn.Linear(num_ftrs, 5)

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=25)

数据集是flowers-dataset,有五个分类分别是:

代码语言:javascript
复制
daisydandelionrosessunflowerstulips

全链路调优,迁移学习训练CNN部分的权重参数

代码语言:javascript
复制
Epoch 0/24
----------
train Loss: 1.3993 Acc: 0.5597
valid Loss: 1.8571 Acc: 0.7073
Epoch 1/24
----------
train Loss: 1.0903 Acc: 0.6580
valid Loss: 0.6150 Acc: 0.7805
Epoch 2/24
----------
train Loss: 0.9095 Acc: 0.6991
valid Loss: 0.4386 Acc: 0.8049
Epoch 3/24
----------
train Loss: 0.7628 Acc: 0.7349
valid Loss: 0.9111 Acc: 0.7317
Epoch 4/24
----------
train Loss: 0.7107 Acc: 0.7669
valid Loss: 0.4854 Acc: 0.8049
Epoch 5/24
----------
train Loss: 0.6231 Acc: 0.7793
valid Loss: 0.6822 Acc: 0.8049
Epoch 6/24
----------
train Loss: 0.5768 Acc: 0.8033
valid Loss: 0.2748 Acc: 0.8780
Epoch 7/24
----------
train Loss: 0.5448 Acc: 0.8110
valid Loss: 0.4440 Acc: 0.7561
Epoch 8/24
----------
train Loss: 0.5037 Acc: 0.8170
valid Loss: 0.2900 Acc: 0.9268
Epoch 9/24
----------
train Loss: 0.4836 Acc: 0.8360
valid Loss: 0.7108 Acc: 0.7805
Epoch 10/24
----------
train Loss: 0.4663 Acc: 0.8369
valid Loss: 0.5868 Acc: 0.8049
Epoch 11/24
----------
train Loss: 0.4276 Acc: 0.8504
valid Loss: 0.6998 Acc: 0.8293
Epoch 12/24
----------
train Loss: 0.4299 Acc: 0.8529
valid Loss: 0.6449 Acc: 0.8049
Epoch 13/24
----------
train Loss: 0.4256 Acc: 0.8567
valid Loss: 0.7897 Acc: 0.7805
Epoch 14/24
----------
train Loss: 0.4062 Acc: 0.8559
valid Loss: 0.5855 Acc: 0.8293
Epoch 15/24
----------
train Loss: 0.4030 Acc: 0.8545
valid Loss: 0.7336 Acc: 0.7805
Epoch 16/24
----------
train Loss: 0.3786 Acc: 0.8730
valid Loss: 1.0429 Acc: 0.7561
Epoch 17/24
----------
train Loss: 0.3699 Acc: 0.8763
valid Loss: 0.4549 Acc: 0.8293
Epoch 18/24
----------
train Loss: 0.3394 Acc: 0.8788
valid Loss: 0.2828 Acc: 0.9024
Epoch 19/24
----------
train Loss: 0.3300 Acc: 0.8834
valid Loss: 0.6766 Acc: 0.8537
Epoch 20/24
----------
train Loss: 0.3136 Acc: 0.8906
valid Loss: 0.5893 Acc: 0.8537
Epoch 21/24
----------
train Loss: 0.3110 Acc: 0.8901
valid Loss: 0.4909 Acc: 0.8537
Epoch 22/24
----------
train Loss: 0.3141 Acc: 0.8931
valid Loss: 0.3930 Acc: 0.9024
Epoch 23/24
----------
train Loss: 0.3106 Acc: 0.8887
valid Loss: 0.3079 Acc: 0.9024
Epoch 24/24
----------
train Loss: 0.3143 Acc: 0.8923
valid Loss: 0.5122 Acc: 0.8049
Training complete in 25m 34s
Best val Acc: 0.926829

冻结CNN部分,只训练全连接分类权重

代码语言:javascript
复制
Params to learn:
         fc.weight
         fc.bias
Epoch 0/24
----------
train Loss: 1.0217 Acc: 0.6465
valid Loss: 1.5317 Acc: 0.8049
Epoch 1/24
----------
train Loss: 0.9569 Acc: 0.6947
valid Loss: 1.2450 Acc: 0.6829
Epoch 2/24
----------
train Loss: 1.0280 Acc: 0.6999
valid Loss: 1.5677 Acc: 0.7805
Epoch 3/24
----------
train Loss: 0.8344 Acc: 0.7426
valid Loss: 1.1053 Acc: 0.7317
Epoch 4/24
----------
train Loss: 0.9110 Acc: 0.7250
valid Loss: 1.1148 Acc: 0.7561
Epoch 5/24
----------
train Loss: 0.9049 Acc: 0.7346
valid Loss: 1.1541 Acc: 0.6341
Epoch 6/24
----------
train Loss: 0.8538 Acc: 0.7465
valid Loss: 1.4098 Acc: 0.8293
Epoch 7/24
----------
train Loss: 0.9041 Acc: 0.7349
valid Loss: 0.9604 Acc: 0.7561
Epoch 8/24
----------
train Loss: 0.8885 Acc: 0.7468
valid Loss: 1.2603 Acc: 0.7561
Epoch 9/24
----------
train Loss: 0.9257 Acc: 0.7333
valid Loss: 1.0751 Acc: 0.7561
Epoch 10/24
----------
train Loss: 0.8637 Acc: 0.7492
valid Loss: 0.9748 Acc: 0.7317
Epoch 11/24
----------
train Loss: 0.8686 Acc: 0.7517
valid Loss: 1.0194 Acc: 0.8049
Epoch 12/24
----------
train Loss: 0.8492 Acc: 0.7572
valid Loss: 1.0378 Acc: 0.7317
Epoch 13/24
----------
train Loss: 0.8773 Acc: 0.7432
valid Loss: 0.7224 Acc: 0.8049
Epoch 14/24
----------
train Loss: 0.8919 Acc: 0.7473
valid Loss: 1.3564 Acc: 0.7805
Epoch 15/24
----------
train Loss: 0.8634 Acc: 0.7490
valid Loss: 0.7822 Acc: 0.7805
Epoch 16/24
----------
train Loss: 0.8069 Acc: 0.7644
valid Loss: 1.4132 Acc: 0.7561
Epoch 17/24
----------
train Loss: 0.8589 Acc: 0.7492
valid Loss: 0.9812 Acc: 0.8049
Epoch 18/24
----------
train Loss: 0.7677 Acc: 0.7688
valid Loss: 0.7176 Acc: 0.8293
Epoch 19/24
----------
train Loss: 0.8044 Acc: 0.7514
valid Loss: 1.4486 Acc: 0.7561
Epoch 20/24
----------
train Loss: 0.7916 Acc: 0.7564
valid Loss: 1.0575 Acc: 0.8049
Epoch 21/24
----------
train Loss: 0.7922 Acc: 0.7647
valid Loss: 1.0406 Acc: 0.7805
Epoch 22/24
----------
train Loss: 0.8187 Acc: 0.7647
valid Loss: 1.0965 Acc: 0.7561
Epoch 23/24
----------
train Loss: 0.8443 Acc: 0.7503
valid Loss: 1.6163 Acc: 0.7317
Epoch 24/24
----------
train Loss: 0.8165 Acc: 0.7583
valid Loss: 1.1680 Acc: 0.7561
Training complete in 20m 7s
Best val Acc: 0.829268

测试结果:

零代码训练演示

我已经完成torchvision中分类模型自定义数据集迁移学习的代码封装与开发,支持基于收集到的数据集,零代码训练,生成模型。图示如下:

轻松支持十种主流的CNN模型

代码语言:javascript
复制
self.models_combox.addItem("resnet18")self.models_combox.addItem("resnet34")self.models_combox.addItem("resnet50")self.models_combox.addItem("resnet101")self.models_combox.addItem("inception")self.models_combox.addItem("densenet")self.models_combox.addItem("wide_resnet50")self.models_combox.addItem("wide_resnet101")self.models_combox.addItem("resnext50_32x4d")self.models_combox.addItem("resnext101_32x8d")

扫码获取YOLOv5 TensorRT INT8量化脚本与视频教程

扫码查看OpenCV+OpenVIO+Pytorch系统化学习路线图

 推荐阅读 

CV全栈开发者说 - 从传统算法到深度学习怎么修炼

2022入坑深度学习,我选择Pytorch框架!

Pytorch轻松实现经典视觉任务

教程推荐 | Pytorch框架CV开发-从入门到实战

OpenCV4 C++学习 必备基础语法知识三

OpenCV4 C++学习 必备基础语法知识二

OpenCV4.5.4 人脸检测+五点landmark新功能测试

OpenCV4.5.4人脸识别详解与代码演示

OpenCV二值图象分析之Blob分析找圆

OpenCV4.5.x DNN + YOLOv5 C++推理

OpenCV4.5.4 直接支持YOLOv5 6.1版本模型推理

OpenVINO2021.4+YOLOX目标检测模型部署测试

比YOLOv5还厉害的YOLOX来了,官方支持OpenVINO推理

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2022-10-08,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenCV学堂 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
人脸识别
腾讯云神图·人脸识别(Face Recognition)基于腾讯优图强大的面部分析技术,提供包括人脸检测与分析、比对、搜索、验证、五官定位、活体检测等多种功能,为开发者和企业提供高性能高可用的人脸识别服务。 可应用于在线娱乐、在线身份认证等多种应用场景,充分满足各行业客户的人脸属性识别及用户身份确认等需求。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档