前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >轻松学Pytorch之Deeplabv3推理

轻松学Pytorch之Deeplabv3推理

作者头像
OpenCV学堂
发布2023-01-04 17:16:05
8170
发布2023-01-04 17:16:05
举报
文章被收录于专栏:贾志刚-OpenCV学堂

微信公众号:OpenCV学堂

Deeplabv3

Torchvision框架中在语义分割上支持的是Deeplabv3语义分割模型,而且支持不同的backbone替换,这些backbone替换包括MobileNetv3、ResNet50、ResNet101。其中MobileNetv3版本训练数据集是COCO子集,类别跟Pascal VOC的20个类别保持一致。这里以它为例,演示一下从模型导出ONNX到推理的全过程。

ONNX格式导出

首先需要把pytorch的模型导出为onnx格式版本,用下面的脚本就好啦:

代码语言:javascript
复制
model = tv.models.segmentation.deeplabv3_mobilenet_v3_large(pretrained=True)
dummy_input = torch.randn(1, 3, 320, 320)
model.eval()
model(dummy_input)
im = torch.zeros(1, 3, 320, 320).to("cpu")
torch.onnx.export(model, im,
                    "deeplabv3_mobilenet.onnx",
                    verbose=False,
                    opset_version=11,
                    training=torch.onnx.TrainingMode.EVAL,
                    do_constant_folding=True,
                    input_names=['input'],
                    output_names=['out', 'aux'],
                    dynamic_axes={'input': {0: 'batch', 2: 'height', 3: 'width'}}
                  )

模型的输入与输出结构如下:

其中out就是我们要解析的语义分割预测结果,input表示支持动态输入格式为NCHW

推理测试

模型推理对图像有个预处理,要求如下:

代码语言:javascript
复制
transform = torchvision.transforms.Compose([
     torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
      std=[0.229, 0.224, 0.225])
 ])

意思是转换为0~1之间的浮点数,然后减去均值除以方差。

剩下部分的代码就比较简单,初始化onnx推理实例,然后完成推理,对结果完成解析,输出推理结果,完整的代码如下:

代码语言:javascript
复制
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

sess_options = ort.SessionOptions()
# Below is for optimizing performance
sess_options.intra_op_num_threads = 24
# sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
ort_session = ort.InferenceSession("deeplabv3_mobilenet.onnx", providers=['CUDAExecutionProvider'], sess_options=sess_options)
# src = cv.imread("D:/images/messi_player.jpg")
src = cv.imread("D:/images/master.jpg")
image = cv.cvtColor(src, cv.COLOR_BGR2RGB)
blob = transform(image)
c, h, w = blob.shape
input_x = blob.view(1, c, h, w)

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input_x)}
ort_outs = ort_session.run(None, ort_inputs)
t1 = ort_outs[0]
t2 = ort_outs[1]
labels = np.argmax(np.squeeze(t1, 0), axis=0)
print(labels.dtype, labels.shape)
red_map = np.zeros_like(labels).astype(np.uint8)
green_map = np.zeros_like(labels).astype(np.uint8)
blue_map = np.zeros_like(labels).astype(np.uint8)
for label_num in range(0, len(label_color_map)):
    index = labels == label_num
    red_map[index] = np.array(label_color_map)[label_num, 0]
    green_map[index] = np.array(label_color_map)[label_num, 1]
    blue_map[index] = np.array(label_color_map)[label_num, 2]
segmentation_map = np.stack([blue_map, green_map, red_map], axis=2)
cv.addWeighted(src, 0.8, segmentation_map, 0.2, 0, src)
cv.imshow("deeplabv3", src)
cv.waitKey(0)
cv.destroyAllWindows()

运行结果如下:

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-12-21,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenCV学堂 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档