写在前面
上一期魔改Transformer好像是两周之前了哈哈,今天继续!最近刚关注的同学感兴趣可以翻翻历史记录补补课
另外,关于Transformer基础和细节可以查看
关于Transformer,面试官们都怎么问 enjoy!来看看今天的主角吧:
- 「Set Transformer from OX,ICML2019」
- 「Star-Transformer from FDU,NAACL 2019」
- 「BP-Transformer from FDU&AWS」
Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks[1]
NLP的输入基本都是有顺序性的,但是还有很多机器学习任务的输入都是顺序无关的「set-structure data」,比如多示例学习(multiple instance learning)、3D shape recognition、sequence ordering、meta-learning等等,处理这类问题的模型有两个特性:
- 「permutation invariant」:即输出结果与输入顺序无关;
- 能够处理任何大小的输入
这篇的重点就是将transformer应用到上述集合型输入问题上,并对transformer的复杂度进行优化。下面我们来看看具体的模型设计。
1.1 Multihead Attention Block(MAB)
其实就是没有「position embedding」和「dropout」的transformer encoder
1.2 Set Attention Block(SAB)
输入集合本身
的MAB
1.3 Induced Set Attention Block (ISAB)
SAB的问题就是transformer的传统问题,复杂度太高。引入「inducing points 矩阵
」,类似于矩阵分解,将原来的一步attention拆分成两步attention:首先用
对输入
做self-attention,接着用得到的结果对输入做attention。将复杂度从
优化为
1.4 整体流程
上述几个子模块和整体流程如下图,
为输入,
为经过ENCODER之后的特征表示
1.5 reference
- Code Here[2]
- Open Review[3]
Star-Transformer[4]
在Transformers Assemble第一趴中已经提到过,vanilla Transformer最大的问题就是模型复杂度非常高,达序列长度的平方。这篇论文提出的一种轻量级的结构「Star-Transformer(ST)」 可以有效缓解这个问题。
具体做法就是将全连接拓扑结构优化成了一种星型结构,两者对比如下图所示。将输入的token看成是图中的一个个节点,左图vanilla transformer中所有的节点两两全连接,在改进的star-transformer中通过一个共享中继节点实现两两节点的连接。这一结构优化可以将复杂度从平方降低到线性,同时保留了transformer捕获局部和全局信息的能力。
2.1 Model Architecture
既然ST模型是一个图结构,那么来看看其节点和边分别都是啥?
- 「Satellite Node」:N个卫星节点,代表输入序列的每个token;
- 「Relay Node」:一个中继节点,共享,连接所有卫星节点整合全局信息;
- 「Radical Connections」:辐射连接,卫星节点同中继节点之间的边,上图左图中所有的非相邻节点之间的链接都可以通过
卫星节点-->中继节点-->卫星节点
的形式完成,从而捕获非局部信息; - 「Ring Connections」:环形连接,将相邻的节点连接起来,捕获局部信息,作用类似于CNN或双向RNN;
通过以上的结构优化,可以实现复杂度的降低,理论上ST可以实现与vanilla transformer相同的任意
->
之间的连接:
->
->
,在实际任务中加速4.5倍。
2.2 Model Implementation
模型训练过程跟vanilla transformer中类似,流程如下:
首先对输入文本做Embedding得到
,以此初始化所有卫星节点的初始值
,中继节点的初始值通过平均得到
,接下去进行
次迭代更新,在每一轮执行两步:
「Update satellite node」
类似于RNN,
代表第
个节点的上下文信息,通过相邻节点
、
,中继节点
,自身的token embedding
来更新;
为第
个节点表示,通过其自身前一步状态和当前上下文信息更新
「Update relay node」
看公式就非常清晰明了
实验部分也很充足,有四五页,没仔细看,反正效果很好的样子~
2.3 Reference
BP-Transformer: Modelling Long-Range Context via Binary Partitioning[6]
跟上一篇一样,通过改变token之间attention的方式来解决vanilla transformer应用于长文本计算量大的问题。「Binary Partitioning Transformer(BPT)」 可以看成是一种图神经网络,将输入序列递归地二分得到多组partition作为图的节点,维护一组「hierarchical representations」,对于每个词来说,比较近的邻居查询细粒度的表示,比较远的邻居查询粗粒度的表示,一共查询
个节点来近似全连接的self-attention,时空复杂度都可以被降到
。
3.1 建图
以下图为例,来看看图中节点和边的定义及建立方式。
「节点的建立」
对于长度为
的输入序列,经过binary partition之后可以得到
个节点(即一棵完美二叉树)。将节点分为两类:
个;
个;
我们定义
为第
层第
个节点,注意由下往上方式定义层的大小,叶子节点所在的为第0层。则span node
包含的token为:
「边的建立」
定义了两种边:
所包含的token到该节点的有向连接,即上图中的虚线连接。作用是缩短span node和token node之间的距离,因此span node的表示可以通过聚合token node的表示来获得;
- 「Contextual Edges」:为了减少计算复杂度,同时保留捕获远程上下文的能力,使用
fine-to-coarse
策略对上下文进行建模,即上图中实现连接,是本文的亮点。
简单起见,只介绍对节点
右侧Contextual Edges的建立,左侧边的建立同理。对于叶子节点
,多粒度地建立入边:比较近的邻居查询细粒度的表示,比较远的邻居查询粗粒度的表示。所有入边可以表示为:
其中
是预先设定的超参数用于定义图的链接密度,即每一层的入边为
条。当
的值大于输入句子长度时,就是原本的Transformer。
是第
层入边的起始节点,可以递归地计算,
特别地, 当下标
为奇数时,
3.2 图的更新
对于节点
,
表示
的邻居节点,如果为token node,则邻居节点为其attend的多粒度span;如果为span node,则邻居节点为其包含的token node。按照这种策略,任意两个token至多通过两条边即可实现连接。「Graph Self-Attention (GSA)」 可以表示为,
所有span nodes都初始化为零向量,而token nodes则使用词向量初始化。
3.3 Relative Position
也正常操作加入了相对位置信息,
:如果
;
or
:如果
是
的第
层第
个左/右邻居;
:假如
是
第
层的祖先节点
加入位置信息后,上述公式需要进行修改,
、
、
、
都是可训练的;
值(比如4)就可以达到较好的效果;
- 应用于下游任务,可以取最后一层的root node(文本分类、NLI等任务),也可以取最后一层所有的token nodes(语言建模、机器翻译等任务);
- 可同时用于单向和双向建模,在长文本的语言模型,分类和翻译上有着不错的性能;由于引入了更多的节点(*2)增大了FFN上的开销,对于短文本速度没有优势。
3.4 reference
敬请期待可能会有的PART 6
本文参考资料
[1]
Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks: http://proceedings.mlr.press/v97/lee19d/lee19d.pdf
[2]
Code Here: https://github.com/juho-lee/set_transformer
[3]
Open Review: https://openreview.net/forum?id=Hkgnii09Ym
[4]
Star-Transformer: https://arxiv.org/abs/1902.09113
[5]
这个: https://github.com/fastnlp/fastNLP/blob/master/fastNLP/modules/encoder/star_transformer.py
[6]
BP-Transformer: Modelling Long-Range Context via Binary Partitioning: https://arxiv.org/pdf/1911.04070.pdf
[7]
Code Here: https://github.com/yzh119/BPT
- END -