前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【pytorch】onnx

【pytorch】onnx

作者头像
JNingWei
发布2021-12-06 21:15:23
7440
发布2021-12-06 21:15:23
举报
文章被收录于专栏:JNing的专栏

t7 / pth -> onnx

pytorch任意形式的model(.t7、.pth等等)转.onnx全都可以采用固定格式。

完整实现:

代码语言:javascript
复制
def pth2onnx(self, simplify_onnx_sw=True):
    import torch
    os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

    model = torch.nn.DataParallel(self.model)
    _state_dict = torch.load(pth_path, map_location=torch.device('cpu'))
    model.load_state_dict(_state_dict, strict=True)
    model.eval()
    torch.onnx.export(model.module,
                      torch.randn(batch_size, *C.input_shape),
                      pure_onnx_path,
                      input_names=["input"],
                      output_names=["output"]
                      )

    if simplify_onnx_sw:
        os.system('python -m onnxsim {} {}'.format(pure_onnx_path, simplified_onnx_path))
        print('\n Simplified onnx has been save to {}\n'.format(simplified_onnx_path))
        os.remove(pure_onnx_path)
    else:
        print('\n Pure onnx has been save to {}\n'.format(pure_onnx_path))

实验举例:

代码语言:javascript
复制
model_dir = './'
pth_path = model_dir + 'A.pth'
onnx_path = model_dir + 'A.onnx'
batch_size = 1
input_shape = (3, 112, 112)

cfg = Config()
cfg.load_from_file(args.model_cfg_file)

model = PFLD_SE3_eval(cfg.model_conf.layer_cfg, cfg.model_conf.num_points)

model.load(pth_path)
model.eval()
torch.onnx.export(model,
                  torch.randn(batch_size, *input_shape),
                  onnx_path,
                  input_names=["input"],
                  output_names=["output_0", "output_1"],
                  )

print('\n\n onnx has been save to {}\n\n'.format(onnx_path))
如在mac下执行,还需要加上这行环境配置: 
os.environ['KMP_DUPLICATE_LIB_OK']='True'

可能的报错:

代码语言:javascript
复制
ImportError: cannot import name 'get_all_providers' from 'onnxruntime.capi._pybind_state' 

mac下的通用解决方法:

代码语言:javascript
复制
brew install libomp

如果还是报相同错误,则可能是版本问题。换版本即可。例如我是执行:

代码语言:javascript
复制
pip install onnxruntime==1.2.0
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2021/09/22 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • t7 / pth -> onnx
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档