基于tensorflow+RNN的MNIST数据集手写数字分类

2018年9月25日笔记

tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流。 RNN是recurrent neural network的简称,中文叫做循环神经网络。 MNIST是Mixed National Institue of Standards and Technology database的简称,中文叫做美国国家标准与技术研究所数据库。 此文在上一篇文章《基于tensorflow+DNN的MNIST数据集手写数字分类预测》的基础上修改模型为循环神经网络模型,模型准确率从98%提升到98.5%,错误率减少了25% 《基于tensorflow+DNN的MNIST数据集手写数字分类预测》文章链接:https://www.jianshu.com/p/9a4ae5655ca6

0.编程环境

操作系统:Win10 tensorflow版本:1.6 tensorboard版本:1.6 python版本:3.6

1.致谢声明

本文是作者学习《周莫烦tensorflow视频教程》的成果,感激前辈; 视频链接:https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/

2.配置环境

使用循环神经网络模型要求有较高的机器配置,如果使用CPU版tensorflow会花费大量时间。 读者在有nvidia显卡的情况下,安装GPU版tensorflow会提高计算速度50倍。 安装教程链接:https://blog.csdn.net/qq_36556893/article/details/79433298 如果没有nvidia显卡,但有visa信用卡,请阅读我的另一篇文章《在谷歌云服务器上搭建深度学习平台》,链接:https://www.jianshu.com/p/893d622d1b5a

3.下载并解压数据集

MNIST数据集下载链接: https://pan.baidu.com/s/1fPbgMqsEvk2WyM9hy5Em6w 密码: wa9p 下载压缩文件MNIST_data.rar完成后,选择解压到当前文件夹不要选择解压到MNIST_data。 文件夹结构如下图所示:

image.png

4.完整代码

此章给读者能够直接运行的完整代码,使读者有编程结果的感性认识。 如果下面一段代码运行成功,则说明安装tensorflow环境成功。 想要了解代码的具体实现细节,请阅读后面的章节。 完整代码中定义函数RNN使代码简洁,但在后面章节中为了易于读者理解,本文作者在第6章搭建神经网络将此部分函数改写为只针对于该题的顺序执行代码。

import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

tf.reset_default_graph()
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
learing_rate = 0.001
batch_size =100
n_steps = 28
n_inputs = 28
n_hidden_units = 128
n_classes = 10
X_holder = tf.placeholder(tf.float32)
Y_holder = tf.placeholder(tf.float32)

def RNN(X_holder):
    reshape_X = tf.reshape(X_holder, [-1, n_steps, n_inputs])
    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units)
    outputs, states = tf.nn.dynamic_rnn(lstm_cell, reshape_X, dtype=tf.float32)
    cell_list = tf.unstack(tf.transpose(outputs, [1, 0, 2]))
    last_cell = cell_list[-1]
    Weights = tf.Variable(tf.truncated_normal([n_hidden_units, n_classes]))
    biases = tf.Variable(tf.constant(0.1, shape=[n_classes]))
    predict_Y = tf.matmul(last_cell, Weights) + biases
    return predict_Y
predict_Y = RNN(X_holder)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predict_Y, labels=Y_holder))
optimizer = tf.train.AdamOptimizer(learing_rate)
train = optimizer.minimize(loss)

init = tf.global_variables_initializer()
session = tf.Session()
session.run(init)

isCorrect = tf.equal(tf.argmax(predict_Y, 1), tf.argmax(Y_holder, 1))
accuracy = tf.reduce_mean(tf.cast(isCorrect, tf.float32))
for i in range(1000):
    X, Y = mnist.train.next_batch(batch_size)
    session.run(train, feed_dict={X_holder:X, Y_holder:Y})
    step = i + 1
    if step % 100 == 0:
        test_X, test_Y = mnist.train.next_batch(3000)
        test_accuracy = session.run(accuracy, feed_dict={X_holder:test_X, Y_holder:test_Y})
        print(step, "{:.4f}".format(test_accuracy))

上面一段代码的运行结果如下:

Extracting MNIST_data\train-images-idx3-ubyte.gz Extracting MNIST_data\train-labels-idx1-ubyte.gz Extracting MNIST_data\t10k-images-idx3-ubyte.gz Extracting MNIST_data\t10k-labels-idx1-ubyte.gz 100 0.852 200 0.888 300 0.939 400 0.930 500 0.946 600 0.959 700 0.953 800 0.948 900 0.956 1000 0.958

5.数据准备

第1行代码导入库warnings; 第2行代码表示不打印警告信息; 第3行代码导入库tensorflow,取别名tf; 第4行代码从tensorflow.examples.tutorials.mnist库中导入input_data方法; 第6行代码表示重置tensorflow图 第7行代码加载数据库MNIST赋值给变量mnist; 第8-13行代码定义超参数学习率learning_rate、批量大小batch_size、步数n_steps、输入层大小n_inputs、隐藏层大小n_hidden_units、输出层大小n_classes。 第14、15行代码中placeholder中文叫做占位符,将每次训练的特征矩阵X和预测目标值Y赋值给变量X_holder和Y_holder。

import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

tf.reset_default_graph()
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
learing_rate = 0.001
batch_size =100
n_steps = 28
n_inputs = 28
n_hidden_units = 128
n_classes = 10
X_holder = tf.placeholder(tf.float32)
Y_holder = tf.placeholder(tf.float32)

6.搭建神经网络

本文作者将此章中使用tensorflow库的所有方法的API链接总结成下表,访问需要vpn。

方法

链接

tf.reshape

https://www.tensorflow.org/api_docs/python/tf/manip/reshape

tf.nn.rnn_cell.LSTMCell

https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/BasicLSTMCell

tf.nn.dynamic_rnn

https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn

tf.transpose

https://www.tensorflow.org/api_docs/python/tf/transpose

tf.unstack

https://www.tensorflow.org/api_docs/python/tf/unstack

tf.Variable

https://www.tensorflow.org/api_docs/python/tf/Variable

tf.truncated_normal

https://www.tensorflow.org/api_docs/python/tf/truncated_normal

tf.matmul

https://www.tensorflow.org/api_docs/python/tf/matmul

tf.reduce_mean

https://www.tensorflow.org/api_docs/python/tf/reduce_mean

tf.nn.softmax_cross_entropy_with_logits

https://www.tensorflow.org/api_docs/python/tf/nn/softmax_cross_entropy_with_logits

tf.train.AdamOptimizer

https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer

第1行代码reshape中文叫做重塑形状,将输入数据X_holder重塑形状为模型需要的; 第2行代码调用tf.nn.rnn_cell.LSTMCell方法实例化LSTM细胞对象; 第3行代码调用tf.nn.dynamic_rnn方法实例化rnn模型对象; 第4、5行代码取得rnn模型中最后一个细胞的数值; 第6、7行代码定义在训练过程会更新的权重Weights、偏置biases; 第8行代码表示xW+b的计算结果赋值给变量predict_Y,即预测值; 第9行代码表示交叉熵作为损失函数loss; 第10行代码表示AdamOptimizer作为优化器optimizer; 第11行代码定义训练过程,即使用优化器optimizer最小化损失函数loss。

reshape_X = tf.reshape(X_holder, [-1, n_steps, n_inputs])
lstm_cell = tf.nn.rnn_cell.LSTMCell(n_hidden_units)
outputs, state = tf.nn.dynamic_rnn(lstm_cell, reshape_X, dtype=tf.float32)
cell_list = tf.unstack(tf.transpose(outputs, [1, 0, 2]))
last_cell = cell_list[-1]
Weights = tf.Variable(tf.truncated_normal([n_hidden_units, n_classes]))
biases = tf.Variable(tf.constant(0.1, shape=[n_classes]))
predict_Y = tf.matmul(last_cell, Weights) + biases
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predict_Y, labels=Y_holder))
optimizer = tf.train.AdamOptimizer(learing_rate)
train = optimizer.minimize(loss)

7.参数初始化

对于神经网络模型,重要是其中的W、b这两个参数。 开始神经网络模型训练之前,这两个变量需要初始化。 第1行代码调用tf.global_variables_initializer实例化tensorflow中的Operation对象。

image.png

第2行代码调用tf.Session方法实例化会话对象; 第3行代码调用tf.Session对象的run方法做变量初始化。

init = tf.global_variables_initializer()
session = tf.Session()
session.run(init)

8.模型训练

第1行代码tf.argmax方法中的第2个参数为1,即求出矩阵中每1行中最大数的索引; 如果argmax方法中的第1个参数为0,即求出矩阵中每1列最大数的索引; tf.equal方法可以比较两个向量的在每个元素上是否相同,返回结果为向量,向量中元素的数据类型为布尔bool; 第2行代码

isCorrect = tf.equal(tf.argmax(predict_Y, 1), tf.argmax(Y_holder, 1))
accuracy = tf.reduce_mean(tf.cast(isCorrect, tf.float32))
for i in range(1000):
    X, Y = mnist.train.next_batch(batch_size)
    session.run(train, feed_dict={X_holder:X, Y_holder:Y})
    step = i + 1
    if step % 100 == 0:
        test_X, test_Y = mnist.test.next_batch(10000)
        test_accuracy = session.run(accuracy, feed_dict={X_holder:test_X, Y_holder:test_Y})
        print(step, "{:.4f}".format(test_accuracy)) 

上面一段代码的运行结果如下:

100 0.8272 200 0.9071 300 0.9334 400 0.9441 500 0.9459 600 0.9585 700 0.9548 800 0.9664 900 0.9654 1000 0.9671

文章篇幅所限,只打印查看1000次训练的结果,训练5000次即可达到98.5%的准确率。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏社区的朋友们

深度学习入门实战(二):用TensorFlow训练线性回归

上一篇文章我们介绍了 MxNet 的安装,但 MxNet 有个缺点,那就是文档不太全,用起来可能是要看源代码才能理解某个方法的含义,所以今天我们就介绍一下 Te...

6.2K1
来自专栏Deep Learning 笔记

CNN+MNIST+INPUT_DATA数字识别

TALK IS CHEAP,SHOW ME THE CODE,先从MNIST数据集下载脚本Input_data开始

4423
来自专栏MelonTeam专栏

深度学习入门实战(二)

导语:上一篇文章我们介绍了MxNet的安装,但MxNet有个缺点,那就是文档不太全,用起来可能是要看源代码才能理解某个方法的含义,所以今天我们就介绍一下Te...

25110
来自专栏深度学习入门与实践

【深度学习系列】PaddlePaddle之数据预处理

  上篇文章讲了卷积神经网络的基本知识,本来这篇文章准备继续深入讲CNN的相关知识和手写CNN,但是有很多同学跟我发邮件或私信问我关于PaddlePaddle如...

2658
来自专栏人工智能LeadAI

人脸识别 | 卷积深度置信网络工具箱的使用

本文主要以ORL_64x64人脸数据库识别为例,介绍如何使用基于matlab的CDBN工具箱。至于卷积深度置信网络(CDBN,Convolutional Dee...

4745
来自专栏IT派

【深度学习入门系列】TensorFlow训练线性回归

作者:董超 来源:腾讯云技术社区「腾云阁」 上一篇文章我们介绍了 MxNet 的安装,但 MxNet 有个缺点,那就是文档不太全,用起来可能是要看源代码才能理...

3343
来自专栏lhyt前端之路

js随机数生成器的扩展0.前言1.扩展+分区2.二进制法3. 总结

给你一个能生成随机整数1-7的函数,就叫他生成器get7吧,用它来生成一个1-11的随机整数,不能使用random,而且要等概率。

971
来自专栏java初学

MD5算法

3016
来自专栏简书专栏

基于tensorflow+DNN的MNIST数据集手写数字分类预测

DNN是deep neural network的简称,中文叫做深层神经网络,有时也叫做多层感知机(Multi-Layer perceptron,MLP)。 从...

3593
来自专栏ATYUN订阅号

面向纯新手的TensorFlow.js速成课程

本课程由CodingTheSmartWay.com出品,在本系列的第一部分中,你将学到:

9074

扫码关注云+社区

领取腾讯云代金券