前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >让大家久等了,BERT推理加速终于开源了

让大家久等了,BERT推理加速终于开源了

作者头像
godweiyang
发布2021-08-12 14:25:01
8870
发布2021-08-12 14:25:01
举报
文章被收录于专栏:算法码上来

作者 | 韦阳

出品 | 公众号:算法码上来(ID:GodNLP)

- BEGIN -

前几个月一直有不少小伙伴问我要「LightSeq的BERT推理加速代码」,当时内部已经使用了,但是一直没空整理开源。

现在代码终于整理好了,写了一个简单的样例,大家有需要的可以使用起来了。

实现原理

这里我直接使用预训练好的BERT模型,用户只需要输入一个带有[MASK]标记的句子,就可以自动预测出完整的句子。

例如我输入“巴黎是[MASK]国的首都”,那么模型就会输出“巴黎是法国的首都。”。

LightSeq已经「完美支持了BERT模型的快速推理」,代码近期已经开源:https://github.com/bytedance/lightseq

BERT推理使用样例可以参考examples/inference/python目录下的ls_bert.py文件。我们用LightSeq来加速BERT推理试试。

首先需要安装LightSeq和Hugging Face:

代码语言:javascript
复制
pip install lightseq transformers

然后需要将Hugging Face的BERT模型导出为LightSeq支持的HDF5模型格式,运行examples/inference/python目录下的hf_bert_export.py文件即可,运行前将代码的第167-168两行修改为下面这样,指定使用中文版本的BERT预训练模型。

代码语言:javascript
复制
output_lightseq_model_name = "lightseq-bert-base-chinese"
input_huggingface_bert_model = "bert-base-chinese"

然后就会在运行目录下生成一个lightseq-bert-base-chinese.hdf5模型文件,导出就成功啦。

最后使用LightSeq进行推理即可:

代码语言:javascript
复制
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import lightseq.inference as lsi

tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
hf_model = AutoModelForMaskedLM.from_pretrained("bert-base-chinese")
hf_model.to("cuda:0")
ls_model = lsi.Bert("lightseq-bert-base-chinese.hdf5", 128)

while True:
    raw_text = input("请输入中文句子,要预测的字符用#代替:\n> ")
    input_text = raw_text.replace("#", "[MASK]")
    inputs = tokenizer(input_text, return_tensors="pt")
    input_ids = inputs["input_ids"]
    mask = inputs["attention_mask"]

    outputs = ls_model.infer(input_ids, mask)
    logits = hf_model.cls(torch.Tensor(outputs).to(dtype=torch.float, device="cuda:0"))
    output_ids = logits.argmax(axis=2)
    res_text = tokenizer.batch_decode(output_ids)

    res_text = res_text[0][1:-1].replace(" ", "")
    output_text = list(raw_text)
    for i in range(len(raw_text)):
        if raw_text[i] == "#":
            output_text[i] = res_text[i]
    print("> " + "".join(output_text))

效果演示

给大家看看效果,运行我写好的代码,我们来看看会输出什么结果:

代码语言:javascript
复制
请输入中文句子,要预测的字符用#代替:
> 巴黎是#国的首都。
> 巴黎是法国的首都。

代码地址

https://github.com/bytedance/lightseq

就在上周,首位外部贡献者出现了,修复了LightSeq的词嵌入表示的bug。

在这里我们非常欢迎感兴趣的同学来贡献自己的代码,包括但不局限于:修复bug、提供训练和推理样例、支持更多模型结构。

- END -

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-08-11,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 算法码上来 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 实现原理
  • 效果演示
  • 代码地址
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档