首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何理解mbart中的decoder_start_token_id和forced_bos_token_id?

如何理解mbart中的decoder_start_token_id和forced_bos_token_id?
EN

Stack Overflow用户
提问于 2021-07-09 16:03:04
回答 1查看 185关注 0票数 0

当我想使用huggingface的预训练模型进行多语言实验时,参数decoder_start_token_idforced_bos_token_id的含义让我感到困惑。我找到类似这样的代码:

代码语言:javascript
运行
复制
# While generating the target text set the decoder_start_token_id to the target language id. 
# The following example shows how to translate English to Romanian 
# using the facebook/mbart-large-en-ro model.
from transformers import MBartForConditionalGeneration, MBartTokenizer

tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro", src_lang="en_XX")
article = "UN Chief Says There Is No Military Solution in Syria"
inputs = tokenizer(article, return_tensors="pt")
translated_tokens = model.generate(**inputs, decoder_start_token_id=tokenizer.lang_code_to_id["ro_RO"])
tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]

和:

代码语言:javascript
运行
复制
# To generate using the mBART-50 multilingual translation models, 
# eos_token_id is used as the decoder_start_token_id and the target language id is forced as the first generated token. 
# To force the target language id as the first generated token, 
# pass the forced_bos_token_id parameter to the generate method. 
# The following example shows how to translate between Hindi to French and Arabic to English 
# using the facebook/mbart-50-large-many-to-many checkpoint.
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

article_hi = "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है"
article_ar = "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا."

model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")

# translate Hindi to French
tokenizer.src_lang = "hi_IN"
encoded_hi = tokenizer(article_hi, return_tensors="pt")
generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.lang_code_to_id["fr_XX"])
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
# => "Le chef de l 'ONU affirme qu 'il n 'y a pas de solution militaire en Syria."

# translate Arabic to English
tokenizer.src_lang = "ar_AR"
encoded_ar = tokenizer(article_ar, return_tensors="pt")
generated_tokens = model.generate(**encoded_ar, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"])
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
# => "The Secretary-General of the United Nations says there is no military solution in Syria."

而这两个参数的注释是:

代码语言:javascript
运行
复制
decoder_start_token_id (:obj:`int`, `optional`): 
If an encoder-decoder model starts decoding with a different token than `bos`, 
the id of that token.

forced_bos_token_id (:obj:`int`, `optional`): 
The id of the token to force as the first generated token after the :obj:`decoder_start_token_id`.
Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where 
the first generated token needs to be the target language token.

对于mbart的不同变体,例如facebook/mbart-large-cc25facebook/mbart-large-50,我们应该指定哪一个来生成特定语言的响应?

EN

回答 1

Stack Overflow用户

发布于 2021-07-14 16:04:25

在标准的序列到序列模型中,解码首先向解码器提供[bos]符号,它生成单词w1,该单词在下一步中作为解码器的输入提供。并且解码器生成单词w2。这将一直持续到生成[eos] (句子结束)标记。

代码语言:javascript
运行
复制
[bos] w_1  w_2  w_3
  ↓    ↓    ↓    ↓
┌──────────────────┐
│     DECODER      │
└──────────────────┘
  ↓    ↓    ↓    ↓
 w_1  w_2  w_2 [eos]

使用mBART,这会更加棘手,因为您需要告诉它目标语言和源语言是什么。对于编码器和训练数据,分词器负责这一点,并在源句的末尾和目标句的开头添加特定于语言的标记。然后,句子按格式排列(假设源有4个单词,目标有3个单词):

消息来源:v1 v2 v3 v4 [src_lng]

  • target:[tgt_lng] w1 w2 w3 [eos]

与训练不同,在推理时,目标句子是未知的,您希望生成它。但是您仍然需要告诉解码器它应该使用什么,而不是通用的[bos]令牌。这就是forced_bos_token_id发挥作用的地方。它仍然是知道特定令牌的It的令牌器。不同的mBART具有不同的标记器,您应该始终使用与模型匹配的标记器的语言ID。

您提到的属性似乎做了同样的事情,但我将坚持使用mBART documentation中提到的forced_bos_token_id。HuggingFace转换器中的方法API非常宽松,其中一些属性只适用于某些模型,而被其他模型忽略。我会避免使用在特定模型的文档中没有明确提到的内容。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68313263

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档