一步步提高手写数字的识别率(1)

手写数字识别是机器学习领域中的一个经典应用,很多机器学习算法以这个问题作为示例,其地位相当于程序界的hello world。这个问题具有以下两个特点:

  1. 问题复杂度适中。手写识别是一门很深的学问,但这里将问题域限制在手写数字的识别,具体说就是识别0 - 9一共十个数字。相对于识别手写汉字,其复杂度低了很多。另一方面这个问题又不是太简单,可以很好的展现算法的特点。
  2. 完善的数据集。这个问题的研究历史悠久,有着完善的样本和分类数据,而且提供免费下载。具体说就是MNIST(Mixed National Institue of Standards and Technology database)数据集,它由几万张28像素 x 28像素的手写数字组成,这些图片只包含灰度值信息。有了这个完善的数据集,我们就可以免去繁琐的收集数据、整理数据、处理数据之苦。

通过机器学习识别手写数字并非难事,然而要做到完美识别手写数字并不容易。在这篇文章中我们使用简单的softmax回归算法来训练一个手写数字识别模型,并测试其正确率,在后续的文章中,我们将采用深度学习、卷积神经网络等算法一步步改进我们的算法,逐步提高手写数字的识别率。

在本系列文章中,你将学习到:

  • 经典机器学习算法、深度神经网络、卷积神经网络在手写识别系统中的应用。
  • Tensorflow的编程技巧,包括Tensorflow编程的基本流程、如何使用Tensorflow内建的函数快速实现softmax回归、深度神经网络、卷积神经网络等算法。

本文将不会深入探讨算法本身,比如softmax、梯度递减、卷积运算等等,在Tensorflow中这都由内建函数实现,通常我们并不会从头写代码来实现,也不用深入算法细节。

在开始Tensorflow编程之前,我们先回顾一下Tensorflow实现机器学习算法的一般流程,通常流程分如下4个步骤:

  1. 加载数据集
  2. 定义算法公式,也就是前向计算的计算图
  3. 定义损失函数(loss function),选定优化器,并指定优化器优化损失函数
  4. 对数据进行迭代训练
  5. 在测试集或交叉验证数据集上进行准确率评估。

接下来将详细展开整个过程。

加载MNIST数据集

MNIST数据集包含55000个训练样本,10000个测试样本,另外还有5000个交叉验证数据样本。每个样本都有对应的标签信息,即label。

TensorFlow为我们提供了一个封装函数,可以直接加载MNIST数据集,并转换为我们期望的格式:

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tfmnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

如果是第一次读取数据,read_data_sets函数将从网络下载数据,并保存在本地的MNIST_data目录下。

加载的每个手写数字图像是28 x 28像素大小的灰度图像,但并非我们通常采用的jpg、png或bmp格式,而是像素矩阵,如图1所示:

图1 手写数字灰度信息示例

在机器学习中,有一个很重要的概念,就是特征信息,即能够表明对象特点的信息。对于图像数据而言,我们通常选取所有像素点作为特征,为了简化起见,我们将28x28的像素点展开为一维数据(shape=784)。考虑到训练样本数为55000个,所以训练数据的特征为一个55000 x 784的Tensor,如图2所示:

图2 MNIST训练样本的特征

训练数据标签(label)为55000x10的Tensor,这里的标签采用了one-hot编码,具体就是每个标签对应一个长度为10的向量,取值只有0和1,只有对应数字的位为1,其余为0,比如数值0对应的one-hot编码是[1,0,0,0,0,0,0,0,0,0],而数值5对应的编码就是[0,0,0,0,1,0,0,0,0,0]。如图3所示:

图3 MNIST训练样本的标签

前向计算公式

处理多分类任务,通常采用Softmax模型,具体来说,公式为:

y = softmax(Wx + b)

其中W为权值矩阵,b为偏置向量,x为输入特征矩阵,也就是我们从数据集中读取的矩阵。用比较形象的图形可以表示如下(为了简化起见,假设输入特征值只有3个):

图4 softmax计算图

通过梯度递减迭代,我们计算出W和b。我们先给W和b一个初始值,通过梯度递减迭代逐步更新W和b,最后达到接近正确值。这在TensorFlow中只需几行代码即可做到:

sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))y = tf.nn.softmax(tf.matmul(x, W) + b)

placeholder是数据输入的地方,placeholder的第一个参数是数据类型,通常我们选择float32,第二个参数[None, 784]代表tensor的shape,也就是数据的维度,None代表不限条数的输入。对于简单模型而言,weights和bias的初始值并不重要,所以可以简单的初始化为0。

softmax是tf.nn下的一个函数,其实tf.nn下包含了大量的神经网络组件。如果你学习过机器学习的课程就知道,梯度递减回归算法还有一个反向计算过程,而TensorFlow的优秀之处就在于可以自动求导,并进行梯度更新,完成softmax回归模型参数的自动学习。

要让Tensorflow进行梯度递减回归,我们还需要定义一个损失函数(loss function)。

定义loss,选择优化器

为了训练模型,我们需要定义一个损失函数来描述优化目标,损失函数值越小,代表模型的分类结果与真实值的偏差越小,也就是说模型的准确率越高。我们给权重矩阵和偏置矩阵填充了全零的初始值,模型计算出一个初始的损失值,而训练的目的是不断将这个损失值减小,直到到达一个全局最优或局部最优解。

对于多分类问题,通常采用交叉熵(cross-entropy)作为损失函数。交叉熵的定义如下,其中y是预测的概率分布,y’是真实的概率分布:

在TensorFlow中,定义交叉熵很容易:

y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), 1))

这里y_就是前面公式中的y’,对应这训练样本的标签。tf.reduce_mean用来对每个批次(batch)的数据结果求均值。

接下来还要选择优化器,通常采用的优化器有随机梯度下降(Stochastic Gradient Descent, SGD)。定义好优化算法之后,TensorFlow就可以根据我们定义的计算图自动求导,并反向传播(Back Propagation)进行训练,每一轮迭代更新参数,减少损失值。这其中还有一个很重要的超参数:学习率,这个值的选择也很重要,不过在这里不详细探讨该如何选择学习率,我们选择一个常用的学习率0.5即可。

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

当然,TensorFlow中也有很多其他的优化器,通常只需修改函数名即可替换。

迭代训练

接下来开始迭代执行训练操作train_step。这里每次都随机从训练集中抽取100条样本构成一个mini-batch,并喂给placeholder。使用一小部分样本进行训练称为批量梯度下降法,与每次使用全样本的全梯度下降算法相比,具有收敛速度快的特点,在训练样本很大的情况下,经常采用。

tf.global_variables_initializer().run()for i in range(1000):
   batch_xs, batch_ys = mnist.train.next_batch(100)
   train_step.run({x: batch_xs, y_: batch_ys})

训练结束后,我们可以得到W和b的值,这样通过简单的前向计算即可预测手写数字识别的结果。不过在投入使用之前,通常我们需要使用测试数据集或交叉验证数据集对模型进行评估,评估其准确率是否满足产品需求,是否需要进一步优化。

模型评估

我们将测试数据样本和对应的标签输入评估流程,计算模型在测试集上的准确率。代码如下:

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

使用Softmax递归算法对MNIST数据进行分类识别,在测试集上平均准确率在92%左右。这是一个比较不错的结果,但还谈不上实用的程度。在后续的文章中,我们将采用深度网络、卷积神经网络来提升手写数字识别准确度。

参考

  1. TensorFlow实战,黄文坚、唐源著,电子工业出版社。

原文发布于微信公众号 - 云水木石(ourpoeticlife)

原文发表时间:2018-07-13

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

扫码关注云+社区

领取腾讯云代金券