首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >仅用pth文件加载pth Tacotron2模型的问题

仅用pth文件加载pth Tacotron2模型的问题
EN

Stack Overflow用户
提问于 2022-06-27 07:28:57
回答 1查看 115关注 0票数 1

我使用Mozilla在自定义数据集上训练了一个Tacotron2模型。训练器输出一个pth文件和一个config.json文件。我很难将经过训练的模型加载到PyTorch中。

代码语言:javascript
运行
复制
from torchaudio.models.tacotron2 import Tacotron2
tacotron2 =Tacotron2()
tacotron2.load_state_dict(torch.load('models/best_model.pth'))

RuntimeError:在为Tacotron2加载state_dict时出错:state_dict中缺少键:"embedding.weight“、”encoder.卷积s.0.0.权重“、”encoder.卷积s.0.0.偏向“、”encoder.卷积s.0.1.权重“、”encoder.卷积s.0.1.偏向“、"encoder.convolutions.0.1.running_mean”、"encoder.convolutions.0.1.running_var“、”encoder.卷积s.1.0.权重“,"encoder.convolutions.1.1.running_mean“、"encoder.convolutions.1.1.running_var”、“encoder.卷积s.2.0.权重”、“Encoder.卷积s.2.0.偏倚”、“encoder.卷积s.2.1偏倚”、“encoder.卷积s.2.1偏倚”、“encoder.卷积s.2.1偏倚”,"encoder.convolutions.2.1.running_mean“、"encoder.convolutions.2.1.running_var”、"encoder.lstm.weight_ih_l0“、"encoder.lstm.weight_hh_l0”、"encoder.lstm.bias_ih_l0“、"encoder.lstm.bias_hh_l0”、"encoder.lstm.weight_ih_l0_reverse“、"encoder.lstm.weight_hh_l0_reverse”、"encoder.lstm.bias_ih_l0_reverse““"decoder.attention_layer.memory_layer.weight",”、"decoder.attention_rnn.weight_ih“、"decoder.attention_rnn.weight_hh”、"decoder.attention_rnn.bias_ih“、"decoder.attention_rnn.bias_hh”、"decoder.attention_layer.query_layer.weight“、”decoder.attention_rnn.bias_ih“"decoder.attention_layer.v.weight","decoder.decoder_rnn.weight_ih","decoder.decoder_rnn.weight_hh","decoder.decoder_rnn.bias_ih","decoder.decoder_rnn.bias_hh","decoder.linear_projection.weight","decoder.linear_projection.bias","decoder.gate_layer.weight“、"decoder.gate_layer.bias”、“postnet.卷积s.0.0.权重”、“postnet.卷积s.0.0.偏向”、“postnet.卷积s.0.1.权重”、“postnet.卷积s.0.1.偏向”、"postnet.convolutions.0.1.running_mean“、"postnet.convolutions.0.1.running_var”、“postnet.卷积s.1.0.权重”、“postnet.卷积s.1.0偏见”,“postnet.卷积s.1.1.重量”、“postnet.卷积s.1.1.偏向”、"postnet.convolutions.1.1.running_mean“、"postnet.convolutions.1.1.running_var”、“postnet.卷积s.2.0.权重”、“postnet.卷积s.2.0.偏向”、“postnet.卷积s.2.1.权重”、“postnet.卷积s.2.1.偏向”、"postnet.convolutions.2.1.running_mean","postnet.convolutions.2.1.running_var“、”Postnet.卷积s.3.0.权重“、”postnet.卷积s.3.0.偏倚“、”postnet.卷积s.3.1重量“、”postnet.卷积s.3.1.偏向“、"postnet.convolutions.3.1.running_mean”、"postnet.convolutions.3.1.running_var“、”postnet.卷积s.4.0.权重“、”postnet.卷积s.4.0.偏向“,“postnet.卷积s.4.1权重”、“postnet.卷积s.4.1.偏向”、"postnet.convolutions.4.1.running_mean“、"postnet.convolutions.4.1.running_var”。State_dict中的意外键:“配置”、“模型”、“优化器”、“标量器”、“步骤”、“时代”、“日期”、“model_loss”。

EN

回答 1

Stack Overflow用户

发布于 2022-06-27 10:58:36

根据错误消息,load_state_dict()命令所期望的显然是一个字典,其中的密钥被命名为"decoder.attention_rnn.bias_hh“等网络参数,即经过训练的参数以及识别它们的方法。然而,pth检查点似乎是一个二进制python字典,包含了恢复培训所必需的所有内容(而不仅仅是使用模型)。我猜:

“配置”是construction

  • "model“中传递给模型的参数--经过训练的weights

  • "optimizer”状态优化器状态

  • “scaler”

  • 步骤“培训步骤,

  • ”纪元“训练期”,

H 115“date”and

  • "model_loss“ditto。H 218f 219

也许试一试

代码语言:javascript
运行
复制
checkpoint = torch.load('models/best_model.pth')
tacotron2.load_state_dict(checkpoint["model"])

看看会发生什么。如果它不起作用,检查嵌套字典checkpoint["model"]的键,然后四处浏览。

如果在培训期间传递了任何非默认参数,则在初始化加载时需要复制它们(提示:使用配置)。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72768404

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档