前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >只用两行代码,我让Transformer推理加速了50倍

只用两行代码,我让Transformer推理加速了50倍

作者头像
godweiyang
发布2021-04-23 10:58:06
4K0
发布2021-04-23 10:58:06
举报
文章被收录于专栏:算法码上来

最近有学妹问我,我训了一个Transformer模型,但是预测好慢啊,有啥解决方案吗?

我心想,你又想好,又想快,咋不上天?呢?

于是我跟她说,你可以试试lightseq啊,跟闪电⚡️一样快,用了你就可以上天了。

她一脸懵比,lightseq是啥玩意儿啊?咋就能让我的模型起飞?️了呢?

我跟她说,你不需要知道太多细节,你只需要知道它是一个Transformer系列模型推理加速库就行了。

她还是一脸疑惑,那用起来能有huggingface方便吗?你看人家就两行代码。

我不屑一笑,就这?lightseq也只要两行代码就够了!

为了方便,我用了一个bart模型预测句子中mask单词的例子来给她吹了一波。

不懂什么是bart?建议先去看看huggingface的文档: https://huggingface.co/transformers/model_doc/bart.html

huggingface bart

我们平时想用huggingface的bart来预测句子中的mask单词,大体上都会像下面这样写代码:

代码语言:javascript
复制
from transformers import BartTokenizer, BartForConditionalGeneration

tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

sentences = ["I love that girl, but <mask> does not <mask> me."]
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
generated_ids = model.generate(inputs["input_ids"], max_length=50)
res = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(res)

当然运行前要先安装一下transformers包:

代码语言:javascript
复制
pip3 install transformers

最后会输出句子“I love that girl, but she does not love me.”,句子中的两个“mask”被预测成了“she”和“love”。

看起来预测的很nice,但是预测的也太慢了,这要是有一堆句子要去预测,不得等到?年?月?

接下来我们来看看lightseq是怎么加速预测的。

lightseq bart

代码我都放在下面地址了,只要两分钟就能跑出结果了: https://github.com/godweiyang/lightseq/tree/python_example/example/python

运行前要先安装一下lightseq包:

代码语言:javascript
复制
pip3 install lightseq

首先lightseq只能接收Protocol Buffer协议定义的模型文件,如果你不知道这是啥也没关系,因为我们帮你写好了模型转换的脚本,就是hf_bart_export.py,它会将huggingface预训练的bart模型转换为transformer_pb2.py定义好的Protocol Buffer格式。

所以直接运行python3 hf_bart_export.py就行了,这里我们用的是bart-base模型。

运行完了会发现执行目录下多出一个lightseq_bart_base.pb文件,这就是转换后的模型文件。

最后直接跟huggingface一样,两行代码就能搞定啦:

代码语言:javascript
复制
import lightseq
from transformers import BartTokenizer

tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = lightseq.Transformer("lightseq_bart_base.pb", 128)

sentences = ["I love that girl, but <mask> does not <mask> me."]
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
generated_ids = model.infer(inputs["input_ids"])
generated_ids = [ids[0] for ids in generated_ids[0]]
res = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(res)

看得出来仅仅替换了模型定义和模型推理那两行代码而已,是不是非常简单快速?

这时候她又问了,那我换一个模型,比如bert,要怎么导出pb模型呢?

也很简单,只需要为bert也单独写一个hf_bert_export.py就行了。不过目前还在开发中,之后会慢慢完善常见的一些模型的。

速度到底怎么样?

我写好了一个例子,就在ls_bart.py里,直接运行就行了,当然你也可以加上--user_input参数来手动输入句子。

输入的句子是:

代码语言:javascript
复制
I love that girl, but <mask> does not <mask> me.
She is so <mask> that I can not help glance at <mask>.
Nothing's gonna <mask> my love for you.
Drop everything now. Meet me in the pouring <mask>. Kiss me on the sidewalk.

运行结果如下:

代码语言:javascript
复制
=========================lightseq=========================
lightseq generating...
lightseq time: 0.034502994269132614s
lightseq results:
I love that girl, but she does not love me.
She is so beautiful that I can not help glance at her.
Nothing's gonna change my love for you.
Drop everything now. Meet me in the pouring rain. Kiss me on the sidewalk.
=========================huggingface=========================
huggingface generating...
huggingface time: 1.6297104470431805s
huggingface results:
I love that girl, but she does not love me.
She is so beautiful that I can not help glance at her.
Nothing's gonna change my love for you.
Drop everything now. Meet me in the pouring rain. Kiss me on the sidewalk.

可以看出预测的是真的??,最后两句歌词都预测的很完美,能看出是啥歌吗?

再看预测时间,lightseq是huggingface的47倍左右,真是一个天上一个地下啊。

总结

总结一下,想要使用lightseq加速你的模型,只需要两步就行了:

  • 将你的模型转换为pb格式的模型。(lightseq为你写好了转换脚本,不断更新中)
  • 调用lightseq.Transformermodel.infer进行快速推理。

学妹赶紧打住了我,好了好了,我知道很??了。还给你装起来了,我这就去用。

但是源码哪里有?我想学一学。

我又甩给她一串地址: https://github.com/bytedance/lightseq

速度超快!字节跳动开源序列推理引擎LightSeq

好好看,好好学,都是CUDA写的,要是看得迷糊,建议先去看看我之前的入门教程嗷: 熬了几个通宵,我写了份CUDA新手入门代码

从此,世上又多了一位快如⚡️的?。

- END -

我是godweiyang,华东师范大学计算机系本硕专业第一,字节跳动AI Lab NLP算法工程师,秋招斩获上海三家互联网大厂ssp offer,主要研究方向为机器翻译、句法分析、模型压缩与加速。最大特点就是脾气好、有耐心,有任何问题都可以随时咨询我,不管是技术上的还是生活上的。

记得一键③连,今天的你格外的可爱?

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-04-13,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 算法码上来 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • huggingface bart
  • lightseq bart
  • 速度到底怎么样?
  • 总结
相关产品与服务
机器翻译
机器翻译(Tencent Machine Translation,TMT)结合了神经机器翻译和统计机器翻译的优点,从大规模双语语料库自动学习翻译知识,实现从源语言文本到目标语言文本的自动翻译,目前可支持十余种语言的互译。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档