前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >NLP | 百度 ERNIE - 从0开始快速上手

NLP | 百度 ERNIE - 从0开始快速上手

作者头像
用户3946442
发布2022-04-11 19:03:37
8490
发布2022-04-11 19:03:37
举报
文章被收录于专栏:程序媛驿站

前文我们已经简单讲解过ERNIE1.0与2.0的构成与区别

NLP | 百度 ERNIE - 简析1.0 与 2.0

一篇简单易懂的好文

本文小媛带来的是【从0开始快速上手百度 ERNIE】

一、前置条件

在使用ERNIE模型之前,用户需要完成如下任务:

  1. 安装Python3.7.5版本。
  2. 安装paddlepaddle 1.8版本,具体安装方法请参见快速安装(文末)。
  3. 执行如下命令从GitHub上获取ERNIE代码库。如果网络速度较慢用户可以跳过此步,因为教程自带有ERNIE的包。
代码语言:javascript
复制
!git clone -b dygraph --single-branch https://github.com/PaddlePaddle/ERNIE.git

使用pip方式安装其他依赖文件。

代码语言:javascript
复制
!pip install -r ERNIE/requirements.txt

依赖文件主要包括:

  1. numpy:Python的一种开源数值计算拓展,可以用来进行大型张量的存储和计算。
  2. scikit-learn:机器学习工具包。
  3. scipy:科学计算库。
  4. six:解决py2和py3代码兼容性的工具包。

二、快速运行

这里以使用情感分析数据集ChnSentiCorp的ERNIE中文预模型为例,展示如何通过简单的三个步骤就可以快速使用ERNIE 1.0中文Base模型实现情感分析场景的推理。

ChnSentiCorp是一个中文情感分析数据集,包含酒店、笔记本电脑和书籍的网购评论。表1对ERNIE1.0/2.0和BERT中文模型在该任务上的效果进行了评测,评测使用的指标为准确率(acc),即在使用验证集或测试集进行推理时,推理正确的数据条目占数据集数据总数的百分比。从表1中可以看到ERNIE具有明显的优势。

使用ERNIE 1.0中文Base模型进行推理分如下三个步骤:

  1. 数据获取。介绍如何下载ChnSentiCorp数据集,以及数据集的结构,这样用户可以参考数据集的结构构造用于Fine-tuning的数据集。此外如果用户希望使用自定义数据集进行训练,则可以仿照ChnSentiCorp数据集的结构,构建自己的数据集。
  2. 运行Fine-tuning。介绍如何设置数据和模型路径的环境变量,以及如何执行脚本进行Fine-tuning。
  3. 执行推理。使用脚本运行Fine-tuning成功的模型进行推理。

表1 ERNIE1.0/2.0和BERT的测评表

数据集

ChnSentiCorp

评估指标

准确率(acc)

验证集(dev)

测试集(test)

BERT Base

94.6

94.3

ERNIE 1.0 Base

95.2 (+0.6)

95.4 (+1.1)

ERNIE 2.0 Base

95.7 (+1.1)

95.5 (+1.2)

ERNIE 2.0 Large

96.1 (+1.5)

95.8 (+1.5)

1. 数据获取

ERNIE在多个中文和英文NLP任务上做过评测,中文任务的数据可以通过下面的命令获取, ChnSentiCorp数据亦包含其中。

代码语言:javascript
复制
!wget https://ernie-github.cdn.bcebos.com/data-chnsenticorp.tar.gz
!tar xvf data-chnsenticorp.tar.gz

在解压出的文件夹“task_data/chnsenticorp”中, 包含了三个文件“train.tsv”、“dev.tsv”、“test.tsv”,分别对应ChnSentiCorp 数据的训练集、验证集和测试集,该任务是一个单句分类任务,数据包含两个字段为“label”和“seg_a”,以“TAB”进行分隔,示例如下:

代码语言:javascript
复制
seg_a label
选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。包的早餐是西式的,还算丰富。服务吗,一般       1
15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错    1
房间太小。其他的都一般。。。。。。。。。0
1.接电源没有几分钟,电源适配器热的不行. 2.摄像头用不起来. 3.机盖的钢琴漆,手不能摸,一摸一个印. 4.硬盘分区不好办.        0
今天才知道这书还有第6卷,真有点郁闷:为什么同一套书有两种版本呢?当当网是不是该跟出版社商量商量,单独出个第6卷,让我们的孩子不会有所遗憾。1
机器背面似乎被撕了张什么标签,残胶还在。但是又看不出是什么标签不见了,该有的都在,怪    0
呵呵,虽然表皮看上去不错很精致,但是我还是能看得出来是盗的。但是里面的内容真的不错,我妈爱看,我自己也学着找一些穴位。0
这本书实在是太烂了,以前听浙大的老师说这本书怎么怎么不对,哪些地方都是误导的还不相信,终于买了一本看一下,发现真是~~~无语,这种书都写得出来  0
地理位置佳,在市中心。酒店服务好、早餐品种丰富。我住的商务数码房电脑宽带速度满意,房间还算干净,离湖南路小吃街近。1

2. 运行Fine-tuning

运行该脚本即可执行Fine-tuning, 脚本会根据你指定的from_pretrained参数下载预训练模型,运行最大步长max_steps样本数 * epoch数 / 批大小算出。

代码语言:javascript
复制
!export CUDA_VISIBLE_DEVICES=0 
!PYTHONPATH=./ERNIE python ./ERNIE/ernie/finetune_sementic_analysis_dygraph.py \
        --from_pretrained ernie-1.0 \
        --data_dir ./chnsenticorp/ \
        --epoch 10 \
        --lr 5e-5 \
        --bsz 32 \
        --max_steps $((9600*10/32)) \
        --save_dir ./tuned_model

执行结束后输出如下的在验证集和测试集上面的测试结果:

代码语言:javascript
复制
training: 250it [01:39,  2.96it/s]2020-05-15 17:52:21,377-DEBUG: train loss 0.00880 lr 3.585e-05
training: 260it [01:43,  3.00it/s]2020-05-15 17:52:24,743-DEBUG: train loss 0.05025 lr 3.568e-05
training: 270it [01:46,  3.00it/s]2020-05-15 17:52:28,108-DEBUG: train loss 0.06813 lr 3.552e-05
training: 280it [01:49,  3.00it/s]2020-05-15 17:52:31,474-DEBUG: train loss 0.12881 lr 3.535e-05
training: 290it [01:53,  3.00it/s]2020-05-15 17:52:34,840-DEBUG: train loss 0.06156 lr 3.518e-05
2020-05-15 17:52:42,877-DEBUG: acc 0.93250
training: 10it [00:08,  1.88it/s]2020-05-15 17:52:46,317-DEBUG: train loss 0.00679 lr 3.485e-05
training: 20it [00:11,  2.84it/s]2020-05-15 17:52:49,817-DEBUG: train loss 0.13993 lr 3.468e-05
training: 30it [00:15,  2.89it/s]2020-05-15 17:52:53,297-DEBUG: train loss 0.02414 lr 3.452e-05

可以看到准确率(acc)达到了0.95左右,与表1中的测评准确率非常接近,说明训练效果达到了良好水平。

3. 执行推理

Fine-tuning 结束后,如果用户希望使用模型运行推理,可以修改上述命令行,并加入参数--eval进入推理模式,从而利用保存在某个checkpoint (由--save_dir指定)的模型执行推理。

代码语言:javascript
复制
!head ./chnsenticorp/dev/part.0|awk -F"\t" '{print $1}'| PYTHONPATH=./ERNIE  python ./ERNIE/ernie/finetune_sementic_analysis_dygraph.py \
        --from_pretrained ernie-1.0 \
        --data_dir ./chnsenticorp/ \
        --epoch 10 \
        --lr 5e-5 \
        --bsz 32 \
        --eval \
        --max_steps $((9600*10/32)) \
        --save_dir ./tuned_model

输入的预测数据由标准输入管道灌入程序。修改完成后请再次运行脚本执行推理。该命令指向的“chnsenticorp/dev/part.0”文件里的前10句话,程序将对这10句话进行推理:

  • 這間酒店環境和服務態度亦算不錯,但房間空間太小,不宣容納太大件行李,且房間格調還可以,中餐廳的廣東點心不太好吃,要改善之。但算價錢平宜,可接受。西餐廳格調都很好,但吃的味道一般且令人等得太耐了,要改善之。
  • <荐书> 推荐所有喜欢<红楼>的红迷们一定要收藏这本书,要知道当年我听说这本书的时候花很长时间去图书馆找和借都没能如愿,所以这次一看到当当有,马上买了,红迷们也要记得备货哦!
  • 商品的不足暂时还没发现,京东的订单处理速度实在.......周二就打包完成,周五才发货...
  • 2001年来福州就住在这里,这次感觉房间就了点,温泉水还是有的.总的来说很满意.早餐简单了些。
  • 不错的上网本,外形很漂亮,操作系统应该是个很大的 卖点,电池还可以。整体上讲,作为一个上网本的定位,还是不错的。
  • 房间地毯太脏,临近火车站十分吵闹,还好是双层玻璃。服务一般,酒店门口的TAXI讲是酒店的长期合作关系,每月要交费给酒店。从酒店到机场讲得是打表147元,到了后非要200元,可能被小宰30-40元。
  • 本来想没事的时候翻翻,可惜看不下去,还是和张没法比,他的书能畅销大部分还是受张的影响,对这个男人实在是没好感,不知道怎么买的,后悔。
  • 这台机外观十分好,本人喜欢,性能不错,是LED显示屏,无线网卡是: 5100AGN 无线网卡,如果装的是一条2G 800MHZ的内存就无敌了,本本发热很小,总体来说是十分值得买的,前提是这台机是4299买的。
  • 全键盘带数字键的 显卡足够强大.N卡相对A卡,个人偏向N卡 GHOST XP很容易.除了指纹识别外.所有驱动都能装齐全了,指纹识别,非要在XP下使用的朋友,可以用替代驱动.贡献下驱动地址: http://dlsvr01.asus.com/pub/ASUS/nb/F9Dc/Fingerprints_XP_080530.zip (华硕官方地址,放心下吧)。
  • 做工很漂亮,老婆很喜欢。T4200足够了,性价比不错的机器。测试了一下很安逸。今天晚上准备TWOW溜达圈,再看看整机表现如何!

其它分类任务的运行方式类似。同时 ERNIE 还支持阅读理解、语义匹配、序列标注等任务,运行方式可以参考 README 中 Fine-tuning 章节。

三、具体实现过程

开始写代码!

ChnSentiCorp任务运行的shell脚本ERNIE/ernie/run_classifier.py,该文件定义了分类任务Fine-tuning 的详细过程,下面我们将通过如下几个步骤进行详细剖析:

  1. 环境准备。导入相关的依赖,解析命令行参数;
  2. 实例化ERNIE 模型,优化器以及Tokenizer, 并设置超参数
  3. 定义辅助函数
  4. 运行训练循环

1. 环境准备

import相关的依赖,解析命令行参数。

代码语言:javascript
复制
import sys
sys.path.append('./ERNIE')
import numpy as np
from sklearn.metrics import f1_score
import paddle as P
import paddle.fluid as F
import paddle.fluid.layers as L
import paddle.fluid.dygraph as D

from ernie.tokenizing_ernie import ErnieTokenizer
from ernie.modeling_ernie import ErnieModelForSequenceClassification

2. 实例化ERNIE 模型,优化器以及Tokenizer, 并设置超参数

设置好所有的超参数,对于ERNIE任务学习率推荐取 1e-5/2e-5/5e-5, 根据显存大小调节BATCH大小, 最大句子长度不超过512.

代码语言:javascript
复制
BATCH=32
MAX_SEQLEN=300
LR=5e-5
EPOCH=10

D.guard().__enter__() # 为了让Paddle进入动态图模式,需要添加这一行在最前面

ernie = ErnieModelForSequenceClassification.from_pretrained('ernie-1.0', num_labels=3)
optimizer = F.optimizer.Adam(LR, parameter_list=ernie.parameters())
tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')

3. 定义辅助函数

(1)定义函数 make_data,将文本数据读入内存并转换为numpy List存储。

代码语言:javascript
复制
def make_data(path):
    data = []
    for i, l in enumerate(open(path)):
        if i == 0:
            continue
        l = l.strip().split('\t')
        text, label = l[0], int(l[1])
        text_id, _ = tokenizer.encode(text) # ErnieTokenizer 会自动添加ERNIE所需要的特殊token,如[CLS], [SEP]
        text_id = text_id[:MAX_SEQLEN]
        text_id = np.pad(text_id, [0, MAX_SEQLEN-len(text_id)], mode='constant') # 对所有句子都补长至300,这样会比较费显存;
        label_id = np.array(label+1)
        data.append((text_id, label_id))
    return data

train_data = make_data('./chnsenticorp/train/part.0')
test_data = make_data('./chnsenticorp/dev/part.0')

(2)定义函数get_batch_data,用于获取BATCH条样本并按照批处理维度stack到一起。

代码语言:javascript
复制
def get_batch_data(data, i):
    d = data[i*BATCH: (i + 1) * BATCH]
    feature, label = zip(*d)
    feature = np.stack(feature)  # 将BATCH行样本整合在一个numpy.array中
    label = np.stack(list(label))
    feature = D.to_variable(feature) # 使用to_variable将numpy.array转换为paddle tensor
    label = D.to_variable(label)
    return feature, label

4. 运行训练循环

队训练数据重复EPOCH遍训练循环;每次循环开头都会重新shuffle数据。在训练过程中每间隔100步在验证数据集上进行测试并汇报结果(acc)。

代码语言:javascript
复制
for i in range(EPOCH):
    np.random.shuffle(train_data) # 每个epoch都shuffle数据以获得最佳训练效果;
    #train
    for j in range(len(train_data) // BATCH):
        feature, label = get_batch_data(train_data, j)
        loss, _ = ernie(feature, labels=label) # ernie模型的返回值包含(loss, logits);其中logits目前暂时不需要使用
        loss.backward()
        optimizer.minimize(loss)
        ernie.clear_gradients()
        if j % 10 == 0:
            print('train %d: loss %.5f' % (j, loss.numpy()))
        # evaluate
        if j % 100 == 0:
            all_pred, all_label = [], []
            with D.base._switch_tracer_mode_guard_(is_train=False): # 在这个with域内ernie不会进行梯度计算;
                ernie.eval() # 控制模型进入eval模式,这将会关闭所有的dropout;
                for j in range(len(test_data) // BATCH):
                    feature, label = get_batch_data(test_data, j)
                    loss, logits = ernie(feature, labels=label) 
                    all_pred.extend(L.argmax(logits, -1).numpy())
                    all_label.extend(label.numpy())
                ernie.train()
            f1 = f1_score(all_label, all_pred, average='macro')
            acc = (np.array(all_label) == np.array(all_pred)).astype(np.float32).mean()
            print('acc %.5f' % acc)

训练过程中单次迭代输出的日志如下所示:

代码语言:javascript
复制
train 0: loss 0.05833
acc 0.91723
train 10: loss 0.03602
train 20: loss 0.00047
train 30: loss 0.02403
train 40: loss 0.01642
train 50: loss 0.12958
train 60: loss 0.04629
train 70: loss 0.00942
train 80: loss 0.00068
train 90: loss 0.05485
train 100: loss 0.01527
acc 0.92821
train 110: loss 0.00927
train 120: loss 0.07236
train 130: loss 0.01391
train 140: loss 0.01612

包含了当前 batch 的训练得到的Loss(ave loss)和每个Epochde 精度(acc)信息。训练完成后用户可以参考快速运行中的方法使用模型体验推理功能。

其它特性

ERNIE 还提供了混合精度训练、模型蒸馏等高级功能,可以在 README 中获得这些功能的使用方法。

附录:

README:https://github.com/PaddlePaddle/ERNIE/blob/develop/README.zh.md

paddle快速安装:https://www.paddlepaddle.org.cn/install/quick

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-06-28,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 程序媛驿站 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、前置条件
  • 二、快速运行
    • 1. 数据获取
      • 2. 运行Fine-tuning
        • 3. 执行推理
        • 三、具体实现过程
          • 1. 环境准备
            • 2. 实例化ERNIE 模型,优化器以及Tokenizer, 并设置超参数
              • 3. 定义辅助函数
                • 4. 运行训练循环
                • 其它特性
                相关产品与服务
                TI-ONE 训练平台
                TI-ONE 训练平台(以下简称TI-ONE)是为 AI 工程师打造的一站式机器学习平台,为用户提供从数据接入、模型训练、模型管理到模型服务的全流程开发支持。TI-ONE 支持多种训练方式和算法框架,满足不同 AI 应用场景的需求。
                领券
                问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档