首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

手把手教你用DGL框架进行批量图分类

机器之心专栏

作者:DGL团队

图分类(预测图的标签)是图结构数据里一类重要的问题。它的应用广泛,可见于生物信息学、化学信息学、社交网络分析、城市计算以及网络安全。随着近来学界对于

图神经网络

的热情持续高涨,出现了一批用图神经网络做图分类的工作。比如训练图神经网络来预测蛋白质结构的性质,根据社交网络结构来预测用户的所属社区等(Ying et al., 2018, Cangea et al., 2018, Knyazev et al., 2018, Bianchi et al., 2019, Liao et al., 2019, Gao et al., 2019)。

在这个教程里,我们将一起学习:

如何使用 DGL 批量化处理大小各异的图数据

训练图神经网络完成一个简易的图分类任务

简易图分类任务

这里我们设计了一个简单的图分类任务。在 DGL 里我们实现了一个迷你图分类数据集(MiniGCDataset)。它由以下 8 类图结构数据组成。每一类图包含同样数量的随机样本。任务目标是训练图神经网络模型对这些样本进行分类。

以下是使用 MiniGCDataset 的示例代码。我们先创建了一个拥有 80 个样本的数据集。数据集中每张图随机有 10 到 20 个节点。DGL 中所有的数据集类都符合 Sequence 的抽象结构——既可以使用 dataset[i] 来访问第 i 个样本。这里每个样本包含图结构以及它对应的标签。

运行以上代码后可以画出数据集中第一个样本的图结构以及它对应的标签:

打包一个图的小批量

为了更高效地训练神经网络,一个常见的做法是将多个样本打包成小批量(mini-batch)。打包尺寸相同的张量样本非常简单。比如说打包两个尺寸为 2828 的图片会得到一个 22828 的张量。相较之下,打包图面临两个挑战:

图的边比较稀疏

图的大小、形状各不相同

DGL 提供了名为 dgl.batch 的接口来实现打包一个图批量的功能。其核心思路非常简单。将 n 张小图打包在一起的操作可以看成是生成一张含 n 个不相连小图的大图。下图的可视化从直觉上解释了 dgl.batch 的功能。

可以看到通过 dgl.batch 操作,我们生成了一张大图,其中包含了一个环状和一个星状的连通分量。其邻接矩阵表示则对应为在对角线上把两张小图的邻接矩阵拼接在一起(其余部分都为 0)。

以下是使用 dgl.batch 的一个实际例子。我们定义了一个 collate 函数来将 MiniGCDataset 里多个样本打包成一个小批量。

正如打包 N 个张量得到的还是张量,dgl.batch 返回的也是一张图。这样的设计有两点好处。首先,任何用于操作一张小图的代码可以被直接使用在一个图批量上。其次,由于 DGL 能够并行处理图中节点和边上的计算,因此同一批量内的图样本都可以被并行计算。

图分类器

这里使用的图分类器和应用在图像或者语音上的分类器类似——先通过多层神经网络计算每个样本的表示(representation),再通过表示计算出每个类别的概率,最后通过向后传播计算梯度。一个常见的图分类器由以下几个步骤构成:

通过图卷积(Graph Convolution)层获得图中每个节点的表示。

使用「读出」操作(Readout)获得每张图的表示。

使用 Softmax 计算每个类别的概率,使用向后传播更新参数。

下图展示了整个流程:

之后我们将分步讲解每一个步骤。

图卷积

我们的图卷积操作基本类似图卷积网络 GCN(具体可以参见我们的关于 GCN 的教程)。图卷积模型可以用以下公式表示:

在这个例子中,我们对这个公式进行了微调:

我们将求和替换成求平均可用来平衡度数不同的节点,在实验中这也带来了模型表现的提升。

此外,在构建数据集时,我们给每个图里所有的节点都加上了和自己的边(自环)。这保证节点在收集邻居节点表示进行更新时也能考虑到自己原有的表示。以下是定义图卷积模型的代码。这里我们使用PyTorch作为 DGL 的后端引擎(DGL 也支持 MXNet 作为后端)。

首先,我们使用 DGL 的内置函数定义消息传递:

其次,我们定义消息累和函数。这里我们对收到的消息进行平均。

之后,我们对收到的消息应用线性变换和激活函数。

最后,我们把所有的小模块串联起来成为 GCNLayer。

读出和分类

读出(Readout)操作的输入是图中所有节点的表示,输出则是整张图的表示。在 Google 的 Neural Message Passing for Quantum Chemistry(Gilmer et al. 2017) 论文中总结过许多不同种类的读出函数。在这个示例里,我们对图中所有节点表示取平均以作为图的表示:

DGL 提供了许多读出函数接口,以上公式可以很方便地用 dgl.mean(g) 完成。最后我们将图的表示输入分类器。分类器对图表示先做了一个线性变换然后得到每一类在 softmax 之前的 logits。具体代码如下:

准备和训练

阅读到这边的读者可以长舒一口气了。因为之后的训练过程和其他经典的图像,语音分类问题基本一致。首先我们创建了一个包含 400 张节点数量为 10~20 的合成数据集。其中 320 张图作为训练数据集,80 张图作为测试集。

其次我们创建一个刚刚定义的图神经网络模型对象。

训练过程则是经典的反向传播和梯度下降。

下图是以上模型训练的学习曲线:

在训练完成后,我们在测试集上验证模型的表现。出于部署教程的考量,我们限制了模型训练的时间。如果你花更多时间训练模型,应该能得到更好的表现(80%-90%)。

我们还制作了一个动画来展示训练好的模型预测每张图真实标签的概率。可以看到我们刚刚定义的图神经网络能够较为准确地预测出图样本的对应标签:

为了更好地理解模型学到的节点和图的表示,我们使用了 t-SNE 来进行降维和可视化。

顶部的两张小图分别可视化了做完 1 层和 2 层图卷积后的节点表示。不同颜色代表属于不同类别的图的节点。可以看到,经过训练后,属于同一类别的节点表示更加接近。并且,经过两层图卷积后这一聚类效果更明显。其原因是因为两层卷积后每个节点能接收到 2 度范围内的邻居信息。

底部的大图可视化了每张图在做 softmax 前的 logits,也就是图表示。可以看到通过读出函数后,图表示能非常好地各自区分开来。这一区分度比节点表示更加明显。

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20190129A0GJO100?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券