本文主要介绍CS224W的第十课,图的深度生成模型。第九课是图神经网络的实战,该部分内容会和其他场景的实战进行统一整合,做一个不断更新的实战篇章,并按场景将其梳理到同一个代码库下。回顾第八章,课程主要讲述了图神经网络,以及引入聚合函数和注意力机制的图卷积网络,通过图神经网络可自动化生成Embedding。本课程则主要讲述图的生成模型,如何生成一张图。
上图为CS224W第十讲的内容框架,如下链接为第十讲的课程讲义
我们可先带着如下两个问题,开始本章图生成模型的学习。
1)生成模型应该怎么设计,我们才能用它来生成图? 2)如何评价图的生成模型?什么样的模型才是好的生成模型?
首先我们了解下,学习图的生成对我们都有哪些帮助。
1)了解网络的生成过程(如人际关系网的发展、公共交通网络的扩张等); 2)异常检测:通过生成模型的结果来检测异常行为; 3)预测:预测网络未来演化的趋势; 4)仿真:对全新的网络结构进行模拟(如新合成化学分子的分析结构模拟); 5)对不完善的网络进行补全(如知识图谱的补全等)
看下图生成任务中需要关注的难点。
1)生成模型的输出空间巨大且可变。我们一般用邻接矩阵来表征网络的节点和边,对于有
个节点的网络来说,其输出空间就是
的矩阵。且输出会随着网络本身的变化而变化。 2)网络表征的不确定性。如下图所示,同一个网络的邻接矩阵并不是固定的,与节点的编号顺序有关。
个节点的邻接矩阵共有
种。
图1
3)图生成过程中的复杂依赖。比如我们想生成一个苯环(苯环有6个碳原子组成),如下图所示,我们在生成图的过程中,得记录历史信息,这样才能准确的生成图。
图2
上一部分简单引入了图生成的实际应用和难点。现在我们则开始介绍图生成的机器学习思想。
假设我们想通过一组数据
来学习网络的生成模型。
是数据的真实分布(我们通常不知道数据的真实分布),我们只有基于该分布进行采样,所得到的采样样本
。 生成模型的目标就是: 1)找到参数为
的模型,让该模型估计出来的数据分布
,尽可能接近数据的真实分布
。 通常使用极大似然来估计参数:
2)我们可对数据分布
进行采样,并生成网络。复杂分布的采样可拆解为如下两步: i)噪音分布采样:
ii)基于生成模型对噪音分布进行转换:
目前常见的深度生成模型如下,如生成对抗网络、变分自编码器等,本文主要讲自回归模型。
图3
思想:联合分布可以看成是一堆条件概率分布的乘积。
。 在图的生成模型中,
可以理解为第
次action(新增节点与边的action)。如此我们便可将图的生成问题转换为序列生成问题,不断地给模型加节点和边。
上一部分主要讲图生成的思想,这部分则开始介绍图生成算法以及如何通过算法生成真实的网络。
图4
正如上部分所讲,图的生成问题可以转换成如上图所示的序列生成问题。给定节点顺序
,图
可以映射为节点+边的序列
。 序列
有两层: 1)节点层:新增节点(每次新增一个节点); 2)边层:新增边(在新增节点和已有节点之间新增边)。新增边也是一个序列过程,如
的序列为
。 所以图的生成模型一般包含两个过程: 1)节点序列生成(Node-level Sequence); 2)对于新增的节点,生成与其相关联的边(Edge-level Sequence)。
背景已经讲了很多,现在我们开始学习图的深度生成模型GraphRNN,RNN生成图的整体流程可参考下图。
图5
再深入了解GraphRNN之前,我们先看下循环神经网络(Recurrent Neural Networks, RNN),如下图所示,RNN通常包括:
时刻的状态
,输入
,输出
。
其中
为参数矩阵,
为非线性算子。
图6
如下图
,将
时刻的输出看作是
时刻的输入,便可利用上述RNN的基本单元来生成序列。其中序列的起始和结束都定义一个标识符,其实标识符为
,结束标识符为
。 由于模型参数是固定的,图
最终得到的结果也是确定的。而我们不仅要能生成图,还需要生成数据分布。因此如下图
,我们要给RNN添加随机性。
:
时刻的输入为
时刻输出的采样。 具体实现可参考下图
,
时刻的输出
为概率向量,
时刻的输入
是在
概率分布的基础之上的采样结果。 比如
代表新增节点和当前节点之间有
的概率存在边连接,最终的采样结果
代表两者之间存在边连接。
图7a
图7b,在上图7a的基础之上添加随机性,x(t)是y(t-1)的采样
图7c,t-1时刻输出y(t-1)为概率向量,t时刻输入为基于y(t-1)概率分布生成的结果
将前述知识都串起来,也就可以理解GraphRNN是如何生成网络的。
如下图所示
和
,GraphRNN由Node-level RNN和Edge-level RNN两部分组成。 Node-level RNN用以生成新的节点,Edge-level RNN用以生成边(输入为边存在的概率),最终基于概率分布的采样结果,就是模型邻接矩阵的新增向量。 最后一个Edge-level RNN的输出
的含义是:下一个节点是否与当前节点存在链接,如果没有,则模型停止并输出我们生成的图。
图8a,GraphRNN生成图的逻辑
图8b
紧接着,我们看下GraphRNN是怎么训练的? 首先关注损失函数。如上图
,图
的邻接矩阵是确定的,即我们有节点和边的ground truth。我们期望模型Edge-level RNN生成的边概率尽可能接近真实图形。所以损失函数可使用二元交叉熵:
其中
为真实结果,
为RNN预测边是否存在的概率。 其次模型训练阶段所有梯度都是随时间逐步累积的,因此我们通过反向传播来训练模型,如下图9所示。
图9
上述逻辑存在一个问题,每新生成一个节点,需要判断该节点与所有历史节点是否存在边。当节点数量很大时,上述方法的复杂度贼高。因此我们需要进一步的模型优化。
通过广度优先搜索(BFS)来给节点编号,可优化模型复杂度。 如下图10,基于广度优先搜索的节点顺序为
,由于搜索方式是广度优先搜索,那么一定有如下结论:节点4和节点1之间不存在边连接,那么节点5与节点1之间也一定不存在边连接。
图10
综合来看,BFS对节点编号的帮助有: 1)节点顺序固定,避免
种可能。 2)减少Edge-level RNN阶段的复杂度(Reducing number of previous nodes to look at)。
有了上述手段,我们可以生成图了。那么我们怎么评价生成图的好坏呢?
目标:定义衡量图相似性的指标。 挑战:没有有效的图同构检验能满足我们的需求。 解决方法:Visual similarity(如下图11)、Graph statistics similarity
图11