首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >小数据福音!BERT 在极小数据下带来显著提升的开源实现

小数据福音!BERT 在极小数据下带来显著提升的开源实现

作者头像
崔庆才
发布2019-09-04 16:05:25
7510
发布2019-09-04 16:05:25
举报

转载来源

本文授权转载自学术平台 PaperWeekly,公众号ID:paperweekly

阅读本文大概需要 8 分钟。

标注数据,可以说是 AI 模型训练里最艰巨的一项工作了。自然语言处理的数据标注更是需要投入大量人力。相对计算机视觉的图像标注,文本的标注通常没有准确的标准答案,对句子理解也是因人而异,让这项工作更是难上加难。

但是,谷歌最近发布的 BERT [1] 大大地解决了这个问题!根据我们的实验,BERT 在文本多分类的任务中,能在极小的数据下带来显著的分类准确率提升。并且,实验主要对比的是仅仅 5 个月前发布的 State-of-the-Art 语言模型迁移学习模型 – ULMFiT [2],结果有着明显的提升。我们先看结果:

图1. 实验结果对比,BERT在极少的数据集上表现非常出色

从上图我们可以看出,在不同的数据集中,BERT 都有非常出色的表现。我们用的实验数据分为 1000、 6700 和 12000 条,并且各自包含了测试数据,训练测试分割为 80%-20%。数据集从多个网页来源获得,并经过了一系列的分类映射。但 Noisy 数据集带有较为显著的噪音,抽样统计显示噪音比例在 20% 左右。

实验对比了几个模型,从最基础的卷积网络作为 Baseline,到卷积网络加上传统的词向量 Glove embedding, 然后是 ULMFiT 和 BERT。为了防止过拟合,CNN 与 CNN+Glove 模型训练时加入了 Early stopping。

值得注意的是,这里用的 BERT 模型均为基础版本,“BERT-Base, Uncased”,12 层,110M 参数,对比的是 ULMFiT 调整过的最优化参数。可见 BERT 在此任务中的强大。

然而,在 12000 条样本的数据集上,BERT 的结果相对 6700 条并没有显著的提升。数据分类不平衡可能是导致此结果的一大因素。

BERT 开源的多个版本的模型:

图2. 开源的多个版本的BERT模型

接下来,我们直奔主题 – 如何在自己的机器上实现 BERT 的文本 25 分类任务。教程分为以下几部分:

  • 运行环境
  • 硬件配置
  • 下载模型
  • 输入数据准备
  • 实现细节

运行环境

TensorFlow 版本为 Windows 1.10.0 GPU,具体安装教程可以参考此链接:

  • https://www.tensorflow.org/install/pip?lang=python3

Anaconda 版本为 1.9.2。

硬件配置

实验用的机器显卡为 NVIDIA GeoForce GTX 1080 Ti,BERT base 模型占用显存约为 9.5G。

下载模型

所有的运行环境设置好后,在这里可以下载到我们实验用的 BERT base:

  • https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip

下载完后,放在 BERT_BASE_DIR 中。

输入数据准备

我们需要将文本数据分为三部分:

  • Train: train.tsv
  • Evaluate: dev.tsv
  • Test: test.tsv

下面可以看到每个文件的格式,非常简单,一列为需要做分类的文本数据,另一列则是对应的 Label。

图3. 输入文本格式样板

并将这三个文件放入 DATA_DIR 中。

实现细节

首先我们 Clone 官方的 BERT Github repo:

  • https://github.com/google-research/bert

由于我们要做的是文本多分类任务,可以在 run_classifier.py 基础上面做调整。

这里简单介绍一下这个脚本本来的任务,也就是 BERT 示范的其中一个任务。这个例子是在 Microsoft Research Paraphrase Corpus (MRPC) corpus 数据集上面做微调,数据集仅包含 3600 个样本,在 GPU 上面几分钟就可完成微调。

此数据集可以用以下脚本下载:

  • https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e

注意运行的时候要用 --tasks all 参数来下载。

图4. 运行脚本下载MRPC数据集

可以打开看一下输入数据的结构,都是以 tsv 的形式保存:

图5. MRPC数据集输入数据样本

图6. MRPC数据集结构样本

这里 MRPC 的任务是 Paraphrase Identification,输入为两个句子,然后判断二者是否表示相同的意思,输出为二分类:是和不是。我们的分类任务只需要一个输入,而不是一对句子,这个在读取的过程中可以自动识别,并调整相应的 Sentence Embedding 如下图所示:

图7. BERT Sentence Embedding自动调整过程

run_classifier.py 的脚本中,由于输入格式和之前有少许不同,我们需要更改 _create_examples 函数里面的读取顺序,原本的读取位置为:

图8. MRPC数据集输入文本读取方式

我们需要让 text_a 读取被分类的文本,而 label 读取我们的标注:

图9. 在文本多分类的任务中,读取输入的方式

同时由于没有 text_b,我们需要在后面制作 example 的时候将他赋值为 None:

图10. 由于没有Sentence Pair输入,在这里需要将text_b定义为None

接下来,相对于原本的二分类,我们需要针对多分类做出一些调整。代码中原本将标签列表手动设置为 0 和 1:

图11. 原本直接将标注列表定义为0和1

这里我们加入一个新的输入,然后将输出调整如下:

图12. 调整get_labels的输入和输出

这里 labels 输入为新添加的所有训练样本的 Label,然后通过 set() 返回所有 25 个标签的列表。调整之后,代码可以自动根据分类数量定义标签列表,可以适配多种多分类任务。

同时,在 _create_examples 中,我们增加两个返回值,labels 和 labels_test:

图13. _create_examples函数增加两个返回值,labels和label_test

labels 返回的是所有训练样本的 label,用来输入到之前提到的 get_labels()。Labels 的定义如下图所示:

图14. 新添加的变量labels

接下来我们需要调整 main() function 里面的一些顺序,因为现在的 get_labels() 需要额外的输入(读取的完整 label list),我们需要把读取训练集的步骤放到前面。原来的顺序:

1. 获取 label_list;

图15. 第一步

2. 如果在训练模式,再读取训练集。

图16. 第二步

现在需要调整为:

1. 无论什么模式都读取训练集,因为需要用到训练标签,注意新添加的输出变量 train_labels;

图17. 第一步

2. 然后再获取 label_list,用前面的 train_labels。

图18. 第二步

最后,我们在开头设置好参数,可以直接输入默认值来运行。下面拿 DATA_DIR 来举例:

图19. 原始参数

调整后的输入参数:

图20. 调整后的参数

1000 条样本数据 10 分类,BERT 运行结果如下:

图21. 1000条样本数据10分类BERT结果

总结

本文介绍了如何实现 BERT 的文本多分类任务,并对比了 Baseline 以及不久前的 State-of-the-Art 模型 ULMFiT。实验结果可以看出 BERT 在此任务中,可以轻松打败先前的 SOTA。

这里附上本教程的开源代码:

  • https://github.com/Socialbird-AILab/BERT-Classification-Tutorial

我们依然会在 BERT 的基础上不断尝试,继续努力研究,也欢迎大家积极交流。

参考文献

[1] Devlin, Jacob and Chang, Ming-Wei and Lee, Kenton and Toutanova, Kristina. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. arXiv preprint arXiv:1810.04805.

[2] Jeremy Howard and Sebastian Ruder. 2018. Universal language model fine-tuning for text classification. In ACL. Association for Computational Linguistics.

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2018-12-08,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 进击的Coder 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 运行环境
  • 硬件配置
  • 下载模型
    • 输入数据准备
      • 实现细节
        • 总结
          • 参考文献
          相关产品与服务
          NLP 服务
          NLP 服务(Natural Language Process,NLP)深度整合了腾讯内部的 NLP 技术,提供多项智能文本处理和文本生成能力,包括词法分析、相似词召回、词相似度、句子相似度、文本润色、句子纠错、文本补全、句子生成等。满足各行业的文本智能需求。
          领券
          问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档