首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

鹅厂专家讲透AI文本生成解码策略与代码实现

腾小云导读

本文以 huggingface-transformers 的文本生成解码代码为例,对文本生成常用的五种解码策略 greedy search、beam search、sample、sample and rank & beam sample、group beam search 进行逐行解读。每一小节首先会介绍对应解码策略的原理,接着给出供大家快速上手的代码示例,并逐层介绍调用过程,最后给出所使用到的所有类之间调用的时序图。由简到繁再到简,帮助大家建立起一个整体的认识,并且能够快速应用。干货较多,欢迎阅读并进行实践尝试。

目录

1 总体介绍

2 greedy search

  2.1 原理介绍

  2.2 快速上手

2.3 代码解读

2.4 整体流程

3 beam search

  3.1 原理介绍

  3.2 快速上手

  3.3 代码解读

  3.4 整体流程

4 sample

  4.1 原理介绍

  4.2 快速上手

  4.3 代码解读

  4.4 整体流程

5 sample and rank & beam sample

  5.1 原理介绍

  5.2 快速上手

  5.3 代码解读

  5.4 整体流程

6 group beam search

  6.1 原理介绍

  6.2 快速上手

  6.3 代码解读

  6.4 整体流程

7 总结

8 主流模型方案

01

总体介绍

在 T5/GPT 等自回归模型中,解码策略直接影响到模型输出的效果。在解码第 t 个 token w 时,模型依赖前面的 t-1 个 token,计算概率分布 P(w∣w1:t−1 )。根据该概率分布,研究者们设计了各式各样的解码策略,每一种解码策略都对应了一个或多个相关的参数,多种参数糅合在一起,容易让人摸不着头脑。在对应官网提供的 API 中,我们可以看到也提供了一些用于调整解码策略的参数,如 temperature、top_p 等。

02

greedy search

2.1 原理介绍

最简单的策略就是 greedy decoding,即每步选择概率最大的 token:。如上图所示,从单词 The 开始,该策略每步都会选择下一步概率最大的词,最后会得到输出序列 The nice woman,总概率是 0.5 * 0.4 = 0.2。greedy decoding 速度最快,也有如下几个缺点:

2.2 快速上手

快速上手的代码参考:Generation,更详细的参数介绍也可从中获取。

链接:https://huggingface.co/docs/transformers/main_classes/text_generation

2.3 代码解读

主要针对快速上手的第30-32行代码调用的 greedy_search 方法进行解读。

代码地址:

transformers/utils.py at v4.26.1 · huggingface/transformers · GitHub

2.3.1 基本设置,对后续需要使用的变量进行初始化

2.3.2 从 bos_token 开始解码

代码:

transformers/logits_process.py at v4.26.1 · huggingface/transformers · GitHub

此处会调用__call__方法,参数 input_ids 为已生成的序列,scores 为下一步预测 token 的得分。

这里介绍快速上手中使用的两种预处理方法最小长度和重复词惩罚对应的 processor。

· 最小长度

代码:transformers/logits_process.py at v4.26.1 · huggingface/transformers · GitHub

· 重复词惩罚

代码:transformers/logits_process.py at v4.26.1 · huggingface/transformers · GitHub

2.3.3 解码结束,返回结果

若需要返回生成过程中的详细结果,则根据架构为 encoder-decoder 和 decoder-only 分别返回对应 dict,否则直接返回预测序列;

2.4 整体流程

整体流程如下面的时序图所示

03

beam search

3.1 原理介绍

为了解决 greedy decoding 可能错过全局最大概率序列的问题,beam search 策略经常会被采用,即维护 beam=n,保留当前最佳的n个序列,并且对于每个序列,都在计算最好的 n 个 next token,然后再从 n*n 个结果中,保留 n 个概率乘积最大的序列。比如上图中,假设 beam=2,从 The 开始,会保留[The dog, The nice]两个序列,接着每个序列选取2个最佳的next token,得到4个序列,再从中选择2个最佳序列[The dog has, The nice woman]。然而,beam Search 有以下缺点:

3.2 快速上手

3.3 代码解读

主要针对快速上手的第45行代码调用的 beam_search 方法进行解读

代码地址:

transformers/utils.py at v4.26.1 · huggingface/transformers · GitHub

3.3.1 基本设置,对后续需要使用的变量进行初始化

这一步与 greedy search 基本一致,区别在于需要额外初始化一些用于 beam search 的变量。

3.3.2 从 bos_token 开始解码

代码:

transformers/logits_process.py at v4.26.1 · huggingface/transformers · GitHub

代码:

transformers/beam_search.py at v4.26.1 · huggingface/transformers · GitHub

代码:

transformers/beam_search.py at v4.26.1 · huggingface/transformers · GitHub

代码:

transformers/beam_search.py at v4.26.1 · huggingface/transformers · GitHub

代码:

transformers/beam_search.py at v4.26.1 · huggingface/transformers · GitHub

3.3.3 解码结束,返回结果

这一步的逻辑与 greedy search 基本一致;

3.4 整体流程

04

sample

4.1 原理介绍

4.1.1 Random sampling

随机采样策略根据当前的概率来抽签选择 next token,即。如上图所示,任何词都有一定概率被选择。该方案生成的序列充满了创造性,也相对较少出现重复序列循环问题。但是它生成的语句却很可能不通顺。

这里一般会引入 temperature,来改变生成 next token 的概率分布,使其更偏向于 high probability token。具体做法是在 softmax 中引入 t,取值范围(0, 1]。t 趋近于0,就变成了 greedy search。通过调整 t 的大小,可以避免 sample from tail distribution。

4.1.2 Top-k sampling

Fan et. al (2018) 提出了 Top-K 采样策略。该策略会在采样之前缩减采样空间,只保留概率最高的 k 个词,然后重新进行归一化得到新的概率分布。比如上图中,取 k=6,则只在6个词中进行采样,这6个词总概率有可能不高(左图),但也可能非常接近1(右图)。这会造成两个问题:

a.  左图中的 people, big, house 等词实际上可能是合理的输出,但是却不在候选里,这就限制了模型的创造性和多样性。

b.  右图中,down, a 的概率很小,但是仍被放在了候选中,这就有可能让模型输出不通顺的垃圾信息。

4.1.3 Top-p(Nucleus)sampling

为了解决上述 top-k 采样的问题,Holtzman et al. (2019) 提出了 top-p 采样策略(nucleus sampling)。给定一个概率阈值 p,从解码词候选集中选择一个最小集 Vp,使得它们出现的概率和大于等于 p。然后再对 Vp 做一次 re-scaling,本时间步仅从 Vp 集合中解码。

比如上图中,将阈值 p 设为0.9,左图会从9个候选词中筛选,右图会从3个候选词中筛选。

从本质上看,Top-p Sampling 和 Top-k Sampling 都是从缩小的候选 token 集合中 sample token,区别在于如何缩小候选集合。在实际使用中,top-k 和 top-p 有时也会同时使用,来避免采样到非常低概率的词,同时保证结果的多样性。

从上表中可以看出,top-p (nucleus)策略的结果是与 human 结果最相近的。并且有较低的重复率 repetition%

4.2 快速上手

4.3 代码解读

主要针对快速上手的第41-46行代码调用的 sample 方法进行解读.

代码地址:

transformers/utils.py at v4.26.1 · huggingface/transformers · GitHub

4.3.1 基本设置,对后续需要使用的变量进行初始化

这一步与 greedy search 基本相同,唯一区别在于初始化了一个 logits_warper;

4.3.2 从bos_token开始解码

这里介绍快速上手中使用的两个采样方法 top-k 和 top-p 对应的 wraper。

top-k

代码:

transformers/logits_process.py at v4.26.1 · huggingface/transformers · GitHub

top-p

代码:

transformers/logits_process.py at v4.26.1 · huggingface/transformers · GitHub

4.3.3 解码结束,返回结果

这一步的逻辑与 greedy search 基本一致;

4.4 整体流程

整体流程如下面的时序图所示:

05

sample and rank & beam sample

5.1 原理介绍

Adiwardana et al., 2020 提出了 sample-and-rank 解码策略,该方法在对话领域效果很好。其思想是先通过 random sampling(结合temperature调整概率分布)生成出 N 个 sentence,然后再从 n 个 sentence 中选择概率乘积最大的。

这种方式通过 random sampling 保留了生成结果的多样性和创造性,后又通过 rank 过滤掉了不通顺的序列。下面两个表格对比了 sample 的结果和 beam search 的结果。明显地,sample 结果多样性会更好。

beam sample 方法是 sample and rank 的改进,原理上类似,相比 sample and rank 在最后才对结果排序去获得最佳的 n 个序列,beam sample在每一步保留当前最佳的 n 个序列,既保证了多样性和创造性,又可以减少在 rank 阶段需要过滤掉的句子

5.2 快速上手

5.3 代码解读

主要针对快速上手的第46-48行代码调用的 beam_sample 方法进行解读。

代码地址:transformers/utils.py at ae54e3c3b18bac0832ad62ea9b896dfd52a09850 · huggingface/transformers · GitHub

5.3.1 基本设置,对后续需要使用的变量进行初始化

这一步与 beam search 相同。

5.3.2 从bos_token开始解码

5.3.3 解码结束,返回结果

这一步的逻辑与 greedy search 基本一致;

5.4 整体流程

整体流程如下面的时序图所示:

06

group beam search

6.1 原理介绍

group beam search 同样是为了解决 beam search 多样性不足的问题,如上图所示,可以发现 beam search 生成的图像描述几乎是重复的,这是由于在搜索树中具有相似的共享路径,导致最终的变化很小。相比之下,group(diverse) beam search 生成的结果则更多样化,也更加类似描述图像的人际差异。

group beam search 主要思路是通过将 beam search 中的候选路径进行分组,在各组内去寻找最优解。如上图所示,beam search 的候选路径有6条,group beam search 将这6条候选路径两两作为一组,分为三组。每一步都在各组内的词表空间下去取 top-2 的结果作为当前预测的 token,对于当前组来说,通过对之前组已生成的 token 进行惩罚,来保证当前组生成的 token 与之前组不同的概率更大,从而更具多样性

6.2 快速上手

6.3 代码解读

主要针对快速上手的第47-49行代码调用的 group beam search 方法进行解读。

代码地址:transformers/utils.py at ae54e3c3b18bac0832ad62ea9b896dfd52a09850 · huggingface/transformers · GitHub

6.3.1 基本设置,对后续需要使用的变量进行初始化

这一步与 beam search 基本一致,区别在于需要额外初始化一些用于 group beam search 的变量。

6.3.2 从 bos_token 开始解码

代码:transformers/logits_process.py at v4.26.1 · huggingface/transformers · GitHub

6.3.3 解码结束,返回结果

这一步的逻辑与 greedy search 基本一致;

6.4 整体流程

整体流程如下面的时序图所示:

07

总结

通过前面的介绍,相信大家已经发现了,各种解码策略无非是通过调整 logits(即模型对每个 token 的预测得分)和 batch_size,来获得不同的生成结果。

对 logits 做调整一般又可分为是用于预处理还是采样,对用于预处理的最小长度、重复词惩罚这些功能,抽象出基类 Processor 类,对用于采样的 top-k、top-p 这些功能,抽象出基类 Warper。而所有对 logits 做调整的功能类都可以又加入到 LogitsProcessList,组成一个 pipeline,每次想用哪一个对其进行初始化并加入即可。

对 batch_size 做调整主要在需要生成多个候选或是需要返回多个结果的情况下,对于 beam search 系列的解码策略,通过将 batch_size 扩大候选路径的个数倍,来获得不同的候选序列。对 sample 系列的解码策略,通过将 batch_size 扩大返回结果个数倍,来采样得到不同的结果。

08

主流模型方案

以上方案被主流模型所采用。下面表格罗列了从公开论文中梳理出的解码方案:

以上就是本篇文章的全部分享,看完文章的开发者可以收藏一下,跟着文章步骤实机进行操作。

参考文献

Holtzman A, Buys J, Du L, et al. The curious case of neural text degeneration[J]. arXiv preprint arXiv:1904.09751, 2019.

Fan A, Lewis M, Dauphin Y. Hierarchical neural story generation[J]. arXiv preprint arXiv:1805.04833, 2018.

Adiwardana D, Luong M T, So D R, et al. Towards a human-like open-domain chatbot[J]. arXiv preprint arXiv:2001.09977, 2020.

Radford A, Wu J, Child R, et al. Language models are unsupervised multitask learners[J]. OpenAI blog, 2019, 1(8): 9.

Brown T, Mann B, Ryder N, et al. Language models are few-shot learners[J]. Advances in neural information processing systems, 2020, 33: 1877-1901.

Thoppilan R, De Freitas D, Hall J, et al. Lamda: Language models for dialog applications[J]. arXiv preprint arXiv:2201.08239, 2022.

Touvron H, Lavril T, Izacard G, et al. LLaMA: Open and Efficient Foundation Language Models[J]. arXiv preprint arXiv:2302.13971, 2023.

Ashwin K V, Michael C, et al. diverse beam search: decoding diverse soulutions from neural sequence models[J]. arXiv preprint arXiv:1610.02424, 2016.

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20230601A08YS200?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

相关快讯

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券