首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >HuggingFace: ValueError:预期长度为165的dim 1序列(got 128)

HuggingFace: ValueError:预期长度为165的dim 1序列(got 128)
EN

Stack Overflow用户
提问于 2022-02-17 23:45:30
回答 1查看 4.2K关注 0票数 2

我试图微调伯特语言模型的我自己的数据。我已经看过他们的文档,但是他们的任务似乎不是我所需要的,因为我的最终目标是嵌入文本。这是我的密码:

代码语言:javascript
运行
复制
from datasets import load_dataset
from transformers import BertTokenizerFast, AutoModel, TrainingArguments, Trainer
import glob
import os


base_path = '../data/'
model_name = 'bert-base-uncased'
max_length = 512
checkpoints_dir = 'checkpoints'

tokenizer = BertTokenizerFast.from_pretrained(model_name, do_lower_case=True)


def tokenize_function(examples):
    return tokenizer(examples['text'], padding=True, truncation=True, max_length=max_length)


dataset = load_dataset('text',
        data_files={
            'train': f'{base_path}train.txt',
            'test': f'{base_path}test.txt',
            'validation': f'{base_path}valid.txt'
        }
)

print('Tokenizing data. This may take a while...')
tokenized_dataset = dataset.map(tokenize_function, batched=True)
train_dataset = tokenized_dataset['train']
eval_dataset = tokenized_dataset['test']

model = AutoModel.from_pretrained(model_name)

training_args = TrainingArguments(checkpoints_dir)

print('Training the model...')
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset)
trainer.train()

我得到以下错误:

代码语言:javascript
运行
复制
  File "train_lm_hf.py", line 44, in <module>
    trainer.train()
...
  File "/opt/conda/lib/python3.7/site-packages/transformers/data/data_collator.py", line 130, in torch_default_data_collator
    batch[k] = torch.tensor([f[k] for f in features])
ValueError: expected sequence of length 165 at dim 1 (got 128)

我做错了什么?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-02-23 06:08:18

我通过将tokenize函数更改为:

代码语言:javascript
运行
复制
def tokenize_function(examples):
    return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=max_length)

(注意padding参数)。此外,我还使用了数据整理器,如下所示:

代码语言:javascript
运行
复制
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)
trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset
)
票数 6
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71166789

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档