前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >教程 | 如何使用贪婪搜索和束搜索解码算法进行自然语言处理

教程 | 如何使用贪婪搜索和束搜索解码算法进行自然语言处理

作者头像
机器之心
发布2018-05-10 11:06:03
1.8K0
发布2018-05-10 11:06:03
举报
文章被收录于专栏:机器之心机器之心

选自MachineLearningMastery

作者:Jason Brownlee

机器之心编译

参与:程耀彤、路雪

本文介绍了贪婪搜索解码算法和束搜索解码算法的定义及其 Python 实现。

自然语言处理任务如图像描述生成和机器翻译,涉及生成一系列的单词。通常,针对这些问题开发的模型的工作方式是生成在输出词汇表上的概率分布,并通过解码算法对概率分布进行采样以生成可能性最大的单词序列。在本教程中,你将学习可用于文本生成问题的贪婪搜索和束搜索解码算法。

完成本教程,你将了解:

  • 文本生成问题中的解码问题;
  • 贪婪搜索解码算法及其在 Python 中的实现;
  • 束搜索解码算法及其在 Python 中的实现。

文本生成解码器

在自然语言处理任务中,如图像描述生成、文本摘要和机器翻译等,需要预测的是一连串的单词。通常,针对此类问题开发的模型会输出每个单词在对应的输出序列词汇表上的概率分布,然后解码器将其转化为最终的单词序列。

当你使用循环神经网络解决以文本作为输出的 NLP 任务时,你很可能会遇到这种情况。神经网络模型的最后一层对输出词汇表中的每个单词都有对应的一个神经元,同时 softmax 激活函数被用来输出词汇表中每个单词成为序列中下一个单词的可能性。

解码最有可能的输出序列包括根据它们的可能性搜索所有可能的输出序列。词汇表中通常含有成千上万个单词,甚至上百万个单词。因此,搜索问题根据输出序列的长度呈指数级变化,并且很难做到完全搜索(NP-complete)。

实际上,对于给定的预测,可以用启发式搜索方法返回一或多个逼近或「足够好」的解码输出序列。

由于搜索图的范围是根据源语句长度呈指数级的,所以我们必须使用近似来有效地找到解决方案。 — Page 272, Handbook of Natural Language Processing and Machine Translation, 2011.

候选单词序列的分数是根据它们的可能性评定的。通常,使用贪婪搜索或束搜索定位文本的候选序列。本文将研究这两种解码算法。

每个单独的预测都有一个关联的分数(或概率),我们对最大分数(或最大概率)的输出序列感兴趣。一种流行的近似方法是使用贪婪预测,即在每个阶段采用得分最高的项。虽然这种方法通常是有效的,但显然不是最佳的。实际上,用束搜索作为近似搜索通常比用贪婪搜索要好得多。 — Page 227, Neural Network Methods in Natural Language Processing, 2017.

贪婪搜索解码器

一个简单的近似方法是使用贪婪搜索,即在输出序列的每一步中选择最有可能的单词。该方法的优点是非常快,但最终输出序列的质量可能远非最佳。

我们可以用 Python 中的一个小例子来展示贪婪搜索的解码方式。我们从一个包含 10 个单词的序列的预测问题开始。每个单词的预测是其在五个单词组成的词汇表上的概率分布。

代码语言:javascript
复制
# define a sequence of 10 words over a vocab of 5 words
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)

我们假定单词是整数编码的,这样,列索引就可以用来查找词汇表中的相关单词。因此,解码任务就变成从概率分布中选择整数序列的任务。argmax() 数学函数可用于选择具有最大值的数组的索引。我们可以用该函数选择在序列每个步骤中最有可能的单词索引。这个函数是直接在 numpy 中提供的。

下面的 greedy_decoder() 函数用 argmax 函数实现了这个解码思路。

代码语言:javascript
复制
# greedy decoder
def greedy_decoder(data):
    # index for largest probability each row
    return [argmax(s) for s in data]

下面展示了贪婪解码器的完整示例。

代码语言:javascript
复制
from numpy import array
from numpy import argmax

# greedy decoder
def greedy_decoder(data):
    # index for largest probability each row
    return [argmax(s) for s in data]

# define a sequence of 10 words over a vocab of 5 words
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)
# decode sequence
result = greedy_decoder(data)
print(result)

运行这个示例会输出一系列整数,然后这些整数可以映射回词汇表中的单词。

代码语言:javascript
复制
[4, 0, 4, 0, 4, 0, 4, 0, 4, 0]

束搜索解码器

另一种流行的启发式算法是在贪婪搜索的基础扩展而来的束搜索,它返回的是可能性最大的输出序列列表。相对于在构建序列时就贪婪地选择最有可能的下一步,束搜索选择扩展所有可能的下一步,并保持 k 是最有可能的,k 是用户指定的参数,它通过一系列概率控制束或并行搜索的数量。

本地束搜索算法跟踪 k 个状态,而不仅仅只跟踪一个。它从 k 个随机生成的状态开始,在每一步中都生成所有 k 个状态的所有后继者。如果这其中的任何一个后继者是目标,那么算法就会停止。否则,它将从完整列表中选择k个最佳后继者并不断重复。 — Pages 125-126, Artificial Intelligence: A Modern Approach (3rd Edition), 2009.

我们不需要从随机状态开始;相反 ,我们以k个最可能的单词开始,作为序列的第一步。对于贪婪搜索,常见的束宽度值为 1,对于机器翻译中常见的基准问题,它的值为 5 或 10。由于多个候选序列增加了更好地匹配目标序列的可能性,所以较大的束宽度会使模型性能提高。性能的提高会导致解码速度降低。

在 NMT 中,新的句子通过一个简单的束搜索解码器被翻译,该解码器可以找到一个近似最大化已训练 NMT 模型的条件概率的译文。束搜索从左到右逐词完成翻译,同时在每一步中都保持固定数目(束)的活跃候选者。增大束尺寸可以提高翻译性能,但代价是解码器的速度显著降低。 — Beam Search Strategies for Neural Machine Translation, 2017.

搜索过程可以通过达到最大长度、到达序列结束标记或到达阈值可能性来分别停止每个候选项。

让我们用一个例子来具体说明这个问题。

我们可以定义一个函数来执行给定序列概率和束宽度参数k的束搜索。在每一步中,每个候选序列都被扩展为所有可能的后续步骤。每个候选步骤的分数通过概率相乘得到。选择具有最大概率的k个序列,并删去其他候选项。然后重复该过程直到序列结束。

概率是很小的数,而把小的数相乘就会得到更小的数。为了避免浮点数的下溢,可将概率的自然对数相乘,这样使得到的数字更大、更易于管理。此外,通过最小化分数来进行搜索也是很常见的,因此,可以将概率的负对数相乘。这个最后的调整使我们能够按照分数对所有候选序列进行升序排序,并选择前k个序列作为可能性最大的候选序列。

下面的 beam_search_decoder() 函数实现了束搜索解码器。

代码语言:javascript
复制
# beam search
def beam_search_decoder(data, k):
    sequences = [[list(), 1.0]]
    # walk over each step in sequence
    for row in data:
        all_candidates = list()
        # expand each current candidate
        for i in range(len(sequences)):
            seq, score = sequences[i]
            for j in range(len(row)):
                candidate = [seq + [j], score * -log(row[j])]
                all_candidates.append(candidate)
        # order all candidates by score
        ordered = sorted(all_candidates, key=lambda tup:tup[1])
        # select k best
        sequences = ordered[:k]
    return sequences

我们可以将它与上一节的样本数据结合在一起,这次返回的是 3 个可能性最大的序列。

代码语言:javascript
复制
from math import log
from numpy import array
from numpy import argmax

# beam search
def beam_search_decoder(data, k):
    sequences = [[list(), 1.0]]
    # walk over each step in sequence
    for row in data:
        all_candidates = list()
        # expand each current candidate
        for i in range(len(sequences)):
            seq, score = sequences[i]
            for j in range(len(row)):
                candidate = [seq + [j], score * -log(row[j])]
                all_candidates.append(candidate)
        # order all candidates by score
        ordered = sorted(all_candidates, key=lambda tup:tup[1])
        # select k best
        sequences = ordered[:k]
    return sequences

# define a sequence of 10 words over a vocab of 5 words
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)
# decode sequence
result = beam_search_decoder(data, 3)
# print result
for seq in result:
    print(seq)

运行该示例将输出整数序列及其对数似然函数值。

试用不同的 k 值。

代码语言:javascript
复制
[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108]
[[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397]
[[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]

扩展阅读

如果你想更深入的了解,本节将提供更多的关于该主题的资源。

  • Argmax on Wikipedia(https://en.wikipedia.org/wiki/Arg_max)
  • Numpy argmax API(https://docs.scipy.org/doc/numpy-1.9.3/reference/generated/numpy.argmax.html)
  • Beam search on Wikipedia(https://en.wikipedia.org/wiki/Beam_search)
  • Beam Search Strategies for Neural Machine Translation, 2017.(https://arxiv.org/abs/1702.01806)
  • Artificial Intelligence: A Modern Approach (3rd Edition), 2009.(http://amzn.to/2x7ynhW)
  • Neural Network Methods in Natural Language Processing, 2017.(http://amzn.to/2fC1sH1)
  • Handbook of Natural Language Processing and Machine Translation, 2011.(http://amzn.to/2xQzTnt)
  • Pharaoh: a beam search decoder for phrase-based statistical machine translation models, 2004.(https://link.springer.com/chapter/10.1007%2F978-3-540-30194-3_13?LI=true)

本文为机器之心编译,转载请联系本公众号获得授权。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2018-02-03,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 机器之心 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
机器翻译
机器翻译(Tencent Machine Translation,TMT)结合了神经机器翻译和统计机器翻译的优点,从大规模双语语料库自动学习翻译知识,实现从源语言文本到目标语言文本的自动翻译,目前可支持十余种语言的互译。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档