首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >GMNN: Graph Markov Neural Networks

GMNN: Graph Markov Neural Networks

作者头像
马上科普尚尚
发布2020-05-13 17:36:53
1.2K0
发布2020-05-13 17:36:53
举报

论文提出的背景与相关工作

我们生活在一个互联的世界中,实体通过各种关系相互联系。例如,网页通过超链接相互关联,社交媒体用户通过社交关系相互关联。对这种关系数据建模是机器学习中的一个重要的研究课题。这一课题涵盖了各种应用,如实体分类,链接预测和链接分类。

上述应用可以归结为对相互关联并包含属性的对象进行半监督的对象分类问题,即给定部分对象的标签推断其它相关联的对象的标签。在统计关系学习领域,这个问题已经被广泛地研究,代表的方法有马尔可夫网络和马尔可夫逻辑网络。这些网络通常使用条件随机场建模以模拟对象标签的依赖性。但是这些方法仍然存在一些限制:

(1)这些方法通常将条件随机场中的潜在函数定义为一些手工提取的特征函数的线性组合,这种模型的容量通常不足。

(2)由于对象之间关系结构的复杂性,推断未标记对象的对象标签的后验分布仍然是一个具有挑战性的问题。

图神经网络通过使用非线性神经架构学习有效的对象表示来处理对象分类问题,并且整个框架可以以端到端的方式进行训练,例如GCN。这些方法已被证明可以实现最先进的性能,因为它们可以有效地学习关系数据的对象表示。然而,这类作法是通过对象之间的关系学出向量,通过向量独立预测每个对象的标签,而忽略了对象标签之间的关联,这是这类做法的一个关键限制。

问题定义

这篇论文关注半监督对象分类。给定图 ,其中V是一组对象,E是对象之间的一组边,xV 代表所有对象的属性,E中的边可以具有多种类型,表示对象之间的不同关系。在本文中,为简单起见,我们假设所有边都属于同一类型。给定部分对象 L⊂V的标签yL,目标是预测剩余未标记对象 U=V/L的标签 yU。在统计关系学习(SRL)和图神经网络(GNN)的文献中已经广泛研究了这个问题。本质上,两种类型的方法都旨在模拟以对象属性和图形结构为条件的对象标签的分布,即

。接下来,我们介绍这两种方法的基本概念。为了简化表示,我们在以下公式中省略了E。

2.1统计关系学习

大多数SRL方法使用条件随机场模型 ,采用以下公式:

(i , j)是图G中的边,

是在边上定义的潜在得分,表示两个对象的关联程度。Z是让输出变为概率的调整函数。通常,潜在得分由一些设计的特征函数(如逻辑公式)的线性组合所决定。从公式(1) 可推断无标签对象的标签分布

。由于对象标签之间的复杂结构,精确推断通常是做不到的。因此,经常使用一些近似推断方法,例如置信传播。

2.2 GNN

与SRL方法不同,GNN方法简单地忽略了对象标签的相关性和依赖性,具体而言,标签的联合分布可以表达为:

基于该公式,GNN独立地推断每个对象n的标签分布

。对于每个对象n,GNN以下列方式预测其标签:

其中每个对象属性的维度为d,标签类的数量为K,

表示所有对象的属性,而

表示对象n的属性,

是线性变换矩阵,Cat代表类别分布。GNN专注于为每个对象n学习有用的表示

。具体而言,用每个对象n的特征初始化为

。然后根据结点(即对象)n当前的特征和n的邻居的特征(即

)迭代更新每个

。对于更新函数,可以使用图卷积层(GC)和图注意力层(GAT)或者使用内部消息传递层。在多层更新之后,最终的对象表示会被送到用于标签预测的线性softmax分类器。整个框架可以以端到端的方式进行训练。

GMNN的原理

本节介绍用于半监督对象分类的称为图马尔可夫神经网络(GMNN)的方法。GMNN的目标是结合统计关系学习方法和图神经网络的优点,这样就可以在学习有用的客观表示预测数据标签的同时模拟标签之间的依赖关系。

首先,可以注意到GNN模型同样也可以对网络节点(对象)之间标签的相关性进行建模。GNN网络的本质就是节点之间的信息(特征)传播,我们只需要将节点标签也当作特征输入到GNN网络(命名为GNNφ)中,就可以对节点之间标签的相关性进行建模,从而实现统计关系学习的功能。然而,这种实现方式存在的问题在于:在半监督学习的场景下,有标签的节点往往比较稀少,因而在一个有标签节点的周边邻居可能都没有标签。因此使用上述基于GNN的方法对很难达到对节点之间的标签进行建模的预期效果。

为了解决标签匮乏的问题,作者采用另一个 GNN网络模型(命名为GNNθ )来预测节点标签,目的是将GNNφ预测出的节点标签作为GNNθ 模型的输入特征。值得注意的是GNNφ只通过节点的属性特征作为输入,无需关注节点之间标签的相关性,类似于2.2节提到的传统的GNN模型。上述GNNφ模型被GNNθ 模型所约束,因为GNNφ模型依赖于GNNθ 提供节点标签信息作为输入。然而GNNθ 不受任何约束,这容易导致模型训练的结果不收敛。为了解决这一问题,可以让GNNθ 学出的概率分布尽可能与GNNφ所学出的概率分布尽可能的一致,这样,两个模型相互约束,就可以采用EM算法将两个模型交替更新。这即解决了传统GNN网络无法对标签相关性进行建模的问题,又避免了使用以标签作为特征的GNNφ网络由于有标签节点不够而导致的模型失效问题。GNNθ 计算过程可表示为 ,

为所有对象的属性,

为所有无标签对象的标签。GNNφ以节点属性

和有标签节点的标签(来自数据集提供的标签或

预测的标签)

同时作为特征输入,这样就同时对节点属性和及诶单标签的依赖性进行了建模,计算过程表示为

。后面会具体介绍两个GNN网络互相更新的具体作法和原理。

具体而言,GMNN通过使用条件随机场对以节点属性为条件的节点标签的联合分布进行建模,即

,并使用伪似然变分EM框架进行优化。在E步骤中,使用通过预测节点标签(预测的结果需要尽可能与GNNφ一致)来训练图神经网络GNNθ 。在M步骤中,采用图神经网络GNNφ来刻画节点标签的局部依赖性。接下来,我们将介绍GMNN方法的详细步骤。

3.1总体框架

图1.框架概述。方块是带标签的和无标签的对象,白色网格是属性,柱状图是对象的标签分布。橙色三重圆是对象表示。GMNN通过在E步骤和M步骤之间交替进行训练。

图1给出了框架的图示。对于给定节点, qθ将邻居节点和自身的属性特征作为输入来预测标签,从而能够学出节点的向量表示。与之相对的是,pφ不仅利用邻居节点的属性还利用邻居节点的标签作为输入特征,如果邻居节点没有标签,则采用 qθ预测的标签作为输入特征。在步骤M中, qθ被用来预测无标签节点的标签,并从中采样一部分标签,结合数据集中已有的标签,来训练和更新pφ;在步骤E中,首先,pφ使用节点属性和标签来预测无标签节点的标签,然后将pφ所预测的标签的概率分布作为 qθ的学习目标,从而训练和更新 qθ。

接下来将会具体介绍这种做法的数学原理,pφ和 qθ的训练方法,以及pφ和 qθ互相更新的方法。

3.2伪似然变分EM

模型的思想是通过对象属性与对象标签共同作为对象特征,预测对象标签,即

。为了使用GNN模型模拟对象依赖性的同时避免由于标签过少引起的模型失效,文章使用另一个GNN模型分布

,与

通过变分EM算法互相逼近,以优化证据下界的方式最小化两者的KL散度(即使得分布趋于相同):

由上式很容易得出

,并且当

=

时,KL散度为0,模型达到最优。

根据变分EM算法,可以通过在变分的E步骤和M步骤之间交替来优化这样的下限。

在E步骤中,pφ将对象特征与对象标签作为对象特征,预测无标签对象的标签,分布为

,目标是将这些标签作为变分分布

的目标,用于更新变分分布

以逼近真实的后验分布

在M步骤中,使用有标签对象与pφ的训练结果为目标训练 qθ,使用 qθ训练结果更新pφ来最大化下面的似然函数:

然而,直接优化似然函数可能很困难,因为我们必须处理pφ中的分区函数。为了避免计算分区函数,改为优化下面的伪似然函数:

NB(n)是n的邻居集合,伪似然方法被广泛用于学习马尔可夫网络。接下来,我们将介绍推断和学习步骤的细节。

3.3 推断

推断步骤旨在计算后验分布

。由于对象标签之间的关系结构复杂,精确推断在计算上是不可行的。因此,我们用另一个变分分布

来近似它。具体而言,我们使用均值场方法,其中 qθ为:

其中,n是未标记对象的索引。在变分分布中,假定所有对象标签都是独立的。

为了模拟 qθ中每个对象标签的分布,遵循摊销推理的思想,并用图神经网络(GNN)参数化 ,学习标签预测的有效对象表示:

是标签的分布,基于对象的属性

,通过softmax函数计算每种标签的可能性。

是以xv为特征,θ为参数通过GNN计算出来的,我们将GNN模型表示为

。使用

可以从属性和局部连接中学习对象的有用表示,从而改进推理。此外,通过在不同对象之间共享

,我们可以显著减少推理所需的参数数量。

利用上述平均场公式,最优分布

满足以下不动点条件:

该条件的右侧与 qθ相关。为了进一步简化该条件,我们通过从分布

中采样来估计期望:

在上面的公式中,

的定义如下:对于对象n的每个未标记的邻居k,该论文对

进行采样,并且对于对象n的每个标记的邻居k,将

设置为正确标签。在实践中,该论文发现使用从

分布一次采样的结果与多次采样的结果相当。因此,在论文的实验中,从效率的角度出发,只进行了一次采样。基于公式8和9,可推断出最优

满足:

为了

学习最优 ,我们首先对 qθ进行多轮训练,预测无标签对象属性,更新pφ对象特征并训练pφ,计算

。然后将

的值固定为 qθ的目标值,并更新θ以最小化

与目标

之间的反向KL发散度,产生下面的目标函数:

此外,作者还注意到还可以通过预测有标签对象的标签来训练 qθ。因此,还可以让 qθ最大化以下监督目标函数:

在这里,yn是n的标签。通过将公式11和12相加,可得到了优化θ的总体目标:

3.4学习

在M步骤中,作者希望学习参数

,因此,在训练

后将固定 qθ并进一步更新pφ以最大化公式5。利用目标函数,我们注意到在推断和学习步骤(方程11和5)中只有条件分布

需要用到pφ。因此,不是通过指定势函数来定义对象标签

的联合分布,而是可以简单地通过条件分布的建模来实现。因为GNN的有效性,可用另一个非线性图神经网络模型(GNN)参数化条件分布 :

这里,

的分布由softmax分类器刻画。该分类器将GNN模型(记作

)学出的对象表示

作为输入特征。当学习对象表示

时,

将对象n周围邻居的所有标签

视为特征。因此,

实质上将对象标签的局部依赖性进行了建模。基于上述公式,就可以避免统计关系学习中需要手工设计特征函数的问题。

标签传播方法与该框架有一定的相关性。标签传播方法也是通过组合周围的标签来更新每个对象的标签。然而,标签传播方法以固定和线性方式传播标签,而

是通过可学习和非线性的方式传播的。 值得注意的是,当定义 时,

仅使用对象n的邻居节点的标签

作为特征,但

本身其实足够灵活,可以结合其他特征。例如,

可以仿照现有的SRL方法,将邻居对象的标签

和属性

都作为

的特征。 此外,基于pφ的总体公式,即公式(1),每个对象的标签

应该只依赖于其相邻对象的标签

和属性

,这意味着

不应该有超过一个消息传递层。然而,从GNN相关的文献中可看出,一种常见做法是在训练期间使用多个消息传递层,这可以很好地刻画不同对象之间的长程依赖性。因此,在实验部分,作者还进一步探索使用多个消息传递层来捕获这种长程依赖性。

当通过优化pφ来最大化公式5时,通过从分布 中抽取一个样本来估计pφ的期望值。具体而言,如果n是未标记的对象,那么通过采样

,否则我们将

设置为真实标签。之后,可以通过最大化以下目标函数来优化参数ϕ:

3.5优化

为了优化方法,我们首先使用有标签对象预先推导推理模型qθ。然后我们选择优化pφ和qθ直到收敛。之后,pφ和qθ都可用于推断未标记物体的标签。在实践中,我们发现qθ始终优于pφ,因此我们默认使用qθ来推断对象标签。

3.6具体算法

上述方法的具体算法如下:

实验

在本节中,我们评估GMNN在三个任务上的表现,包括对象分类,无监督节点表示学习和链接分类。

4.1数据集和实验设置

对于对象分类,用Cora,Citeseer,Pubmed三个数据集进行评估,在每个数据集中,每个类中的20个对象被视为标记对象,用准确度用作评估指标。

对于无监督节点表示学习,还使用上述数据集,对象被视为结点,我们不使用任何有标签的结点学习结点的表示。将表示视为特征以在标记的节点上训练线性分类器以评价学习出的表示。然后我们对测试节点进行分类并报告准确性。

对于链接分类,我们分别从比特币Alpha和比特币OTC数据集构建两个数据集。数据集包含比特币用户之间的图表,链接的权重表示连接用户的信任度。我们将权重大于3的链接视为正实例,权重小于-3的链接视为负实例。给定一些带标签的链接,尝试对测试链接进行分类。

4.2参数设计

对象分类:对于GMNN,pφ和qθ由具有16个隐藏单元和ReLU激活函数的卷积层的两个卷积层组成,然后用softmax函数,Dropout率p = 0.5,使用RMSProp优化器,初始学习率为0.05,权重衰减为0.0005。在每次迭代中,两个网络都训练了100个epochs。在实验中报告了超过100次运行的平均准确度。

无监督节点表示学习:对于GMNN方法,pφ和qθ由两个图卷积层组成,后面是线性层和softmax函数。对于Cora和Citeseer,隐藏层的维度设置为512,对于Pubmed,维度设置为256。ReLU为激活函数。我们两个网络Dropout率p = 0.5。Adam SGD优化器用于训练,初始学习率为0.1,权重衰减为0.0005。我们训练qθ 200个epoches。之后,我们训练pφ和qθ进行两次迭代,每次迭代每个网络有100个epochs。报告50次运行的平均准确度。

链接分类:此任务中GMNN的设置与对象分类中的类似,具有以下差异。隐藏层的尺寸设置为128.不使用权重衰减和dropout。在每次迭代中,两个网络都使用Adam优化器训练了5个epochs,学习率为0.01。

4.3 实验结果

1.与基线方法的比较。三张表列出了这三项任务的定量结果。分别为图2,3,4。对于对象分类,GMNN明显优于所有SRL方法。首先,在推理期间,GMNN采用GNN模型,可以更加有效的学习对象表示。其次,在学习过程中,我们使用另一个GNN模拟局部标签依赖关系,比SRL更有效。GMNN也优于标签传播方法,因为GMNN能够以非线性方式使用对象属性和标签进行传播。与GCN相比,GMNN显着优于GCN,并且性能增益主要来自对标签依赖性建模的能力。此外,通过将对象属性合并到学习网络pφ中,我们进一步提高了性能,表明GMNN可以灵活且有效地将附加特征结合到学习网络中。对于链接分类,获得了类似的结果。

参考链接

https://arxiv.org/pdf/1905.06214.pdf

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-10-02,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 人工智能前沿讲习 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档