首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何提前停止自回归模型与停止词列表?

如何提前停止自回归模型与停止词列表?
EN

Stack Overflow用户
提问于 2021-10-01 09:30:21
回答 1查看 266关注 0票数 1

我正在使用transformers的GPT模型来生成文本.因为我使用的提示符是以'{'开头的,所以我想在生成paring '}'之后停止这个句子。我发现源代码中有一个StoppingCriteria方法,但没有进一步说明如何使用它。有没有人找到了一种方法来尽早阻止这一代的模特?谢谢!

以下是我尝试过的:

代码语言:javascript
运行
复制
from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, torch_dtype=dtype).eval()

class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords_ids:list):
        self.keywords = keywords_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids in self.keywords:
            return True
        return False

stop_words = ['}', ' }', '\n']
stop_ids = [tokenizer.encode(w) for w in stop_words]
stop_ids.append(tokenizer.eos_token_id)
stop_criteria = KeywordsStoppingCriteria(stop_ids)

model.generate(
    text_inputs='some text:{', 
    StoppingCriteria=stop_criteria
)
EN

回答 1

Stack Overflow用户

发布于 2022-04-25 17:22:15

我已经能让你的代码适应工作了。此外,确保您使用的是最新版本的变压器,您可能需要升级。

代码语言:javascript
运行
复制
import torch
from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id).eval()

class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords_ids:list):
        self.keywords = keywords_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids[0][-1] in self.keywords:
            return True
        return False


stop_words = ['}', ' }', '\n']
stop_ids = [tokenizer.encode(w)[0] for w in stop_words]
stop_criteria = KeywordsStoppingCriteria(stop_ids)


inputs = tokenizer.encode('some text: {', add_special_tokens=False, return_tensors='pt')

output = model.generate(
    inputs,
    do_sample=True,
    stopping_criteria=StoppingCriteriaList([stop_criteria]),

)
print(tokenizer.decode(*output))
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69403613

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档