开发 | Facebook开源 PyTorch版 fairseq,准确性最高、速度比循环神经网络快9倍

AI科技评论按:今年5月,FacebookAI研究院(FAIR)发表了他们的研究成果fairseq,在fairseq中,他们使用了一种新型的卷积神经网络来做语言翻译,比循环神经网络的速度快了9倍,而且准确性也是现有模型中最高的。此外,他们在GitHub公布了fair序列建模工具包的源代码和训练好的系统,其他的研究者可以在此基础上建立自己的关于翻译、文本总结和其他任务的模型。

详情可参见:快9倍!Facebook开源机器学习翻译项目fairseq 一文。

日前,FacebookAI研究团队又在GitHub上开源了fairseqPyTorch版本。

相关介绍

fairseq是FacebookAI研究院发布的一个序列到序列的学习工具,它的原作者(排名不分先后)是SergeyEdunov、MyleOtt和SamGross。该工具包能实现Convolutional Sequence to Sequence Learning(地址:https://arxiv.org/abs/1705.03122)中描述的全卷积模型,并能在一台机器上进行多GPU训练,也能在CPU和GPU上快速产生束搜索(beamsearch)。在开源的数据中,他们提供了英译法和英译德的预训练模型。

引用

如果你的论文中用了FAIR的相关代码,可以这样引用:

@inproceedings{ gehring2017convs2s, author = {Gehring, Jonas, and Auli, Michael and Grangier, David and Yarats, Denis and Dauphin, Yann N}, title = "{Convolutional Sequence to Sequence Learning}", booktitle = {Proc. of ICML}, year = 2017, }

工具和安装

  • macOS或是Linux系统的电脑
  • 要是想训练新的模型,需要用到NVIDIA GPU和NCCL(https://github.com/NVIDIA/nccl)
  • Python 3.6
  • 安装PyTorch(http://pytorch.org/)

目前的fairseq-py需要从GitHub库中获得PyTorch,有多种方式安装它。我们建议利用Miniconda3,执行如下的步骤。

1、安装Miniconda3,激活 Python 3环境

https://conda.io/miniconda.html

2、安装PyTorch

conda install gcc numpy cudnn nccl conda install magma-cuda80 -c soumith pip install cmake pip install cffi git clone https://github.com/pytorch/pytorch.git cd pytorch git reset --hard a03e5cb40938b6b3f3e6dbddf9cff8afdff72d1b git submodule update --init pip install -r requirements.txt NO_DISTRIBUTED=1 python setup.py install

3、在GitHub中复制和执行如下代码来安装fairseq-py

pip install -r requirements.txt python setup.py build python setup.py develop

快速开始

你将需要使用到如下的命令:

  • python preprocess.py: 数据预处理: 构造词汇和二进制训练数据
  • python train.py: 在一个或多个GPU上训练新的模型
  • python generate.py: 用训练好的模型翻译预处理之后的数据
  • python generate.py -i:用训练好的模型翻译新的文本
  • python score.py: 通过与参考译文对比,给出生成译文的BLEU分数

评估预训练模型:

首先,下载预训练好的模型和词汇:

$ curl https://s3.amazonaws.com/fairseq-py/models/wmt14.en-fr.fconv-py.tar.bz2 | tar xvjf -

模型中用的是BPE词汇(https://arxiv.org/abs/1508.07909),用户必须在翻译之前将编码应用到源文本。可以用apply_bpe.py 脚本中的wmt14.en-fr.fconv-cuda/bpecodes文件。@@是延续标记,原始文本可以通过sed s/@@ //g来恢复,此外把 --remove-bpe 标记传递到generate.py也有同样的作用。在生成BPE词汇之前。输入文本需要用mosesdecoder中的tokenizer.perl来标记。

下面是利用python generate.py -i产生翻译的例子, beam size为5:

$ MODEL_DIR=wmt14.en-fr.fconv-py $ python generate.py -i \ --path $MODEL_DIR/model.pt $MODEL_DIR \ --beam 5 | [en] dictionary: 44206 types | [fr] dictionary: 44463 types | model fconv_wmt_en_fr | loaded checkpoint /private/home/edunov/wmt14.en-fr.fconv-py/model.pt (epoch 37) > Why is it rare to discover new marine mam@@ mal species ? S Why is it rare to discover new marine mam@@ mal species ? O Why is it rare to discover new marine mam@@ mal species ? H -0.08662842959165573 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins ? A 0 1 3 3 5 6 6 10 8 8 8 11 12

训练新模型

数据预处理

fairseq-py工具包中包含用于IWSLT 2014德转英语料库的一个预处理脚本样例。先将数据进行预处理和二进制编码:

$ cd data/ $ bash prepare-iwslt14.sh $ cd .. $ TEXT=data/iwslt14.tokenized.de-en $ python preprocess.py --source-lang de --target-lang en \ --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \ --thresholdtgt 3 --thresholdsrc 3 --destdir data-bin/iwslt14.tokenized.de-en

这将会得到能够用于训练模型的二进制数据。

训练

用python train.py来训练新的模型,下面是能很好的适于 IWSLT 2014数据集中的一些样例设置。

$ mkdir -p checkpoints/fconv $ CUDA_VISIBLE_DEVICES=0 python train.py data-bin/iwslt14.tokenized.de-en \ --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \ --arch fconv_iwslt_de_en --save-dir checkpoints/fconv

默认情况下,python train.py会占用电脑中所有可用的GPU,可以用CUDA_VISIBLE_DEVICES环境来选择特定的GPU,或者改变使用的GPU数目。

有一点需要注意,batch大小是基于每个batch的最大token数来设置的,你需要基于系统中可用的GPU内存,选取一个稍小的值。

生成翻译

模型训练好之后就能利用python generate.py(用于二进制数据)或python generate.py -i(用于未处理文本)生成翻译了。

$ python generate.py data-bin/iwslt14.tokenized.de-en \ --path checkpoints/fconv/checkpoint_best.pt \ --batch-size 128 --beam 5 | [de] dictionary: 35475 types | [en] dictionary: 24739 types | data-bin/iwslt14.tokenized.de-en test 6750 examples | model fconv | loaded checkpoint trainings/fconv/checkpoint_best.pt S-721 danke . T-721 thank you . ...

如果只想用一个CPU,加入--cpu标记。可以通过--remove-bpe移除掉BPE标记。

训练好的模型

目前开源的全卷积序列到序列模型如下:

  • wmt14.en-fr.fconv-py.tar.bz2(https://s3.amazonaws.com/fairseq-py/models/wmt14.en-fr.fconv-py.tar.bz2): 用于WMT14英译法的模型,包含词汇
  • wmt14.en-de.fconv-py.tar.bz2(https://s3.amazonaws.com/fairseq-py/models/wmt14.en-de.fconv-py.tar.bz2): 用于WMT14英译德的模型,包含词汇

针对以上模型,已经预处理和编码过的测试集如下:

  • wmt14.en-fr.newstest2014.tar.bz2(https://s3.amazonaws.com/fairseq-py/data/wmt14.en-fr.newstest2014.tar.bz2): 用于WMT14英译法的newstest2014测试集
  • wmt14.en-fr.ntst1213.tar.bz2(https://s3.amazonaws.com/fairseq-py/data/wmt14.en-fr.ntst1213.tar.bz2): 用于WMT14英译法的newstest2012和newstest2013测试集
  • wmt14.en-de.newstest2014.tar.bz2(https://s3.amazonaws.com/fairseq-py/data/wmt14.en-de.newstest2014.tar.bz2): 用于WMT14英译德的newstest2014测试集

下面是在一块GTX-1080ti上利用测试集产生结果的样例(英译德),运行在batch模式下:

$ curl https://s3.amazonaws.com/fairseq-py/models/wmt14.en-fr.fconv-py.tar.bz2 | tar xvjf - -C data-bin $ curl https://s3.amazonaws.com/fairseq-py/data/wmt14.en-fr.newstest2014.tar.bz2 | tar xvjf - -C data-bin $ python generate.py data-bin/wmt14.en-fr.newstest2014 \ --path data-bin/wmt14.en-fr.fconv-py/model.pt \ --beam 5 --batch-size 128 --remove-bpe | tee /tmp/gen.out ... | Translated 3003 sentences (95451 tokens) in 81.3s (1174.33 tokens/s) | Generate test with beam=5: BLEU4 = 40.23, 67.5/46.4/33.8/25.0 (BP=0.997, ratio=1.003, syslen=80963, reflen=81194) # Scoring with score.py: $ grep ^H /tmp/gen.out | cut -f3- > /tmp/gen.out.sys $ grep ^T /tmp/gen.out | cut -f2- > /tmp/gen.out.ref $ python score.py --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref BLEU4 = 40.23, 67.5/46.4/33.8/25.0 (BP=0.997, ratio=1.003, syslen=80963, reflen=81194)

via:GitHub(https://github.com/facebookresearch/fairseq-py)

原文发布于微信公众号 - AI科技评论(aitechtalk)

原文发表时间:2017-09-19

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏宏伦工作室

一次不成功的深度学习实践 - 微信跳一跳

1814
来自专栏IT技术精选文摘

机器学习在启动耗时测试中的应用及模型调优(一)

启动耗时自动化方案在关键帧识别时,常规的图像对比准确率很低。本文详细介绍了采用scikit-learn图片分类算法在启动耗时应用下的模型调优过程。在之后的续篇中...

1334
来自专栏数据科学与人工智能

Python玩机器学习简易教程

本文介绍利用Python和Python的机器学习库scikit-learn完成一个端到端的机器学习项目。 俗话说,“师傅领进门,修行在个人”。本文就是扮演领进门...

3877
来自专栏PaddlePaddle

【AI核心技术】课程七:计算机视觉深入认知

UAI与PaddlePaddle联合推出的【AI核心技术掌握】系列课程持续更新中!

913
来自专栏机器之心

资源 | 下一代PS工具:Adobe照片级图像风格转换的Torch实现

选自arxiv 作者:栾福军等 机器之心编译 参与:李泽南、微胖 康奈尔大学与 Adobe 的研究者们最近发布了一项通过卷积神经网络进行照片风格迁移的研究。随后...

34411
来自专栏疯狂的小程序

微信跳一跳之深度实践

最近微信的跳一跳小程序火了一把,所以前天也更新了微信玩了几局,最多手动到200左右就不行了。

22910
来自专栏大数据杂谈

【Excel系列】Excel数据分析:时间序列预测

移动平均 18.1 移动平均工具的功能 “移动平均”分析工具可以基于特定的过去某段时期中变量的平均值,对未来值进行预测。移动平均值提供了由所有历史数据的简单的平...

3419
来自专栏企鹅号快讯

一次不成功的深度学习实践-微信跳一跳

最近微信的跳一跳小程序火了一把,所以前天也更新了微信玩了几盘,最多手动到200左右就不行了。 ? 后来准备用代码写个辅助工具,上Github一查,已经有人做出来...

1985
来自专栏目标检测和深度学习

教程 | 先理解Mask R-CNN的工作原理,然后构建颜色填充器应用

选自matterport 作者:Waleed Abdulla 机器之心编译 参与:刘晓坤 上年 11 月,matterport 开源了 Mask R-CNN 实...

2195
来自专栏ATYUN订阅号

【教程】使用TensorFlow对象检测接口标注数据集

当为机器学习对象检测和识别模型构建数据集时,为数据集中的所有图像生成标注非常耗时。而这些标注是训练和测试模型所必需的,并且标注必须是准确的。因此,数据集中的所有...

4467

扫码关注云+社区