前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >torch.jit.trace与torch.jit.script的区别

torch.jit.trace与torch.jit.script的区别

作者头像
三更两点
发布2022-08-07 12:37:03
5.6K0
发布2022-08-07 12:37:03
举报

文章目录

术语

  1. Tochscript:狭义概念导出图形的表示/格式;广义概念为导出模型的方法;
  2. (Torch)Scriptable:可以用torch.jit.script导出模型
  3. Traceable:可以用torch.jit.trace导出模型

什么时候用torch.jit.trace(结论:首选)

  1. torch.jit.trace一种导出方法;它运行具有某些张量输入的模型,并“跟踪/记录”所有执行到图形中的操作。
  2. 在模型内部的数据类型只有张量,且没有for if while等控制流,选择torch.jit.trace
  3. 支持python的预处理和动态行为;
  4. torch.jit.trace编译function并返回一个可执行文件,该可执行文件将使用即时编译进行优化。
  5. 大项目优先选择torch.jit.trace,特别是是图像检测和分割的算法;

优点

  1. 不会损害代码质量;
  2. 2.它的主要限制可以通过与torch.jit.script混合来解决

什么时候用torch.jit.script(结论:必要时)

  1. 定义:一种模型导出方法,其实编译python的模型源码,得到可执行的图;
  2. 在模型内部的数据类型只有张量,且没有for if while等控制流,也可以选择torch.jit.script
  3. 不支持python的预处理和动态行为;
  4. 必须做一下类型标注;
  5. torch.jit.script在编译function或 nn.Module 脚本将检查源代码,使用 TorchScript 编译器将其编译为 TorchScript 代码。

错误举例

代码语言:javascript
复制
import torch
from torch import nn


class MyModule(nn.Module):
    def __init__(self, return_b=False):
        super().__init__()
        self.return_b = return_b

    def forward(self, x):
        a = x + 2
        if self.return_b:  #属于静态控制
            b = x + 3
            return a, b
        return a


model = MyModule(return_b=True)

# Will work  成功
traced = torch.jit.trace(model, (torch.randn(10, ), ))

# Will fail 失败
scripted = torch.jit.script(model)
  • 总结:控制流是静态的,torch.jit.trace将正常工作

动态控制

  1. if x[0] == 4: x += 1 is a dynamic control flow.
代码语言:javascript
复制
model: nn.Sequential = ...
for m in model:  # 动态控制
  x = m(x) 

输入和输出有丰富类型的模型需要格外注意

代码语言:javascript
复制
outputs = model(inputs)   # inputs/outputs are rich structure
# torch.jit.trace(model, inputs)  # FAIL! unsupported format
adapter = TracingAdapter(model, inputs)
traced = torch.jit.trace(adapter, adapter.flattened_inputs)  # Can now trace the model

# Traced model can only produce flattened outputs (tuple of tensors):
flattened_outputs = traced(*adapter.flattened_inputs)
# Adapter knows how to convert it back to the rich structure (new_outputs == outputs):
new_outputs = adapter.outputs_schema(flattened_outputs)

QA

    1. JIT要求python的代码要是低级的;详情 因为更多动态高级的python语法,jit不支持.具体哪些支持哪些没支持官方也没有详细的列表; JIT should not force users to write ugly code #48108
    1. 错误示例:动态控制流:对于动态控制流torch.jit.trace只会编译一个分支,在其他分支处理的时候会报错;
代码语言:javascript
复制
def f(x):
    return torch.sqrt(x) if x.sum() > 0 else torch.square(x)
m = torch.jit.trace(f, torch.tensor(3))
print(m.code) # 可以打印出trace的情况
#--------------------------------------------
def f(x: Tensor) -> Tensor:
  return torch.sqrt(x)
    1. 错误示例:将变量视为常量
代码语言:javascript
复制
import torch

a, b = torch.rand(1), torch.rand(2)
print(a,b)

def f1(x): return torch.arange(x.shape[0])
def f2(x): return torch.arange(len(x))
result = torch.jit.trace(f1, a)(b)
print(result)

result =torch.jit.trace(f2, a)(b) # TracerWarning
print(result) #

print(torch.jit.trace(f1, a).code, torch.jit.trace(f2, a).code)
cuX01R
cuX01R
  • 错误示例:获取设备

解决错误的方法

    1. 严格消除警告信息,才C++运行的时候会报错
    1. 局部单元测试
    • 单元测试一样要做在导出模型后,这样避免在应用模型的时候(C++运行)出错;
代码语言:javascript
复制
assert allclose(torch.jit.trace(model, input1)(input2), model(input2))
    1. 避免非必要的动态控制,例如:
代码语言:javascript
复制
if x.numel() > 0:
  output = self.layers(x)
else:
  output = torch.zeros((0, C, H, W))  # Create empty outputs
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2022-07-20,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 文章目录
  • 术语
  • 什么时候用torch.jit.trace(结论:首选)
    • 优点
    • 什么时候用torch.jit.script(结论:必要时)
    • 错误举例
      • 动态控制
        • 输入和输出有丰富类型的模型需要格外注意
        • QA
        • 解决错误的方法
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档