写在前面
上一期魔改Transformer好像是两周之前了哈哈,今天继续!最近刚关注的同学感兴趣可以翻翻历史记录补补课
另外,关于Transformer基础和细节可以查看
关于Transformer,面试官们都怎么问 enjoy!来看看今天的主角吧:
- 「Set Transformer from OX,ICML2019」
- 「Star-Transformer from FDU,NAACL 2019」
- 「BP-Transformer from FDU&AWS」
Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks[1]
NLP的输入基本都是有顺序性的,但是还有很多机器学习任务的输入都是顺序无关的「set-structure data」,比如多示例学习(multiple instance learning)、3D shape recognition、sequence ordering、meta-learning等等,处理这类问题的模型有两个特性:
- 「permutation invariant」:即输出结果与输入顺序无关;
- 能够处理任何大小的输入
这篇的重点就是将transformer应用到上述集合型输入问题上,并对transformer的复杂度进行优化。下面我们来看看具体的模型设计。
1.1 Multihead Attention Block(MAB)
其实就是没有「position embedding」和「dropout」的transformer encoder
\operatorname{MAB}(X, Y)=\text { LayerNorm }(H+\operatorname{rFF}(H))
H=\text {LayerNorm} (X+\text { Multihead }(X, Y, Y ; \omega))
1.2 Set Attention Block(SAB)
输入集合本身
X 的MAB
\operatorname{SAB}(X):=\operatorname{MAB}(X, X)
1.3 Induced Set Attention Block (ISAB)
SAB的问题就是transformer的传统问题,复杂度太高。引入「inducing points 矩阵
I\in \mathbb{R}^{m \times d}」,类似于矩阵分解,将原来的一步attention拆分成两步attention:首先用
I 对输入
X 做self-attention,接着用得到的结果对输入做attention。将复杂度从
O(N^{2}) 优化为
O(mn)\operatorname{ISAB}_{m}(X)=\operatorname{MAB}(X, H) \in \mathbb{R}^{n \times d}
\text { where } H=\operatorname{MAB}(I, X) \in \mathbb{R}^{m \times d}
1.4 整体流程
上述几个子模块和整体流程如下图,
X\in \mathbb{R}^{n \times d_{x}} 为输入,
Z\in \mathbb{R}^{n \times d} 为经过ENCODER之后的特征表示
\begin{aligned}
&\text { Encoder }(X)=\operatorname{SAB}(\operatorname{SAB}(X))\\
&\text { Encoder }(X)=\operatorname{ISAB}_{m}\left(\operatorname{ISAB}_{m}(X)\right)
\end{aligned}
\begin{aligned}
&\operatorname{Decoder}(Z ; \lambda)=\operatorname{rFF}\left(\operatorname{SAB}\left(\operatorname{PMA}_{k}(Z)\right)\right) \in \mathbb{R}^{k \times d}\\
&\text { where } \operatorname{PMA}_{k}(Z)=\operatorname{MAB}(S, \operatorname{rFF}(Z)) \in \mathbb{R}^{k \times d}
\end{aligned}
1.5 reference
- Code Here[2]
- Open Review[3]
Star-Transformer[4]
在Transformers Assemble第一趴中已经提到过,vanilla Transformer最大的问题就是模型复杂度非常高,达序列长度的平方。这篇论文提出的一种轻量级的结构「Star-Transformer(ST)」 可以有效缓解这个问题。
具体做法就是将全连接拓扑结构优化成了一种星型结构,两者对比如下图所示。将输入的token看成是图中的一个个节点,左图vanilla transformer中所有的节点两两全连接,在改进的star-transformer中通过一个共享中继节点实现两两节点的连接。这一结构优化可以将复杂度从平方降低到线性,同时保留了transformer捕获局部和全局信息的能力。
2.1 Model Architecture
既然ST模型是一个图结构,那么来看看其节点和边分别都是啥?
- 「Satellite Node」:N个卫星节点,代表输入序列的每个token;
- 「Relay Node」:一个中继节点,共享,连接所有卫星节点整合全局信息;
- 「Radical Connections」:辐射连接,卫星节点同中继节点之间的边,上图左图中所有的非相邻节点之间的链接都可以通过
卫星节点-->中继节点-->卫星节点
的形式完成,从而捕获非局部信息; - 「Ring Connections」:环形连接,将相邻的节点连接起来,捕获局部信息,作用类似于CNN或双向RNN;
通过以上的结构优化,可以实现复杂度的降低,理论上ST可以实现与vanilla transformer相同的任意
h_{i}->
h_{j} 之间的连接:
h_{i}->
s->
h_{j},在实际任务中加速4.5倍。
2.2 Model Implementation
模型训练过程跟vanilla transformer中类似,流程如下:
首先对输入文本做Embedding得到
\mathbf{E}=\left[\mathbf{e}_{1} ; \cdots ; \mathbf{e}_{n}\right],以此初始化所有卫星节点的初始值
\mathbf{H}^{0}=\mathbf{E},中继节点的初始值通过平均得到
s^{0}=\text average (\mathbf{E}),接下去进行
T 次迭代更新,在每一轮执行两步:
「Update satellite node」
类似于RNN,
\mathbf{C}_{i}^{t} 代表第
i 个节点的上下文信息,通过相邻节点
h_{i-1}、
h_{i+1},中继节点
s^{t-1},自身的token embedding
e^{i} 来更新;
h_{i} 为第
i 个节点表示,通过其自身前一步状态和当前上下文信息更新
\begin{aligned}
&\mathbf{C}_{i}^{t}=\left[\mathbf{h}_{i-1}^{t-1} ; \mathbf{h}_{i}^{t-1} ; \mathbf{h}_{i+1}^{t-1} ; \mathbf{e}^{i} ; \mathbf{s}^{t-1}\right]\\
&\mathbf{h}_{i}^{t}=\text { MultiAtt }\left(\mathbf{h}_{i}^{t-1}, \mathbf{C}_{i}^{t}\right)
\end{aligned}
「Update relay node」
看公式就非常清晰明了
\begin{aligned}
&\mathbf{s}^{t}=\text { Multi } \operatorname{Att}\left(\mathbf{s}^{t-1},\left[\mathbf{s}^{t-1} ; \mathbf{H}^{t}\right]\right)\\
&\mathbf{s}^{t}=\text { LayerNorm }\left(\operatorname{ReLU}\left(\mathbf{s}^{t}\right)\right)
\end{aligned}
实验部分也很充足,有四五页,没仔细看,反正效果很好的样子~
2.3 Reference
BP-Transformer: Modelling Long-Range Context via Binary Partitioning[6]
跟上一篇一样,通过改变token之间attention的方式来解决vanilla transformer应用于长文本计算量大的问题。「Binary Partitioning Transformer(BPT)」 可以看成是一种图神经网络,将输入序列递归地二分得到多组partition作为图的节点,维护一组「hierarchical representations」,对于每个词来说,比较近的邻居查询细粒度的表示,比较远的邻居查询粗粒度的表示,一共查询
log(n) 个节点来近似全连接的self-attention,时空复杂度都可以被降到
O(d \cdot n \log n)。
3.1 建图
以下图为例,来看看图中节点和边的定义及建立方式。
「节点的建立」
对于长度为
N 的输入序列,经过binary partition之后可以得到
2N-1 个节点(即一棵完美二叉树)。将节点分为两类:
N 个;
N-1 个;
我们定义
u_{l,m} 为第
l 层第
m 个节点,注意由下往上方式定义层的大小,叶子节点所在的为第0层。则span node
u_{l,m} 包含的token为:
u_{0,2^{l} * m+1}, \cdots, u_{0,2^{l} *(m+1)}「边的建立」
定义了两种边:
u_{l,m} 所包含的token到该节点的有向连接,即上图中的虚线连接。作用是缩短span node和token node之间的距离,因此span node的表示可以通过聚合token node的表示来获得;
- 「Contextual Edges」:为了减少计算复杂度,同时保留捕获远程上下文的能力,使用
fine-to-coarse
策略对上下文进行建模,即上图中实现连接,是本文的亮点。
简单起见,只介绍对节点
u_{0,i} 右侧Contextual Edges的建立,左侧边的建立同理。对于叶子节点
u_{0,m},多粒度地建立入边:比较近的邻居查询细粒度的表示,比较远的邻居查询粗粒度的表示。所有入边可以表示为:
\begin{array}{c}
{u_{0, p_{0}}, \cdots, u_{0, p_{0}+k-1}} \\
{u_{1, p_{0}}, \cdots, u_{1, p_{1}+k-1}} \\
{\cdots} \\
{u_{l, p_{l}}, \cdots, u_{1, p_{l}+k-1}} \\
{\cdots}
\end{array}
其中
k 是预先设定的超参数用于定义图的链接密度,即每一层的入边为
k 条。当
k 的值大于输入句子长度时,就是原本的Transformer。
p_{l} 是第
l 层入边的起始节点,可以递归地计算,
p_{l}=\operatorname{parent}\left(p_{l-1}+k\right)
p_{0}=i+1
特别地, 当下标
p_{l}+k-1 为奇数时,
p_{l}=\operatorname{parent}\left(p_{l-1}+k+1\right)
3.2 图的更新
对于节点
u,
\mathcal{A}(u) 表示
u 的邻居节点,如果为token node,则邻居节点为其attend的多粒度span;如果为span node,则邻居节点为其包含的token node。按照这种策略,任意两个token至多通过两条边即可实现连接。「Graph Self-Attention (GSA)」 可以表示为,
\mathbf{A}^{u}=\operatorname{concat}\left(\left\{\mathbf{h}_{v} | v \in \mathcal{A}(u)\right\}\right)
\mathbf{Q}_{i}^{u}=\mathbf{H}_{k} \mathbf{W}_{i}^{Q}, \mathbf{K}_{i}^{u}=\mathbf{A}^{u} \mathbf{W}_{i}^{K}, \mathbf{V}_{i}^{u}=\mathbf{A}^{u} \mathbf{W}_{i}^{V}
\operatorname{head}_{i}^{u}=\operatorname{softmax}\left(\frac{\mathbf{Q}_{i}^{u} \mathbf{K}_{i}^{u T}}{\sqrt{d}}\right) \mathbf{V}_{i}^{u}
\operatorname{GSA}\left(\mathcal{G}, \mathbf{h}_{u}\right)=\left[\operatorname{head}_{1}^{u}, \cdots, \operatorname{head}_{h}^{u}\right] \mathbf{W}^{O}
所有span nodes都初始化为零向量,而token nodes则使用词向量初始化。
3.3 Relative Position
也正常操作加入了相对位置信息,
r_{v, u}=r^{\text {self }}:如果
u=v;
r_{v, u}=r_{j, i}^{\text {left }} or
r_{j, i}^{\text {right }}:如果
v 是
u 的第
j 层第
i 个左/右邻居;
r_{v, u}=r_{j}^{\text {anc }}:假如
v 是
u 第
j 层的祖先节点
加入位置信息后,上述公式需要进行修改,
\mathbf{R}^{u}=\operatorname{concat}\left(\left\{r_{v, u} | v \in \mathcal{A}(u)\right\}\right)
\text { head }_{i}^{u}=\operatorname{softmax}\left(\frac{\mathbf{Q}_{i}^{u}\left(\mathbf{K}_{i}^{u}+\mathbf{R}^{u}\right)^{T}}{\sqrt{d}}\right) \mathbf{V}_{i}^{u}
r^{\text {self }}、
r_{j, i}^{\text {left }}、
r_{j, i}^{\text {right }}、
r_{j}^{\text {anc }} 都是可训练的;
k 值(比如4)就可以达到较好的效果;
- 应用于下游任务,可以取最后一层的root node(文本分类、NLI等任务),也可以取最后一层所有的token nodes(语言建模、机器翻译等任务);
- 可同时用于单向和双向建模,在长文本的语言模型,分类和翻译上有着不错的性能;由于引入了更多的节点(*2)增大了FFN上的开销,对于短文本速度没有优势。
3.4 reference
敬请期待可能会有的PART 6
本文参考资料
[1]
Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks: http://proceedings.mlr.press/v97/lee19d/lee19d.pdf
[2]
Code Here: https://github.com/juho-lee/set_transformer
[3]
Open Review: https://openreview.net/forum?id=Hkgnii09Ym
[4]
Star-Transformer: https://arxiv.org/abs/1902.09113
[5]
这个: https://github.com/fastnlp/fastNLP/blob/master/fastNLP/modules/encoder/star_transformer.py
[6]
BP-Transformer: Modelling Long-Range Context via Binary Partitioning: https://arxiv.org/pdf/1911.04070.pdf
[7]
Code Here: https://github.com/yzh119/BPT
- END -