前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >C++中使用pytorch保存的tensor

C++中使用pytorch保存的tensor

作者头像
王云峰
发布2023-10-21 16:27:50
4500
发布2023-10-21 16:27:50
举报
文章被收录于专栏:Yunfeng's Simple Blog

概述

最近在学习Libtorch——即Pytorch的C++版本,需要使用 Pytorch 导出的 tensor 以便对模型进行 debug。下面是转换代码,总体原理是将 tensor 转换为二进制数据,再在 C++ 里面读入。

下面是 Pytorch 中的导出 tensor 示例:

代码语言:javascript
复制
import io

import torch


def save_tensor(device):
    my_tensor = torch.rand(3, 3).to(device);
    print("[python] my_tensor: ", my_tensor)
    f = io.BytesIO()
    torch.save(my_tensor, f, _use_new_zipfile_serialization=True)
    with open('my_tensor_%s.pt' % device, "wb") as out_f:
        # Copy the BytesIO stream to the output file
        out_f.write(f.getbuffer())


if __name__ == '__main__':
    save_tensor('cpu')

这里以导出 cpu tensor 为例,cuda tensor 也是同理。

在 C++ 中的调用示例如下:

代码语言:javascript
复制
#include <iostream>
#include <torch/torch.h>

std::vector<char> get_the_bytes(std::string filename) {
    std::ifstream input(filename, std::ios::binary);
    std::vector<char> bytes(
        (std::istreambuf_iterator<char>(input)),
        (std::istreambuf_iterator<char>()));

    input.close();
    return bytes;
}

int main()
{
    std::vector<char> f = get_the_bytes("my_tensor_cpu.pt");
    torch::IValue x = torch::pickle_load(f);
    torch::Tensor my_tensor = x.toTensor();
    std::cout << "[cpp] my_tensor: " << my_tensor << std::endl;

    return 0;
}

注意事项:

  1. torch的Python和C++版本需要保持一致,否则转换可能不成功.

题外话

最近在学习Libtorch——即Pytorch的C++版本,发现使用起来异常的丝滑,写C++有了Python的体验,妙不可言。 后面会更新一些关于libtorch使用的文章,敬请关注。

参考

  1. https://discuss.pytorch.org/t/how-to-load-python-tensor-in-c/88813
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2021-03-212,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 概述
  • 题外话
  • 参考
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档