专栏首页进击的CoderIBM研究院提出Graph2Seq,基于注意力机制的图到序列学习

IBM研究院提出Graph2Seq,基于注意力机制的图到序列学习

介绍

Seq2Seq(序列到序列)及其变体在机器翻译、自然语言生成、语音识别、新药发现之类的领域表现非常出色。大多数Seq2Seq模型都属于编码器-解码器家族,其中编码器将输入序列编码为固定维度的连续向量表示,而解码器则解码向量得到目标序列。

然而,Seq2Seq有一个限制,它只能应用于输入表示为序列的问题。而在许多问题中,输入为更复杂的结构,比如图(graph)。对于这类图到序列(graph-to-sequence)问题,如果要应用Seq2Seq,就需要将图转换为序列。然而,将图精确地转换为序列是一项艰巨的挑战,因为在将图这种比较复杂的结构数据转换为序列时,难免会损失不少信息,特别是当输入数据本身适合用图表示的时候。最近的一些研究尝试在输入数据中提取句法特征,例如句子的词组结构(Tree2Seq),或将注意力机制应用于输入集(Set2Seq),或将句子递归地编码为树(Tree-LSTM)。在特定类别问题上,这类方法取得了充满希望的结果,然而,这类方法大多难以推广。

为此,IBM研究院的Kun Xu、Lingfei Wu等提出了Graph2Seq,一个端到端的处理图到序列问题的模型。

Graph2Seq采用与Seq2Seq相似的编码器-解码器架构,包括一个图编码器和一个序列解码器。图编码器部分,通过聚合有向图和无向图中的相邻信息,学习节点嵌入。然后根据学习到的节点嵌入,构建图嵌入。序列解码器部分,论文作者设计了一个基于注意力机制的LSTM网络,使用图嵌入作为初始隐藏状态,输出目标预测。注意力机制用于学习节点和序列元素的对齐,以更好地应对大型图。整个Graph2Seq的设计是模块化的,可扩展性很好。比如,编码器可以换成图卷积网络,解码器可以换成普通的LSTM。

Gragh2Seq模型

在上一节的末尾,我们已经简单介绍了Graph2Seq的架构。这一节我们将具体介绍Graph2Seq模型。下面是Graph2Seq的整体架构示意图。

节点嵌入生成

如前所述,节点嵌入中包含了节点的相邻信息。具体的嵌入生成过程如下:

  1. 通过查询嵌入矩阵We,将节点v的文本属性转换为一个特征向量av
  2. 根据边的方向,将v的邻居分类为前向邻(forward neighbor)N|-(v)和反向邻(backward neighbor)N-|(v)。
  3. 将v的前向邻的前向表示

聚合为单个向量

其中k为迭代索引。注意,在迭代k时,聚合仅仅使用k-1时生成的表示。每个节点的初始化前向表示为其特征向量。

  1. 我们将v的当前前向表示(k-1)和新生成的前向聚合向量(k)连接。连接所得的向量传入一个带非线性激活的全连接层,从而更新v的前向表示,在下一次迭代中使用。
  2. 将上述过程应用于反向表示。
  3. 重复前向表示聚合与反向表示聚合过程K次,连接最终的前向表示和反向表示,作为v的最终表示。

用伪代码表示以上节点嵌入生成过程:

上面我们提到了聚合前向表示和反向表示,却没有提到具体的聚合方法。实际上,论文作者尝试了3种不同的聚合方法。

  1. 均值 这是最简单直接的聚合方式,取分素均值(element-wise mean)。
  2. LSTM 使用LSTM处理节点邻居的单个随机排列(无序集)。
  3. 池化 将每个邻居向量传入一个全连接网络,然后应用分素最大池化(element-wise max-pooling)。

其中,σ为非线性激活函数。

经论文作者试验,总体而言,最简单的均值聚合效果最好。

均值(MA)、LSTM(LA)、池化(PA)聚合在3个合成SDP数据集(有向无环图、有向有环图、序线图)上的精确度

图嵌入生成

论文作者引入了两种基于节点嵌入构造图嵌入的方法。

  1. 基于池化的图嵌入。类似上面基于池化的聚合,论文作者将节点嵌入传给一个全连接神经网络,然后分素应用池化方法。论文作者共试验了三种池化方法,最大池化、最小池化、平均池化,最后发现三种池化方法没有显著差别。因此,论文作者最后选用了最大池化作为默认的池化方法。
  2. 基于节点的图嵌入。这一方法加入了一个超(super)节点vs至输入图,使图中的所有其他节点指向vs。我们使用之前提到的节点嵌入生成算法生成vs嵌入,因而vs嵌入捕获了所有节点的信息,可视为图嵌入。

经论文作者试验,总体而言,基于池化的图嵌入表现较好。

基于注意力的解码器

序列解码器是一个基于注意力的LSTM网络,根据给定的y1,...,yi-1,隐藏状态si(i表示时刻),以及上下文向量ci,预测下一个token,即yi。其中,上下文向量ci取决于前述图编码器根据输入图生成的节点表示集合(z1,...,zv)。具体而言,上下文向量ci通过节点表示的加权和计算得出:

相应的权重aij由下式计算得出:

其中,a为对齐模型(alignment model),为j处的输入节点和i处的输出的匹配程度评分。评分基于LSTM的隐藏状态si-1和输入图的第j个节点表示。对齐模型a为前馈神经网络,和系统的其他部分一起训练。

试验

试验设定

论文作者使用了Adam优化,mini-batch大小为30,学习率为0.001,解码器层dropout率为0.5(避免过拟合)。norm大于20时裁剪梯度。图编码器部分,默认跳(hop)大小为6,节点初始特征向量为40,非线性激活函数为ReLU,聚合器的参数随机初始化。解码器为单层,隐藏状态大小为80. 如前所述,使用了表现最佳的均值聚合和基于池化的图嵌入生成。

试验结果

从上表可以看到,在bAbI Task 19上,LSTM失败了,而Graph2Seq的表现是最好的,超过了GGS-NN和GCN。

而在最短路径任务(Shortest Path Task)上,LSTM同样失败了。尽管GGS-NN、GCN、Graph2Seq在小数据集上(SP-S,节点尺寸=5)上都达到了100%的精确度,但在大数据上(SP-L,节点尺寸=100),得益于解码器部分注意力机制的应用,Graph2Seq的表现超过了GGS-NN和GCN。

最后,论文作者在自然语言生成(Natural Language Generation)任务上评估了Graph2Seq的表现。具体而言,这一任务根据SQL查询语句,生成描述其含义的自然语言。论文作者使用的是WikiSQL数据集,该数据集包含87726对手工标注的自然语言查询问题,SQL查询,以及相应的SQL表。WikiSQL原本是为评测问题回答任务而创建的,这里论文作者逆向使用该数据集,将SQL请求视作输入,将生成正确的英语问题视作目标。WikiSQL的SQL请求分割为训练、验证、测试集,分别包含61297、9145、17284个请求。

从上表可以看出,Graph2SQL的BLEU-4评分显著高于Seq2Seq、Seq2Seq + Copy、Tree2Seq。

作者:IBM Research 编译:weakish

本文分享自微信公众号 - 进击的Coder(FightingCoder)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2018-06-19

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 深度学习效果不好?试试 Batch Normalization 吧!

    Batch Normalization(简称BN)自从提出之后,因为效果特别好,很快被作为深度学习的标准工具应用在了各种场合。BN大法虽然好,但是也存在一些局...

    崔庆才
  • 自然语言处理基石 Embedding 最新进展汇总

    词嵌入(word embeddings)和句嵌入(sentence embeddings)已经成为任何基于深度学习的自然语言处理系统不可或缺的部分。

    崔庆才
  • 【一】从0开始,用flask+mongo打造分布式服务器监控平台

    在前面一篇文章《 Python 系统资源信息获取工具,你用过没?》中,我们学习了如何使用 Python 中的第三方库 psutil 来获取系统的资源信息,比如 ...

    崔庆才
  • Huber Loss和Focal Loss的原理与实现

    Huber Loss主要用于解决回归问题中,存在奇点数据带偏模型训练的问题;Focal Loss主要解决分类问题中类别不均衡导致的模型训偏问题。

    CodeInHand
  • 完整部署CentOS7.2+OpenStack+kvm 云平台环境(2)--云硬盘等后续配置

    继上一篇博客介绍了完整部署CentOS7.2+OpenStack+kvm 云平台环境(1)--基础环境搭建,本篇继续讲述后续部分的内容  1 虚拟机相关 1.1...

    洗尽了浮华
  • 夯实Java基础系列23:一文读懂继承、封装、多态的底层实现原理

    本系列文章将整理到我在GitHub上的《Java面试指南》仓库,更多精彩内容请到我的仓库里查看

    Java技术江湖
  • 轻松学Pytorch-使用卷积神经网络实现图像分类

    大家好,本篇教程的贡献者来自社区投稿作者【陨星落云】,使用CIFAR-10数据集进行图像分类。该数据集中的图像是彩色小图像,其中被分为了十类。一些示例图像,如下...

    OpenCV学堂
  • CQRS+ES项目解析-Diary.CQRS

    在《当我们在讨论CQRS时,我们在讨论些神马》中,我们讨论了当使用CQRS的过程中,需要关心的一些问题。其中与CQRS关联最为紧密的模式莫过于Event Sou...

    thz
  • 苹果macOS Mojave发布,四年来最大更新

    现在起,登录Mac App Store就能下载更新。新版macOS大小5.7GB,版本号为macOS 10.14。macOS Mojave可以安装在2012年中...

    量子位
  • vector.clear()不能用来清零

    vector.clear()函数并不会把所有元素清零,笔者就曾经这样幻想过这个函数的作用,然而事实证明并不是。

    chenjx85

扫码关注云+社区

领取腾讯云代金券