首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

谷歌ALBERT模型V2+中文版来了,GitHub热榜第二

比BERT模型参数小18倍,性能还超越了它。

这就是谷歌前不久发布的轻量级BERT模型——ALBERT

不仅如此,还横扫各大“性能榜”,在SQuAD和RACE测试上创造了新的SOTA。

而最近,谷歌开源了中文版本和Version 2,项目还登上了GitHub热榜第二

ALBERT 2性能再次提升

在这个版本中,“no dropout”、“additional training data”、“long training time”策略将应用到所有的模型。

与初代ALBERT性能相比结果如下。

预训练模型

可以使用 TF-Hub 模块:

Base [Tar File]: https://storage.googleapis.com/albert_models/albert_base_v1.tar.gz [TF-Hub]: https://tfhub.dev/google/albert_base/1

Large [Tar File]: https://storage.googleapis.com/albert_models/albert_large_v1.tar.gz [TF-Hub]: https://tfhub.dev/google/albert_large/1

XLarge [Tar File]: https://storage.googleapis.com/albert_models/albert_xlarge_v1.tar.gz [TF-Hub]: https://tfhub.dev/google/albert_xlarge/1

Xxlarge [Tar File]: https://storage.googleapis.com/albert_models/albert_xxlarge_v1.tar.gz [TF-Hub]: https://tfhub.dev/google/albert_xxlarge/1

TF-Hub模块使用示例:

代码语言:javascript
复制
tags=set()ifis_training:tags.add("train")albert_module=hub.Module("https://tfhub.dev/google/albert_base/1",tags=tags,trainable=True)albert_inputs=dict(input_ids=input_ids,input_mask=input_mask,segment_ids=segment_ids)albert_outputs=albert_module(inputs=albert_inputs,signature="tokens",as_dict=True)#Ifyouwanttousethetoken-leveloutput,use#albert_outputs["sequence_output"]instead.output_layer=albert_outputs["pooled_output"] 

预训练说明

要预训练ALBERT,可以使用run_pretraining.py:

代码语言:javascript
复制
pipinstall-ralbert/requirements.txtpython-malbert.run_pretraining\--input_file=...\--output_dir=...\--init_checkpoint=...\--albert_config_file=...\--do_train\--do_eval\--train_batch_size=4096\--eval_batch_size=64\--max_seq_length=512\--max_predictions_per_seq=20\--optimizer='lamb'\--learning_rate=.00176\--num_train_steps=125000\--num_warmup_steps=3125\--save_checkpoints_steps=5000 

GLUE上的微调

要对 GLUE 进行微调和评估,可以参阅该项目中的run_glue.sh文件。

底层的用例可能希望直接使用run_classifier.py脚本。

run_classifier.py可对各个 GLUE 基准测试任务进行微调和评估。

比如 MNLI:

代码语言:javascript
复制
pipinstall-ralbert/requirements.txtpython-malbert.run_classifier\--vocab_file=...\--data_dir=...\--output_dir=...\--init_checkpoint=...\--albert_config_file=...\--spm_model_file=...\--do_train\--do_eval\--do_predict\--do_lower_case\--max_seq_length=128\--optimizer=adamw\--task_name=MNLI\--warmup_step=1000\--learning_rate=3e-5\--train_step=10000\--save_checkpoints_steps=100\--train_batch_size=128 

可以在run_glue.sh中找到每个GLUE任务的default flag。

从TF-Hub模块开始微调模型:

代码语言:javascript
复制
albert_hub_module_handle==https://tfhub.dev/google/albert_base/1 

在评估之后,脚本应该报告如下输出:

代码语言:javascript
复制
*****Evalresults*****global_step=...loss=...masked_lm_accuracy=...masked_lm_loss=...sentence_order_accuracy=...sentence_order_loss=... 

在SQuAD上微调

要对 SQuAD v1上的预训练模型进行微调和评估,请使用 run SQuAD v1.py 脚本:

代码语言:javascript
复制
pipinstall-ralbert/requirements.txtpython-malbert.run_squad_v1\--albert_config_file=...\--vocab_file=...\--output_dir=...\--train_file=...\--predict_file=...\--train_feature_file=...\--predict_feature_file=...\--predict_feature_left_file=...\--init_checkpoint=...\--spm_model_file=...\--do_lower_case\--max_seq_length=384\--doc_stride=128\--max_query_length=64\--do_train=true\--do_predict=true\--train_batch_size=48\--predict_batch_size=8\--learning_rate=5e-5\--num_train_epochs=2.0\--warmup_proportion=.1\--save_checkpoints_steps=5000\--n_best_size=20\--max_answer_length=30 

对于 SQuAD v2,使用 run SQuAD v2.py 脚本:

代码语言:javascript
复制
pipinstall-ralbert/requirements.txtpython-malbert.run_squad_v2\--albert_config_file=...\--vocab_file=...\--output_dir=...\--train_file=...\--predict_file=...\--train_feature_file=...\--predict_feature_file=...\--predict_feature_left_file=...\--init_checkpoint=...\--spm_model_file=...\--do_lower_case\--max_seq_length=384\--doc_stride=128\--max_query_length=64\--do_train\--do_predict\--train_batch_size=48\--predict_batch_size=8\--learning_rate=5e-5\--num_train_epochs=2.0\--warmup_proportion=.1\--save_checkpoints_steps=5000\--n_best_size=20\--max_answer_length=30 

传送门

GitHub项目地址: https://github.com/google-research/ALBERT

本文经AI新媒体量子位(公众号ID:QbitAI)授权转载,转载请联系出处。

  • 发表于:
  • 原文链接http://news.51cto.com/art/202001/608845.htm
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券