Seq2seq模型的一个变种网络:Pointer Network的简单介绍

Pointer Network(为方便起见以下称为指针网络)是seq2seq模型的一个变种。他们不是把一个序列转换成另一个序列, 而是产生一系列指向输入序列元素的指针。最基础的用法是对可变长度序列或集合的元素进行排序。

seq2seq的基础是一个LSTM编码器加上一个LSTM解码器。在机器翻译的语境中, 最常听到的是: 用一种语言造句, 编码器把它变成一个固定大小的陈述。解码器将他转换成一个句子, 可能和之前的句子长度不同。例如, “como estas?”-两个单词-将被翻译成 “how are you?”-三个单词。

当“注意力”增强时模型效果会更好。这意味着解码器在输入的前后都可以访问。就是说, 它可以从每个步骤访问编码器状态, 而不仅仅是最后一个。思考一下它怎样帮助西班牙语让形容词在名词之前: “neural network”变成 “red neuronal”

在专业术语中,“注意力”(至少是这种特定的 基于内容的注意力) 归结为加权平均值均数。简而言之,编码器状态的加权平均值转换为解码器状态。注意力只是权重的分配。

想知道更多可以访问:https://medium.com/datalogue/attention-in-keras-1892773a4f22

在指针网络中, 注意力更简单:它不考虑输入元素,而是在概率上指向它们。实际上,你得到了输入的排列。有关更多细节和公式, 请参阅论文:

https://arxiv.org/abs/1506.03134

注意, 不需要使用所有的指针。例如, 给定一段文本, 网络可以通过指向两个元素来标记摘录: 它的起始位置和结束位置。

实验

我们从顺序数字开始?换句话说,一个深入的argsort:

In [3]: np.argsort([10,30,20 ])
Out[3]: array([0,2,1], dtype=int64)

In [4]: np.argsort([40,10,30,20 ])
Out[4]: array([1,3,2,0], dtype=int64)

令人惊讶的是,作者在论文中没有继续进行完成任务。相反的,他们使用两个奇特的问题:旅行推销员和凸包(参考README), 虽然结果是好的。但为什么不按照数字顺序呢?

原来,数字排序很难做到。他们在后续文件中提到了这个问题(Order Matters: Sequence to sequence for sets)。重点是顺序不能错。也就是说,我们讨论的是输入元素的顺序。作者发现,它对结果影响很大, 这不是我们想要的。因为本质上我们处理的是集合作为输入, 而不是序列。集合没有固定的顺序,所以元素是如何排列在理论上不应该影响结果。

因此, 本文介绍了一种改进的架构, 它们通过连接到另一个LSTM的前馈网络来替换LSTM编码器。这就是说,LSTM重复运行,以产生一个置换不变的嵌入给输入。解码器同样是一个指针网络。

让我们回到数字排列。较长的集合更难去排列。对于5个数字,他们报告的准确度范围是81%-94%, 具体取决于模型 (这里提到的准确度是指正确排序序列的百分比)。当处理15数字时, 这个范围变成了0%-10%。

在我们的研究中,对于五个数字,我们几乎达到了100%的准确度。请注意, 这是Keras所报告的 “分类精度”, 意思是在正确位置上元素的百分比。例如, 这个例子是50%准确度,即前两个元素不动, 但最后两个被调换:

4 3 2 1 ->3 2 0 1

对于有八元素的序列, 分类精度下降到大约33%。我们还尝试了一个更具挑战性的任务, 按它们的和对一个集合进行排序:

[1 2] [3 4] [2 3]->0 2 1

网络处理它就像处理简单的(un)标量数字。

我们注意到的一个意想不到的事情是, 网络倾向于重复指针, 尤其是在训练的早期。这是令人失望的:显然它不记得它不久之前的预测。

y_test: [2 0 1 4 3]
p:      [2 2 2 2 2]

在训练的早期, 人们聚集在一起, 构想指针网络的输出。

y_test: [2 0 1 4 3]
p:      [2 0 2 4 3]

同时, 训练有时会被某种准确度所困。而一个对少量数字进行训练的网络并不能概括更大的, 比如:

981,66,673
856,10,438
884,808,241

为了帮助网络使用数字, 我们添加一个 ID (1,2, 3…) 到序列的每个元素。这个假设是因为注意力是基于内容的, 也许它可以使用内容中明确编码的位置。此ID是一个数字 (train_with_positions) 或独热向量 (train_with_positions_categorical)。这看起来有点效果,但没有解决根本问题。

实验代码在GitHub可以使用。与original repo相比, 我们添加了一个数据生成脚本, 并更改了训练脚本以从生成的文件中加载数据。我们还将优化算法改成RMSPro, 因为它在处理学习率的过程中似乎收敛得很好。

数据结构

3D数组中的数据。第一个维度 (行) 是像往常一样的例子。第二个维度“列”通常是特征(属性), 但带序列的特征进入第三个维度。第二个维度由给定序列的元素组成。下面是三个序列示例, 每个都有三个元素 (步骤), 每个元素都有两个特征:

array([[[8, 2],
        [3, 3],
        [10, 3]],

       [[1, 4],
        [19,12],
        [4,10]],

       [[19, 0],
        [15,12],
        [8, 6]],

目标是按特征的和对元素进行排序, 因此相应的目标将是:

array([[1,0,2],
       [0,2,1],
       [2,0,1],

并且,它们将被明确编码:

array([[[0., 1., 0.],
        [1., 0., 0.],
        [0., 0., 1.]],

       [[1., 0., 0.],
        [0., 0., 1.],
        [0., 1., 0.]],

       [[0., 0., 1.],
        [1., 0., 0.],
        [0., 1., 0.]],

这里有一个问题,我们一直在讨论循环网络如何处理可变长度的序列,但实际上数据是3D数组,如上所示。换句话说,序列长度是固定的。

处理这一问题的方法是在最大可能的序列长度上固定维度, 并用零填充未使用的位置。

但它有可能搞乱代价函数,因此我们更好地掩盖那些零, 确保他们在计算损失时被省略。Keras官方的做法似乎是embdedding layer。相关参数为mask_zero:

mask_zero: 无论输入值0是否是一个特殊的 “padding” 值, 都应该被屏蔽掉。当使用可变长度输入的循环层时这很有用。如果它为“True”,那么模型中的所有后续层都需要支持掩蔽, 否则将引发异常。如果 mask_zero设置为True, 那么作为一个序列,词汇表中不能使用索引0(input_dim应等于词汇量“+1”)。

关于实现

我们使用了一个Keras执行的指针网络。GitHub上还有一些其他的, 大部分用Tensorflow。

附录A:指针网络的实现

  • https://github.com/keon/pointer-networks 幻灯片
  • https://github.com/devsisters/pointer-network-tensorflow
  • https://github.com/vshallc/PtrNets
  • https://github.com/ikostrikov/TensorFlow-Pointer-Networks
  • https://github.com/Chanlaw/pointer-networks
  • https://github.com/devnag/tensorflow-pointer-networks
  • https://github.com/udibr/pointer-generator
  • https://github.com/JerrikEph/SentenceOrdering_PTR
  • https://github.com/pradyu1993/seq2set-keras

附录B:seq2seq的一些注意力的实现

  • https://github.com/philipperemy/keras-attention-mechanism
  • https://github.com/tensorflow/models/tree/master/textsum
  • https://github.com/tensorflow/tensor2tensor
  • Translation with a Sequence to Sequence Network and Attention (PyTorch tutorial)
  • https://github.com/MaximumEntropy/Seq2Seq-PyTorch
  • https://github.com/rowanz/pytorch-seq2seq
  • https://github.com/chainer/chainer/tree/seq2seq-europal/examples/seq2seq

原文发布于微信公众号 - ATYUN订阅号(atyun_com)

原文发表时间:2017-09-23

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏Vamei实验室

Python标准库12 数学与随机数 (math包,random包)

我们已经在Python运算中看到Python最基本的数学运算功能。此外,math包补充了更多的函数。当然,如果想要更加高级的数学功能,可以考虑选择标准库之外的n...

2178
来自专栏素质云笔记

聚类︱python实现 六大 分群质量评估指标(兰德系数、互信息、轮廓系数)

之前关于聚类题材的博客有以下两篇: 1、 笔记︱多种常见聚类模型以及分群质量评估(聚类注意事项、使用技巧) 2、k-means+python︱sciki...

1.5K9
来自专栏深度学习自然语言处理

基于attention的seq2seq机器翻译实践详解

理理思路 文本处理,这里我是以eng_fra的文本为例,每行是english[tab]french,以tab键分割。获取文本,清洗。 分别建立字典,一个engl...

5306
来自专栏数据小魔方

左手用R右手Python系列——因子变量与分类重编码

今天这篇介绍数据类型中因子变量的运用在R语言和Python中的实现。 因子变量是数据结构中用于描述分类事物的一类重要变量。其在现实生活中对应着大量具有实际意义的...

3995
来自专栏互联网大杂烩

算法岗面试

快速排序由于排序效率在同为O(N*logN)的几种排序方法中效率较高,因此经常被采用,再加上快速排序思想----分治法也确实实用,因此很多软件公司的笔试面试,包...

812
来自专栏余林丰

12.高斯消去法(1)——矩阵编程基础

对于一阶线性方程的求解有多种方式,这里将介绍利用高斯消去法解一阶线性方程组。在介绍高斯消去法前需要对《线性代数》做一下温习,同时在代码中对于矩阵的存储做一个简...

2427
来自专栏深度学习之tensorflow实战篇

Python中map函数

python中的map()函数 map(function, iterable, ...) 1.对可迭代函数'iterable'中的每一个元素应用‘functi...

3584
来自专栏菜鸟程序员

Java中在特定区间产生随机数

722
来自专栏C语言及其他语言

【优质题解】题号1174:【计算直线的交点数】 (C语言描述)

题号1174,原题见下图: ? 解题思路: 将n条直线排成一个序列,直线2和直线1最多只有一个交点,直线3和直线1,2最多有两个交点,……,直线n 和其他n...

2946
来自专栏机器之心

教程 | 入门Python神经机器翻译,这是一篇非常精简的实战指南

传统意义上来说,机器翻译一般使用高度复杂的语言知识开发出的大型统计模型,但是近来很多研究使用深度模型直接对翻译过程建模,并在只提供原语数据与译文数据的情况下自动...

2711

扫码关注云+社区

领取腾讯云代金券