前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >CS224w图机器学习(七):Graph Neural Networks

CS224w图机器学习(七):Graph Neural Networks

作者头像
慎笃
发布2021-09-15 10:32:30
5150
发布2021-09-15 10:32:30
举报
文章被收录于专栏:深度学习进阶深度学习进阶

内容简介

本文主要介绍CS224W的第八课,图神经网络。上一篇章的主题是图表征学习,主要在讲Node Embedding,核心步骤包含编码网络和相似性度量。本文则是从图神经网络的角度出发,展开一些编码网络的深度方法。

CS224W Lecture 8: Graph Neural Networks

上图为CS224W第八讲的内容框架,如下链接为第八讲的课程讲义

1 Introduction

我们先简单回顾上一章节讲的Node Embedding,如下图所示,我们期望通过一个编码网络,将图的节点映射到一个Embedding空间中,同时满足节点在图中的相似度与Embedding空间的相似度是类似的。上一章节主要在讲述基于不同随机游走策略的相似度计算方式,这章节则主要讲基于图神经网络的深度编码方法。

上文中也曾提到,当前深度网络结构主要应用于图像、音频等网格/序列数据,如下图所示。而相比于此,图的结构会复杂很多,这就导致现在的深度网络很难应用到图上,因此我们需要图神经网络。

<br> (二维码自动识别)

我们可以先尝试一下,强行把处理图像的卷积神经网络(CNN)用于编码图的信息,看到底存在哪些地方不合理(如果想深入了解深度学习相关的内容,推荐《Dive into Deep Learning-动手学深度学习》,学完这门课程会对深度学习有一个全面的了解,链接如下)。

CNN的结构如下图所示,在CNN的卷积层,有一个滑动滤波器沿着像素滑动(下图中的黑框),以捕捉图片信息。 如果我们在图里面也考虑一个这样的滤波器,去捕捉每个节点接收到的来自于其他节点的信息。图中存在的矩阵,最为常见的是邻接矩阵,那么一个可行的想法便是,将滑动滤波器应用于图的邻接矩阵上,并将处理后的结果作为神经网络的输入,最后训练的输出为节点的表征。 我们再思考下这个想法有哪些不足: 1)输入参数量级为O(N),随着节点个数的增加,参数会越来越多; 2)节点个数发生变化时,输入参数也随之变化,需要重新训练模型; 3)网络节点间的关系发生变化,邻接矩阵也会发生变化,原有的模型可能不再适用。 由此可知,将传统神经网络应用于图中是不太合理的,因此我们需要专属于图的神经网络结构。

将每个节点与其他节点的连接,作为神经网络的输入

2 Basics of deep learning for graphs

回顾下图神经网络要做的事情。

我们希望有一个编码网络,能够自动提取节点的特征,以避免人工手动从节点自身以及与其他节点的连接中提取信息,并最终将提取出来的节点特征用于节点分类等诸多场景。 那么我们不难发现,要想设计自动化提取节点特征的编码网络,还是需要从节点的邻居出发,关键在于我们怎么设计图神经网络才能更好的提取节点特征。

计算图(Computation Graph)

神经网络的核心还是在于节点的局部连接,如节点A的Embedding向量取决于其邻居节点

[公式]
[公式]

。 由此我们可以引入计算图,对于每一个节点来说,它的计算图由邻居节点的数量来决定,如右下图就是节点A的计算图,所有节点的计算图可参见下下图。

左上图中所有节点的计算图详情

再更详细的了解下计算图: 1)模型深度可以是任意值(上图的深度为2); 2)节点每层都有一个Embedding向量(如上图,节点A在第0层和第二层有着不同的Embedding向量); 3)节点在第0层的Embedding向量,是模型输入的特征向量; 4)节点在第

[公式]
[公式]

层的Embedding向量,通过图神经网络(计算图)计算所得。 我们知道每个节点的计算图是不一样的,那么该怎么设计图神经网络呢?过程分两步: 1)先将邻居节点传过来的信息进行整合,可以用平均、求和、取最大值等手段; 2)在上述结果的基础之上,添加神经网络(此时所有节点的输入和输出都一样)。

基于计算图的深度编码Deep Encoder

深度编码的公式详情如下图所示,我们逐个拆解: 1)初始层的Embedding向量为输入特征向量; 2)第

[公式]
[公式]

层的Embedding向量由两部分组成: 一个是节点

[公式]
[公式]

的邻居节点从

[公式]
[公式]

层传入的Embedding向量的平均值(可选)

[公式]
[公式]

另一个是节点

[公式]
[公式]

[公式]
[公式]

层的Embedding向量

[公式]
[公式]

; 再经过非线性的激活函数

[公式]
[公式]

便可得到第

[公式]
[公式]

层的Embedding向量。 3)最后一层的输出,也就是我们期望的特征向量。

深度编码的模型训练

上述公式的矩阵表示为:

[公式]
[公式]
[公式]
[公式]

和传统机器学习一样,我们也需要定义损失函数来帮助我们训练模型。实际上我们可以套入任意常用的损失函数,然后使用随机梯度下降训练模型的权重。接下来我们将分别介绍无监督训练和有监督训练两种模式。 无监督训练: 相似节点有着相似的Embedding向量,在上一章节的基础之上,我们多了深度编码的模型,那么可直接将这个模型带入上节课所讲方法:随机游走(DeepWalk、node2vec等),如此便能得到我们想要的损失函数,并用其来训练权重(附上传送链接)。

有监督训练: 在这一部分最开始的时候,我们回顾了特征向量的用途,它可以用于节点分类,也可以用于其他场景。 以节点分类为例,有监督的训练在于,我们可以把节点类别这个先验知识也引入到损失函数之中,让模型的训练结果,更贴合当前场景的需要。节点分类的损失函数可参照下图(可以理解为:在Embedding向量的基础之上进行节点分类,再将模型预测结果和节点初始类别的交叉熵作为损失函数):

把上述所有过程串起来,就是图神经网络(Graph Network Networks)。最后再对比Introduction部分胡诌的方法,我们综合看下图神经网络有哪些不同:

1)所有节点共享相同的聚合函数,模型参数与网络结构无关; 2)模型具有很强的推广性,针对新图或图中新节点,模型都同样适用。

3 Graph Convolutional Networks and GraphSAGE

聚合函数

在介绍计算图时,我们是先将来自于节点邻居的信息输入进行聚合,再添加神经网络模型并优化其参数。 上一部分所提到的聚合函数是采用均值,那么如果不用均值,改用其他的聚合函数呢? GraphSAGE的思想:如下图所示,模型可采用不同的聚合函数(都写成AGG)来对邻居节点的信息输入进行整合。

常用的聚合函数有: 1)Mean:取邻居节点Embedding向量的均值; 2)Pool:先对邻居节点Embedding向量进行一次非线性变换(下图中的Q),再对变换后的向量进行mean pooling或max pooling(下图中的

[公式]
[公式]

) 3)LSTM:首先LSTM本身用于序列数据,而邻居节点是没有任何顺序的,所以LSTM聚合函数的做法是:先将所有邻居节点的顺序随机打乱,再基于LSTM,使用邻居节点Embedding向量组成的随机序列作为输入,生成最终的聚合结果。

关于图卷积网络更深入的部分,课程没有详细的介绍,仅列出相关论文。

关于GNN的论文

4 Graph Attention Networks

除去聚合函数,图神经网络还有一个待深挖的点,那就是邻居节点传入信息的权重。

[公式]
[公式]

我们先看下这个公式,每个邻接节点的权重都是

[公式]
[公式]

,即节点

[公式]
[公式]

的所有邻接节点传入的信息,对于节点

[公式]
[公式]

来讲是同等重要。 但是对于部分节点来讲,有的邻接节点会更加重要。我们可以采用注意力机制给不同的节点分配权重。

注意力机制(Attention Mechanism)

定义参数

[公式]
[公式]

为节点

[公式]
[公式]

传给节点

[公式]
[公式]

的重要程度。假设注意力函数

[公式]
[公式]

可用于计算节点传输的重要性,那么有节点

[公式]
[公式]

相对于节点

[公式]
[公式]

的重要性

[公式]
[公式]

[公式]
[公式]

再使用softmax函数对所有重要性进行归一化:

[公式]
[公式]

邻居节点的信息输入也就可以表示为:

[公式]
[公式]

。 注意力机制

[公式]
[公式]

应该怎么选择呢? 可以用单层神经网络来作为注意力机制函数,在图神经网络的训练过程,蕴含注意力机制的单层神经网络也会跟着模型一起训练。

章节最后还讲述了注意力机制的一些属性(参见下图)。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 内容简介
  • CS224W Lecture 8: Graph Neural Networks
    • 1 Introduction
      • 2 Basics of deep learning for graphs
        • 3 Graph Convolutional Networks and GraphSAGE
          • 4 Graph Attention Networks
          领券
          问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档