前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >使用huggingface全家桶(transformers, datasets)实现一条龙BERT训练(trainer)和预测(pipeline)

使用huggingface全家桶(transformers, datasets)实现一条龙BERT训练(trainer)和预测(pipeline)

作者头像
blmoistawinde
发布2021-01-21 10:45:35
5K0
发布2021-01-21 10:45:35
举报

使用huggingface全家桶(transformers, datasets)实现一条龙BERT训练(trainer)和预测(pipeline)

huggingface的transformers在我写下本文时已有39.5k star,可能是目前最流行的深度学习库了,而这家机构又提供了datasets这个库,帮助快速获取和处理数据。这一套全家桶使得整个使用BERT类模型机器学习流程变得前所未有的简单。

不过,目前我在网上没有发现比较简单的关于整个一套全家桶的使用教程。所以写下此文,希望帮助更多人快速上手。

这里,我们以AGNews新闻分类任务为例,演示整套流程的实现。

代码语言:javascript
复制
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"  # 在此我指定使用2号GPU,可根据需要调整
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import Trainer, TrainingArguments
from transformers import pipeline
from datasets import load_dataset

使用datasets读取数据集

下面的代码读取原始数据集的train部分的前40000条作为我们的训练集,40000-50000条作为开发集(只使用这个子集已经可以训出不错的模型,并且可以让训练时间更短),原始的测试集作为我们的测试集。

代码语言:javascript
复制
train_dataset = load_dataset("ag_news", split="train[:40000]")
dev_dataset = load_dataset("ag_news", split="train[40000:50000]")
test_dataset = load_dataset("ag_news", split="test")
print(train_dataset)
print(dev_dataset)
print(test_dataset)
代码语言:javascript
复制
Dataset({
    features: ['text', 'label'],
    num_rows: 40000
})
Dataset({
    features: ['text', 'label'],
    num_rows: 10000
})
Dataset({
    features: ['text', 'label'],
    num_rows: 7600
})

原始数据集包含text和label两个字段

代码语言:javascript
复制
train_dataset.features
代码语言:javascript
复制
{'text': Value(dtype='string', id=None),
 'label': ClassLabel(num_classes=4, names=['World', 'Sports', 'Business', 'Sci/Tech'], names_file=None, id=None)}

由于bert模型期望得到的标签的字段为labels而原始数据集中的名字是label,所以做一下调整。

下面的代码把label字段复制到labels

代码语言:javascript
复制
train_dataset = train_dataset.map(lambda examples: {'labels': examples['label']}, batched=True)
train_dataset[0]
代码语言:javascript
复制
{'label': 2,
 'labels': 2,
 'text': "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again."}
代码语言:javascript
复制
dev_dataset = dev_dataset.map(lambda examples: {'labels': examples['label']}, batched=True)
test_dataset = test_dataset.map(lambda examples: {'labels': examples['label']}, batched=True)

加载模型,tokenizer,并预处理数据

为了快速实验,我们选择一个较小的bert-tiny模型进行实验。加载对应模型和tokenizer

代码语言:javascript
复制
model_id = 'prajjwal1/bert-tiny'
# note that we need to specify the number of classes for this task
# we can directly use the metadata (num_classes) stored in the dataset
model = AutoModelForSequenceClassification.from_pretrained(model_id, 
            num_labels=train_dataset.features["label"].num_classes)
tokenizer = AutoTokenizer.from_pretrained(model_id)

用bert的方法对数据集做分词预处理,把所有序列补充或截断到256个token

代码语言:javascript
复制
MAX_LENGTH = 256
train_dataset = train_dataset.map(lambda e: tokenizer(e['text'], truncation=True, padding='max_length', max_length=MAX_LENGTH), batched=True)
dev_dataset = dev_dataset.map(lambda e: tokenizer(e['text'], truncation=True, padding='max_length', max_length=MAX_LENGTH), batched=True)
test_dataset = test_dataset.map(lambda e: tokenizer(e['text'], truncation=True, padding='max_length', max_length=MAX_LENGTH), batched=True)

为了放进pytorch模型训练,还要再声明格式和使用的字段

代码语言:javascript
复制
train_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
dev_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
test_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])

现在我们的训练样本长这样,可以直接放进bert训练了

代码语言:javascript
复制
train_dataset.features
代码语言:javascript
复制
{'attention_mask': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'input_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'label': ClassLabel(num_classes=4, names=['World', 'Sports', 'Business', 'Sci/Tech'], names_file=None, id=None),
 'labels': Value(dtype='int64', id=None),
 'text': Value(dtype='string', id=None),
 'token_type_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None)}
代码语言:javascript
复制
train_dataset[0]
代码语言:javascript
复制
{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 'input_ids': tensor([  101,  2813,  2358,  1012,  6468, 15020,  2067,  2046,  1996,  2304,
          1006, 26665,  1007, 26665,  1011,  2460,  1011, 19041,  1010,  2813,
          2395,  1005,  1055,  1040, 11101,  2989,  1032,  2316,  1997, 11087,
          1011, 22330,  8713,  2015,  1010,  2024,  3773,  2665,  2153,  1012,
           102,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0]),
 'labels': tensor(2),
 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}

我们可以指定模型训练时,显示的验证指标

代码语言:javascript
复制
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

指定训练参数,使用trainer直接训练

代码语言:javascript
复制
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    learning_rate=3e-4,
    num_train_epochs=3,              # total number of training epochs
    per_device_train_batch_size=64,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    logging_dir='./logs',            # directory for storing logs
    logging_steps=100,
    do_train=True,
    do_eval=True,
    no_cuda=False,
    load_best_model_at_end=True,
    # eval_steps=100,
    evaluation_strategy="epoch"
)

trainer = Trainer(
    model=model,                         # the instantiated ? Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=dev_dataset,            # evaluation dataset
    compute_metrics=compute_metrics
)

train_out = trainer.train()

文章中不能显示那个数据表格,但是在训练过程中,或者results/checkpoint-XXX下的trainer_state.json可以看到,这个模型在第二次epoch达到了0.899的F1。

使用pipeline直接对文本进行预测

pipeline可以直接加载训练好的模型和tokenizer,然后直接对文本进行分类预测,无需再自行预处理

首先我们把模型放回cpu来进行预测

代码语言:javascript
复制
model = model.cpu()

sentiment-analysis来指定我们做的是文本分类任务(情感分析是一类代表性的文本分类任务),并指定我们之前训好的模型。

代码语言:javascript
复制
classifier = pipeline('sentiment-analysis', model=model, tokenizer=tokenizer)

我们从模型没有见过的test集里挑一个例子来进行预测

代码语言:javascript
复制
test_examples = load_dataset("ag_news", split="test[:10]")
test_examples[0]
代码语言:javascript
复制
{'label': 2,
 'text': "Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul."}

该文本的类别为2,看看模型能不能做出正确预测?

代码语言:javascript
复制
result = classifier(test_examples[0]['text'])
result
代码语言:javascript
复制
[{'label': 'LABEL_2', 'score': 0.9601152539253235}]

预测正确!

到此我们的huggingface全家桶就大功告成了~

本文的完全代码可以直接在这里找到:https://github.com/blmoistawinde/hello_world/blob/master/huggingface_classification.ipynb

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2021-01-16 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 使用huggingface全家桶(transformers, datasets)实现一条龙BERT训练(trainer)和预测(pipeline)
    • 使用datasets读取数据集
      • 加载模型,tokenizer,并预处理数据
        • 使用pipeline直接对文本进行预测
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档