专栏首页AI科技大本营的专栏Simple Transformer:用BERT、RoBERTa、XLNet、XLM和DistilBERT进行多类文本分类

Simple Transformer:用BERT、RoBERTa、XLNet、XLM和DistilBERT进行多类文本分类

作者 | Thilina Rajapakse

译者 | Raku

编辑 | 夕颜

【导读】本文将介绍一个简单易操作的Transformers库——Simple Transformers库。它是AI创业公司Hugging Face在Transformers库的基础上构建的。Hugging Face Transformers是供研究与其他需要全面控制操作方式的人员使用的库,简单易操作。

简介

Simple Transformers专为需要简单快速完成某项工作而设计。不必拘泥于源代码,也不用费时费力地去弄清楚各种设置,文本分类应该非常普遍且简单——Simple Transformers就是这么想的,并且专为此实现。

一行代码建立模型,另一行代码训练模型,第三行代码用来预测,老实说,还能比这更简单吗?

所有源代码都可以在Github Repo上找到,如果你有任何问题或疑问,请在这上面自行寻求答案。

GitHub repo:

https://github.com/ThilinaRajapakse/simpletransformers

安装

1、从这里(https://www.anaconda.com/distribution/)安装Anaconda或Miniconda Package Manager。

2、创建一个新的虚拟环境并安装所需的包。

conda create -n transformers python pandas tqdm
conda activate transformers

如果是cuda:

conda install pytorch cudatoolkit=10.0 -c pytorch

其他:

conda install pytorch cpuonly -c pytorch
conda install -c anaconda scipy
conda install -c anaconda scikit-learn
pip install transformers
pip install tensorboardx

3、安装simpletransformers。

pip install simpletransformers

用法

让我们看看如何对AGNews数据集执行多类分类。

对于用Simple Transformers简单二分类,参考这里。

下载并提取数据

1、从Fast.ai(https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz)下载数据集。

2、提取train.csv和test.csv并将它们放在目录data/ 中。

为训练准备数据

import pandas as pd


  train_df = pd.read_csv('data/train.csv', header=None)
  train_df['text'] = train_df.iloc[:, 1] + " " + train_df.iloc[:, 2]
  train_df = train_df.drop(train_df.columns[[1, 2]], axis=1)
  train_df.columns = ['label', 'text']
  train_df = train_df[['text', 'label']]
  train_df['text'] = train_df['text'].apply(lambda x: x.replace('\\', ' '))


  eval_df = pd.read_csv('data/test.csv', header=None)
  eval_df['text'] = eval_df.iloc[:, 1] + " " + eval_df.iloc[:, 2]
  eval_df = eval_df.drop(eval_df.columns[[1, 2]], axis=1)
  eval_df.columns = ['label', 'text']
  eval_df = eval_df[['text', 'label']]
  eval_df['text'] = eval_df['text'].apply(lambda x: x.replace('\\', ' '))
  eval_df['label'] = eval_df['label'].apply(lambda x:x-1

Simple Transformers要求数据必须包含在至少两列的Pandas DataFrames中。你只需为列的文本和标签命名,SimpleTransformers就会处理数据。或者你也可以遵循以下约定:

• 第一列包含文本,类型为str。

• 第二列包含标签,类型为int。

对于多类分类,标签应该是从0开始的整数。如果数据具有其他标签,则可以使用python dict保留从原始标签到整数标签的映射。

模型

from simpletransformers.model import TransformerModel


  # Create a TransformerModel
  model = TransformerModel('roberta', 'roberta-base', num_labels=4)

这将创建一个TransformerModel,用于训练,评估和预测。第一个参数是model_type,第二个参数是model_name,第三个参数是数据中的标签数:

• model_type可以是['bert','xlnet','xlm','roberta','distilbert']之一。

• 有关可用于model_name的预训练模型的完整列表,请参阅“当前预训练模型”(https://github.com/ThilinaRajapakse/simpletransformers#current-pretrained-models)。

要加载以前保存的模型而不是默认模型的模型,可以将model_name更改为包含已保存模型的目录的路径。

model = TransformerModel('xlnet', 'path_to_model/', num_labels=4)

TransformerModel具有dict参数,其中包含许多属性,这些属性提供对超参数的控制。有关每个属性的详细说明,请参阅repo。默认值如下所示:

self.args = {
      'output_dir': 'outputs/',
      'cache_dir': 'cache_dir',

      'fp16': True,
      'fp16_opt_level': 'O1',
      'max_seq_length': 128,
      'train_batch_size': 8,
      'gradient_accumulation_steps': 1,
      'eval_batch_size': 8,
      'num_train_epochs': 1,
      'weight_decay': 0,
      'learning_rate': 4e-5,
      'adam_epsilon': 1e-8,
      'warmup_ratio': 0.06,
      'warmup_steps': 0,
      'max_grad_norm': 1.0,

      'logging_steps': 50,
      'save_steps': 2000,

      'overwrite_output_dir': False,
      'reprocess_input_data': False,
      'process_count': cpu_count() - 2 if cpu_count() > 2 else 1,
      }

在创建TransformerModel或调用其train_model方法时,只要简单地传递包含要更新的键值对的字典,就可以修改这些属性中的任何一个。下面给出一个例子:

# Create a TransformerModel with modified attributes
model = TransformerModel('roberta', 'roberta-base', num_labels=4, 
args={'learning_rate':1e-5, 'num_train_epochs': 2, 
'reprocess_input_data': True, 'overwrite_output_dir': True})

训练

# Train the model
model.train_model(train_df)

这就是训练模型所需要做的全部。你还可以通过将包含相关属性的字典传递给train_model方法来更改超参数。请注意,即使完成训练,这些修改也将保留。

train_model方法将在第n个步骤(其中n为self.args ['save_steps'])的第n个步骤创建模型的检查点(保存)。训练完成后,最终模型将保存到self.args ['output_dir']。

评估

result, model_outputs, wrong_predictions = model.eval_model(eval_df)

要评估模型,只需调用eval_model。此方法具有三个返回值:

• result:dict形式的评估结果。默认情况下,仅对多类分类计算马修斯相关系数(MCC)。

• model_outputs:评估数据集中每个项目的模型输出list。用softmax函数来计算预测值,输出 每个类别的概率而不是单个预测。

• wrong_predictions:每个错误预测的InputFeature list。可以从InputFeature.text_a属性获取文本。(可以在存储库 https://github.com/ThilinaRajapakse/simpletransformers 的utils.py文件中找到InputFeature类)

你还可以包括在评估中要使用的其他指标。只需将指标函数作为关键字参数传递给eval_model方法。指标功能应包含两个参数,第一个是真实标签,第二个是预测,这遵循sklearn标准。

对于任何需要附加参数的度量标准函数(在sklearn中为f1_score),你可以在添加了附加参数的情况下将其包装在自己的函数中,然后将函数传递给eval_model。

from sklearn.metrics import f1_score, accuracy_score

def f1_multiclass(labels, preds):
      return f1_score(labels, preds, average='micro')

result, model_outputs, wrong_predictions = model.eval_model(eval_df, f1=f1_multiclass, acc=accuracy_score

作为参考,我使用这些超参数获得的结果如下:

{'mcc': 0.937104098029913, 'f1': 0.9527631578947369, 'acc': 0.9527631578947369}

考虑到我实际上并没有进行任何超参数调整,效果还不错。感谢RoBERTa!

预测/测试

在实际应用中,我们常常不知道什么是真正的标签。要对任意示例执行预测,可以使用predict方法。此方法与eval_model方法非常相似,不同之处在于,该方法采用简单的文本列表并返回预测列表和模型输出列表。

predictions, raw_outputs = model.predict(['Some arbitary sentence'])

结论

在许多实际应用中,多分类是常见的NLP任务,Simple Transformers是将Transformers的功能应用于现实世界任务的一种简单方法,你无需获得博士学位才能使用它。

关于项目

我计划在不久的将来将“问答”添加到Simple Transformers 库中。敬请关注!

Simple Transformers 库:https://github.com/ThilinaRajapakse/simpletransformers

原文链接:https://medium.com/swlh/simple-transformers-multi-class-text-classification-with-bert-roberta-xlnet-xlm-and-8b585000ce3a

本文分享自微信公众号 - AI科技大本营(rgznai100)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-10-28

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 基于AFN封装的带缓存的网络请求

    git: https://github.com/zhouxihi/NVNetworking

    周希
  • iOS PureLayout使用

    PureLayout是iOS Auto Layout的终端API,强大而简单。由UIView、NSArray和NSLayoutConstraint类别组成。 ...

    周希
  • 一个月真的可以学会一门语言吗?

    知乎上总看到有人提这个问题, 我想转行,我要学多久才能学会,分享下我到经历,以便你评估一下自身来不来得及,别人是没办法帮你评估的。

    王炸
  • GitHub 博客项目学习之接入GitHu登录

    限于篇幅原因源码以上传github: https://github.com/codesbull/community

    cherishspring
  • 【React】377- 实现 React 中的状态自动保存

    移动端中,用户访问了一个列表页,上拉浏览列表页的过程中,随着滚动高度逐渐增加,数据也将采用触底分页加载的形式逐步增加,列表页浏览到某个位置,用户看到了感兴趣的项...

    pingan8787
  • 创建Github远程仓库

    之后在在Repository name 填入 ZXTabBarController(你的远程仓库名) ,其他保持默认设置,

    周希
  • 预训练语言模型关系图+必读论文列表,清华荣誉出品

    Github 项目:https://github.com/thunlp/PLMpapers

    机器之心
  • 关于Git和Github你不知道的十件事

    Git 和 GitHub 都是非常强大的工具。即使你已经使用他们很长时间,你也很有可能不知道每个细节。

    Rookie
  • 高效协同开发

    假设服务机器开通sambas服务端口,并且windows防火墙允许访问。这时候可以在windows打开网盘一样,打开sambas共享的服务器文件夹,把代码工程放...

    mariolu
  • Go 模块存在的意义与解决的问题

    作者:William Kennedy | 原文:Modules Part 01: Why And What

    波罗学

扫码关注云+社区

领取腾讯云代金券