首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何使用map_location='cpu‘,原因是"RuntimeError:试图在CUDA设备上反序列化对象,但torch.cuda.is_available()是假的“

如何使用map_location='cpu‘,原因是"RuntimeError:试图在CUDA设备上反序列化对象,但torch.cuda.is_available()是假的“
EN

Stack Overflow用户
提问于 2020-05-15 18:43:11
回答 1查看 3.4K关注 0票数 1

我试图在塔可龙下载以下模型

代码语言:javascript
运行
复制
import torch
tacotron2 = torch.hub.load('nvidia/DeepLearningExamples:torchhub', 'nvidia_tacotron2')

我收到了:

代码语言:javascript
运行
复制
>>> import torch
>>> tacotron2 = torch.hub.load('nvidia/DeepLearningExamples:torchhub', 'nvidia_tacotron2')
Using cache found in .cache\torch\hub\nvidia_DeepLearningExamples_torchhub
...
  File "Anaconda3\envs\env3_pytorch\lib\site-packages\torch\serialization.py", line 79, in validate_cuda_device
    raise RuntimeError('Attempting to deserialize object on a CUDA '
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location='cpu' to map your storages to the CPU.

因此,我在map_location='cpu'中使用了下面的代码,但仍然得到了相同的错误。

代码语言:javascript
运行
复制
>>> tacotron2 = torch.hub.load('nvidia/DeepLearningExamples:torchhub', 'nvidia_tacotron2', map_location='cpu')

torch.version.cuda显示了一个版本,但是torch.cuda.is_available()是假的。

代码语言:javascript
运行
复制
>>> import torch
>>> torch.version.cuda
'9.0'
>>> torch.cuda.is_available()
False
  • 如何避免这个与map_location相关的错误,因为我准确地使用了要求使用的命令错误消息。
  • 我如何在加载模型时使用cuda?
EN

回答 1

Stack Overflow用户

发布于 2020-05-16 17:24:04

torch.hub.load不专门支持map_location,它只将额外的参数转发给模型的加载,因此它的实现取决于是否支持它。

在这种情况下,不支持加载,加载是在NVIDIA/DeepLearningExamples:torchhub - hubconf.py中实现的,并且不将任何map_location传递给当加载检查点时

这意味着您需要手动加载检查点并将其应用于模型。值得庆幸的是,可以通过设置pretrained=Falsetorch.hub.load_state_dict_from_url分别加载模型,而无需加载检查点就可以加载相同的配置,后者支持map_location。状态dict只需要一个小的更改,因为它是用DistributedDataParallel训练的,它封装了模块,因此模型中的每个layer都变成了module.layer。在状态判断中,需要删除module.前缀。

代码语言:javascript
运行
复制
tacotron2 = torch.hub.load('nvidia/DeepLearningExamples:torchhub', 'nvidia_tacotron2', pretrained=False)

checkpoint = torch.hub.load_state_dict_from_url('https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2pyt_fp32/versions/1/files/nvidia_tacotron2pyt_fp32_20190306.pth', map_location="cpu")

# Unwrap the DistributedDataParallel module
# module.layer -> layer
state_dict = {key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()}

# Apply the state dict to the model
tacotron2.load_state_dict(state_dict)
票数 6
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61826246

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档