首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >为什么在HuggingFace BART中生成时需要一个decoder_start_token_id?

为什么在HuggingFace BART中生成时需要一个decoder_start_token_id?
EN

Stack Overflow用户
提问于 2020-11-19 11:09:17
回答 1查看 689关注 0票数 3

在HuggingFace的代码中的生成阶段:https://github.com/huggingface/transformers/blob/master/src/transformers/generation_utils.py#L88-L100

他们传入了一个decoder_start_token_id,我不知道他们为什么需要这个。在BART配置中,decoder_start_token_id实际上是2 (https://huggingface.co/facebook/bart-base/blob/main/config.json),这是句子末尾的标记</s>

我尝试了一个简单的例子:

代码语言:javascript
运行
复制
from transformers import *

import torch
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
input_ids = torch.LongTensor([[0, 894, 213, 7, 334, 479, 2]])
res = model.generate(input_ids, num_beams=1, max_length=100)

print(res)

preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).strip() for g in res]
print(preds)

我得到的结果是:

代码语言:javascript
运行
复制
tensor([[  2,   0, 894, 213,   7, 334, 479,   2]])
['He go to school.'] 

虽然它不会影响最终的“标记化解码”结果。但我觉得奇怪的是,我们生成的第一个令牌实际上是2(</s>)。

EN

回答 1

Stack Overflow用户

发布于 2021-11-09 02:45:07

您可以在encoder-decoder models的代码中看到,解码器的输入标记从原始标记向右移动(请参阅函数shift_tokens_right)。这意味着要猜测的第一个标记始终是BOS (句子的开头)。您可以检查您的示例中是否存在这种情况。

为了让解码器理解这一点,我们必须选择始终跟随BOS的第一个令牌,那么它会是哪一个呢?博斯?显然不是因为它后面必须有常规的标记。填充标记?也不是一个好的选择,因为它后面是另一个填充标记或EOS (句子结束)。那么,EOS呢?嗯,这是有道理的,因为它后面永远不会有训练集中的任何东西,所以不会有下一个冲突的标记。此外,句子的开头跟在另一个句子的结尾,这不是很自然吗?

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64904840

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档