本项目链接:https://aistudio.baidu.com/aistudio/projectdetail/5061984?contributionType=1
相关项目参考:更多资料见主页
关于图计算&图学习的基础知识概览:前置知识点学习(PGL)[系列一] https://aistudio.baidu.com/aistudio/projectdetail/4982973?contributionType=1
图机器学习(GML)&图神经网络(GNN)原理和代码实现(前置学习系列二):https://aistudio.baidu.com/aistudio/projectdetail/4990947?contributionType=1
在图神经网络中,使用的数据集可能是亿量级的数据,而由于GPU/CPU资源有限无法一次性全图送入计算资源,需要借鉴深度学习中的mini-batch思想。
传统的深度学习mini-batch训练每个batch的样本之间无依赖,多层样本计算量固定;而在图神经网络中,每个batch中的节点之间互相依赖,在计算多层时会导致计算量爆炸,因此引入了图采样的概念。
GraphSAGE也是图嵌入算法中的一种。在论文Inductive Representation Learning on Large Graphs 在大图上的归纳表示学习中提出。github链接和官方介绍链接。
与node2vec相比较而言,node2vec是在图的节点级别上进行嵌入,GraphSAGE则是在整个图的级别上进行嵌入。之前的网络表示学习的transductive,难以从而提出了一个inductive的GraphSAGE算法。GraphSAGE同时利用节点特征信息和结构信息得到Graph Embedding的映射,相比之前的方法,之前都是保存了映射后的结果,而GraphSAGE保存了生成embedding的映射,可扩展性更强,对于节点分类和链接预测问题的表现也比较突出。
现存的方法需要图中所有的顶点在训练embedding的时候都出现;这些前人的方法本质上是transductive,不能自然地泛化到未见过的顶点。文中提出了GraphSAGE,是一个inductive的框架,可以利用顶点特征信息(比如文本属性)来高效地为没有见过的顶点生成embedding。GraphSAGE是为了学习一种节点表示方法,即如何通过从一个顶点的局部邻居采样并聚合顶点特征,而不是为每个顶点训练单独的embedding。 这个算法在三个inductive顶点分类benchmark上超越了那些很强的baseline。文中基于citation和Reddit帖子数据的信息图中对未见过的顶点分类,实验表明使用一个PPI(protein-protein interactions)多图数据集,算法可以泛化到完全未见过的图上。
在大型图中,节点的低维向量embedding被证明了作为各种各样的预测和图分析任务的特征输入是非常有用的。顶点embedding最基本的基本思想是使用降维技术从高维信息中提炼一个顶点的邻居信息,存到低维向量中。这些顶点嵌入之后会作为后续的机器学习系统的输入,解决像顶点分类、聚类、链接预测这样的问题。
图采样算法:顾名思义,图采样算法就是在一张图中进行采样得到一个子图,这里的采样并不是随机采样,而是采取一些策略。典型的图采样算法包括GraphSAGE、PinSAGE等。
文章码源链接:
https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf
https://github.com/williamleif/GraphSAGE
前面 GCN 讲解的文章中,我使用的图节点个数非常少,然而在实际问题中,一张图可能节点非常多,因此就没有办法一次性把整张图送入计算资源,所以我们应该使用一种有效的采样算法,从全图中采样出一个子图 ,这样就可以进行训练了。
GraphSAGE与GCN对比:
既然新增的节点,一定会改变原有节点的表示,那么为什么一定要得到每个节点的一个固定的表示呢?何不直接学习一种节点的表示方法。去学习一个节点的信息是怎么通过其邻居节点的特征聚合而来的。 学习到了这样的“聚合函数”,而我们本身就已知各个节点的特征和邻居关系,我们就可以很方便地得到一个新节点的表示了。
GCN等transductive的方法,学到的是每个节点的一个唯一确定的embedding; 而GraphSAGE方法学到的node embedding,是根据node的邻居关系的变化而变化的,也就是说,即使是旧的node,如果建立了一些新的link,那么其对应的embedding也会变化,而且也很方便地学到。
在了解图采样算法前,我们至少应该保证采样后的子图是连通的。例如上图图中,左边采样的子图就是连通的,右边的子图不是连通的。
GraphSAGE的核心:GraphSAGE不是试图学习一个图上所有node的embedding,而是学习一个为每个node产生embedding的映射。 GraphSage框架中包含两个很重要的操作:Sample采样和Aggregate聚合。这也是其名字GraphSage(Graph SAmple and aggreGatE)的由来。GraphSAGE 主要分两步:采样、聚合。GraphSAGE的采样方式是邻居采样,邻居采样的意思是在某个节点的邻居节点中选择几个节点作为原节点的一阶邻居,之后对在新采样的节点的邻居中继续选择节点作为原节点的二阶节点,以此类推。
文中不是对每个顶点都训练一个单独的embeddding向量,而是训练了一组aggregator functions,这些函数学习如何从一个顶点的局部邻居聚合特征信息(见图1)。每个聚合函数从一个顶点的不同的hops或者说不同的搜索深度聚合信息。测试或是推断的时候,使用训练好的系统,通过学习到的聚合函数来对完全未见过的顶点生成embedding。
GraphSAGE 是Graph SAmple and aggreGatE的缩写,其运行流程如上图所示,可以分为三个步骤:
邻居采样的优点:
采样的阶段首先选取一个点,然后随机选取这个点的一阶邻居,再以这些邻居为起点随机选择它们的一阶邻居。例如下图中,我们要预测 0 号节点,因此首先随机选择 0 号节点的一阶邻居 2、4、5,然后随机选择 2 号节点的一阶邻居 8、9;4 号节点的一阶邻居 11、12;5 号节点的一阶邻居 13、15
聚合具体来说就是直接将子图从全图中抽离出来,从最边缘的节点开始,一层一层向里更新节点
上图展示了邻居采样的优点,极大减少训练计算量这个是毋庸置疑的,泛化能力增强这个可能不太好理解,因为原本要更新一个节点需要它周围的所有邻居,而通过邻居采样之后,每个节点就不是由所有的邻居来更新它,而是部分邻居节点,所以具有比较强的泛化能力。
在图中顶点的邻居是无序的,所以希望构造出的聚合函数是对称的(即也就是对它输入的各种排列,函数的输出结果不变),同时具有较高的表达能力。 聚合函数的对称性(symmetry property)确保了神经网络模型可以被训练且可以应用于任意顺序的顶点邻居特征集合上。
**a. Mean aggregator **: mean aggregator将目标顶点和邻居顶点的第
层向量拼接起来,然后对向量的每个维度进行求均值的操作,将得到的结果做一次非线性变换产生目标顶点的第
层表示向量。
卷积聚合器Convolutional aggregator: 文中用下面的式子替换算法1中的4行和5行得到GCN的inductive变形:
原始算法1中的第4,5行是
论文提出的均值聚合器Mean aggregator:
聚合的邻居向量
拼接起来
b. LSTM aggregator
文中也测试了一个基于LSTM的复杂的聚合器[Long short-term memory]。和均值聚合器相比,LSTMs有更强的表达能力。但是,LSTMs不是对称的(symmetric),也就是说不具有排列不变性(permutation invariant),因为它们以一个序列的方式处理输入。因此,需要先对邻居节点随机顺序,然后将邻居序列的embedding作为LSTM的输入。
c. Pooling aggregator
pooling聚合器,它既是对称的,又是可训练的。Pooling aggregator 先对目标顶点的邻居顶点的embedding向量进行一次非线性变换,之后进行一次pooling操作(max pooling or mean pooling),将得到结果与目标顶点的表示向量拼接,最后再经过一次非线性变换得到目标顶点的第k层表示向量。 一个element-wise max pooling操作应用在邻居集合上来聚合信息:
其中
表示element-wise最大值操作,取每个特征的最大值
是非线性激活函数
在定义好聚合函数之后,接下来就是对函数中的参数进行学习。文章分别介绍了无监督学习和监督学习两种方式。
基于图的无监督损失
基于图的损失函数倾向于使得相邻的顶点有相似的表示,但这会使相互远离的顶点的表示差异变大:
其中
为节点
通过GraphSAGE生成的embedding
是节点
随机游走到达的“邻居”
是sigmoid函数
是负采样的概率分布,类似word2vec中的负采样
是负样本的数目
文中输入到损失函数的表示
是从包含一个顶点局部邻居的特征生成出来的,而不像之前的那些方法(如DeepWalk),对每个顶点训练一个独一无二的embedding,然后简单进行一个embedding查找操作得到。
基于图的有监督损失
无监督损失函数的设定来学习节点embedding 可以供下游多个任务使用。监督学习形式根据任务的不同直接设置目标函数即可,如最常用的节点分类任务使用交叉熵损失函数。
参数学习
通过前向传播得到节点
的embedding
,然后梯度下降(实现使用Adam优化器) 进行反向传播优化参数
和聚合函数内的参数。
新节点embedding的生成
这个
就是所谓的dynamic embedding的核心,因为保存下来了从节点原始的高维特征生成低维embedding的方式。现在,如果想得到一个点的embedding,只需要输入节点的特征向量,经过卷积(利用已经训练好的
以及特定聚合函数聚合neighbor的属性信息),就产生了节点的embedding。
有了GCN为啥还要GraphSAGE?
GCN灵活性差、为新节点产生embedding要求 额外的操作 ,比如“对齐”:
GCN是 直推式(transductive) 学习,无法直接泛化到新加入(未见过)的节点;
GraphSAGE是 归纳式(inductive) 学习,可以为新节点输出节点特征。
GCN输出固定:
GCN输出的是节点 唯一确定 的embedding;
GraphSAGE学习的是节点和邻接节点之间的关系,学习到的是一种 映射关系 ,节点的embedding可以随着其邻接节点的变化而变化。
GCN很难应用在超大图上:
无论是拉普拉斯计算还是图卷积过程,因为GCN其需要对 整张图 进行计算,所以计算量会随着节点数的增加而递增。
GraphSAGE通过采样,能够形成 minibatch 来进行批训练,能用在超大图上
GraphSAGE有什么优点?
采用 归纳学习 的方式,学习邻居节点特征关系,得到泛化性更强的embedding;
采样技术,降低空间复杂度,便于构建minibatch用于 批训练 ,还让模型具有更好的泛化性;
多样的聚合函数 ,对于不同的数据集/场景可以选用不同的聚合方式,使得模型更加灵活。
采样数大于邻接节点数怎么办?
设采样数量为K:
若节点邻居数少于K,则采用 有放回 的抽样方法,直到采样出K个节点。
若节点邻居数大于K,则采用 无放回 的抽样。
训练好的GraphSAGE如何得到节点Embedding?
假设GraphSAGE已经训练好,我们可以通过以下步骤来获得节点embedding,具体算法请看下图的算法1。
训练过程则只需要将其产生的embedding扔进损失函数计算并反向梯度传播即可。
对图中每个节点的邻接节点进行 采样 ,输入节点及其n阶邻接节点的特征向量
根据K层的 聚合函数 聚合邻接节点的信息
就产生了各节点的embedding
minibatch的子图是怎么得到的?
那和DeepWalk、Node2vec这些有什么不一样?
DeepWalk、Node2Vec这些embedding算法直接训练每个节点的embedding,本质上依然是直推式学习,而且需要大量的额外训练才能使他们能预测新的节点。同时,对于embedding的正交变换(orthogonal transformations),这些方法的目标函数是不变的,这意味着生成的向量空间在不同的图之间不是天然泛化的,在再次训练(re-training)时会产生漂移(drift)。
与DeepWalk不同的是,GraphSAGE是通过聚合节点的邻接节点特征产生embedding的,而不是简单的进行一个embedding lookup操作得到。
论文仿真结果:
实验对比了四个基线:随机分类,基于特征的逻辑回归(忽略图结构),DeepWalk算法,DeepWork+特征;同时还对比了四种GraphSAGE,其中三种在3.3节中已经说明,GraphSAGE-GCN是GCNs的归纳版本。具体超参数为:K=2,s1=25,s2=10。程序使用TensorFlow编写,Adam优化器。
对于跨图泛化的任务,需要学习节点角色而不是训练图的结构。使用跨各种生物蛋白质-蛋白质相互作用(PPI)图,对蛋白质功能进行分类。在20个图表上训练算法,2个图用于测试,2个图用于验证,平均每图包含2373个节点,平均度为28.8。从实验结果可以看出LSTM和池化方法比Mean和GCN效果更好。
对比不同聚合函数:
如表-1所示,LSTM和POOL方法效果最好,与其它方法相比有显著差异,LSTM和POOL之间无显著差异,但LSTM比POOL慢得多(≈2x),使POOL聚合器在总体上略有优势。
采样
为什么要采样?
采样数大于邻接节点数怎么办?
采样的邻居节点数应该选取多大?
每一跳采样需要一样吗?
适合有向边吗?
采样是随机的吗?
聚合函数
聚合函数的选取有什么要求?
GraphSAGE论文中提供多少种聚合函数?
均值聚合的操作是怎样的?
pooling聚合的操作是怎样的?
使用LSTM聚合时需要注意什么?
均值聚合和其他聚合函数有啥区别?
max-和mean-pooling有什么区别?
这三种聚合方法,哪种比较好?
一般聚合多少层?层数越多越好吗?
什么时候和GCN的聚合形式“等价”?
无监督学习
GraphSAGE怎样进行无监督学习?
GraphSAGE如何定义邻近和远处的节点?
如何计算无监督GraphSAGE的损失函数?
GraphSAGE是怎么随机游走的?
GraphSAGE什么时候考虑边的权重了?
训练
如果只有图、没有节点特征,能否使用GraphSAGE?
训练好的GraphSAGE如何得到节点Embedding?
minibatch的子图是怎么得到的?
增加了新的节点来训练,需要为所有“旧”节点重新输出embeding吗?
GraphSAGE有监督学习有什么不一样的地方吗?
参考链接:https://zhuanlan.zhihu.com/p/184991506
https://blog.csdn.net/yyl424525/article/details/100532849
采样时只能选取真实的邻居节点吗?如果构建的是一个与虚拟邻居相连的子图有什么优点?PinSAGE 算法将会给我们解答,PinSAGE 算法通过多次随机游走,按游走经过的频率选取邻居,上图右侧为进行随机游走得到的节点序列,统计序列的频数可以发现节点5,10,11的频数为2,其余为1,当我们希望采样三个节点时,我们选取5,10,11作为0号节点的虚拟邻居。之后如果希望得到0号节点的二阶虚拟邻居则在已采样的节点继续进行随机游走即可。
回到上述问题,采样时选取虚拟邻居有什么好处?这种采样方式的好处是我们能更快的聚合到远处节点的信息。。实际上如果是按照 GraphSAGE 算法的方式生成子图,在聚合的过程中,非一阶邻居的信息可以通过消息传递逐渐传到中心,但是随着距离的增大,离中心越远的节点,其信息在传递过程中就越困难,甚至可能无法传递到;如果按照 PinSAGE 算法的方式生成子图,有一定的概率可以将非一阶邻居与中心直接相连,这样就可以快速聚合到多阶邻居的信息
和GraphSAGE相比,PinSAGE改进了什么?
PinSAGE使用多大的计算资源?
训练时,PinSAGE使用32核CPU、16张Tesla K80显卡、500GB内存;
推理时,MapReduce运行在378个d2.8xlarge Amazon AWS节点的Hadoop2集群。
PinSAGE和node2vec、DeepWalk这些有啥区别?
node2vec,DeepWalk是无监督训练;PinSAGE是有监督训练;
node2vec,DeepWalk不能利用节点特征;PinSAGE可以;
node2vec,DeepWalk这些模型的参数和节点数呈线性关系,很难应用在超大型的图上;
PinSAGE的单层聚合过程是怎样的?
和GraphSAGE一样,PinSAGE的核心就是一个 局部卷积算子 ,用来学习如何聚合邻居节点信息。
如下图算法1所示,PinSAGE的聚合函数叫做CONVOLVE。主要分为3部分:
PinSAGE是如何采样的?
如何采样这个问题从另一个角度来看就是:如何为目标节点构建邻居节点。
和GraphSAGE的均匀采样不一样的是,PinSAGE使用的是重要性采样。
PinSAGE对邻居节点的定义是:对目标节点 影响力最大 的T个节点。
PinSAGE的邻居节点的重要性是如何计算的?
其影响力的计算方法有以下步骤:
从目标节点开始随机游走;
使用 正则 来计算节点的“访问次数”,得到重要性分数;
目标节点的邻居节点,则是重要性分数最高的前T个节点。
这个重要性分数,其实可以近似看成Personalized PageRank分数。
关于随机游走,可以阅读《Pixie: A System for Recommending 3+ Billion Items to 200+ Million Users in Real-Time》
重要性采样的好处是什么?
和GraphSAGE一样,可以使得 邻居节点的数量固定 ,便于控制内存/显存的使用。
在聚合邻居节点时,可以考虑节点的重要性;在PinSAGE实践中,使用的就是 加权平均 (weighted-mean),原文把它称作 importance pooling 。
采样的大小是多少比较好?
从PinSAGE的实验可以看出,随着邻居节点的增加,而收益会递减;
并且两层GCN在 邻居数为50 时能够更好的抓取节点的邻居信息,同时保持运算效率。
PinSage论文中还介绍了落地过程中采用的大量工程技巧。
聚合函数
PinSAGE的单层聚合过程是怎样的?
为什么要将邻居节点的聚合embedding和当前节点的拼接?
采样
PinSAGE是如何采样的?
PinSAGE的邻居节点的重要性是如何计算的?
重要性采样的好处是什么?
采样的大小是多少比较好?
MiniBatch
PinSAGE的minibatch和GraphSAGE有啥不一样?
batch应该选多大?
训练
PinSAGE使用什么损失函数?
PinSAGE如何定义标签(正例/负例)?
PinSAGE用什么方法提高模型训练的鲁棒性和收敛性?
负采样
PinSAGE如何进行负采样?
训练时简单地负采样,会有什么问题?
如何解决简单负采样带来的问题?
如果只使用“hard”负样本,会有什么问题?
如何解决只使用“hard”负采样带来的问题?
如何区分采样、负采样、”hard“负采样?
推理
直接为使用训练好的模型产生embedding有啥问题?
如何解决推理时重复计算的问题?
下游任务如何应用PinSAGE产生的embedding?
如何为用户进行个性化推荐?
工程性技巧
pin样本的特征如何构建?
board样本的特征如何构建?
如何使用多GPU并行训练PinSAGE?
PinSAGE为什么要使用生产者-消费者模式?
PinSAGE是如何使用生产者-消费者模式?
https://zhuanlan.zhihu.com/p/195735468
https://zhuanlan.zhihu.com/p/133739758?utm_source=wechat_session&utm_id=0
学习大图、不断扩展的图,未见过节点的表征,是一个很常见的应用场景。GraphSAGE通过训练聚合函数,实现优化未知节点的表示方法。之后提出的GAN(图注意力网络)也针对此问题优化。 论文中提出了:传导性问题和归纳性问题,传导性问题是已知全图情况,计算节点表征向量;归纳性问题是在不完全了解全图的情况下,训练节点的表征函数(不是直接计算向量表示)。 图工具的处理过程每轮迭代( 一次propagation)一般都包含:收集信息、聚合、更新,从本文也可以更好地理解,其中聚合的重要性,及优化方法。
GraohSage主要贡献如下:
一般情况下一个节点的表式通过聚合它k跳之内的邻近节点计算,而全图的表示则通过对所有节点的池化计算。GIN使用了WL-test方法,即图同构测试,它是一个区分网络结构的强效方法,也是通过迭代聚合邻居的方法来更新节点,它的强大在于使用了injective(见后)聚合更新方法。而这里要评测GNN是否能达到类似WL-test的效果。文中还使用了多合集multiset的概念,指可能包含重复元素的集合。
GIN主要贡献如下:
在图采样之后,我们需要进行邻居聚合的操作。经典的邻居聚合函数包括取平均、取最大值、求和。
评估聚合表达能力的指标——单射(一对一映射),在上述三种经典聚合函数中,取平均倾向于学习分布,取最大值倾向于忽略重复值,这两个不属于单射,而求和能够保留邻居节点的完整信息,是单射。单射的好处是可以保证对聚合后的结果可区分。
Graph Isomorphic Net(GIN)的聚合部分是基于单射的。
如上图所示,GIN的聚合函数使用的是求和函数,它特殊的一点是在中心节点加了一个自连边(自环),之后对自连边进行加权。
这样做的好处是即使我们调换了中心节点和邻居节点,得到的聚合结果依旧是不同的。所以带权重的自连边能够保证中心节点和邻居节点可区分。
数据源:http://snap.stanford.edu/graphsage/ 斯坦福
使用科学网引文数据集,将学术论文分类为不同的主题。数据集共包含302424个节点,平均度9.15,使用2000-2004年数据作为训练集,2005年数据作为测试集。使用节点的度以及论文摘要的句嵌入作为特征。
https://aistudio.baidu.com/aistudio/datasetdetail/177810
将Reddit帖子归类为属于不同社区。数据集包含232965个帖子,平均度为492。使用现成的300维GloVe Common Crawl单词向量;对于每个帖子,使用特征包含:(1) 帖子标题的平均嵌入 (2) 帖子所有评论的平均嵌入 (3) 帖子的分数 (4)帖子的评论数量
为了对社区进行抽样,根据 2014 年的评论总数对社区进行了排名,并选择了排名 [11,50](含)的社区。省略了最大的社区,因为它们是大型的通用默认社区,大大扭曲了类分布。选择了在这些社区的联合上定义的图中最大的连通分量。
更多数据资料见:
http://files.pushshift.io/reddit/comments/
https://github.com/dingidng/reddit-dataset
最新数据已经更新到2022.10了
https://aistudio.baidu.com/aistudio/datasetdetail/177807
PPI 网络是蛋白质相互作用(Protein-Protein Interaction,PPI)网络的简称,在GCN中主要用于节点分类任务
PPI是指两种或以上的蛋白质结合的过程,通常旨在执行其生化功能。
一般地,如果两个蛋白质共同参与一个生命过程或者协同完成某一功能,都被看作这两个蛋白质之间存在相互作用。多个蛋白质之间的复杂的相互作用关系可以用PPI网络来描述。
PPI数据集共24张图,每张图对应不同的人体组织,平均每张图有2371个节点,共56944个节点818716条边,每个节点特征长度为50,其中包含位置基因集,基序集和免疫学特征。基因本体基作为label(总共121个),label不是one-hot编码。
参考链接:
https://blog.csdn.net/ziqingnian/article/details/112979175
GraphSAGE是一个通用的归纳框架,它利用节点特征信息(例如,文本属性)为以前看不见的数据有效地生成节点嵌入。GraphSAGE 不是为每个节点训练单独的嵌入,而是学习一个函数,该函数通过从节点的本地邻域中采样和聚合特征来生成嵌入。基于PGL,我们重现了GraphSAGE算法,在Reddit Dataset中达到了与论文同等水平的指标。此外,这是PGL中子图采样和训练的一个例子。
超参数
epoch: Number of epochs default (10)
normalize: Normalize the input feature if assign normalize.
sample_workers: The number of workers for multiprocessing subgraph sample.
lr: Learning rate.
symmetry: Make the edges symmetric if assign symmetry.
batch_size: Batch size.
samples: The max neighbors for each layers hop neighbor sampling. (default: [25, 10])
hidden_size: The hidden size of the GraphSAGE models.
parser = argparse.ArgumentParser(description='graphsage')
parser.add_argument(
"--normalize", action='store_true', help="normalize features") # normalize:归一化节点特征
parser.add_argument(
"--symmetry", action='store_true', help="undirect graph") # symmetry:聚合函数的对称性
parser.add_argument("--sample_workers", type=int, default=5) # sample_workers:多线程数据读取器的线程个数
parser.add_argument("--epoch", type=int, default=10)
parser.add_argument("--hidden_size", type=int, default=128)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument('--samples', nargs='+', type=int, default=[25, 10]) # samples_1:第一级邻居采样时候选择的最大邻居个数(默认25)#,samples_2:第而级邻居采样时候选择的最大邻居个数(默认10)
部分结果展示:
[INFO] 2022-11-18 16:45:44,177 [ train.py: 63]: Batch 800 train-Loss [0.5213774] train-Acc [0.9140625]
[INFO] 2022-11-18 16:45:45,783 [ train.py: 63]: Batch 900 train-Loss [0.65641916] train-Acc [0.875]
[INFO] 2022-11-18 16:45:47,385 [ train.py: 63]: Batch 1000 train-Loss [0.57411766] train-Acc [0.921875]
[INFO] 2022-11-18 16:45:48,977 [ train.py: 63]: Batch 1100 train-Loss [0.68337256] train-Acc [0.890625]
[INFO] 2022-11-18 16:45:50,434 [ train.py: 160]: Runing epoch:9 train_loss:[0.58635516] train_acc:[0.90786038]
[INFO] 2022-11-18 16:45:57,836 [ train.py: 165]: Runing epoch:9 val_loss:0.55885834 val_acc:0.9139818
[INFO] 2022-11-18 16:46:05,259 [ train.py: 169]: Runing epoch:9 test_loss:0.5578749 test_acc:0.91468066
100%|███████████████████████████████████████████| 10/10 [06:02<00:00, 36.29s/it]
[INFO] 2022-11-18 16:46:05,260 [ train.py: 172]: Runs 0: Model: graphsage Best Test Accuracy: 0.918849
目前官网最佳性能是95.7%,我这里没有调参
Aggregator | Accuracy_me_10 epochs | Accuracy_200 epochs | Reported in paper_200 epochs |
---|---|---|---|
Mean | 91.88% | 95.70% | 95.0% |
其余聚合器下官网和论文性能对比:
Aggregator | Accuracy_200 epochs | Reported in paper_200 epochs |
---|---|---|
Meanpool | 95.60% | 94.8% |
Maxpool | 94.95% | 94.8% |
LSTM | 95.13% | 95.4% |
图同构网络(GIN)是一个简单的图神经网络,期望达到Weisfeiler-Lehman图同构测试的能力。基于 PGL重现了 GIN 模型。
超参数
parser.add_argument('--data_path', type=str, default='./gin_data')
parser.add_argument('--dataset_name', type=str, default='MUTAG')
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--fold_idx', type=int, default=0)
parser.add_argument('--output_path', type=str, default='./outputs/')
parser.add_argument('--use_cuda', action='store_true')
parser.add_argument('--num_layers', type=int, default=5)
parser.add_argument('--num_mlp_layers', type=int, default=2)
parser.add_argument('--feat_size', type=int, default=64)
parser.add_argument('--hidden_size', type=int, default=64)
parser.add_argument(
'--pool_type',
type=str,
default="sum",
choices=["sum", "average", "max"])
parser.add_argument('--train_eps', action='store_true')
parser.add_argument('--init_eps', type=float, default=0.0)
parser.add_argument('--epochs', type=int, default=350)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--dropout_prob', type=float, default=0.5)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
GIN github代码复现含数据集下载:How Powerful are Graph Neural Networks? https://github.com/weihua916/powerful-gnns
https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip
论文使用 9 个图形分类基准:4 个生物信息学数据集(MUTAG、PTC、NCI1、PROTEINS) 和 5 个社交网络数据集(COLLAB、IMDB-BINARY、IMDB-MULTI、REDDITBINARY 和 REDDIT-MULTI5K)(Yanardag & Vishwanathan,2015)。 重要的是,我目标不是让模型依赖输入节点特征,而是主要从网络结构中学习。因此,在生物信息图中,节点具有分类输入特征,但在社交网络中,它们没有特征。 对于社交网络,按如下方式创建节点特征:对于 REDDIT 数据集,将所有节点特征向量设置为相同(因此,这里的特征是无信息的); 对于其他社交图,我们使用节点度数的 one-hot 编码。
社交网络数据集。
生物信息学数据集。
[INFO] 2022-11-18 17:12:34,203 [ main.py: 98]: eval: epoch 347 | step 2082 | | loss 0.448468 | acc 0.684211
[INFO] 2022-11-18 17:12:34,297 [ main.py: 98]: eval: epoch 348 | step 2088 | | loss 0.393809 | acc 0.789474
[INFO] 2022-11-18 17:12:34,326 [ main.py: 92]: train: epoch 349 | step 2090 | loss 0.401544 | acc 0.8125
[INFO] 2022-11-18 17:12:34,391 [ main.py: 98]: eval: epoch 349 | step 2094 | | loss 0.441679 | acc 0.736842
[INFO] 2022-11-18 17:12:34,476 [ main.py: 92]: train: epoch 350 | step 2100 | loss 0.573693 | acc 0.7778
[INFO] 2022-11-18 17:12:34,485 [ main.py: 98]: eval: epoch 350 | step 2100 | | loss 0.481966 | acc 0.789474
[INFO] 2022-11-18 17:12:34,485 [ main.py: 103]: best evaluating accuracy: 0.894737
结果整合:(这里就不把数据集一一跑一遍了)
MUTAG | COLLAB | IMDBBINARY | IMDBMULTI | |
---|---|---|---|---|
PGL result | 90.8 | 78.6 | 76.8 | 50.8 |
paper reuslt | 90.0 | 80.0 | 75.1 | 52.3 |
原论文所有结果: