本文针对图神经网络中存在的假死现象以及过平滑的问题,提出了GRAPH-BERT, 这种方法不需要依赖卷积、聚合的操作就可以实现图表示学习。主要的思路是将原始图分解成以每一个节点为中心的多个子图,只利用attention机制在子图上进行表征学习,然后利用attention去学习结点表征,而不考虑子图中的边信息;另一方面也解决了大规模图的效率问题。这里提出三种计算Distance的方法,结合之前普渡大学Prof. Lipan的工作,可以看出来distance在解决GNN问题的重要作用。
注:假死现象是说随着模型层数加深,对输入没有了什么反应;过平滑是指节点的表示非常相似,没有了区分度。
题目: GRAPH-BERT : Only Attention is Needed for Learning Graph Representations
论文地址:https://arxiv.org/abs/2001.05140
代码:
https://github.com/jwzhanggy/Graph-Bert
本文提出目前图神经网络(GNN)的一些问题,比如suspended animation(这里翻译成假死)和过度平滑问题。假死现象是说随着模型层数加深,对输入没有了什么反应;过平滑是指节点的表示非常相似,没有了区分度。引起这两个问题的主要原因是邻居一次又一次地聚合,每一个节点的特征趋于相似,最后迷失了自我,无法区分。严谨的说,community或者class内部的聚合是有利于节点表示,但是community或者class之间的消息传递聚合对该类型的数据集是非常容易造成过拟合。因为不同的community一般属于不同的类,特征混合在一起了,还怎么分类呢?那么如何解决这个问题呢?
在本文中,作者将引入一种新的图神经网络,即GRAPH-BERT(基于图的BERT),它完全基于注意力机制而没有任何图卷积或聚合算子。另外,该方法不用完整的大的输入图来训练Graph-BERT,而是在局部抽取的无连边的子图来训练Graph-BERT。那么存储起来以及进行并行运算就非常容易了。小编最后有个问题,该怎么解决呢?
跟NLP中的Bert类似,也是『pretrain+fintune』范式,不过是应用于图网络中,所以预训练的Graph-BERT可以直接转移到其他应用任务中,或者进行少量的微调就OK。
1. 本文的主要贡献
2. 模型图
第一个子图(左)是说随机抽取子图,第二个子图是关键-下方四个横条条对应的是Graph-Bert四个输入特征,后面的跟Bert基本一致了,咱们看一下中间子图Bert的输入是什么,从下往上依次是:
3.1 子图分解
原始图分解为多个子图(方便并行化),进行训练。每个子图包含一个中心节点和固定数量的邻居节点,邻居节点代表了中心节点的上下文信息,因此如何为中心节点选取邻居节点是这一步的关键。方法有很多种比如,最短距离, Jaccard’s coefficienty Adamic/Adar, Katz,pagerankd等等,这些都是跟图的结构相关的相似度度量指标,跟节点的特征没有一点关系。本文采用的是「top-k intimacy」方法,类似KNN算法,计算每个节点与其他所有节点的关联(intimacy)程度,然后选取关系值最大的前K个点作为邻接节点。计算方法采用pagerank算法:
其中alpha是一个[0,1]的超参数。值得一提的是,这里得到的子图并不局限与近距离的节点,也包含了远距离但是intimacy值大的点。换句话说,这里并不是直接采样一个局部的子图,而是通过计算当前节点与图上所有节点的intimacy值,设定一个阈值得到的子图。
3.2.1原始图的节点的特征raw embedding
Embed函数可以是CNN, LSTM,FC, MLP等,这一步主要将原始特征嵌入到一个新的特征空间。
3.2.2 WL绝对位置嵌入(全局)
注:利用global node role的结构信息
Weisfeiler-Lehman算法是用来判断两个图是否同构,基本思路是不断迭代聚合邻居节点来判断当前中心节点的独立性(Identity),这个操作是全图的操作,只要找好hash函数就好,这里也只是利用图的结构信息,与节点特征无关。
3.2.3 基于亲密度的相对位置嵌入(局部)
注:intimacy based,利用采样的子图的结构信息
相对位置嵌入是用来捕捉局部信息,
P是一个位置函数,对于中心节点vi,P(vi)=1,对于其他的周围节点vj,和vi关联值越大,P(vj)越小。
3.2.4 基于原始图上点对最短距离嵌入(全局)
这一步的计算也是全图的,只是把抽样出来的子图中点对之间的最短距离拿过来。这一步跟上一步不太一样。
上面的四种信息可以采用多种花式操作进行组合拼接等,本文采用简单的加操作,然后输入到Transformer中。Transformer模型是一种基于全attention的sequence to sequence模型,在NLP任务上取得了比LSTM更好的成绩。
Transformer模型
QKV都是输入H_(l-1),这里额外加了一个mask矩阵M,可以学习到子图的上下文信息,类别Bert中的mask操作。整个Transoformer包含2K层,前K层是编码器,后K层是解码器,通过自回归来监督训练,最终编码器的输出即为学习到的节点特征表示zi。本文还在编码器输入和输入之间添加了res连接。
4 任务与实验
在将Graph-Bert应用在新任务中时,既可以直接使用模型学习到的图特征表示,也可以根据实际情况做一些必要的调整。这里主要考虑节点分类和图聚类任务,其中图聚类任务可以直接使用学习到的特征表示,而节点分类还需要做一些调整。
4.1 模型的预训练
4.1.1 通过对节点原始特征重建进行预训练
如何从学习到的节点特征表示中还原原始输入特征是采用预训练模型时需要关注的点:
在监督时自然采用和编码-解码器相同的监督方式:
考虑到节点的特征可能非常稀疏,这里额外加入了一个Mask矩阵:
4.1.2 通过图的结构重建进行预训练
通过图结构重建来保证模型能学习到图的结构信息,该任务依然是作为预训练任务。图的结构可以用一个标签向量y来表示,yij表示节点i和j之间的连接关系,包含两个值(i->j, j->i),因此我们需要预测的是:
损失函数采用多类交叉熵损失函数即可:
考虑到图比较稀疏时,大部分节点之间都没有连接(例如对于一个总节点数为100,平均1阶邻居节点数为5的图,其邻接矩阵就有95%的地方为0),需要对负例(即无连接的节点对)进行采样来保证正负例平衡。
通过将上述两种不同的学习任务与变量正则化项结合,可以定义正式训练前的目标函数。通过优化该目标函数将有助于对Graph-Bert进行有效的预训练。
4.2预训练之后应用
4.2.1节点分类:
通过额外添加FC层和softmax来进行节点分类,通过交叉熵进行优化,这里需只需要对FC层进行微调。
4.2.2图上的聚类
对于图聚类任务,主要目标是将图中的节点划分为几个不同的cluster,直接将学习到的节点特征上进行聚类即可,不需要其他的分类器等等
不过,该函数中仍然包含了一些变量,可以用EM算法来高效地学习,而不是反向传播。
5实验分析与结果