前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >torchvision中怎么加载本地模型实现训练与推理

torchvision中怎么加载本地模型实现训练与推理

作者头像
OpenCV学堂
发布2023-09-22 15:38:38
3360
发布2023-09-22 15:38:38
举报
Torchvision介绍

Torchvision是基于Pytorch的视觉深度学习迁移学习训练框架,当前支持的图像分类、对象检测、实例分割、语义分割、姿态评估模型的迁移学习训练与评估。支持对数据集的合成、变换、增强等,此外还支持预训练模型库下载相关的模型,直接预测推理。

预训练模型使用

Torchvision从0.13版本开始预训练模型支持多源backbone设置,以图像分类的ResNet网络模型为例:

支持多个不同的数据集上不同精度的预训练模型,下载模型,转化为推理模型

对输入图像实现预处理

本地加载模型

Torchvision中支持的预训练模型当你使用的时候都会加载模型的预训练模型,然后才可以加载你自己的权重文件,如果你不想加载torchvision的预训练模型,只想从本地加载pt或者pth文件实现推理或者训练的时候,一定要通过下面的方式完成,以Faster-RCNN为例:

代码语言:javascript
复制
# Load the model from local host
num_classes = len(self.labels)
self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
                                                             num_classes=num_classes, pretrained_backbone=False)
self.model.load_state_dict(torch.load(self.model_file))
self.model.eval()
self.transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
# 使用GPU
train_on_gpu = torch.cuda.is_available()
if train_on_gpu:
    self.model.cuda()

就这样解锁了在torchvision框架下如何从本地加载预训练模型文件或者定义训练模型文件。

推荐阅读

用注意力机制魔改YOLOv5涨点

一个教程搞定YOLOv8 对象检测+跟踪+分割

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2023-09-21 18:57,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenCV学堂 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档