首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >当要查找`start`分数最高的令牌时,torch.argmax()中的TypeError

当要查找`start`分数最高的令牌时,torch.argmax()中的TypeError
EN

Stack Overflow用户
提问于 2021-09-19 02:52:23
回答 2查看 103关注 0票数 1

我想运行这段代码,使用拥抱面孔转换器回答问题。

代码语言:javascript
运行
复制
import torch
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer

#Model
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

#Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

question = '''Why was the student group called "the Methodists?"'''

paragraph = ''' The movement which would become The United Methodist Church began in the mid-18th century within the Church of England.
            A small group of students, including John Wesley, Charles Wesley and George Whitefield, met on the Oxford University campus.
            They focused on Bible study, methodical study of scripture and living a holy life.
            Other students mocked them, saying they were the "Holy Club" and "the Methodists", being methodical and exceptionally detailed in their Bible study, opinions and disciplined lifestyle.
            Eventually, the so-called Methodists started individual societies or classes for members of the Church of England who wanted to live a more religious life. '''
            
encoding = tokenizer.encode_plus(text=question,text_pair=paragraph)

inputs = encoding['input_ids']  #Token embeddings
sentence_embedding = encoding['token_type_ids']  #Segment embeddings
tokens = tokenizer.convert_ids_to_tokens(inputs) #input tokens

start_scores, end_scores = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding]))

start_index = torch.argmax(start_scores)

但是我在最后一行得到了这个错误:

代码语言:javascript
运行
复制
Exception has occurred: TypeError
argmax(): argument 'input' (position 1) must be Tensor, not str
  File "D:\bert\QuestionAnswering.py", line 33, in <module>
    start_index = torch.argmax(start_scores)

我不知道出了什么问题。有谁可以帮我?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2021-09-19 03:17:18

BertForQuestionAnswering返回一个QuestionAnsweringModelOutput对象。

由于您将BertForQuestionAnswering的输出设置为start_scores, end_scores,因此将强制将返回的QuestionAnsweringModelOutput对象转换为字符串元组('start_logits', 'end_logits'),从而导致类型不匹配错误。

下面的代码应该可以工作:

代码语言:javascript
运行
复制
outputs = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding]))

start_index = torch.argmax(outputs.start_logits)
票数 1
EN

Stack Overflow用户

发布于 2021-09-19 03:38:34

Huggingface转换器提供了一种运行模型的简单的高级方法,如此guide所示

代码语言:javascript
运行
复制
from transformers import pipeline

nlp = pipeline('question-answering', model=model, tokenizer=tokenizer)
print(nlp(question=question, context=paragraph, topk=5))

topk允许选择几个得分最高的答案。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69239925

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档