首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

智源FlagAttention:面向多种训练芯片的大模型高性能Triton算子集

随着人工智能产业的高速增长,大模型已成为行业创新的驱动引擎。大模型对计算资源的需求巨大,对各种AI芯片而言是难得的发展机遇。尽管NVIDIA的GPU芯片和CUDA软件生态在市场上占据主导地位,但其他芯片制造商也在加速研发,智源研究院牵头开发的FlagAttention项目,目标是构建一套对多种芯片适配更友好的大模型核心算子集合,我们选择Triton作为开发语言,基于Triton的开放性,FlagAttention不仅支持NVIDIA GPU,面向未来,还可以显著降低不同芯片之间模型适配的成本,提高大模型的训练与推理效率。

FlagAttention

在基于 Transformer 的大语言模型中,为了在训练和推理支持更长上下文,提升模型的性能,往往需要使用内存高效的Attention实现,如FlashAttention。而为了支持长序列建模,也需要对模型进行改造,包括对标准FlashAttention进行自定义修改,此时开发者一般面临两个选择:

使用神经网络框架提供的API组合实现。虽然该方法灵活,但会失去内存高效的优势;

对FlashAttention算子进行定制和扩展,增加新功能而保持内存高效的性质。由于CUDA C语言的特性,开发者修改难度较高,耗时耗力。

自定义CUDA算子虽然可以解决效率的问题,但适配更多的芯片变得困难。目前 NVIDIA的GPU芯片和CUDA是深度学习的主流生态,许多高性能算子率先在 NVIDIA 平台上实现,而其他芯片上的适配进程相对较慢,因此模型层面的创新虽然层出不穷,但往往难以惠及非NVIDIA硬件的用户。

为了让更广泛的模型用户使用到最新的模型优化技术,由智源研究院AI系统研究组牵头开发的FlagAttention项目应时而生。该项目选择Triton作为算子开发语言,Triton以其开放性、轻量级和易开发正在飞速赢得厂商和开发者的支持。FlagAttention希望通过Triton推动模型领域进展的普适化,让大模型领域的最新技术迅速落地到多种硬件平台。

FlagAttention开源仓库

https://github.com/FlagOpen/FlagAttention

本次FlagAttention 目前包含两个算子:

piecewise_attention:分段式Attention算子,支持长文本模型的训练、推理的重要算子。

flash_attention:基于Triton的Multihead Attention高效实现,前向性能优于 CUDA。

piecewise_attention

智源大模型算法团队发现,带旋转位置编码 (RoPE) 的Transformer语言模型生成的序列超出其训练时的最大长度时,存在效果下降的问题,因此提出了新的 NLPE(Non-Linearized Position Embedding) 算法对Attention进行修改。NLPE算法针对qk之间的距离近和远使用两种不同的位置编码,并且用不同的方式计算qk的内积s。

图 1:NLPE 算法中的分段式 Attention 计算方式

在代码、电子书等领域的续写实验显示,NLPE算法最终可以将文本长度4K的 Aquila2-34B 模型外延到16K长度,续写文本的连贯性好于Dynamic-NTK、位置内插(Position Interpolation)等方法,且保持4K长度内的语言模型能力。

图 2:NLPE 与主流 Dynamic-NTK 外延方法在 Base 模型上的能力对比(注:ppl 值越低越好)

此外,在多个长文本评测集上的指令跟随能力测试结果显示,AquilaChat2-7B-NLPE(2K)准确率为 17.2%,远高于AquilaChat2-7B-Dynamic-NTK的准确率 0.4%。

图 3:NLPE 与主流 Dynamic-NTK 外延方法在 SFT 模型上的能力对比

由于NLPE需要以分段的方式计算Attention, 所以FlashAttention算子不适用。piecewise_attention扩展了FlashAttention,支持分段式的Attention计算。和 FlashAttention的关键区别是在正向算子中拼合了s, 而在反向算子中分割了s的梯度。

图 4: piecewise_attention 算子实现关键部分

piecewise_attention算子使用了分块和重计算等技巧,保持了内存高效的特性。运算效率远高于使用Pytorch小算子组合的实现。

图 5:piecewise_attention 算子性能测试

piecewise_attention算子已用于Aquila2-34B语言模型的训练和推理,在Aquila2-34B-16K长文本语言模型训练过程中表现稳定,可用性得到了验证。以下为Aquila2-34B-16K模型(其中使用了piecewise_attention)的训练过程loss曲线。

图 6:Aquila2-34B-16K 模型训练 Loss 曲线

flash_attention

为了方便用户使用,除了piecewise_attention之外,FlagAttention算子集还提供了标准 flash_attention 实现,这是经优化过的标准 Multihead Attention 的 Triton 高效实现版本,其正向算子性能优于CUDA 实现的FlashAttention。在 head dim=64,带causal masking的情况下比CUDA实现的FlashAttention快 12%,在 head dim=128,带causal masking的情况下和CUDA实现速度相当,在推理应用上有一定的优势。后续我们也会进一步优化。

图 7:flash attention 算子 (head dim=64, with causal masking)

图 8: flash attention 算子 (head dim=128, with causal masking)

(注:以上数据测试过程中 batch size 随着 seqlen 的增大而减小,二者乘积保持 32768。图中为 Pytorch 实现数值为 0 的表示因显存不足无法运行。)

通向统一的多芯片 Kernel 编程

神经网络框架提供了深度学习领域最常用的通用算子,芯片厂商通过适配算子支持深度学习框架。这些常用的算子可以组合实现更复杂的算子,但这样的实现效率不足以满足深度学习的训练和推理的所有需求。为了提升效率,为特定领域和用途开发高性能的复杂算子是常见的做法。对于硬件厂商而言,最常用的算法会实现为高性能数学库,比如基本线性代数 (BLAS), 傅里叶变换 (FFT),循环神经网络 (RNN), 卷积神经网络 (CNN) 等。神经网络框架也不同程度地包含对应这些算法的算子。但不同硬件厂商对于这些特定领域的库和复杂算子的支持程度不同。

对于快速发展的大模型研究和生产实践而言,无论是通用的小算子,还是用于特定领域的库和复杂算子,都不能同时满足灵活性和高性能的要求。大模型研究呼唤一种成本更低,更容易适配多芯片的 Kernel 编程方式,使得可以快速针对特定需求进行自定义 Kernel 的开发。

CUDA C 是当前 Kernel 开发使用较多的编程语言,但是开发难度较高。而且 CUDA 生态和 NVIDIA 芯片存在强绑定的关系,其他芯片厂商对 CUDA C 语言和 CUDA 库的移植和适配程度不同,而且难以一直跟随 NVIDIA。因此使用 CUDA C 开发算子成本较高,可移植性也不好。

为了降低开发和适配成本,以统一方式面向多芯片开发 Kernel, FlagAttention 选择了基于 Triton 语言开发自定义算子。Triton 语言是 Python 的一种嵌入式 DSL,提供了类似 Python 的语法,易于学习。Triton 语言提供了更高的抽象层级,将许多常用的优化手段内化到编译器层面,因此可以降低心智负担和开发难度,即使是没有丰富 CUDA 的开发经验,也可以迅速上手,并开发出性能不俗的算子。同时,Triton 提供了良好的 Python 与 PyTorch 集成。不需要使用跨语言的 binding 开发 Python C 扩展,可以简化构建和安装流程。

Triton 语言是开源的,兼容性更强。因此,国内外非 NVIDIA 硬件厂商均在积极自研基于 Triton 语言的后端,来绕过 CUDA 的限制。FlagAttention 基于 Triton 实现,因此与它们天然兼容,仅需要少量修改就可将算子移植到非 NVIDIA 芯片。我们可以通过与硬件厂商合作优化算子实现,适配 Triton 编译器来提升算子库性能。

未来展望

FlagAttention项目起源于大模型中对piecewise_attention算子的需求,但在未来的计划中,将会支持更多功能,包括Attention的更多功能以及Transformer模型相关的更多功能。

支持batch内句子长度不同;

支持推理时使用kv cache;

支持dropout、attention_bias等。

在NVIDIA平台上基于Triton开发的flash_attention正向算子性能优于CUDA版,已经证明使用Triton实现高性能算子的可行性。但反向算子还存在较大优化空间,未来将进一步优化算子的性能。

此外,FlagAttention的算子在天数智芯MR-V100上完成了适配和功能验证,正在与天数智芯团队合作优化算子性能,同时改进Triton编译器。未来也希望让FlagAttention适配更多芯片。

在此,欢迎更多大模型开发者使用FlagAttention算子库,期待大家的反馈和建议!如有问题,可以在GitHub issue中与我们沟通:

https://github.com/FlagOpen/FlagAttention/issues

  • 发表于:
  • 原文链接https://page.om.qq.com/page/OA0JalouW3j2RqEVU_DcFhEg0
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

相关快讯

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券