前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >torch.load()

torch.load()

作者头像
狼啸风云
修改2022-09-02 22:05:12
2.1K0
修改2022-09-02 22:05:12
举报
文章被收录于专栏:计算机视觉理论及其实现

torch.load(f, map_location=None, pickle_module=<module 'pickle' from '/opt/conda/lib/python3.6/pickle.py'>, **pickle_load_args)[source]

从文件中加载一个用torch.save()保存的对象。

load()使用Python的unpickling工具,但是专门处理存储,它是张量的基础。他们首先在CPU上并行化,然后移动到保存它们的设备。如果失败(例如,因为运行时系统没有某些设备),就会引发异常。但是,可以使用map_location参数动态地将存储重新映射到另一组设备。storage参数是存储的初始反序列化,驻留在CPU上。storage参数是存储的初始反序列化,驻留在CPU上。每个序列化存储都有一个与之关联的位置标记,它标识保存它的设备,这个标记是传递给map_location的第二个参数。内置的位置标签是“cpu”为cpu张量和“cuda:device_id”(例如:device_id)。“cuda:2”)表示cuda张量。map_location应该返回None或一个存储。如果map_location返回一个存储,它将被用作最终的反序列化对象,已经移动到正确的设备。否则,torch.load()将退回到默认行为,就好像没有指定map_location一样。如果map_location 是可以调用的,那么对于带有两个参数:存储和位置的序列化存储将被调用一次。如果map_location是一个torch.device对象或一个包含设备标签的字符串,它表示所有张量应该被加载的位置。否则,如果map_location是一个dict,它将用于将文件中出现的位置标记(键)重新映射为指定存储位置的位置标记(值)。用户扩展可以使用torch.serialize.register_package()注册他们自己的位置标签、标记和反序列化方法。

参数:

  • name 类似文件的对象(必须实现read(),:meth ' readline ',:meth ' tell '和:meth ' seek '),或者是包含文件的字符串。
  • map_location – 函数、torch.device或者字典指明如何重新映射存储位置。
  • pickle_module – 用于unpickling元数据和对象的模块(必须匹配用于序列化文件的pickle_module)
  • pickle_load_args – (仅适用于Python 3)传递给pickle_module.load()和pickle_module.Unpickler()的可选关键字参数,例如errors=…

警告:

load()隐式地使用pickle模块,这是不安全的。可以构造恶意pickle数据,在unpickle期间执行任意代码。永远不要加载可能来自不受信任的数据源或可能被篡改的数据。只加载你信任的数据。

注意:

当你在包含GPU张量的文件上调用torch.load()时,默认情况下这些张量会被加载到GPU。你可以调用torch.load(.., map_location='cpu'),然后load_state_dict()以避免在加载一个模型检查点时GPU内存激增。

注意:

默认情况下,我们将字节字符串解码为utf-8。这是为了避免一个常见的错误情况UnicodeDecodeError: 'ascii' codec can't decode byte 0x...在python3中加载由python2保存的文件时。如果这个默认是不正确的,你可以使用一个额外的编码关键字参数指定应该如何加载这些对象,例如,encoding='latin1'中的一个解码字符串使用latin1编码中的一个,和encoding='bytes'让他们作为字节数组可以解码后byte_array.decode (…)。

例:

代码语言:javascript
复制
>>> torch.load('tensors.pt')
# Load all tensors onto the CPU
>>> torch.load('tensors.pt', map_location=torch.device('cpu'))
# Load all tensors onto the CPU, using a function
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
# Load all tensors onto GPU 1
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
# Map tensors from GPU 1 to GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
# Load tensor from io.BytesIO object
>>> with open('tensor.pt', 'rb') as f:
        buffer = io.BytesIO(f.read())
>>> torch.load(buffer)
# Load a module with 'ascii' encoding for unpickling
>>> torch.load('module.pt', encoding='ascii')
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2020/06/09 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
文件存储
文件存储(Cloud File Storage,CFS)为您提供安全可靠、可扩展的共享文件存储服务。文件存储可与腾讯云服务器、容器服务、批量计算等服务搭配使用,为多个计算节点提供容量和性能可弹性扩展的高性能共享存储。腾讯云文件存储的管理界面简单、易使用,可实现对现有应用的无缝集成;按实际用量付费,为您节约成本,简化 IT 运维工作。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档