首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >'Seq2SeqModelOutput‘对象没有属性'logits’BART转换器

'Seq2SeqModelOutput‘对象没有属性'logits’BART转换器
EN

Stack Overflow用户
提问于 2021-07-12 06:54:41
回答 1查看 3.8K关注 0票数 3

我试图生成长篇PDF的摘要。所以,我所做的,首先我把我的pdf转换成文字使用pdfminer.six库。接下来,我使用了讨论这里中提供的两个函数。

守则:

代码语言:javascript
运行
复制
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
bart_model = BartModel.from_pretrained("facebook/bart-large", return_dict=True)

# generate chunks of text \ sentences <= 1024 tokens
def nest_sentences(document):
  nested = []
  sent = []
  length = 0
  for sentence in nltk.sent_tokenize(document):
    length += len(sentence)
    if length < 1024:
      sent.append(sentence)
    else:
      nested.append(sent)
      sent = [sentence]
      length = len(sentence)

  if sent:
    nested.append(sent)
  return nested

# generate summary on text with <= 1024 tokens
def generate_summary(nested_sentences):
  device = 'cuda'
  summaries = []
  for nested in nested_sentences:
    input_tokenized = bart_tokenizer.encode(' '.join(nested), truncation=True, return_tensors='pt')
    input_tokenized = input_tokenized.to(device)
    summary_ids = bart_model.to(device).generate(
        input_tokenized,
        length_penalty=3.0,
        min_length=30,
        max_length=100,
    )
    output = [bart_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]
    summaries.append(output)
  summaries = [sentence for sublist in summaries for sentence in sublist]
  return summaries

然后,为了得到总结,我做了:

代码语言:javascript
运行
复制
nested_sentences = nest_sentences(text)

其中,text是一个长度在10K左右的字符串,我用pdf库进行了转换。

代码语言:javascript
运行
复制
summary = generate_summary(nested_sentences)

然后,我得到以下错误:

代码语言:javascript
运行
复制
---------------------------------------------------------------------------

AttributeError                            Traceback (most recent call last)

<ipython-input-15-d5aa7709bb5f> in <module>()
----> 1 summary = generate_summary(nested_sentences)

3 frames

<ipython-input-11-8554509269e0> in generate_summary(nested_sentences)
     28         length_penalty=3.0,
     29         min_length=30,
---> 30         max_length=100,
     31     )
     32     output = [bart_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]

/usr/local/lib/python3.7/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     26         def decorate_context(*args, **kwargs):
     27             with self.__class__():
---> 28                 return func(*args, **kwargs)
     29         return cast(F, decorate_context)
     30 

/usr/local/lib/python3.7/dist-packages/transformers/generation_utils.py in generate(self, input_ids, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, repetition_penalty, bad_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, max_time, max_new_tokens, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, forced_bos_token_id, forced_eos_token_id, remove_invalid_values, synced_gpus, **model_kwargs)
   1061                 return_dict_in_generate=return_dict_in_generate,
   1062                 synced_gpus=synced_gpus,
-> 1063                 **model_kwargs,
   1064             )
   1065 

/usr/local/lib/python3.7/dist-packages/transformers/generation_utils.py in beam_search(self, input_ids, beam_scorer, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs)
   1799                 continue  # don't waste resources running the code we don't need
   1800 
-> 1801             next_token_logits = outputs.logits[:, -1, :]
   1802 
   1803             # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`

AttributeError: 'Seq2SeqModelOutput' object has no attribute 'logits'

我找不到与这个错误有关的任何东西,所以如果有人能帮忙,或者有什么更好的方法来为长的文本生成摘要,我会非常感激。

提前谢谢你!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-07-16 04:55:22

这里的问题是BartModel行。将此转换为BartForConditionalGeneration类,问题将得到解决。实际上,生成实用程序假定它是一个可以用于语言生成的模型,在这种情况下,BartModel只是没有LM头的基础。

票数 5
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68343073

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档