前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >轻松学Pytorch之torchscript使用!

轻松学Pytorch之torchscript使用!

作者头像
OpenCV学堂
发布2022-05-12 21:41:09
发布2022-05-12 21:41:09
3K00
代码可运行
举报
运行总次数:0
代码可运行

点击上方蓝字关注我们

微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识

TorchScript介绍

TorchScript是PyTorch模型推理部署的中间表示,可以在高性能环境libtorch(C ++)中直接加载,实现模型推理,而无需Pytorch训练框架依赖。torch.jit是torchscript Python语言包支持,支持pytorch模型快速,高效,无缝对接到libtorch运行时,实现高效推理。它是Pytorch中除了训练部分之外,开发者最需要掌握的Pytorch框架开发技能之一。

trace使用

Torchscript使用分为两个部分分别是script跟trace,其中trace是跟踪执行步骤,记录模型调用推理时执行的每个步骤,代码演示如下:

代码语言:javascript
代码运行次数:0
运行
复制
class MyCell(torch.nn.Module):
      def __init__(self):
            super(MyCell, self).__init__()
            self.linear = torch.nn.Linear(4, 4)


      def forward(self, x, h):
            new_h = torch.tanh(self.linear(x) + h)
            return new_h, new_h


my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)
print(traced_cell.graph)

运行结果如下:

代码语言:javascript
代码运行次数:0
运行
复制
MyCell(
  original_name=MyCell
  (linear): Linear(original_name=Linear)
)

跟踪执行结果

代码语言:javascript
代码运行次数:0
运行
复制
graph(%self.1 : __torch__.MyCell,
      %input : Float(3:4, 4:1, requires_grad=0, device=cpu),
      %h : Float(3:4, 4:1, requires_grad=0, device=cpu)):
  %19 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
  %21 : Tensor = prim::CallMethod[name="forward"](%19, %input)
  %12 : int = prim::Constant[value=1]() # D:/python/pytorch_openvino_demo/ch5/faster_rcnn2onnx.py:112:0
  %13 : Float(3:4, 4:1, requires_grad=1, device=cpu) = aten::add(%21, %h, %12) # D:/python/pytorch_openvino_demo/ch5/faster_rcnn2onnx.py:112:0
  %14 : Float(3:4, 4:1, requires_grad=1, device=cpu) = aten::tanh(%13) # D:/python/pytorch_openvino_demo/ch5/faster_rcnn2onnx.py:112:0
  %15 : (Float(3:4, 4:1, requires_grad=1, device=cpu), Float(3:4, 4:1, requires_grad=1, device=cpu)) = prim::TupleConstruct(%14, %14)
  return (%15)

script部分使用

script是导出模型为中间IR格式文件,支持高性能libtorch C++部署,我们以torchvision中Mask-RCNN导出中间格式IR为例,代码演示如下:

代码语言:javascript
代码运行次数:0
运行
复制
import torch
import torchvision as tv

num_classes = 3
model = tv.models.detection.maskrcnn_resnet50_fpn(
    pretrained=False, progress=True,
    num_classes=num_classes,
    pretrained_backbone=True)
im = torch.zeros(1, 3, *(1333, 800)).to("cpu")
model.load_state_dict(torch.load("D:/gaobao_model.pth"))
model = model.to("cpu")
model.eval()
ts = torch.jit.script(model)
ts.save("gaobao.ts")

loaded_trace = torch.jit.load("gaobao.ts")
loaded_trace.eval()
with torch.no_grad():
    print(loaded_trace(list(im)))

最终得到torchscript文件,支持直接通过libtorch部署,其中通过torchscript C++部分加载的代码如下:

代码语言:javascript
代码运行次数:0
运行
复制
#include <torch/script.h> // One-stop header.
#include <iostream>
#include <memory>
int main(int argc, const char* argv[]) {
    if (argc != 2) {
      std::cerr << "usage: example-app <path-to-exported-script-module>\n";
      return -1;
    }

    // Deserialize the ScriptModule from a file using torch::jit::load().
    torch::jit::script::Module module = torch::jit::load(argv[1]);
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::randn({4, 8}));
    inputs.push_back(torch::randn({8, 5}));
    torch::Tensor output = module.forward(std::move(inputs)).toTensor();

    std::cout << output << std::endl;
}

上面代码来自官方测试程序,特别说明一下!

扫码查看OpenCV+OpenVIO+Pytorch系统化学习路线图

 推荐阅读 

CV全栈开发者说 - 从传统算法到深度学习怎么修炼

2022入坑深度学习,我选择Pytorch框架!

Pytorch轻松实现经典视觉任务

教程推荐 | Pytorch框架CV开发-从入门到实战

OpenCV4 C++学习 必备基础语法知识三

OpenCV4 C++学习 必备基础语法知识二

OpenCV4.5.4 人脸检测+五点landmark新功能测试

OpenCV4.5.4人脸识别详解与代码演示

OpenCV二值图象分析之Blob分析找圆

OpenCV4.5.x DNN + YOLOv5 C++推理

OpenCV4.5.4 直接支持YOLOv5 6.1版本模型推理

OpenVINO2021.4+YOLOX目标检测模型部署测试

比YOLOv5还厉害的YOLOX来了,官方支持OpenVINO推理

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-05-09,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenCV学堂 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档