前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >专栏 >去掉 Attention 的 Softmax,复杂度降为 O (n)

去掉 Attention 的 Softmax,复杂度降为 O (n)

作者头像
mathor
发布2021-05-12 09:54:08
发布2021-05-12 09:54:08
1.2K00
代码可运行
举报
文章被收录于专栏:mathormathor
运行总次数:0
代码可运行

众所周知,尽管基于 Attention 机制的 Transformer 类模型有着良好的并行性能,但它的空间和时间复杂度都是 O(n2)\mathcal {O}(n^2) 级别的,nn 是序列长度,所以当 nn 比较大时 Transformer 模型的计算量难以承受。近来,也有不少工作致力于降低 Transformer 模型的计算量,比如模型剪枝、量化、蒸馏等精简技术,又或者修改 Attention 结构,使得其复杂度能降低到 O(nlog⁡n)\mathcal {O}(nlog⁡n) 甚至 O(n)\mathcal {O}(n)

论文《Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention》当中提到一种线性化 Attention(Linear Attention)的方法,由此引发了我的兴趣,继而阅读了一些相关博客,有一些不错的收获,最后将自己对线性化 Attention 的理解汇总在此文中

Attention

当前最流行的 Attention 机制当属 Scaled-Dot Attention,即

(1)Attention(Q,K,V)=softmax(QK⊤)V

\begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax\left(\boldsymbol{Q}\boldsymbol{K}^{\top}\right)\boldsymbol{V}\tag{1}\end{equation}

这里的 Q∈Rn×dk,K∈Rm×dk,V∈Rm×dv\boldsymbol {Q}\in \mathbb {R}^{n\times d_k}, \boldsymbol {K}\in \mathbb {R}^{m\times d_k}, \boldsymbol {V}\in \mathbb {R}^{m\times d_v},简单起见我就没显示的写出 Attention 的缩放因子 1d\frac {1}{\sqrt {d}} 了。本文我们主要关心 Self Attention 的场景,所以为了介绍上的方便,统一设 Q,K,V∈Rn×d\boldsymbol {Q},\boldsymbol {K},\boldsymbol {V}\in \mathbb {R}^{n\times d}

摘掉 Softmax

读者也许想不到,制约 Attention 性能的关键因素,其实是定义里边的 Softmax!事实上,简单地推导一下就可以得到这个结论。QKTQK^T 这一步我们得到一个 n×nn\times n 的矩阵,之后还要做一个 Softmax

对一个 1×n1\times n 的行向量进行 Softmax,时间复杂度是 O(n)O (n),但是对一个 n×nn\times n 矩阵的每一行做一个 Softmax,时间复杂度就是 O(n2)O (n^2)

如果没有 Softmax,那么 Attention 的公式就变为三个矩阵连乘 QK⊤V\boldsymbol {QK^{\top} V},而矩阵乘法是满足结合率的,所以我们可以先算 K⊤V\boldsymbol {K^{\top} V},得到一个 d×dd\times d 的矩阵(这一步的时间复杂度是 O(d2n)O (d^2n)),然后再用 QQ 左乘它(这一步的时间复杂度是 O(d2n)O (d^2n)),由于 d≪nd \ll n,所以这样算大致的时间复杂度只是 O(n)O (n)

对于 BERT base 来说,d=64d=64 而不是 768,why?因为 768 实际上是通过 Multi-Head 拼接得到的,而每个 head 的 d=64d=64

也就是说,去掉 Softmax 的 Attention 复杂度可以降到最理想的线性级别 O(n)\mathcal {O}(n)!这显然就是我们的终极追求:Linear Attention

一般的定义

问题是,直接去掉 Softmax 还能算是 Attention 吗?他还能有标准的 Attention 的效果吗?为了回答这个问题,我们先将 Scaled-Dot Attention 的定义等价的改写为(本文的向量都是列向量)

(2)Attention(Q,K,V)i=∑j=1neqi⊤kjvj∑j=1neqi⊤kj

\begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}\boldsymbol{v}_j}{\sum\limits_{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}}\tag{2}\end{equation}

这里稍微解释下,首先我们知道 Q,K∈Rn×d\boldsymbol {Q},\boldsymbol {K}\in \mathbb {R}^{n\times d},令 M=Q×K⊤\boldsymbol {M} = \boldsymbol {Q}\times \boldsymbol {K^{\top}},由矩阵乘法法则可知,M\boldsymbol {M} 的第一行是由 Q\boldsymbol {Q} 的第一行乘以 K⊤\boldsymbol {K^{\top}} 的所有列得到的 Attention(Q,K,V)iAttention (\boldsymbol {Q},\boldsymbol {K},\boldsymbol {V})_i 表示最终输出结果矩阵的第 ii 行 qi⊤\boldsymbol {q}_i^{\top} 表示 Q∈Rn×d\boldsymbol {Q}\in \mathbb {R}^{n\times d} 矩阵的第 ii 行(行向量) kj\boldsymbol {k}_j 表示 K⊤∈Rd×n\boldsymbol {K^{\top}}\in \mathbb {R}^{d\times n} 矩阵的第 jj 列(列向量) vj\boldsymbol {v}_j 表示 V⊤∈Rd×nV^{\top}\in \mathbb {R}^{d\times n} 矩阵的的第 jj 列(列向量)

所以,Scaled-Dot Attention 其实就是以 eqi⊤kje^{\boldsymbol {q}_i^{\top}\boldsymbol {k}_j} 为权重对 vj\boldsymbol {v}_j 做加权平均。所以我们可以提出一个 Attention 的一般化定义

(3)Attention(Q,K,V)i=∑j=1nsim(qi,kj)vj∑j=1nsim(qi,kj)

\begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\boldsymbol{v}_j}{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)}\tag{3}\end{equation}

也就是把 eqi⊤kje^{\boldsymbol {q}_i^{\top}\boldsymbol {k}_j} 换成 qi,ki\boldsymbol {q}_i,\boldsymbol {k}_i 的一般函数 sim(qi,kj)\text {sim}(\boldsymbol {q}_i,\boldsymbol {k}_j),为了保留 Attention 相似的分布特性,我们要求 sim(qi,kj)≥0\text {sim}(\boldsymbol {q}_i, \boldsymbol {k}_j)\geq 0 恒成立。也就是说,我们如果要定义新的 Attention,必须要保留式 (3) 的形式,并且满足 sim(qi,kj)≥0\text {sim}(\boldsymbol {q}_i, \boldsymbol {k}_j)\geq 0

这种一般形式的 Attention 在 CV 中也被称为 Non-Local 网络,出自论文《Non-local Neural Networks》

几个例子

如果直接去掉 Softmax,那么就是 sim(qi,kj)=qi⊤kj\text {sim}(\boldsymbol {q}_i, \boldsymbol {k}_j) = \boldsymbol {q}_i^{\top}\boldsymbol {k}_j,问题是内积无法保证非负性,所以这还不是一个合理的选择。下面我们介绍几种可取的方案

值得一提的是,下面介绍的这几种 Linear Attention,前两种来自 CV 领域,第三种是苏剑林大佬构思的(除了下面的介绍外,还有 EMANet 等 CV 领域对 Attention 的改进工作)

核函数形式

一个自然的想法是:如果 qi,kj\boldsymbol {q}_i, \boldsymbol {k}_j 的每个元素都是非负的,那么内积自然也是非负的。为了完成这点,我们可以给 qi,kj\boldsymbol {q}_i, \boldsymbol {k}_j 各自加个激活函数 ϕ,φ\phi,\varphi,即

(4)sim(qi,kj)=ϕ(qi)⊤φ(kj)

\begin{equation}\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = \phi(\boldsymbol{q}_i)^{\top} \varphi(\boldsymbol{k}_j)\tag{4}\end{equation}

其中 ϕ(⋅),φ(⋅)\phi (\cdot), \varphi (\cdot) 是值域非负的激活函数。本文开头提到的论文《Transformers are RNNs》选择的是 ϕ(x)=φ(x)=elu(x)+1\phi (x)=\varphi (x)=\text {elu}(x)+1,其中

elu(x)={xif x>0α(ex−1)if x<0

\text{elu}(x)=\begin{cases}x& \text{if} \ x>0\\ \alpha (e^x-1) & \text{if}\ x<0\end{cases}

常见的 α\alpha 取值为 [0.1,0.3][0.1, 0.3]

非要讲故事的话,式 (4) 可以联想到 "核方法",尤其是 ϕ=φ\phi=\varphi 时,ϕ\phi 就相当于一个核函数,而 ⟨ϕ(qi),ϕ(kj)⟩\langle \phi (\boldsymbol {q}_i), \phi (\boldsymbol {k}_j)\rangle 就是通过核函数所定义的内积。这方面的思考可以参考论文《Transformer dissection: An unified understanding for transformer’s attention via the lens of kernel》,此处不做过多延伸

妙用 Softmax

另一篇更早的文章《Efficient Attention: Attention with Linear Complexities》则给出了一个更有意思的选择。它留意到在 QK⊤\boldsymbol {QK^{\top}} 中,Q,K∈Rn×d\boldsymbol {Q},\boldsymbol {K}\in \mathbb {R}^{n\times d},如果 “Q\boldsymbol {Q} 在 dd 那一维是归一化的,并且 K\boldsymbol {K} 在 nn 那一维是归一化的”,那么 QK⊤\boldsymbol {QK^{\top}} 就是自动满足归一化了,所以它给出的选择是

(5)Attention(Q,K,V)=softmax2(Q)softmax1(K)⊤V

\begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax_2\left(\boldsymbol{Q}\right)softmax_1(\boldsymbol{K})^{\top}\boldsymbol{V}\tag{5}\end{equation}

其中 softmax1softmax_1、softmax2softmax_2 分别表示在第一个 (n)(n)、第二个维度 (d)(d) 进行 Softmax 运算。也就是说,这时候我们是各自给 Q,K\boldsymbol {Q},\boldsymbol {K} 加 Softmax,而不是算完 QK⊤\boldsymbol {QK^{\top}} 之后再加 Softmax

其实可以证明这个形式也是式 (4)​的一个特例,此时对应于 ϕ(qi)=softmax(qi),φ(kj)=ekj\phi (\boldsymbol {q}_i)=softmax (\boldsymbol {q}_i),\varphi (\boldsymbol {k}_j)=e^{\boldsymbol {k}_j},读者可以自行推导一下

苏神的构思

在这里,苏神给出了一种构思。这个构思的出发点不再是式 (4),而是源于我们对原始定义 (2)​的泰勒展开。由泰勒展开我们有

(6)eqi⊤kj≈1+qi⊤kj

\begin{equation}e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j} \approx 1 + \boldsymbol{q}_i^{\top}\boldsymbol{k}_j\tag{6}\end{equation}

如果 qi⊤kj≥−1\boldsymbol {q}_i^{\top}\boldsymbol {k}_j\geq -1,那么就可以保证右端的非负性,从而可以让 sim(qi,kj)=1+qi⊤kj\text {sim}(\boldsymbol {q}_i, \boldsymbol {k}_j)=1 + \boldsymbol {q}_i^{\top}\boldsymbol {k}_j。到这里读者可能已经想到了,想要保证 qi⊤kj≥−1\boldsymbol {q}_i^{\top}\boldsymbol {k}_j\geq -1,只需要分别对 qi,kj\boldsymbol {q}_i,\boldsymbol {k}_j 做 l2l_2 归一化。所以,苏神最终提出的方案就是:

(7)sim(qi,kj)=1+(qi‖qi‖)⊤(kj‖kj‖)

\begin{equation}\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = 1 + \left( \frac{\boldsymbol{q}_i}{\Vert \boldsymbol{q}_i\Vert}\right)^{\top}\left(\frac{\boldsymbol{k}_j}{\Vert \boldsymbol{k}_j\Vert}\right)\tag{7}\end{equation}

若 x=[x1,x2,...,xn]\boldsymbol {x}=[x_1,x_2,...,x_n],则 ・・・‖x‖=x12+x22+・・・+xn2\Vert x\Vert=\sqrt {x_1^2+x_2^2+・・・+x_n^2}

这不同于式 (4),但理论上它更加接近原始的 Scaled-Dot Attention

实现

这里主要是针对苏神所提出的方法进行实现,但是由于笔者本人水平有限,因此最终实现的代码当中其实存在一些问题,主要是:

  1. 从测试结果来看,改进后的计算速度并没有提升
  2. 无法做到求和为 1

代码实现主要是针对 BERT 的 PyTorch 实现这篇文章的代码,更具体的说,其实仅修改了 ScaledDotProductAttention 这个函数,因此下面只放出这部分代码

代码语言:javascript
代码运行次数:0
运行
复制
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        Q = F.normalize(Q, dim=3)
        K = F.normalize(K, dim=3)
        M = (torch.ones(Q.shape[0], Q.shape[1], Q.shape[2], K.shape[2]) + torch.matmul(Q, K.transpose(-1, -2))) # scores : [batch_size, n_heads, seq_len, seq_len]
        M_sum = torch.sum(M, dim=3)
        M = M / M_sum.unsqueeze(3).repeat(1, 1, 1, M.shape[3])
        attn = M.masked_fill(attn_mask, 0) # Fills elements of self tensor with value where mask is one.
        context = torch.matmul(attn, V)
        return context

如果您有更好的实现方法,还望不吝赐教

Reference
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2021-04-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Attention
  • 摘掉 Softmax
  • 一般的定义
  • 几个例子
  • 核函数形式
  • 妙用 Softmax
  • 苏神的构思
  • 实现
  • Reference
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档