专栏首页磐创AI技术团队的专栏【注意力机制】transformers序列化实践

【注意力机制】transformers序列化实践

来源 | Github

作者 | huggingface

编译 | VK

【导读】本节说明如何保存和重新加载微调模型(BERT,GPT,GPT-2和Transformer-XL)。

你需要保存三种文件类型才能重新加载经过微调的模型:

  • 模型本身应该是PyTorch序列化保存的模型(https://pytorch.org/docs/stable/notes/serialization.html#best-practices)
  • 模型的配置文件是保存为JSON文件
  • 词汇表(以及基于GPT和GPT-2合并的BPE的模型)。

这些文件的默认文件名如下:

  • 模型权重文件:pytorch_model.bin
  • 配置文件:config.json
  • 词汇文件:vocab.txt代表BERT和Transformer-XL,vocab.json代表GPT/GPT-2(BPE词汇),
  • 代表GPT/GPT-2(BPE词汇)额外的合并文件:merges.txt

如果使用这些默认文件名保存模型,则可以使用from_pretrained()方法重新加载模型和tokenizer。

这是保存模型,配置和配置文件的推荐方法。词汇到output_dir目录,然后重新加载模型和tokenizer:

from transformers import WEIGHTS_NAME, CONFIG_NAME

output_dir = "./models/"

# 步骤1:保存一个经过微调的模型、配置和词汇表

#如果我们有一个分布式模型,只保存封装的模型
#它包装在PyTorch DistributedDataParallel或DataParallel中
model_to_save = model.module if hasattr(model, 'module') else model
#如果使用预定义的名称保存,则可以使用`from_pretrained`加载
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(output_dir, CONFIG_NAME)

torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_dir)

# 步骤2: 重新加载保存的模型

#Bert模型示例
model = BertForQuestionAnswering.from_pretrained(output_dir)
tokenizer = BertTokenizer.from_pretrained(output_dir, do_lower_case=args.do_lower_case)  # Add specific options if needed
#GPT模型示例
model = OpenAIGPTDoubleHeadsModel.from_pretrained(output_dir)
tokenizer = OpenAIGPTTokenizer.from_pretrained(output_dir)

如果要为每种类型的文件使用特定路径,则可以使用另一种方法保存和重新加载模型:

output_model_file = "./models/my_own_model_file.bin"
output_config_file = "./models/my_own_config_file.bin"
output_vocab_file = "./models/my_own_vocab_file.bin"

# 步骤1:保存一个经过微调的模型、配置和词汇表

#如果我们有一个分布式模型,只保存封装的模型
#它包装在PyTorch DistributedDataParallel或DataParallel中
model_to_save = model.module if hasattr(model, 'module') else model

torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_vocab_file)

# 步骤2: 重新加载保存的模型

# 我们没有使用预定义权重名称、配置名称进行保存,无法使用`from_pretrained`进行加载。
# 下面是在这种情况下的操作方法:

#Bert模型示例
config = BertConfig.from_json_file(output_config_file)
model = BertForQuestionAnswering(config)
state_dict = torch.load(output_model_file)
model.load_state_dict(state_dict)
tokenizer = BertTokenizer(output_vocab_file, do_lower_case=args.do_lower_case)

#GPT模型示例
config = OpenAIGPTConfig.from_json_file(output_config_file)
model = OpenAIGPTDoubleHeadsModel(config)
state_dict = torch.load(output_model_file)
model.load_state_dict(state_dict)
tokenizer = OpenAIGPTTokenizer(output_vocab_file)

原文链接:https://huggingface.co/transformers/serialization.html#serialization-best-practices

- End -

本文分享自微信公众号 - 磐创AI(xunixs),作者:VK

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2020-04-01

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • transformer快速入门

    我们将通过一些简单的快速启动示例来完成这个快速启动之旅,看看如何实例化和使用这些类。其余的文件分为两部分:

    磐创AI
  • 构建对象检测模型

    我喜欢深度学习。坦率地说,这是一个有大量技术和框架可供倾注和学习的广阔领域。当我看到现实世界中的应用程序,如面部识别和板球跟踪等时,建立深度学习和计算机视觉模型...

    磐创AI
  • PyTorch专栏(五):迁移学习

    实际中,基本没有人会从零开始(随机初始化)训练一个完整的卷积网络,因为相对于网络,很难得到一个足够大的数据集[网络很深, 需要足够大数据集]。通常的做法是在一个...

    磐创AI
  • NSA(美国国安局)泄漏文件深度分析(PART 1)

    * 本文原创作者:tom_vodu,本文属FreeBuf原创奖励计划,未经许可禁止转载 一、前言 防火墙是保护内网机器不受网络攻击者侵害的第一道也是非常核心...

    FB客服
  • 白帽黑帽真的有差别么

    下面讲讲我对于他们的理解,众所周知,黑帽无疑是最赚钱的,单从以前接的一个简单的渗透单,虽然本质是违法,但是利润高的让人禁不住诱惑,区区一周就有好几万入账,可能这...

    网e渗透安全部
  • Java描述设计模式(20):命令模式

    知了一笑
  • 为了这个羞羞的机器学习项目,我差点成为“鉴黄师”

    开个玩笑,不过今天文摘菌确实要给大家介绍一个不太一样有点羞羞的机器学习项目,也就是嘿嘿嘿的时候,最加深感情的某种动作的分析,英文俗称“blowjob”或者“or...

    大数据文摘
  • 人生想要开挂,快来学习“画中画”!

    首先请允许我介绍下“画中画”:画中画(英文Picture-in-picture,缩写PiP)是指将一个电视节目(或其他画面)显示在整个画面上,同时将另一个或多个...

    腾讯IVWEB团队
  • 为了这个羞羞的机器学习项目,我差点成为“鉴黄师”

    开个玩笑,不过今天小编确实要给大家介绍一个不太一样有点羞羞的机器学习项目,也就是嘿嘿嘿的时候,最加深感情的某种动作的分析,英文俗称“blowjob”或者“ora...

    磐创AI
  • 为了这个羞羞的机器学习项目,我差点成为“鉴黄师”

    开个玩笑,不过今天确实要给大家介绍一个不太一样有点羞羞的机器学习项目,也就是嘿嘿嘿的时候,最加深感情的某种动作的分析,英文俗称“blowjob”或者“oral ...

    华章科技

扫码关注云+社区

领取腾讯云代金券