| 导语 BERT模型在多种下游任务表现优异,但庞大的模型结果也带来了训练及推理速度过慢的问题,难以满足对实时响应速度要求高的场景,模型轻量化就显得非常重要。因此,笔者对BERT系列模型进行剪枝,并部署到实际项目中,在满足准确率的前提下提高推理速度。
一. 模型轻量化
模型轻量化是业界一直在探索的一个课题,尤其是当你使用了BERT系列的预训练语言模型,inference速度始终是个绕不开的问题,而且训练平台可能还会对训练机器、速度有限制,训练时长也是一个难题。
目前业界上主要的轻量化方法如下:
我们团队对这些轻量化方法都进行了尝试,简单总结如下:
在这些方法中,剪枝显得非常简单又高效,如果你想快速得对BERT模型进行轻量化,不仅inference快,还希望训练快,模型文件小,效果基本维持,那么剪枝将是一个非常好的选择,本文将介绍如何为BERT系列模型剪枝,并附上代码,教你十分钟剪枝。
二. BERT剪枝
本节先重温BERT[1]及其变体AL-BERT[2]的模型结构,分析在哪里地方参数量大,再介绍如何为这类结构进行剪枝。
1. BERT模型主要组件
按照默认的维度配置,得到的模型参数大小如下(此处仅展示一层):
可以看到BERT模型的参数维度都比较大,都是768起步,而在每一层的结构中,全连接层的3072维,是造成该层参数爆炸的主要原因。单层的参数量已经比普通模型大了许多,当该层参数量再乘以12,杀伤指数更是暴增。
海量的参数加上海量的无监督训练数据,BERT模型取得奇效,但我们在训练我们的下游任务时,是否真的需要这么大的模型呢?
可以看到,AL-BERT对Embedding参数进行了因式分解,分解成了2个小矩阵,先将Embedding矩阵投射到一个更小的矩阵E,再投影到隐藏空间H中,减少了参数量(注:同时AL-BERT进行了跨层参数共享,所以保存的参数量少,得到的模型文件非常小),大大加快了模型的训练速度,但遗憾的是AL-BERT并没有提高inference速度。
2. 剪枝方法
基于以上分析,针对BERT系列模型的结构,可采取的剪枝方法如下:
1)层数剪枝
在BERT模型的应用中,我们一般取第12层的hidden向量用于下游任务。而低层向量基本上包含了基础信息,我们可以取低层的输出向量接到任务层,进行微调。
(跟许老板讨论过一个论文,BERT的低层向量可以学习到一些基础的词法信息,高层向量可以学到更多跟任务相关的特征,暂时找不到这篇论文了,找到会补上)
2)维度剪枝
接下来对每一层的维度进行剪枝,ok,全连接层的3072维,在一堆768中成功引起了我们的注意:
intermediate层的参数量 =(768+1)*3072 *2 = 4724736
假设我们剪到768维,全连接层的参数量可以减少75%,假如剪到384维,全连接的参数量可以减少87.5%!
3)Attention剪枝
在12头注意力中,每头维度是64,最终叠加注意力向量共768维。
相关研究[3]表明:
因此,我们可以尝试只保留1-2层模型,裁剪ffn维度,减少head个数,在裁剪大量参数的同时维持精度不会下降太多。
三. 工程实现
首先我们看下市面上有没有啥方便的工具可以剪枝:
这些工具都不适合使用,那就让我们自己来动手剪枝吧:
下面进入了超级简单的代码环节!关键代码仅20行!
1)首先,将谷歌pretrain的模型参数预存好,保存到一个json文件中:
2)参数赋值,在model_fn_builder函数中,加载预存的参数进行剪枝赋值:
是的!剪枝就是如此简单!从前笔者为了多方面做对比实验(例如,第一层剪到768维,第2层剪到384维),强行修改了BERT的模型代码,传入一个字典进行剪枝,迁移到另一个BERT变体模型就不太方便。
最后附上部分实验结果(时间可能会有所波动):
模型 | 层数 | ffn维度 | head个数 | hidden size | tes acc | inference时间 |
---|---|---|---|---|---|---|
BERT | 12 | 3072 | 12 | 768 | 0.78 | 1000ms+ |
BERT | 2 | 384 | 6 | 768 | 0.75 | 340ms |
BERT | 1 | 384 | 6 | 384 | 0.701 | 217ms |
AL-BERT | 4 | 1248 | 12 | 312 | 0.771 | 650ms |
AL-BERT | 2 | 312 | 6 | 312 | 0.763 | 388ms |
AL-BERT | 1 | 312 | 6 | 312 | 0.74 | 183ms |
小结:对BERT系列模型来说,剪枝是一个非常不错的轻量化方法,很多下游任务可以不需要这么庞大的模型,也能达到很好的效果。
References