首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

使用Tensor2Tensor和10行代码训练尖端语言翻译神经网络

编译:yxy

出品:ATYUN订阅号

有许多库可以帮助人们构建深度学习应用程序,但如果想使用最新架构的最先进模型和最少的代码,有这样一个API脱颖而出:Google的Tensor2Tensor。我通过这个库来使用高级的新神经网络架构(特别是Transformer)进行翻译,几乎不需要任何代码。

即使我不会说法语,我可以用T2T听懂法国团队和客户的话。

每周都有来自大学教授、谷歌和其他大型科技公司的研究人员、甚至是对深度学习有浓厚兴趣的开发人员发布新的神经网络架构和新的人工智能研究论文。

不幸的是,对于没有博士学位或者在反向传播,线性代数或计算数学方面了解不深的人,在没有高级API(如Keras)的情况下实现这些新的深度学习技术既困难又费时。

但好在,Google Brain团队认识到AI社区普遍存在的这些问题,随后创建了一个开源库来帮助解决它。

Tensor2Tensor,简称T2T,是一个深度学习模型和数据集库,旨在使深度学习更容易实现并加速ML研究。T2T由Google Brain团队和用户社区的研究人员和工程师积极使用和维护。

引自Github上的tensor2tensor介绍

深度学习和Tensor2Tensor

尽管深度学习并不总是人们在数据科学领域所期望的灵丹妙药,但它对于自然语言处理(NLP)任务来说非常有用。例如,使用词嵌入已经彻底改变了语言理解技术的有效性。

我想使用当前最先进的技术为我的团队和客户制作一个离线的法语到英语翻译器,也就是Transformer架构。T2T为快速、简单的训练和模型制作提供了一个框架,不需要从头开始编写和训练这个神经网络。

架构论文:https://arxiv.org/abs/1706.03762

Tensor2Tensor API概述

T2T库旨在与shell脚本一起使用,但你可以轻松地将其打包以供Python使用。API是多模块化的,这意味着任何内置模型都可以与各种类型的数据(文本,图像,音频等)一起使用。而API的作者为特定任务(如翻译,文本摘要,语音识别等)提供了推荐的数据集和模型。

GitHub:https://github.com/tensorflow/tensor2tensor#suggested-datasets-and-models

有时,你可能想要使用Tensor2Tensor的预编码模型之一,并将其应用于你自己的数据集和超参数组合。或者,你也可能想使用他们的简单框架来试验你自己的模型架构。通过定义一些新的子类可以很容易地做到这一点(我稍后会详细说明)。

T2T库有详细的说明文档,但为了深入了解,我们将逐步介绍其API的核心部分,并使用T2T开始你的第一个项目。

定义Tensor2Tensor问题

想要使用Tensor2Tensor(T2T),你要做的第一件事就是确定你要用它做什么,即问题是什么。这定义了你解决的任务,你使用的数据集,以及词汇表(如果可用)。这与模型架构和训练超参数无关。

你需要首先选择T2T中可找到的许多问题之一。你可以使用命令行查看API中已经内置的所有问题(使用命令t2t-datagen),也可以使用Python:

summarize_cnn_dailymail32:使用具有32k词汇量的CNN Daily Mail数据集的文本摘要神经网络

img2img_celeba:超分辨率的图像到图像转换(8×8到32×32)

sentiment_imdb:使用IMBD数据集的情绪分析模型

生成训练数据

选择并命名要解决的问题后,你需要为其选择正确的数据。如果你使用预置的问题,Tensor2Tensor会自动下载和准备用于训练的数据。

你首先需要选择一个目录来存储T2T将为你下载的未处理数据。目录名为tmp_dir。很多相同的问题都下载相同的数据,因此可以在T2T中重复使用此目录来解决多个问题,尤其是如果这些问题位于同一个任务或问题系列中。

在生成最终训练数据之前,你还需要确定存储预处理数据的目录。Tensor2Tensor中名为data_dir。同样,你可以在适当时重用目录。

可以认为tmp_dir是internet上的zip文件存储的位置,而data_dir是在从tmp_dir中读取数据之后,针对特定的T2T问题进行适当的预处理的位置。例如,如果进行NLP,在预处理期间,T2T将使用数字对每个单词进行编码,分割训练和测试集,创建词汇表等。

如果你想使用自己的数据集并使用T2T的预编码神经网络训练模型,则需要创建一个新的问题子类。

初始化这些目录后,你可以使用命令行生成数据,如下:

也可以用Python:

模型选择和超参数

你可以通过t2t-trainer在命令行中调用,也可以使用Python调用来查看所有可用的模型

例如,Transformer 模型最适合翻译。

当然,你还可以在模型中自定义多个超参数集。例如,在Transformer python文件的底部,你可以看到所有可以进行训练的超参数(见下图)。但通常最好先从基础参数集开始,然后根据需要进行调整。

值得注意的是,用于Tensor2Tensor的hparams和模型参数一起定义了训练参数。这意味着在测试新模型时,你可以非常轻松地调整网络的大小、批尺寸,学习率,优化器类型等。

训练你最先进的神经网络

现在,你已准备好用几行代码训练你的神经网络。

使用命令行,你需要做的就是通过设置相应的变量来执行以下脚本:

output_dir参数是为此模型运行存储模型文件检查点的位置,这样你可以通过预加载该目录中的模型文件来获取之前的训练。

你可以通过在上面的shell脚本中添加额外的标志来更改任何超参数。

要在Python中设置训练,需要花费更多精力,但同样可行。

使用逆向工程Notebook构建翻译器

首先,你必须设置所需的T2T变量,目录,预处理数据的位置以及模型文件存储位置。

接下来,你需要初始化hparam对象并重置一些变量。如果你的VRAM(显存)有限,你需要减少批尺寸(例如,从4096减小到1024),以便在训练时可以适应内存。随后,你将需要调整学习率和学习率准备步骤,以针对修改的批尺寸优化模型的收敛。接下来,你可以使用隐藏层来确定这是否有助于提高特定情况下的模型性能。

要开始训练模型,你需要初始化Tensorflow的run_config和实验对象。最后,打电话tensorflow_exp_fn.train_and_evaluate()实施训练。

跟踪模型训练和表现

现在你的模型训练已经开始,你可以看到损失和准确性指标的变化:

初始化Tensorflow实验对象时设置了train_steps参数。这是训练停止前的训练次数。你可以使用save_checkpoints_steps(默认为1000)控制执行评估的频率。初始化run_config对象时,将其设置为可选的hypermeter。

虽然Tensor2Tensor与CPU完全兼容,但GPU和分布式训练也有很多选择。比如使用哪一个,在每个GPU中限制多少内存,等等。如果你有兴趣学习如何集成GPU以训练T2T模型,访问下方链接。

链接:https://github.com/tensorflow/tensor2tensor/blob/master/docs/distributed_training.md

Tensorboard

要激活Tensorboard,首先需要转到命令行并输入tensorboard — logdir /。要直接访问Tensorboard二进制文件,你可能首先要在bash shell中激活包含Tensorflow的python环境。

详细操作:https://stackoverflow.com/questions/14604699/how-to-activate-virtualenv

激活后,你将能够在http://:6006 /实时跟踪模型性能。Tensorboard可用于比较周期的训练和评估指标,请参阅Tensorflow模型图。

最常用于此任务准确性的指标是BLEU分数。对于法语到英语的翻译,这个模型的BLEU得分大约为28,这是最先进水平。

使用Tensor2Tensor模型进行评分

要使用新训练的模型进行评分,你可以使用t2t-decoder二进制文件:

但是,这只能读取文本文件,并将结果输出到文本文件,这并不总是你需要的。所以我已经移植了它,以便在Python中更轻松地访问模型。你可以查看下面的代码,看看它是如何实现的。

你可能已经看到上面的代码中有两个函数,名为encode()和decode()。它们用于获取常规文本数据并将其编码为适合模型的格式。类似地,在相应的输出格式中对模型输出进行解码。

这意味着他们可以在批尺寸(1024)上输入许多输入序列,并且可以更快地翻译长段落,而不必对模型进行1024次翻译调用以翻译1024个句子。

让我的Tensor2Tensor模型投入生产

我做法语翻译器的主要原因之一是因为我在一家法国公司工作。很多人在团队聊天中讲法语。不幸的是,我根本不知道他们在说什么。

我最终使用Dataiku创建REST API端点,以使用我制作的Tensorflow模型执行翻译。我使用名为Errbot的聊天机器人API将REST端点连接到公司的Hipchat上。

Dataiku:https://www.dataiku.com/learn/guide/tutorials/deploy-scoring.html

errbot:http://errbot.io/en/latest/

现在,无论同事说些什么,我都可以轻松看懂。

完整项目:https://github.com/alexwolf22/tensor2tensor_translator

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20181103B0KUH900?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券