机器之心专栏
机器之心编辑部
位置编码技术是一种能够让神经网络建模句子中 Token 位置信息的技术。在 Transformer 大行其道的时代,由于 Attention 结构无法建模每个 token 的位置信息,位置编码(Position embedding) 成为 Transformer 非常重要的一个组件。研究人员也提出了各种各样的位置编码方案来让网络建模位置信息,Rope 和 Alibi 是目前最被广泛采纳的两种位置编码方案。
然而最近来自百川智能的研究发现,Rope 和 alibi 位置编码的主流实现在低精度(尤其是 bfloat16) 下存在位置编码碰撞的 bug, 这可能会影响模型的训练和推理。而且目前大部分主流开源模型的实现都存在该问题,连 llama 官方代码也中招了。
还得从位置编码算法说起
为了弄清楚这个问题,得先从位置编码的算法原理说起,在 Transformer 结构中,所有 Attention Block 的输入都会先经过位置编码,再输入网络进行后续处理。纯粹的 Attention 结构是无法精确感知到每个 token 的位置信息的,而对于语言的很多任务来说,语句的顺序对语义信息的影响是非常大的,为了建模 token 之间的位置关系,Transfomer 原始论文中引入位置编码来建模位置信息。
图 1 - 施加 Positon Embedding 示意图。
为了让模型更好地建模句子的位置信息,研究人员提出了多种位置编码方案,meta 开源的 llama [4] 模型采用了 Rope [5] 方案,使得 Rope 成为在开源社区被广泛采纳的一种位置编码方案。而 Alibi 编码因其良好的外推性也被广泛应用。
了解低精度下的位置编码碰撞之前,先来回顾一下相关算法原理。
Sinusoidal 位置编码
这是 Transformer 原始论文中提出的位置编码方法。它通过使用不同频率的正弦和余弦函数来为每个位置产生一个独特的编码。选择三角函数来生成位置编码有两个良好的性质:
1)编码相对位置信息,数学上可以证明 PE (pos+k) 可以被 PE (pos) 线性表示, 这意味着位置编码中蕴含了相对位置信息。
图 2- 句子长度为 50 的位置编码,编码维度 128,每行代表一个 Position Embedding。
2)远程衰减:不同位置的 position encoding 点乘结果会随着相对位置的增加而递减 [1]。
图 3 - 不同位置的位置编码点积可视化。
Rope
Rope 是目前开源社区应用最广泛的一种位置编码方案, 通过绝对位置编码的方式实现相对位置编码,在引入相对位置信息的同时保持了绝对位置编码的优势(不需要像相对位置编码一样去操作 attention matrix)。令 f_q, f_k 为 位置编码的函数,m 表示位置,x_m 表示该位置 token 对应的 embedding,我们希望经过位置编码后的 embedding 点积仅和相对位置有关,则可以有公式:
上面公式中 g 是某个函数,表示内积的结果只和 x_m 和 x_n 的值,以及二者位置的相对关系 (m-n) 有关在 2 维的情况下可以推导出(详细推导过程可参考原论文):
因为矩阵乘法线性累加的性质,可以拓展到多维的情况可得:
为了引入远程衰减的特性,Rope 中 \theta 的选取选择了 Transformer 原始论文中 sinusoidal 公式。
Alibi
Alibi 是谷歌发表在 ICLR2022 的一篇工作,Alibi 主要解决了位置编码外推效果差的痛点,算法思想非常的简单,而且非常直观。与直接加在 embedding 上的绝对位置编码不同,Alibi 的思想是在 attention matrix 上施加一个与距离成正比的惩罚偏置,惩罚偏置随着相对距离的增加而增加。在具体实现时,对于每个 head 会有一个超参 m 来控制惩罚偏置随着相对距离增加的幅度(斜率)。
图 4 - Alibi attention bias 示意图
论文结果显示 Alibi 极大的提升了模型的外推性能,16k token 的输入依然可以很好的支持。
图 5 - Alibi 外推效果对比。
混合精度下位置编码的 bug
从上面的算法原理中,不管是 rope 的 cos (m\theta) 还是 alibi 的 i-1(m, i 代表 postion id), 需要为每个位置生成一个整型的 position_id, 在上下文窗口比较大的时候,百川智能发现目前主流的位置编码实现在混合精度下都存在因为低精度(float16/bfloat16) 浮点数表示精度不足导致位置编码碰撞的问题。尤其当模型训练(推理)时上下文长度越来越长,低精度表示带来的位置编码碰撞问题越来越严重,进而影响模型的效果,下面以 bfloat16 为例来说明这个 bug。
浮点数表示精度
浮点数在计算机中表示由符号位(sign),指数位 (exponent),尾数位 (fraction) 三部分组成,对于一个常规的数值表示,可以由如下公式来计算其代表的数值(其中 offset 是指数位的偏置):
由公式可知,尾数位的长度决定了浮点数的表示精度。深度学习中常用的 float32/float16/bfloat16 内存中的表示分别如下图所示:
图 6- bfloat16 的表示格式
图 7- float16 的表示格式
图 8- float32 的表示格式
可以看到 float16 和 bfloat16 相比于 float32 都牺牲了表示的精度,后续以 bfloat16 为例说明位置编码中存在的问题(float16 同理)。下表展示了 bfloat16 在不同数值范围(只截取整数部分)内的表示精度。
可以看到当整数范围超过 256, bfloat16 就无法精确表示每一个整数,可以用代码验证一下表示精度带来的问题。
Rope& Alibi 编码的问题
Meta 开源的 llama 模型采用了 Rope 的位置编码方式, 官方的实现(以及大部分的第三方 llama 系列模型)在 bfloat16 下存在精度问题带来的位置编码碰撞(不同位置的 token 在 bfloat16 下变成同一个数)。Llama 官方代码如下:
Python
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
上面第 18 行核心一句根据输入序列长度生成每个位置的 positon idx 在 bfloat16 下产生位置碰撞。
Python
# self.inv_freq.dtype == torch.bfloat16 when bfloat16 is enabled during training
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
在实际训练时如果开了 bfloat16, self.inv_freq 的 dtype 会被转为 bfloat16, 可以通过简单的代码来看一下位置碰撞的问题。
Python
t = torch.arange(4096, dtype=torch.float32)
plt.scatter(t[-100:], t[-100:].to(torch.bfloat16).float(),s=0.8)
plt.xlabel('position in float32')
plt.ylabel('position in bfloat16'
根据 bfloa16 的表示精度可知,训练(推理)时上下文长度越长,位置编码碰撞的情况越严重,长度为 8192 的上下文推理中,仅有大约 10% 的 token 位置编码是精确的,好在位置编码碰撞有局域性的特质,只有若干个相邻的 token 才会共享同一个 position Embedding, 在更大的尺度上,不同位置的 token 还是有一定的区分性。
图 10- 不同上下文窗口下位置编码精确 token 所占比例。
除了 llama 模型,百川智能发现 alibi 位置编码也存在上述问题,原因依然在于生成整数的位置索引时会在低精度下产生碰撞问题。
修复方案
Rope 修复
Rope 的修复相对简单,只需要保证在生成 position_id 的时候一定在 float32 的精度上即可。注意:
float32 的 tensor register_buffer 后在训练时如果开启了 bfloat16, 也会被转为 bfloat16。
Python
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
Alibi 修复
图 11- 修复前后 alibi attention_bias 对照。
修复效果
百川智能仅在推理阶段对位置编码的精度问题进行修复【注:训练阶段可能也存在问题,取决于训练的具体配置和方法】,可以看到:
a. 在长上下文的推理中,模型的 ppl 要显著优于修复前的 ppl
b.Benchmark 上测试结果显示修复前后区别不大,可能是因为 benchmark 上测试文本长度有限,很少触发 Position embedding 的碰撞
Benchmark 对比
Perplexity
我们在通用的文本数据上对修改前后模型在中英文文本上的困惑度进行测试,效果如下:
[0] Dongxu Zhang, & Dong Wang. (2015). Relation Classification via Recurrent Neural Network.
[1] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, & Illia Polosukhin. (2023). Attention Is All You Need.
[2] Zihang Dai, Zhilin Yang, Yiming Yang, Jaime Carbonell, Quoc V. Le, & Ruslan Salakhutdinov. (2019). Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context.
[3] Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, & Peter J. Liu. (2020). Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer.
[4] Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, & Guillaume Lample. (2023). LLaMA: Open and Efficient Foundation Language Models.
[5] Jianlin Su, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, & Yunfeng Liu. (2022). RoFormer: Enhanced Transformer with Rotary Position Embedding.
[6] Ofir Press, Noah A. Smith, & Mike Lewis. (2022). Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation.
[7] Yutao Sun, Li Dong, Barun Patra, Shuming Ma, Shaohan Huang, Alon Benhaim, Vishrav Chaudhary, Xia Song, & Furu Wei. (2022). A Length-Extrapolatable Transformer.
[8] https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
[9] Shouyuan Chen, Sherman Wong, Liangjian Chen, & Yuandong Tian. (2023). Extending Context Window of Large Language Models via Positional Interpolation.
[10] https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/