专栏首页贾志刚-OpenCV学堂轻松学Pytorch-使用ResNet50实现图像分类

轻松学Pytorch-使用ResNet50实现图像分类

微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识

Hello大家好,这篇文章给大家详细介绍一下pytorch中最重要的组件torchvision,它包含了常见的数据集、模型架构与预训练模型权重文件、常见图像变换、计算机视觉任务训练。可以是说是pytorch中非常有用的模型迁移学习神器。本文将会介绍如何使用torchvison的预训练模型ResNet50实现图像分类。

模型

Torchvision.models包里面包含了常见的各种基础模型架构,主要包括:

AlexNet VGG ResNet SqueezeNet DenseNet Inception v3 GoogLeNet ShuffleNet v2 MobileNet v2 ResNeXt Wide ResNet MNASNet

这里我选择了ResNet50,基于ImageNet训练的基础网络来实现图像分类, 网络模型下载与加载如下:

model = torchvision.models.resnet50(pretrained=True).eval().cuda()
tf = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )])

使用模型实现图像分类

这里首先需要加载ImageNet的分类标签,目的是最后显示分类的文本标签时候使用。然后对输入图像完成预处理,使用ResNet50模型实现分类预测,对预测结果解析之后,显示标签文本,完整的代码演示如下:

with open('imagenet_classes.txt') as f:
    labels = [line.strip() for line in f.readlines()]

src = cv.imread("D:/images/space_shuttle.jpg") # aeroplane.jpg
image = cv.resize(src, (224, 224))
image = np.float32(image) / 255.0
image[:,:,] -= (np.float32(0.485), np.float32(0.456), np.float32(0.406))
image[:,:,] /= (np.float32(0.229), np.float32(0.224), np.float32(0.225))
image = image.transpose((2, 0, 1))
input_x = torch.from_numpy(image).unsqueeze(0)
print(input_x.size())
pred = model(input_x.cuda())
pred_index = torch.argmax(pred, 1).cpu().detach().numpy()
print(pred_index)
print("current predict class name : %s"%labels[pred_index[0]])
cv.putText(src, labels[pred_index[0]], (50, 50), cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)
cv.imshow("input", src)
cv.waitKey(0)
cv.destroyAllWindows()

运行结果如下:

转ONNX支持

在torchvision中的模型基本上都可以转换为ONNX格式,而且被OpenCV DNN模块所支持,所以,很方便的可以对torchvision自带的模型转为ONNX,实现OpenCV DNN的调用,首先转为ONNX模型,直接使用torch.onnx.export即可转换(还不知道怎么转,快点看前面的例子)。转换之后使用OpenCV DNN调用的代码如下:

with open('imagenet_classes.txt') as f:
    labels = [line.strip() for line in f.readlines()]
net = cv.dnn.readNetFromONNX("resnet.onnx")
src = cv.imread("D:/images/messi.jpg")  # aeroplane.jpg
image = cv.resize(src, (224, 224))
image = np.float32(image) / 255.0
image[:, :, ] -= (np.float32(0.485), np.float32(0.456), np.float32(0.406))
image[:, :, ] /= (np.float32(0.229), np.float32(0.224), np.float32(0.225))
blob = cv.dnn.blobFromImage(image, 1.0, (224, 224), (0, 0, 0), False)
net.setInput(blob)
probs = net.forward()
index = np.argmax(probs)
cv.putText(src, labels[index], (50, 50), cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)
cv.imshow("input", src)
cv.waitKey(0)
cv.destroyAllWindows()

运行结果见上图,这里就不再贴了。

本文分享自微信公众号 - OpenCV学堂(CVSCHOOL),作者:gloomyfish

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2020-07-19

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 教程 | OpenCV Grabcut对象分割

    Grabcut是基于图割(graph cut)实现的图像分割算法,它需要用户输入一个bounding box作为分割目标位置,实现对目标与背景的分离/分割,这个...

    OpenCV学堂
  • Python OpenCV像素操作

    Python OpenCV像素操作 环境声明 : Python3.6 + OpenCV3.3 + PyCharm IDE 首先要引入OpenCV和Numpy支持...

    OpenCV学堂
  • Tensorflow Object Detection API 终于支持tensorflow1.x与tensorflow2.x了

    基于tensorflow框架构建的快速对象检测模型构建、训练、部署框架,是针对计算机视觉领域对象检测任务的深度学习框架。之前tensorflow2.x一直不支持...

    OpenCV学堂
  • 安装Windows Performance Toolkit进行0.1微秒级CPU监控

    我研究了WPR,它的最小时间单位是0.1微秒,即10000个单位是1毫秒,精细度非常高

    我爱你的一诺
  • Matplotlib入门

    qiangbo.space/2018-04-06/matplotlib_l1/ 入门代码示例 import matplotlib.pyplot as plt ...

    林清猫耳
  • PHP开发——yii2多图上传组件的使用

    最近在使用yii2开发一个表单页面的时候,有多图上传的需求,稍微找了找这方面的组件,基本都安利fileInput这个组件,于是就尝试着使用这个库来完成后端表单页...

    Originalee
  • 使用Python处理Word文档

    1. 前言2. 使用Document对象创建文档3. 在word文档中使用标题4. 在word文档中使用段落5. 在word文档中使用列表6. 在word文档中...

    LogicPanda
  • bootstrap实战 作品展示站点

    用户5760343
  • 基于linux的嵌入IPv4协议栈的内容过滤防火墙系统(6)-系统效果

    下图是本程序所使用的系统:redhat7.2,这是它的一个图形界面,叫Gnome。

    源哥
  • python中For循环

    py3study

扫码关注云+社区

领取腾讯云代金券