在HuggingFace的代码中的生成阶段:https://github.com/huggingface/transformers/blob/master/src/transformers/generation_utils.py#L88-L100
他们传入了一个decoder_start_token_id
,我不知道他们为什么需要这个。在BART配置中,decoder_start_token_id
实际上是2
(https://huggingface.co/facebook/bart-base/blob/main/config.json),这是句子末尾的标记</s>
。
我尝试了一个简单的例子:
from transformers import *
import torch
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
input_ids = torch.LongTensor([[0, 894, 213, 7, 334, 479, 2]])
res = model.generate(input_ids, num_beams=1, max_length=100)
print(res)
preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).strip() for g in res]
print(preds)
我得到的结果是:
tensor([[ 2, 0, 894, 213, 7, 334, 479, 2]])
['He go to school.']
虽然它不会影响最终的“标记化解码”结果。但我觉得奇怪的是,我们生成的第一个令牌实际上是2
(</s>
)。
发布于 2021-11-09 02:45:07
您可以在encoder-decoder models的代码中看到,解码器的输入标记从原始标记向右移动(请参阅函数shift_tokens_right
)。这意味着要猜测的第一个标记始终是BOS (句子的开头)。您可以检查您的示例中是否存在这种情况。
为了让解码器理解这一点,我们必须选择始终跟随BOS的第一个令牌,那么它会是哪一个呢?博斯?显然不是因为它后面必须有常规的标记。填充标记?也不是一个好的选择,因为它后面是另一个填充标记或EOS (句子结束)。那么,EOS呢?嗯,这是有道理的,因为它后面永远不会有训练集中的任何东西,所以不会有下一个冲突的标记。此外,句子的开头跟在另一个句子的结尾,这不是很自然吗?
https://stackoverflow.com/questions/64904840
复制相似问题