前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >CIKM2022: LTE4G:图神经网络中的长尾专家

CIKM2022: LTE4G:图神经网络中的长尾专家

作者头像
Houye
发布2023-03-01 15:33:24
5530
发布2023-03-01 15:33:24
举报
文章被收录于专栏:图与推荐

LTE4G: Long-Tail Experts for Graph Neural Networks

CIKM2022

01

本文介绍

尽管GNNs最近取得了显著的成功,但大多数现有GNN工作都基于一种平衡的假设,即每个类的节点数量和每个节点的邻居数量(即节点度数)是平衡的。然而,大多数现实世界的图在类别和节点度两个方面都遵循长尾分布。在类别方面,一部分类别中的节点数量要超过其他类别的节点数量,导致GNN更偏向于样本丰富的类别(头类),而不能很好地泛化样本较少的类(尾类);在节点度方面,少数高度节点(头节点)往往拥有大部分链接,而大多数低度数(尾节点)的链接数较少,导致GNN更偏向于头节点。图中的类别/度长尾分布如图1所示。

目前有部分工作关注到了上述两个挑战,然而这些工作只专注于两个挑战中的一个,而忽视它们之间的联系。图1(c)展示了不同分布节点子集下的分类准确率,其中HH表示共同属于头类别以及头节点的节点,其他三类节点表示类似,从而证明了联合考虑两种长尾分布的重要性。图2则展示了图1(c)中准确率分布趋势的细节。本文联合考虑到类别与度长尾分布,提出了LTE4G模型,以提升GNN的整体泛化能力。

02

本文方法

本文方法的整体框架如图3所示。

2.1 预训练阶段

为了获得良好的初始节点嵌入,本文首先在原始图上预训练GCN编码器以获得节点嵌入。编码器表示为:

,预训练GNN的输出为:

由于通过传统的交叉熵损失对不平衡数据进行训练会导致编码器预测偏向于头部类别,故本文利用了焦点损失函数,对错误分类的样本赋予比正确分类的样本更高的权重,从而减少偏差。编码器的焦点损失函数表示为:

2.2 长尾专家

本文考虑到类分布和节点度分布的长尾性,以平衡的方式在图中分割节点。首先计算每个类中的节点数量,并根据类基数对类进行排序,top-p%的类被看作是头类别,其余的被认为是尾类别;其次,将度大于5的节点看作是头节点,剩余节点看作是尾节点。最终得到四个节点子集:HH,HT,TH,TT,如图1(c)所示。

得到以上四个相对平衡的子集之后,本文为每一个子集分配一个基于GNN的专家并训练这些专家。基于GNN的专家定义如下:

其中∗∈{HH,HT,TH,TT}。由图1(c)可得头节点在头尾类别上的表现都比尾节点要好,故可以通过微调WHH以及WTH来得到WHT以及WTT,各个专家的损失可表示为:

其中V∗以及C∗表示属于∗∈{HH,HT,TH,TT}的节点以及类别集合。虽然上述损失可以在每个专家负责的类别以及节点度上提供准确的分类结果,但剩下的挑战是但剩下的挑战是如何利用专家的知识来获得最终的节点分类结果。

2.3 将专家的知识蒸馏给学生模型

需要注意的是,当为某个专家分配的节点数量不够时,上一步获得的知识有时可能会有噪声。为了缓解这种情况,并进一步利用专家的知识,本文引入了两个学生,即头类学生和尾类学生,每个学生负责分类属于头类和尾类的节点。学生的定义如下表示:

在定义两个学生之后,本文利用知识蒸馏从学习到的专家中提取知识。头类学生向HH、HT专家学习,而尾类学生向TH、TT专家学习。本文利用学生与相关专家之间的KL-散度进行知识蒸馏,头类学生以及尾类学生的蒸馏过程可分别表示为:

然而,需要注意的是,由于分配给同一学生的两个节点子集在节点度上存在差异,分配给节点度高的子集的专家的性能要优于对应节点度低的专家。因此,对于学生来说,尾度专家(例如HT)中包含的知识比头度专家中包含的知识更难获得。故本文采用课程学习以减少头度专家与尾度专家之间的差距,更准确地说,本文希望学员在训练前期多向头度专家学习,在训练后期多向尾度专家学习,从而使训练过程呈现一个由易向难的趋势。其损失可定义为:

其中

,e和E分别表示当前的epoch以及总的训练epoch。

除了蒸馏损失之外,本文同样计算了头类与尾类的分类损失,表示为:

学生模型的整体损失表示为:

模型的整体损失表示为:

2.4 基于类原型的推理

由于LTE4G基于头尾类学生执行节点分类,因此推理阶段的主要挑战是如何确定测试节点应该发送给头类还是尾类学生。为此,本文设计了一个基于类原型的推理方法,其主要思想是根据每个测试节点与类原型的相似性将其分配给一个学生。即对于给定的测试节点,需要找到原型与测试节点最相似的类,然后将测试节点分配给相应的学生。计算类原型最简单的方法是计算训练数据中属于每个类的标记节点的预训练嵌入的平均值,可表示为:

在得到所有pc之后,计算给定测试节点的嵌入与类原型之间的相似性,以确定相似度最大的类c:

其中sim()表示余弦相似度。

然而,由于类分布通常是不平衡的,用于计算类原型的节点数量会有很大的变化。例如,在极端情况下,具有单个标记节点的类必须依赖该单个节点来计算类原型。这意味着一些类原型的质量比其他的要低。因此,为了扩展类原型计算的候选节点,本文额外考虑了它们的相邻节点。然而尽管存在同质性假设,但并非所有邻近节点都与目标节点共享相同的标签,这可能会导致类原型计算中的噪声。因此,本文利用预先计算的类概率来选择与目标节点属于同一类的具有高置信度的相邻节点。此外,由于节点度也呈现长尾分布,大多数节点只有少数相邻节点。在这方面,我们进一步扩展候选节点,以包含具有相似特征的节点。

03

实验

本文实验中的数据集的具体统计如表1、表2所示。

本文模型与基线模型的节点分类对比如表3、4、5所示,本文模型在不同不平衡率下基本都取得了最先进的结果。

本文的消融实验如表6、7所示。

本文模型的参数敏感性以及复杂性分析图6、7所示。

04 — 结论

本文提出了一种新的基于gnn的节点分类方法,该方法同时考虑了类长尾性和节点度长尾性。LTE4G首先考虑类和度分布,将图中的节点分成四个平衡的子集,并在每个平衡的子集上训练一个专家。然后,采用了头度、尾度学生学习的知识蒸馏技术。对于推理,本文设计了一种基于类原型的推理,不仅利用训练集中的标记节点,还利用它们的邻近节点以及相似节点。通过对各种人工和自然不平衡设置的大量实验,实证地证明了其优越性。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2023-02-22,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 图神经网络与推荐系统 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档