解析Tensorflow官方PTB模型的demo

RNN 模型作为一个可以学习时间序列的模型被认为是深度学习中比较重要的一类模型。在Tensorflow的官方教程中,有两个与之相关的模型被实现出来。第一个模型是围绕着Zaremba的论文Recurrent Neural Network Regularization,以Tensorflow框架为载体进行的实验再现工作。第二个模型则是较为实用的英语法语翻译器。在这篇博客里,我会主要针对第一个模型的代码进行解析。在之后的随笔里我会进而解析英语法语翻译器的机能。

论文以及Tensorflow官方教程介绍:

Zaremba设计了一款带有regularization机制的RNN模型。该模型是基于RNN模型的一个变种,叫做LSTM。论文中,框架被运用在语言模型,语音识别,机器翻译以及图片概括等应用的建设上来验证架构的优越性。作为Tensorflow的官方demo,该模型仅仅被运用在了语言模型的建设上来试图重现论文中的数据。官方已经对他们的模型制作了一部教程,点击这里https://github.com/tensorflow/tensorflow/blob/master/tensorflow/g3doc/tutorials/recurrent/index.md

查看官方教程(英语版)。

代码解析:

代码可以在github找到,这里先放上代码地址。点击这里https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/rnn/ptb/ptb_word_lm.py查看代码。

代码框架很容易理解,一开始,PTB模型被设计入了一个类。该类的init函数为多层LSTM语言模型的架构,代码如下:

上面的代码注释已就框架进行了解释。但我有意的留下了一个最为关键的部分没有解释,即variable_scope以及reuse_variable函数。该类函数有什么特殊意义呢?我们这里先卖个关子,下面的内容会就这个问题深入探究。

模型建立好后该类还有其他如assign_lr(self,session,lr_value)以及property函数如input_data(self). 这些函数浅显易懂,就不在这里解释了。

之后,官方代码设计了小模型(原论文中没有regularized的模型)外,还原了论文里的中等模型以及大模型。这些模型是基于同样的框架,不过不同在迭代数,神经元数以及dropout概率等地方。另有由于小模型的keep_prob概率被设计为1,将不会运用dropout。

另外,由于系统的运行是在terminal里输入”python 文件名 --参数 参数值“格式,名为get_config()的函数的意义在于把用户输入,如small,换算成运用SmallConfig()类。

最后,我们来看一看main函数以及run_epoch函数。首先来看下run_epoch:

还记得之前卖的关子么?这个重要的variable_scope函数的目的其实是允许我们在保留模型权重的情况下运行多个模型。首先,从RNN的根源上说,因为输入输出有着时间关系,我们的模型在训练时每此迭代都要运用到之前迭代的结果,所以如果我们直接使用(cell_output, state) = cell(inputs[:, time_step, :], state)我们可能会得到一堆新的RNN模型,而不是我们所期待的前一时刻的RNN模型。再看main函数,当我们训练时,我们需要的是新的模型,所以我们在定义了一个scope名为model的模型时说明了我们不需要使用以存在的参数,因为我们本来的目的就是去训练的。而在我们做validation和test的时候呢?训练新的模型将会非常不妥,所以我们需要运用之前训练好的模型的参数来测试他们的效果,故定义reuse=True。这个概念有需要的朋友可以参考Tensorflow的官方文件对共享变量的描述。

好了,我们了解了这个模型代码的架构以及运行的机制,那么他在实际运行中效果如何呢?让我们来实际测试一番。由于时间问题,我只运行了小模型,也就是不用dropout的模型。运行方式为在ptb_word_lm.py的文件夹下输入python ptb_word_lm.py --data_path=/tmp/simple-examples/data/ --model small。这里需要注意的是你需要下载simple-examples.tar.gz包,下载地址点击这里http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz 运行结果如下:

这里简便的放入了最后结果,我们可见,在13个epoch时,我们的测试perplexity为117.605, 对应了论文里non-regularized LSTM的114.5,运行时间约5到6小时。

原文发布于微信公众号 - 大数据挖掘DT数据分析(datadw)

原文发表时间:2017-07-12

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏人工智能头条

模仿人类智慧——“多任务学习”动手实践

973
来自专栏计算机视觉战队

稀疏&集成的卷积神经网络学习(续)

昨天跟大家详细的说了分类,定位的一些相关知识,今天把剩下的最后一点知识给大家补充完整,也感谢大家一直的支持,谢谢! 昨天的推送告诉大家了分类方案,我们再温习一...

3176
来自专栏深度学习那些事儿

TensorFlow中滑动平均模型介绍

其中a的取值范围[0,1],具体就是:本次滤波结果=(1-a)*本次采样值+a*上次滤波结果,采用此算法的目的是:

5079
来自专栏图形学与OpenGL

实验四 二维几何变换

1062
来自专栏机器之心

心中无码:这是一个能自动脑补漫画空缺部分的AI项目

本文将简要介绍这项研究与 DeepCreamPy 实现项目,读者可下载项目代码或预构建的二进制文件,并尝试修复漫画图像或马赛克。这一个项目可以直接使用 CPU ...

1243
来自专栏PaddlePaddle

转载|在TensorFlow和PaddleFluid中使用多块GPU卡进行训练

前四篇文章我们介绍了 PaddleFluid 和 TensorFlow 的设计原理基本使用概念,分别通过在两个平台上实现完全相同的模型完成图像分类,语言模型和序...

1743
来自专栏贾志刚-OpenCV学堂

使用Tensorflow Object Detection API实现对象检测

Tensorflow Object Detection API自从发布以来,其提供预训练模型也是不断更新发布,功能越来越强大,对常见的物体几乎都可以做到实时准确...

1633
来自专栏王嘉的专栏

安全 AI 的智能对抗系统之架构实现篇

在AI的浪潮下,在现有的安全系统的基础上,SNG业务安全中心将机器学习应用到业务安全对抗中,自研建设并搭建了 – 安全AI的智能对抗系统。智能对抗系统现已应用在...

1.1K0
来自专栏AI科技大本营的专栏

TensorSpace:一套用于构建神经网络3D可视化应用的框架

今天要为大家推荐一套超酷炫的,用于构建神经网络 3D 可视化应用的框架——TensorSpace。

1171
来自专栏SnailTyan

非极大值抑制(Non-Maximum Suppression)

博客:noahsnail.com  |  CSDN  |  简书 |   云+社区

7770

扫码关注云+社区

领取腾讯云代金券