前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >深度学习算法(第18期)----用RNN也能玩分类

深度学习算法(第18期)----用RNN也能玩分类

作者头像
智能算法
发布2019-05-14 10:18:15
5520
发布2019-05-14 10:18:15
举报
文章被收录于专栏:智能算法智能算法

上期我们一起学习了RNN是怎么处理变化长度的输入输出的, 深度学习算法(第17期)----RNN如何处理变化长度的输入和输出? 我们知道之前学过CNN在处理分类问题上的强大能力,今天我们看下前几期介绍的RNN是如何玩分类的。

MNIST数据集,我们都已经很熟悉了,是一个手写数字的数据集,之前我们用它来实战CNN分类器和机器学习的方法(在公众号中回复“MNIST”,即可免费下载)。今天我们就用RNN来对MNIST数据集进行一个预测。 这个时候,我们需要将每一张数据图像当成一个28x28的序列信号(图像的大小为28x28pixels)。对于整个网络框架,我们使用一个150个循环神经元外加一个有10个神经元的全连接层(每个类对应一个),最后接一个softmax层。如下:

整个模型的构建阶段,也很直接,跟我们前几期学的dnn构建方法非常类似,这里只是用了没有展开的RNN代替了之前的隐藏层,需要注意的是最后的全连接层连接的是RNN的状态tensor,该状态tensor仅仅包含了RNN的最后一个状态,并且y是目标类别。

代码语言:javascript
复制
from tensorflow.contrib.layers import fully_connected
n_steps = 28
n_inputs = 28
n_neurons = 150
n_outputs = 10
learning_rate = 0.001
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.int32, [None])
basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)
logits = fully_connected(states, n_outputs, activation_fn=None)
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=y, logits=logits)
loss = tf.reduce_mean(xentropy)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
training_op = optimizer.minimize(loss)
correct = tf.nn.in_top_k(logits, y, 1)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
init = tf.global_variables_initializer()

接下来,我们加载数据集,并对数据集进行reshape,如下:

代码语言:javascript
复制
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/")
X_test = mnist.test.images.reshape((-1, n_steps, n_inputs))
y_test = mnist.test.labels

现在,我们将对上面的RNN进行training,在执行阶段跟之前的dnn也是非常类似的,如下:

代码语言:javascript
复制
n_epochs = 100
batch_size = 150
with tf.Session() as sess:
    init.run()
    for epoch in range(n_epochs):
        for iteration in range(mnist.train.num_examples // batch_size):
            X_batch, y_batch = mnist.train.next_batch(batch_size)
            X_batch = X_batch.reshape((-1, n_steps, n_inputs))
            sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
        acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})
        acc_test = accuracy.eval(feed_dict={X: X_test, y: y_test})
        print(epoch, "Train accuracy:", acc_train, "Test accuracy:", acc_test)

输出的结果如下:

代码语言:javascript
复制
0 Train accuracy: 0.713333 Test accuracy: 0.7299
1 Train accuracy: 0.766667 Test accuracy: 0.7977
...
98 Train accuracy: 0.986667 Test accuracy: 0.9777
99 Train accuracy: 0.986667 Test accuracy: 0.9809

最终得到了98%的准确率,还挺不错的,如果我们调整下超参数或者RNN权重初始化的方式,训练的更久一些,或者加一些正则化的方法,结果应该还会更好。学习了RNN的分类玩法,下一期我们将实战下RNN在时序信号上的预测能力。

今天我们主要从我们熟悉的MNIST数据集出发,来更深层次的学习了下RNN在分类方面的知识,希望有些收获,欢迎留言或进社区共同交流,喜欢的话,就点个“在看”吧,您也可以置顶公众号,第一时间接收最新内容。


智能算法,与您携手,沉淀自己,引领AI!

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-05-05,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 智能算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档