众所周知,尽管基于 Attention 机制的 Transformer 类模型有着良好的并行性能,但它的空间和时间复杂度都是 O(n2)\mathcal {O}(n^2) 级别的,nn 是序列长度,所以当 nn 比较大时 Transformer 模型的计算量难以承受。近来,也有不少工作致力于降低 Transformer 模型的计算量,比如模型剪枝、量化、蒸馏等精简技术,又或者修改 Attention 结构,使得其复杂度能降低到 O(nlogn)\mathcal {O}(nlogn) 甚至 O(n)\mathcal {O}(n)
论文《Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention》当中提到一种线性化 Attention(Linear Attention)的方法,由此引发了我的兴趣,继而阅读了一些相关博客,有一些不错的收获,最后将自己对线性化 Attention 的理解汇总在此文中
当前最流行的 Attention 机制当属 Scaled-Dot Attention,即
(1)Attention(Q,K,V)=softmax(QK⊤)V
\begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax\left(\boldsymbol{Q}\boldsymbol{K}^{\top}\right)\boldsymbol{V}\tag{1}\end{equation}
这里的 Q∈Rn×dk,K∈Rm×dk,V∈Rm×dv\boldsymbol {Q}\in \mathbb {R}^{n\times d_k}, \boldsymbol {K}\in \mathbb {R}^{m\times d_k}, \boldsymbol {V}\in \mathbb {R}^{m\times d_v},简单起见我就没显示的写出 Attention 的缩放因子 1d\frac {1}{\sqrt {d}} 了。本文我们主要关心 Self Attention 的场景,所以为了介绍上的方便,统一设 Q,K,V∈Rn×d\boldsymbol {Q},\boldsymbol {K},\boldsymbol {V}\in \mathbb {R}^{n\times d}
读者也许想不到,制约 Attention 性能的关键因素,其实是定义里边的 Softmax!事实上,简单地推导一下就可以得到这个结论。QKTQK^T 这一步我们得到一个 n×nn\times n 的矩阵,之后还要做一个 Softmax
对一个 1×n1\times n 的行向量进行 Softmax,时间复杂度是 O(n)O (n),但是对一个 n×nn\times n 矩阵的每一行做一个 Softmax,时间复杂度就是 O(n2)O (n^2)
如果没有 Softmax,那么 Attention 的公式就变为三个矩阵连乘 QK⊤V\boldsymbol {QK^{\top} V},而矩阵乘法是满足结合率的,所以我们可以先算 K⊤V\boldsymbol {K^{\top} V},得到一个 d×dd\times d 的矩阵(这一步的时间复杂度是 O(d2n)O (d^2n)),然后再用 QQ 左乘它(这一步的时间复杂度是 O(d2n)O (d^2n)),由于 d≪nd \ll n,所以这样算大致的时间复杂度只是 O(n)O (n)
对于 BERT base 来说,d=64d=64 而不是 768,why?因为 768 实际上是通过 Multi-Head 拼接得到的,而每个 head 的 d=64d=64
也就是说,去掉 Softmax 的 Attention 复杂度可以降到最理想的线性级别 O(n)\mathcal {O}(n)!这显然就是我们的终极追求:Linear Attention
问题是,直接去掉 Softmax 还能算是 Attention 吗?他还能有标准的 Attention 的效果吗?为了回答这个问题,我们先将 Scaled-Dot Attention 的定义等价的改写为(本文的向量都是列向量)
(2)Attention(Q,K,V)i=∑j=1neqi⊤kjvj∑j=1neqi⊤kj
\begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}\boldsymbol{v}_j}{\sum\limits_{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}}\tag{2}\end{equation}
这里稍微解释下,首先我们知道 Q,K∈Rn×d\boldsymbol {Q},\boldsymbol {K}\in \mathbb {R}^{n\times d},令 M=Q×K⊤\boldsymbol {M} = \boldsymbol {Q}\times \boldsymbol {K^{\top}},由矩阵乘法法则可知,M\boldsymbol {M} 的第一行是由 Q\boldsymbol {Q} 的第一行乘以 K⊤\boldsymbol {K^{\top}} 的所有列得到的 Attention(Q,K,V)iAttention (\boldsymbol {Q},\boldsymbol {K},\boldsymbol {V})_i 表示最终输出结果矩阵的第 ii 行 qi⊤\boldsymbol {q}_i^{\top} 表示 Q∈Rn×d\boldsymbol {Q}\in \mathbb {R}^{n\times d} 矩阵的第 ii 行(行向量) kj\boldsymbol {k}_j 表示 K⊤∈Rd×n\boldsymbol {K^{\top}}\in \mathbb {R}^{d\times n} 矩阵的第 jj 列(列向量) vj\boldsymbol {v}_j 表示 V⊤∈Rd×nV^{\top}\in \mathbb {R}^{d\times n} 矩阵的的第 jj 列(列向量)
所以,Scaled-Dot Attention 其实就是以 eqi⊤kje^{\boldsymbol {q}_i^{\top}\boldsymbol {k}_j} 为权重对 vj\boldsymbol {v}_j 做加权平均。所以我们可以提出一个 Attention 的一般化定义
(3)Attention(Q,K,V)i=∑j=1nsim(qi,kj)vj∑j=1nsim(qi,kj)
\begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\boldsymbol{v}_j}{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)}\tag{3}\end{equation}
也就是把 eqi⊤kje^{\boldsymbol {q}_i^{\top}\boldsymbol {k}_j} 换成 qi,ki\boldsymbol {q}_i,\boldsymbol {k}_i 的一般函数 sim(qi,kj)\text {sim}(\boldsymbol {q}_i,\boldsymbol {k}_j),为了保留 Attention 相似的分布特性,我们要求 sim(qi,kj)≥0\text {sim}(\boldsymbol {q}_i, \boldsymbol {k}_j)\geq 0 恒成立。也就是说,我们如果要定义新的 Attention,必须要保留式 (3) 的形式,并且满足 sim(qi,kj)≥0\text {sim}(\boldsymbol {q}_i, \boldsymbol {k}_j)\geq 0
这种一般形式的 Attention 在 CV 中也被称为 Non-Local 网络,出自论文《Non-local Neural Networks》
如果直接去掉 Softmax,那么就是 sim(qi,kj)=qi⊤kj\text {sim}(\boldsymbol {q}_i, \boldsymbol {k}_j) = \boldsymbol {q}_i^{\top}\boldsymbol {k}_j,问题是内积无法保证非负性,所以这还不是一个合理的选择。下面我们介绍几种可取的方案
值得一提的是,下面介绍的这几种 Linear Attention,前两种来自 CV 领域,第三种是苏剑林大佬构思的(除了下面的介绍外,还有 EMANet 等 CV 领域对 Attention 的改进工作)
一个自然的想法是:如果 qi,kj\boldsymbol {q}_i, \boldsymbol {k}_j 的每个元素都是非负的,那么内积自然也是非负的。为了完成这点,我们可以给 qi,kj\boldsymbol {q}_i, \boldsymbol {k}_j 各自加个激活函数 ϕ,φ\phi,\varphi,即
(4)sim(qi,kj)=ϕ(qi)⊤φ(kj)
\begin{equation}\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = \phi(\boldsymbol{q}_i)^{\top} \varphi(\boldsymbol{k}_j)\tag{4}\end{equation}
其中 ϕ(⋅),φ(⋅)\phi (\cdot), \varphi (\cdot) 是值域非负的激活函数。本文开头提到的论文《Transformers are RNNs》选择的是 ϕ(x)=φ(x)=elu(x)+1\phi (x)=\varphi (x)=\text {elu}(x)+1,其中
elu(x)={xif x>0α(ex−1)if x<0
\text{elu}(x)=\begin{cases}x& \text{if} \ x>0\\ \alpha (e^x-1) & \text{if}\ x<0\end{cases}
常见的 α\alpha 取值为 [0.1,0.3][0.1, 0.3]
非要讲故事的话,式 (4) 可以联想到 "核方法",尤其是 ϕ=φ\phi=\varphi 时,ϕ\phi 就相当于一个核函数,而 ⟨ϕ(qi),ϕ(kj)⟩\langle \phi (\boldsymbol {q}_i), \phi (\boldsymbol {k}_j)\rangle 就是通过核函数所定义的内积。这方面的思考可以参考论文《Transformer dissection: An unified understanding for transformer’s attention via the lens of kernel》,此处不做过多延伸
另一篇更早的文章《Efficient Attention: Attention with Linear Complexities》则给出了一个更有意思的选择。它留意到在 QK⊤\boldsymbol {QK^{\top}} 中,Q,K∈Rn×d\boldsymbol {Q},\boldsymbol {K}\in \mathbb {R}^{n\times d},如果 “Q\boldsymbol {Q} 在 dd 那一维是归一化的,并且 K\boldsymbol {K} 在 nn 那一维是归一化的”,那么 QK⊤\boldsymbol {QK^{\top}} 就是自动满足归一化了,所以它给出的选择是
(5)Attention(Q,K,V)=softmax2(Q)softmax1(K)⊤V
\begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax_2\left(\boldsymbol{Q}\right)softmax_1(\boldsymbol{K})^{\top}\boldsymbol{V}\tag{5}\end{equation}
其中 softmax1softmax_1、softmax2softmax_2 分别表示在第一个 (n)(n)、第二个维度 (d)(d) 进行 Softmax 运算。也就是说,这时候我们是各自给 Q,K\boldsymbol {Q},\boldsymbol {K} 加 Softmax,而不是算完 QK⊤\boldsymbol {QK^{\top}} 之后再加 Softmax
其实可以证明这个形式也是式 (4)的一个特例,此时对应于 ϕ(qi)=softmax(qi),φ(kj)=ekj\phi (\boldsymbol {q}_i)=softmax (\boldsymbol {q}_i),\varphi (\boldsymbol {k}_j)=e^{\boldsymbol {k}_j},读者可以自行推导一下
在这里,苏神给出了一种构思。这个构思的出发点不再是式 (4),而是源于我们对原始定义 (2)的泰勒展开。由泰勒展开我们有
(6)eqi⊤kj≈1+qi⊤kj
\begin{equation}e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j} \approx 1 + \boldsymbol{q}_i^{\top}\boldsymbol{k}_j\tag{6}\end{equation}
如果 qi⊤kj≥−1\boldsymbol {q}_i^{\top}\boldsymbol {k}_j\geq -1,那么就可以保证右端的非负性,从而可以让 sim(qi,kj)=1+qi⊤kj\text {sim}(\boldsymbol {q}_i, \boldsymbol {k}_j)=1 + \boldsymbol {q}_i^{\top}\boldsymbol {k}_j。到这里读者可能已经想到了,想要保证 qi⊤kj≥−1\boldsymbol {q}_i^{\top}\boldsymbol {k}_j\geq -1,只需要分别对 qi,kj\boldsymbol {q}_i,\boldsymbol {k}_j 做 l2l_2 归一化。所以,苏神最终提出的方案就是:
(7)sim(qi,kj)=1+(qi‖qi‖)⊤(kj‖kj‖)
\begin{equation}\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = 1 + \left( \frac{\boldsymbol{q}_i}{\Vert \boldsymbol{q}_i\Vert}\right)^{\top}\left(\frac{\boldsymbol{k}_j}{\Vert \boldsymbol{k}_j\Vert}\right)\tag{7}\end{equation}
若 x=[x1,x2,...,xn]\boldsymbol {x}=[x_1,x_2,...,x_n],则 ・・・‖x‖=x12+x22+・・・+xn2\Vert x\Vert=\sqrt {x_1^2+x_2^2+・・・+x_n^2}
这不同于式 (4),但理论上它更加接近原始的 Scaled-Dot Attention
这里主要是针对苏神所提出的方法进行实现,但是由于笔者本人水平有限,因此最终实现的代码当中其实存在一些问题,主要是:
代码实现主要是针对 BERT 的 PyTorch 实现这篇文章的代码,更具体的说,其实仅修改了 ScaledDotProductAttention
这个函数,因此下面只放出这部分代码
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super(ScaledDotProductAttention, self).__init__()
def forward(self, Q, K, V, attn_mask):
Q = F.normalize(Q, dim=3)
K = F.normalize(K, dim=3)
M = (torch.ones(Q.shape[0], Q.shape[1], Q.shape[2], K.shape[2]) + torch.matmul(Q, K.transpose(-1, -2))) # scores : [batch_size, n_heads, seq_len, seq_len]
M_sum = torch.sum(M, dim=3)
M = M / M_sum.unsqueeze(3).repeat(1, 1, 1, M.shape[3])
attn = M.masked_fill(attn_mask, 0) # Fills elements of self tensor with value where mask is one.
context = torch.matmul(attn, V)
return context
如果您有更好的实现方法,还望不吝赐教