前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >IBM研究院提出Graph2Seq,基于注意力机制的图到序列学习

IBM研究院提出Graph2Seq,基于注意力机制的图到序列学习

作者头像
崔庆才
发布2019-09-04 15:34:59
2.2K0
发布2019-09-04 15:34:59
举报
文章被收录于专栏:进击的Coder

介绍

Seq2Seq(序列到序列)及其变体在机器翻译、自然语言生成、语音识别、新药发现之类的领域表现非常出色。大多数Seq2Seq模型都属于编码器-解码器家族,其中编码器将输入序列编码为固定维度的连续向量表示,而解码器则解码向量得到目标序列。

然而,Seq2Seq有一个限制,它只能应用于输入表示为序列的问题。而在许多问题中,输入为更复杂的结构,比如图(graph)。对于这类图到序列(graph-to-sequence)问题,如果要应用Seq2Seq,就需要将图转换为序列。然而,将图精确地转换为序列是一项艰巨的挑战,因为在将图这种比较复杂的结构数据转换为序列时,难免会损失不少信息,特别是当输入数据本身适合用图表示的时候。最近的一些研究尝试在输入数据中提取句法特征,例如句子的词组结构(Tree2Seq),或将注意力机制应用于输入集(Set2Seq),或将句子递归地编码为树(Tree-LSTM)。在特定类别问题上,这类方法取得了充满希望的结果,然而,这类方法大多难以推广。

为此,IBM研究院的Kun Xu、Lingfei Wu等提出了Graph2Seq,一个端到端的处理图到序列问题的模型。

Graph2Seq采用与Seq2Seq相似的编码器-解码器架构,包括一个图编码器和一个序列解码器。图编码器部分,通过聚合有向图和无向图中的相邻信息,学习节点嵌入。然后根据学习到的节点嵌入,构建图嵌入。序列解码器部分,论文作者设计了一个基于注意力机制的LSTM网络,使用图嵌入作为初始隐藏状态,输出目标预测。注意力机制用于学习节点和序列元素的对齐,以更好地应对大型图。整个Graph2Seq的设计是模块化的,可扩展性很好。比如,编码器可以换成图卷积网络,解码器可以换成普通的LSTM。

Gragh2Seq模型

在上一节的末尾,我们已经简单介绍了Graph2Seq的架构。这一节我们将具体介绍Graph2Seq模型。下面是Graph2Seq的整体架构示意图。

节点嵌入生成

如前所述,节点嵌入中包含了节点的相邻信息。具体的嵌入生成过程如下:

  1. 通过查询嵌入矩阵We,将节点v的文本属性转换为一个特征向量av
  2. 根据边的方向,将v的邻居分类为前向邻(forward neighbor)N|-(v)和反向邻(backward neighbor)N-|(v)。
  3. 将v的前向邻的前向表示

聚合为单个向量

其中k为迭代索引。注意,在迭代k时,聚合仅仅使用k-1时生成的表示。每个节点的初始化前向表示为其特征向量。

  1. 我们将v的当前前向表示(k-1)和新生成的前向聚合向量(k)连接。连接所得的向量传入一个带非线性激活的全连接层,从而更新v的前向表示,在下一次迭代中使用。
  2. 将上述过程应用于反向表示。
  3. 重复前向表示聚合与反向表示聚合过程K次,连接最终的前向表示和反向表示,作为v的最终表示。

用伪代码表示以上节点嵌入生成过程:

上面我们提到了聚合前向表示和反向表示,却没有提到具体的聚合方法。实际上,论文作者尝试了3种不同的聚合方法。

  1. 均值 这是最简单直接的聚合方式,取分素均值(element-wise mean)。
  2. LSTM 使用LSTM处理节点邻居的单个随机排列(无序集)。
  3. 池化 将每个邻居向量传入一个全连接网络,然后应用分素最大池化(element-wise max-pooling)。

其中,σ为非线性激活函数。

经论文作者试验,总体而言,最简单的均值聚合效果最好。

均值(MA)、LSTM(LA)、池化(PA)聚合在3个合成SDP数据集(有向无环图、有向有环图、序线图)上的精确度

图嵌入生成

论文作者引入了两种基于节点嵌入构造图嵌入的方法。

  1. 基于池化的图嵌入。类似上面基于池化的聚合,论文作者将节点嵌入传给一个全连接神经网络,然后分素应用池化方法。论文作者共试验了三种池化方法,最大池化、最小池化、平均池化,最后发现三种池化方法没有显著差别。因此,论文作者最后选用了最大池化作为默认的池化方法。
  2. 基于节点的图嵌入。这一方法加入了一个超(super)节点vs至输入图,使图中的所有其他节点指向vs。我们使用之前提到的节点嵌入生成算法生成vs嵌入,因而vs嵌入捕获了所有节点的信息,可视为图嵌入。

经论文作者试验,总体而言,基于池化的图嵌入表现较好。

基于注意力的解码器

序列解码器是一个基于注意力的LSTM网络,根据给定的y1,...,yi-1,隐藏状态si(i表示时刻),以及上下文向量ci,预测下一个token,即yi。其中,上下文向量ci取决于前述图编码器根据输入图生成的节点表示集合(z1,...,zv)。具体而言,上下文向量ci通过节点表示的加权和计算得出:

相应的权重aij由下式计算得出:

其中,a为对齐模型(alignment model),为j处的输入节点和i处的输出的匹配程度评分。评分基于LSTM的隐藏状态si-1和输入图的第j个节点表示。对齐模型a为前馈神经网络,和系统的其他部分一起训练。

试验

试验设定

论文作者使用了Adam优化,mini-batch大小为30,学习率为0.001,解码器层dropout率为0.5(避免过拟合)。norm大于20时裁剪梯度。图编码器部分,默认跳(hop)大小为6,节点初始特征向量为40,非线性激活函数为ReLU,聚合器的参数随机初始化。解码器为单层,隐藏状态大小为80. 如前所述,使用了表现最佳的均值聚合和基于池化的图嵌入生成。

试验结果

从上表可以看到,在bAbI Task 19上,LSTM失败了,而Graph2Seq的表现是最好的,超过了GGS-NN和GCN。

而在最短路径任务(Shortest Path Task)上,LSTM同样失败了。尽管GGS-NN、GCN、Graph2Seq在小数据集上(SP-S,节点尺寸=5)上都达到了100%的精确度,但在大数据上(SP-L,节点尺寸=100),得益于解码器部分注意力机制的应用,Graph2Seq的表现超过了GGS-NN和GCN。

最后,论文作者在自然语言生成(Natural Language Generation)任务上评估了Graph2Seq的表现。具体而言,这一任务根据SQL查询语句,生成描述其含义的自然语言。论文作者使用的是WikiSQL数据集,该数据集包含87726对手工标注的自然语言查询问题,SQL查询,以及相应的SQL表。WikiSQL原本是为评测问题回答任务而创建的,这里论文作者逆向使用该数据集,将SQL请求视作输入,将生成正确的英语问题视作目标。WikiSQL的SQL请求分割为训练、验证、测试集,分别包含61297、9145、17284个请求。

从上表可以看出,Graph2SQL的BLEU-4评分显著高于Seq2Seq、Seq2Seq + Copy、Tree2Seq。

作者:IBM Research 编译:weakish

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2018-06-19,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 进击的Coder 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Gragh2Seq模型
  • 试验
相关产品与服务
语音识别
腾讯云语音识别(Automatic Speech Recognition,ASR)是将语音转化成文字的PaaS产品,为企业提供精准而极具性价比的识别服务。被微信、王者荣耀、腾讯视频等大量业务使用,适用于录音质检、会议实时转写、语音输入法等多个场景。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档