前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >深度学习算法(第21期)----RNN中的Dropout技术

深度学习算法(第21期)----RNN中的Dropout技术

作者头像
智能算法
发布2019-07-25 14:57:16
7200
发布2019-07-25 14:57:16
举报
文章被收录于专栏:智能算法

上期我们一起学习了如何训练RNN并预测时序信号, 深度学习算法(第20期)----创意RNN和深度RNN的简单实现 今天我们一起简单学习下RNN中的Dropout的实现。

前几期学过,我们知道在CNN中,为了防止过拟合,我们常用DropOut技术。同样,如果我们想创建一个很深的RNN网络,那么它很有可能会产生过拟合,在RNN中该怎样应用DropOut技术来防止过拟合呢?

我们可以简单的在RNN之前或之后加一个DropOut层,但是如果我们想在RNN层中间加上DropOut的话,就得用DropoutWrapper了。下面代码在每个RNN层的输入都应用Dropout,对每个输入有50%的概率丢弃。

代码语言:javascript
复制
keep_prob = 0.5

cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)
cell_drop = tf.contrib.rnn.DropoutWrapper(cell, input_keep_prob=keep_prob)
multi_layer_cell = tf.contrib.rnn.MultiRNNCell([cell_drop] * n_layers)
rnn_outputs, states = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float32)

当然,我们也可以通过设置output_keep_prob来对输出进行dropout。 其实,细心的童鞋可能已经发现,上面的代码是有问题的,因为我们在前面CNN中应用Dropout的时候是有一个is_training的placeholder来区分是在training还是testing应用的。但是上面代码并没有。确实,上面代码的最大问题就是在testing的时候,也会应用Dropout,当然,这并不是我们想要的。不幸的是,DropoutWrapper并不支持is_training的placeholder,因此,我们要么自己重写一个DropoutWapper类,要么我们有两个计算图,一个是用来training,另一个用来testing。这里我们看下两个计算图是怎么实现的,如下:

代码语言:javascript
复制
import sys
is_training = (sys.argv[-1] == "train")
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.float32, [None, n_steps, n_outputs])
cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)
if is_training:
    cell = tf.contrib.rnn.DropoutWrapper(cell, input_keep_prob=keep_prob)
multi_layer_cell = tf.contrib.rnn.MultiRNNCell([cell] * n_layers)
rnn_outputs, states = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float32)
[...] # build the rest of the graph
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
    if is_training:
        init.run()
        for iteration in range(n_iterations):
            [...] # train the model
        save_path = saver.save(sess, "/tmp/my_model.ckpt")
    else:
        saver.restore(sess, "/tmp/my_model.ckpt")
        [...] # use the model

好了,至此,今天我们简单学习了RNN中DropOut技术,希望有些收获,下期我们将一起学习下RNN中的LSTM模块,欢迎留言或进社区共同交流,喜欢的话,就点个“在看”吧,您也可以置顶公众号,第一时间接收最新内容。


智能算法,与您携手,沉淀自己,引领AI!

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-07-23,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 智能算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档