前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >将Pytorch模型移植到C++详细教程(附代码演练)

将Pytorch模型移植到C++详细教程(附代码演练)

作者头像
磐创AI
发布2023-08-29 08:19:50
1.6K0
发布2023-08-29 08:19:50
举报
文章被收录于专栏:磐创AI技术团队的专栏

说明

在本文中,我们将看到如何将Pytorch模型移植到C++中。Pytorch通常用于研究和制作新模型以及系统的原型。该框架很灵活,因此易于使用。主要的问题是我们如何将Pytorch模型移植到更适合的格式C++中,以便在生产中使用。

我们将研究不同的管道,如何将PyTrac模型移植到C++中,并使用更合适的格式应用到生产中。

1) TorchScript脚本

2) 开放式神经网络交换

3) TFLite(Tensorflow Lite)

TorchScript脚本

TorchScript是PyTorch模型(nn.Module的子类)的中间表示,可以在高性能环境(例如C ++)中运行。它有助于创建可序列化和可优化的模型。在Python中训练这些模型之后,它们可以在Python或C++中独立运行。因此,可以使用Python轻松地在PyTorch中训练模型,然后通过torchscript将模型导出到无法使用Python的生产环境中。它基本上提供了一个工具来捕获模型的定义。

跟踪模块:

代码语言:javascript
复制
class DummyCell(torch.nn.Module):    def __init__(self):        super(DummyCell, self).__init__()        self.linear = torch.nn.Linear(4, 4)    def forward(self, x):        out = self.linear(x)        return out
dummy_cell = DummyCell()x =  torch.rand(2, 4)traced_cell = torch.jit.trace(dummy_cell, (x))
# Print Traced Graphprint(traced_cell.graph)
# Print Traced Codeprint(traced_cell.code)

在这里,torchscript调用了模块,将执行的操作记录到称为图的中间表示中。traced_cell.graph提供了一个非常低级的表示,并且图形中的大部分信息最终对用户没有用处。traced_cell.code 提供了更多的python语法解释代码。

上述代码的输出(traced_cell.graph和traced_cell.code) :

代码语言:javascript
复制
graph(%self.1 : __torch__.DummyCell,      %input : Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu)):  %16 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)  %18 : Tensor = prim::CallMethod[name="forward"](%16, %input)  return (%18)def forward(self,    input: Tensor) -> Tensor:  return (self.linear).forward(input, )
TorchScript的优点

1) TorchScript代码可以在自己的解释器中调用。所保存的图形也可以在C++中加载用于生产。

2) TorchScript为我们提供了一种表示,在这种表示中,我们可以对代码进行编译器优化,以提供更高效的执行。

ONNX(开放式神经网络交换)

ONNX是一种开放格式,用于表示机器学习模型。ONNX定义了一组通用的操作符、机器学习和深度学习模型的构建块以及一种通用的文件格式,使AI开发人员能够将模型与各种框架、工具、运行时和编译器一起使用。它定义了一个可扩展的计算图模型,以及内置操作符和标准数据类型的定义。

可以使用以下代码将上述DummyCell模型导出到onnx:

代码语言:javascript
复制
torch.onnx.export(dummy_cell, x, "dummy_model.onnx", export_params=True, verbose=True)

输出:

代码语言:javascript
复制
graph(%input : Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu),      %linear.weight : Float(4, 4, strides=[4, 1], requires_grad=1, device=cpu),      %linear.bias : Float(4, strides=[1], requires_grad=1, device=cpu)):  %3 : Float(2, 4, strides=[4, 1], requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1](%input, %linear.weight, %linear.bias)  return (%3)

它将模型保存到文件名“dummy_model.onnx“中,可以使用python模块onnx加载该模型。为了在python中进行推理,可以使用ONNX运行时。ONNX运行时是一个针对ONNX模型的以性能为中心的引擎,它可以跨多个平台和硬件高效地进行推断。查看此处了解有关性能的更多详细信息。

https://cloudblogs.microsoft.com/opensource/2019/05/22/onnx-runtime-machine-learning-inferencing-0-4-release/

C++中的推理

为了从C++中执行ONNX模型,首先,我们必须使用tract库在Rust中编写推理代码。现在,我们有了用于推断ONNX模型的rust库。我们现在可以使用cbindgen将rust库导出为公共C头文件。

tract:https://github.com/sonos/tract

cbindgen:https://github.com/eqrion/cbindgen

现在,此头文件以及从Rust生成的共享库或静态库可以包含在C ++中以推断ONNX模型。在从rust生成共享库的同时,我们还可以根据不同的硬件提供许多优化标志。Rust也可以轻松实现针对不同硬件类型的交叉编译。

Tensorflow Lite

Tensorflow Lite是一个用于设备上推理的开源深度学习框架。它是一套帮助开发人员在移动、嵌入式和物联网设备上运行Tensorflow模型的工具。它使在设备上的机器学习推理具有低延迟和小二进制大小。它有两个主要组成部分:

1) Tensorflow Lite解释器:它在许多不同的硬件类型上运行特别优化的模型,包括移动电话、嵌入式Linux设备和微控制器。

2) Tensorflow Lite转换器:它将Tensorflow模型转换为一种有效的形式,供解释器使用。

将PyTorch模型转换为TensorFlow lite的主管道如下:

1) 构建PyTorch模型

2) 以ONNX格式导模型

3) 将ONNX模型转换为Tensorflow(使用ONNX tf)

在这里,我们可以使用以下命令将ONNX模型转换为TensorFlow protobuf模型:

代码语言:javascript
复制
!onnx-tf convert -i "dummy_model.onnx" -o  'dummy_model_tensorflow'

4) 将Tensorflow模型转换为Tensorflow Lite(tflite)

TFLITE模型(Tensorflow Lite模型)现在可以在C++中使用。这里请参考如何在C++中对TFLITE模型进行推理。

https://www.tensorflow.org/lite/guide/inference#load_and_run_a_model_in_c

尾注

我希望你觉得这篇文章有用。我们试图简单地解释一下,我们可以用不同的方式将PyTorch训练过的模型部署到生产中。

参考文献

1)TorchScript简介:https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html

2)在C++ 中加载TorchScript模型:https://pytorch.org/tutorials/advanced/cpp_export.html

3)将Pytorch模型导出到ONNX:https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html

4)Rust中的Tract神经网络推理工具包:https://github.com/sonos/tract

5)在C++中的TfLite模型上运行推理:https://www.tensorflow.org/lite/guide/inference#load_and_run_a_model_in_c

6)Colab - 在Android设备上进行Pytorch训练的模型:https://colab.research.google.com/drive/1MwFVErmqU9Z6cTDWLoTvLgrAEBRZUEsA

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2023-06-06,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 磐创AI 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 说明
  • TorchScript脚本
    • TorchScript的优点
    • ONNX(开放式神经网络交换)
    • C++中的推理
    • Tensorflow Lite
    • 尾注
    • 参考文献
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档