首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >简单变压器模型的保存与加载

简单变压器模型的保存与加载
EN

Data Science用户
提问于 2021-04-06 10:49:03
回答 3查看 5.4K关注 0票数 1

我已经训练文本分类器使用simpleTranformer.ai,我正在努力保存和加载模型的码头集装箱。请让我知道如何保存培训的模型,然后顺利地加载到不同的环境。我正在使用这个库:https://simpletransformers.ai/来使用以下命令来训练文本模型

代码语言:javascript
运行
复制
 model = ClassificationModel('xlmroberta', 'xlm-roberta-base',use_cuda=cuda_available, num_labels=78, args={'learning_rate':1e-5, 'num_train_epochs': 1,'train_batch_size':256,'eval_batch_size':1048, 'n_gpu':4, 'reprocess_input_data': True, 
'overwrite_output_dir':True, 'overwrite_output_dir': True})

model.train_model(train_df)

我使用pytorch函数来保存经过训练的模型:

代码语言:javascript
运行
复制
torch.save(model, 'classifier')

但是当我试图从不同的虚拟机加载这个模型时,它显示了一些丢失的文件的错误。因此,我正在寻找保存和加载simpleTransformer模型的最佳替代方案。

EN

回答 3

Data Science用户

发布于 2023-02-24 00:31:26

只需将每个区段保存到目标路径即可。

代码语言:javascript
运行
复制
model.model.save_pretrained('model1')
model.tokenizer.save_pretrained('model1')
model.config.save_pretrained('model1/')

然后你就可以像这样装载

代码语言:javascript
运行
复制
model = ClassificationModel('roberta','model1')
票数 1
EN

Data Science用户

发布于 2021-06-25 11:09:33

在PyTorch中,torch.nn.Module模型的可学习参数(即权重和偏差)包含在模型的参数中(用model.parameters()访问)。state_dict只是一个Python字典对象,它将每个层映射到其参数张量。所以,你可以用不同的方式保存一个模型,

  1. 第一种方法是像存储torch.save(model.state_dict(), PATH)那样存储模型,并将相同的模型加载到不同的机器或不同的位置,然后首先创建该模型的实例,然后将该模型分配给像这样的模型参数。model = TheModelClass(*args,**kwargs)

在创建类的实例之后,然后像这样加载模型。

代码语言:javascript
运行
复制
model.load_state_dict(torch.load(PATH))

这是一种保存和加载模型的推荐方法。如果要保存模型,那么在将模型加载到另一台机器之前,首先生成该模型的实例,然后可以运行该模型。

  1. 另一种方法是使用泡菜pickle.dump(model.state_dict(), open(filename, 'wb'))保存model.state_dict(),然后通过pickle.load(open(filename, 'rb'))加载模型,但是,这不是保存和加载模型的标准方法--最推荐的方法是第一种。

此外,如果您的错误无法解决,那么您可以在注释部分询问。

票数 0
EN

Data Science用户

发布于 2021-06-25 08:14:39

您的模型存储在名称为pytorch_model.bin的输出目录中。去看看那个。

票数 -1
EN
页面原文内容由Data Science提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://datascience.stackexchange.com/questions/92630

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档