前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TorchVision对象检测RetinaNet推理演示

TorchVision对象检测RetinaNet推理演示

作者头像
OpenCV学堂
发布2022-10-10 11:40:40
8210
发布2022-10-10 11:40:40
举报
文章被收录于专栏:贾志刚-OpenCV学堂

点击上方蓝字关注我们

微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识

torchvision对象检测介绍

Pytorch1.11版本以上支持Torchvision高版本支持以下对象检测模型的迁移学习:

代码语言:javascript
复制
- Faster-RCNN- Mask-RCNN- FCOS- RetinaNet- SSD- KeyPointsRCNN

其中基于COCO的预训练模型mAP对应关系如下:

最近一段时间本人已经全部亲测,都可以转换为ONNX格式模型,都可以支持ONNXRUNTIME框架的Python版本与C++版本推理,本文以RetinaNet为例,演示了从模型下载到导出ONNX格式,然后基于ONNXRUNTIME推理的整个流程。

RetinaNet转ONNX

把模型转换为ONNX格式,Pytorch是原生支持的,只需要把通过torch.onnx.export接口,填上相关的参数,然后直接运行就可以生成ONNX模型文件。相关的转换代码如下:

代码语言:javascript
复制
model = tv.models.detection.retinanet_resnet50_fpn(pretrained=True)
dummy_input = torch.randn(1, 3, 1333, 800)
model.eval()
model(dummy_input)
im = torch.zeros(1, 3, 1333, 800).to("cpu")
torch.onnx.export(model, im,
                    "retinanet_resnet50_fpn.onnx",
                    verbose=False,
                    opset_version=11,
                    training=torch.onnx.TrainingMode.EVAL,
                    do_constant_folding=True,
                    input_names=['input'],
                    output_names=['output'],
                    dynamic_axes={'input': {0: 'batch', 2: 'height', 3: 'width'}}
                  )

运行时候控制台会有一系列的警告输出,但是绝对不影响模型转换,影响不影响精度我还没做个仔细的对比。

模型转换之后,可以直接查看模型的输入与输出结构,图示如下:

RetinaNet的ONNX格式推理

基于Python版本的ONNXRUNTIME完成推理演示,这个跟我之前写过一篇文章Faster-RCNN的ONNX推理演示非常相似,大概是去年写的,链接在这里:

代码很简单,只有三十几行,Python就是方便使用,这里最需要注意的是输入图像的预处理必须是RGB格式,需要归一化到0~1之间。对得到的三个输出层分别解析,就可以获取到坐标(boxes里面包含的实际坐标,无需转换),推理部分的代码如下:

代码语言:javascript
复制
import onnxruntime as ort
import cv2 as cv
import numpy as np
import torchvision


coco_names = {'0': 'background', '1': 'person', '2': 'bicycle', '3': 'car', '4': 'motorcycle', '5': 'airplane', '6': 'bus',
         '7': 'train', '8': 'truck', '9': 'boat', '10': 'traffic light', '11': 'fire hydrant', '13': 'stop sign',
         '14': 'parking meter', '15': 'bench', '16': 'bird', '17': 'cat', '18': 'dog', '19': 'horse', '20': 'sheep',
         '21': 'cow', '22': 'elephant', '23': 'bear', '24': 'zebra', '25': 'giraffe', '27': 'backpack',
         '28': 'umbrella', '31': 'handbag', '32': 'tie', '33': 'suitcase', '34': 'frisbee', '35': 'skis',
         '36': 'snowboard', '37': 'sports ball', '38': 'kite', '39': 'baseball bat', '40': 'baseball glove',
         '41': 'skateboard', '42': 'surfboard', '43': 'tennis racket', '44': 'bottle', '46': 'wine glass',
         '47': 'cup', '48': 'fork', '49': 'knife', '50': 'spoon', '51': 'bowl', '52': 'banana', '53': 'apple',
         '54': 'sandwich', '55': 'orange', '56': 'broccoli', '57': 'carrot', '58': 'hot dog', '59': 'pizza',
         '60': 'donut', '61': 'cake', '62': 'chair', '63': 'couch', '64': 'potted plant', '65': 'bed',
         '67': 'dining table', '70': 'toilet', '72': 'tv', '73': 'laptop', '74': 'mouse', '75': 'remote',
         '76': 'keyboard', '77': 'cell phone', '78': 'microwave', '79': 'oven', '80': 'toaster', '81': 'sink',
         '82': 'refrigerator', '84': 'book', '85': 'clock', '86': 'vase', '87': 'scissors', '88': 'teddybear',
         '89': 'hair drier', '90': 'toothbrush'}

transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

sess_options = ort.SessionOptions()
src = cv.imread("D:/images/mmc.png")
cv.namedWindow("Retina-Net Detection Demo", cv.WINDOW_AUTOSIZE)
image = cv.cvtColor(src, cv.COLOR_BGR2RGB)
blob = transform(image)
c, h, w = blob.shape
input_x = blob.view(1, c, h, w)
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input_x)}
ort_outs = ort_session.run(None, ort_inputs)
#  (N,4) dimensional array containing the absolute bounding-box
boxes = ort_outs[0]
scores = ort_outs[1]
labels = ort_outs[2]
print(boxes.shape, boxes.dtype, labels.shape, labels.dtype, scores.shape, scores.dtype)

index = 0
for x1, y1, x2, y2 in boxes:
    if scores[index] > 0.65:
        cv.rectangle(src, (np.int32(x1), np.int32(y1)),
                     (np.int32(x2), np.int32(y2)), (140, 199, 0), 2, 8, 0)
        label_id = labels[index]
        label_txt = coco_names[str(label_id)]
        cv.putText(src, label_txt, (np.int32(x1), np.int32(y1)), cv.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255), 1)
    index += 1
cv.imshow("Retina-Net Detection Demo", src)
cv.imwrite("D:/mmc_result.png", src)
cv.waitKey(0)
cv.destroyAllWindows()

运行结果如下:

扫码获取Pytorch与TorchVision视频教程

扫码查看OpenCV+OpenVIO+Pytorch系统化学习路线图

 推荐阅读 

CV全栈开发者说 - 从传统算法到深度学习怎么修炼

2022入坑深度学习,我选择Pytorch框架!

Pytorch轻松实现经典视觉任务

教程推荐 | Pytorch框架CV开发-从入门到实战

OpenCV4 C++学习 必备基础语法知识三

OpenCV4 C++学习 必备基础语法知识二

OpenCV4.5.4 人脸检测+五点landmark新功能测试

OpenCV4.5.4人脸识别详解与代码演示

OpenCV二值图象分析之Blob分析找圆

OpenCV4.5.x DNN + YOLOv5 C++推理

OpenCV4.5.4 直接支持YOLOv5 6.1版本模型推理

OpenVINO2021.4+YOLOX目标检测模型部署测试

比YOLOv5还厉害的YOLOX来了,官方支持OpenVINO推理

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-10-09,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenCV学堂 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
人脸识别
腾讯云神图·人脸识别(Face Recognition)基于腾讯优图强大的面部分析技术,提供包括人脸检测与分析、比对、搜索、验证、五官定位、活体检测等多种功能,为开发者和企业提供高性能高可用的人脸识别服务。 可应用于在线娱乐、在线身份认证等多种应用场景,充分满足各行业客户的人脸属性识别及用户身份确认等需求。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档