我使用Mozilla在自定义数据集上训练了一个Tacotron2模型。训练器输出一个pth文件和一个config.json文件。我很难将经过训练的模型加载到PyTorch中。
from torchaudio.models.tacotron2 import Tacotron2
tacotron2 =Tacotron2()
tacotron2.load_state_dict(torch.load('models/best_model.pth'))
RuntimeError:在为Tacotron2加载state_dict时出错:state_dict中缺少键:"embedding.weight“、”encoder.卷积s.0.0.权重“、”encoder.卷积s.0.0.偏向“、”encoder.卷积s.0.1.权重“、”encoder.卷积s.0.1.偏向“、"encoder.convolutions.0.1.running_mean”、"encoder.convolutions.0.1.running_var“、”encoder.卷积s.1.0.权重“,"encoder.convolutions.1.1.running_mean“、"encoder.convolutions.1.1.running_var”、“encoder.卷积s.2.0.权重”、“Encoder.卷积s.2.0.偏倚”、“encoder.卷积s.2.1偏倚”、“encoder.卷积s.2.1偏倚”、“encoder.卷积s.2.1偏倚”,"encoder.convolutions.2.1.running_mean“、"encoder.convolutions.2.1.running_var”、"encoder.lstm.weight_ih_l0“、"encoder.lstm.weight_hh_l0”、"encoder.lstm.bias_ih_l0“、"encoder.lstm.bias_hh_l0”、"encoder.lstm.weight_ih_l0_reverse“、"encoder.lstm.weight_hh_l0_reverse”、"encoder.lstm.bias_ih_l0_reverse““"decoder.attention_layer.memory_layer.weight",”、"decoder.attention_rnn.weight_ih“、"decoder.attention_rnn.weight_hh”、"decoder.attention_rnn.bias_ih“、"decoder.attention_rnn.bias_hh”、"decoder.attention_layer.query_layer.weight“、”decoder.attention_rnn.bias_ih“"decoder.attention_layer.v.weight","decoder.decoder_rnn.weight_ih","decoder.decoder_rnn.weight_hh","decoder.decoder_rnn.bias_ih","decoder.decoder_rnn.bias_hh","decoder.linear_projection.weight","decoder.linear_projection.bias","decoder.gate_layer.weight“、"decoder.gate_layer.bias”、“postnet.卷积s.0.0.权重”、“postnet.卷积s.0.0.偏向”、“postnet.卷积s.0.1.权重”、“postnet.卷积s.0.1.偏向”、"postnet.convolutions.0.1.running_mean“、"postnet.convolutions.0.1.running_var”、“postnet.卷积s.1.0.权重”、“postnet.卷积s.1.0偏见”,“postnet.卷积s.1.1.重量”、“postnet.卷积s.1.1.偏向”、"postnet.convolutions.1.1.running_mean“、"postnet.convolutions.1.1.running_var”、“postnet.卷积s.2.0.权重”、“postnet.卷积s.2.0.偏向”、“postnet.卷积s.2.1.权重”、“postnet.卷积s.2.1.偏向”、"postnet.convolutions.2.1.running_mean","postnet.convolutions.2.1.running_var“、”Postnet.卷积s.3.0.权重“、”postnet.卷积s.3.0.偏倚“、”postnet.卷积s.3.1重量“、”postnet.卷积s.3.1.偏向“、"postnet.convolutions.3.1.running_mean”、"postnet.convolutions.3.1.running_var“、”postnet.卷积s.4.0.权重“、”postnet.卷积s.4.0.偏向“,“postnet.卷积s.4.1权重”、“postnet.卷积s.4.1.偏向”、"postnet.convolutions.4.1.running_mean“、"postnet.convolutions.4.1.running_var”。State_dict中的意外键:“配置”、“模型”、“优化器”、“标量器”、“步骤”、“时代”、“日期”、“model_loss”。
发布于 2022-06-27 10:58:36
根据错误消息,load_state_dict()
命令所期望的显然是一个字典,其中的密钥被命名为"decoder.attention_rnn.bias_hh“等网络参数,即经过训练的参数以及识别它们的方法。然而,pth
检查点似乎是一个二进制python字典,包含了恢复培训所必需的所有内容(而不仅仅是使用模型)。我猜:
“配置”是construction
H 115“date”and
"model_loss“ditto。H 218f 219
也许试一试
checkpoint = torch.load('models/best_model.pth')
tacotron2.load_state_dict(checkpoint["model"])
看看会发生什么。如果它不起作用,检查嵌套字典checkpoint["model"]的键,然后四处浏览。
如果在培训期间传递了任何非默认参数,则在初始化加载时需要复制它们(提示:使用配置)。
https://stackoverflow.com/questions/72768404
复制相似问题