首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何下载HuggingFace模型‘Transformers.Traine.Trainer’?

如何下载HuggingFace模型‘Transformers.Traine.Trainer’?
EN

Stack Overflow用户
提问于 2021-10-21 14:48:38
回答 1查看 203关注 0票数 0

在第一个代码中,我使用第二个代码中的save_pretrained()函数上传了基于拥抱脸'transformers.trainer.Trainer‘的模型,我想下载这个上传的模型并使用它进行预测。我需要这一步的帮助-如何下载上传的模型,然后进行预测?

创建模型的步骤:

代码语言:javascript
运行
复制
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer
model = AutoModelForQuestionAnswering.from_pretrained('xlm-roberta-large)
trainer = Trainer(
model,
args,
train_dataset=tokenized_train_ds,
eval_dataset=tokenized_val_ds,
data_collator=data_collator,
tokenizer=tokenizer,)

#Arguments used above not mentioned here - model, args, tokenized_train_ds, tokenized_val_ds, data_collator, tokenizer
#Below step train the pre-trained model
trainer.train()

然后,我使用下面的命令上传了这个'trainer‘模型:

代码语言:javascript
运行
复制
trainer.save_model('./trainer_sm')

在不同的代码中,我现在想下载这个模型,并使用它进行预测,有人可以建议如何做到这一点吗?我尝试了下面的命令来上传它:

代码语言:javascript
运行
复制
model_sm=AutoModelForQuestionAnswering.from_pretrained("./trainer_sm")

并使用它通过下面这行代码进行预测:

代码语言:javascript
运行
复制
model_sm.predict(test_features)
AttributeError: 'XLMRobertaForQuestionAnswering' object has no attribute 'predict'

我也使用了'use_auth_token=True‘作为from_pretrained的参数,但也不起作用。

同样,type(培训者)是'transformers.trainer.Trainer‘,而type(model_sm)是transformers.models.xlm_roberta.modeling_xlm_roberta.XLMRobertaForQuestionAnswering

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-10-21 18:03:36

您保存的是训练者将要调整的模型,您应该知道预测、训练、评估等是transformers.trainer.Trainer object的实用程序,而不是transformers.models.xlm_roberta.modeling_xlm_roberta.XLMRobertaForQuestionAnswering。根据前面提到的,保持运行的最简单方法是创建训练器的另一个实例。

代码语言:javascript
运行
复制
model_sm=AutoModelForQuestionAnswering.from_pretrained("./trainer_sm")
reloaded_trainer = Trainer(
                    model = model_sm,
                    tokenizer = tokenizer,
                    # other arguments if you have changed the defaults 
                    )
reloaded_trainer.predict(test_dataset)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69664125

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档