用TensorFlow进行手写数字识别

作者 : 陈龙,腾讯即通产品部Android开发工程师,负责Android QQ的开发与维护。热衷于机器学习的研究与分享。

对于人类来说,识别手写的数字是一件非常容易的事情。我们甚至不用思考,就可以看出下面的数字分别是5,0,4,1。

但是想让机器识别这些数字,则要困难得多。

如果让你用传统的编程语言(如Java)写一个程序去识别这些形态各异的数字,你会怎么写?写很多方法去检测横、竖、圆这些基本形状,然后计算它们的相对位置?我想你很快就会陷入绝望之中。即使你花很多时间写出来了程序,精确度也一定不高。当你在传统编程方法的小黑屋里挣扎时,机器学习这种更高阶的方法却为我们打开了一扇窗。

为了找到识别手写数字的方法,机器学习界的大师Yann LeCun利用NIST(National Institute of Standards and Technology 美国国家标准技术研究所)的手写数字库构建了一个便于机器学习研究的子集MNIST。MNIST由70000张手写数字(0~9)图片(灰度图)组成,由很多不同的人写成,其中60000张是训练集,另外10000张是测试集,每张图片的大小是28 x 28像素,数字的大小是20 x 20,位于图片的中心。更详细的信息可以参考Yann LeCun的网站:http://yann.lecun.com/exdb/mnist/

已经有很多研究人员利用该数据集进行了手写数字识别的研究,也提出了很多方法,比如KNN、SVM、神经网络等,精度已经达到了人类的水平。

抛开这些研究成果,我们从头开始,想想怎样用机器学习的方法来识别这些手写数字。因为数字只包含0~9,对于任意一张图片,我们需要确定它是0~9中的哪个数字,所以这是一个分类问题。对于原始图片,我们可以将它看作一个28 x 28的矩阵,或者更简单地将它看作一个长度为784的一维数组。将图片看作一维数组将不能理解图片里的二维结构,我们暂且先这么做,看能够达到什么样的精度。这样一分析,我们很自然地就想到可以用Softmax回归来解决这个问题。关于Softmax Regression可以参考下面的文章:

http://ufldl.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92

我们的模型如下:

对于一张图片,我们需要算出它分别属于0~9的概率,哪个概率最大,我们即认为这张图片上是那个数字。我们给每个像素10个权重,对应于0~9,这样我们的权重矩阵的大小就是784 x 10。将上图的模型用公式表示,即为:

写成向量的形式:

softmax函数将n个非负的值归一化为0~1之间的值,形成一个概率分布。

模型有了,我们的代价函数是什么呢?怎样评估模型输出的结果和真实值之间的差距呢?我们可以将数字表示成一个10维的向量,数字是几则将第几个元素置为1,其它都为0,如下图所示:

比如1可表示为:[0, 1, 0, 0, 0, 0, 0, 0, 0, 0]。这样我们就可以用交叉熵来衡量模型输出结果和真实值之间的差距了,我们将其定义为:

其中,y是模型输出的概率分布,y一撇是真实值,也可以看作概率分布,只不过只有一个值为1,其它都为0。将训练集里每个样本与真实值的差距累加起来,就得到了成本函数。这个函数可以通过梯度下降法求解其最小值。

关于交叉熵(cross-entropy)可以参考下面这篇文章:

http://colah.github.io/posts/2015-09-Visual-Information/

模型和成本函数都有了,接下来我们用TensorFlow来实现它,代码如下:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# Import data
mnist = input_data.read_data_sets('input_data/', one_hot=True)

# Create the model
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, W) + b
y_ = tf.placeholder(tf.float32, [None, 10])

# Define loss and optimizer

# The raw formulation of cross-entropy,
#
#   tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),
#                                 reduction_indices=[1]))
#
# can be numerically unstable.
#
# So here we use tf.nn.softmax_cross_entropy_with_logits on the raw
# outputs of 'y', and then average across the batch.

cross_entropy = tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

sess = tf.InteractiveSession()

tf.global_variables_initializer().run()

# Train
for _ in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

# Test trained model
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: mnist.test.images,
                                    y_: mnist.test.labels}))

下面我来解释一些比较重要的代码:

 mnist = input_data.read_data_sets('input_data/', one_hot=True)

这行代码用来下载(如果没有下载)和读取MNIST的训练集、测试集和验证集(验证集暂时可以先不用管)。TensorFlow的安装包里就有了input_data这个module,所以我们直接import进来就好了。将数据读到内存后,我们就可以直接通过mnist.test.imagesmnist.test.labels来获得测试集的图片和对应的标签了。TensorFlow提供的方法从训练集里取了5000个样本作为验证集,所以训练集、测试集、验证集的大小分别为:55000、10000、5000。

'input_data/'是你存放下载的数据集的文件夹名。

W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

这里简单的将参数都初始化为0。在复杂的模型中,初始化参数有很多的技巧。

y = tf.matmul(x, W) + b

这一行是我们建立的模型,很简单的一个模型。tf.matmul表示矩阵相乘。

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))

这一行等价于:

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)), reduction_indices=[1]))

y_ * tf.log(tf.nn.softmax(y)做的事情就是计算每个样本的cross-entropy,因为这样算数值不稳定,所以换成了tf.nn.softmax_cross_entropy_with_logits这个方法。

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

这一行就是用TensorFlow提供的梯度下降法在成本函数最小化的过程中调整参数,学习率设置为0.5。

for _ in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

这3行代码是进行训练,因为训练集有55000个样本,太大了,所以采用了Stochastic Gradient Descent(随机梯度下降),这样做大大降低了计算量,同时又能有效的训练参数,使其收敛。mnist.train.next_batch方法就是从训练集里随机取100个样本来训练。迭代的次数为1000。

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

这3行是在模型参数训练完之后,在测试集上检测其准确度。

运行我们的代码,其结果如下:

可以看到,这样一个简单的模型就可以达到92%的准确度。Amazing~

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

编辑于

longchen的专栏

1 篇文章1 人订阅

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏专知

【干货】卷积神经网络中的四种基本组件

【导读】当今,卷积神经网络在图像识别等领域取得巨大的成功,那么是什么使其高效而快速呢?本文整理John Olafenwa的一篇博文,主要介绍了卷积神经网络采用的...

5386
来自专栏人工智能

具有张量流的混合密度网络

不久之前,Google开源了TensorFlow,这是一个旨在简化图表计算的库。 主要的应用程序是针对深度学习,将神经网络以图形形式显示。 我花了几天的时间阅读...

4446
来自专栏和蔼的张星的图像处理专栏

SAMF

论文:paper 结合了CN和KCF的多尺度扩展,看文章之前就听说很暴力,看了以后才发现原来这么暴力。 论文的前一半讲KCF,后一半讲做的实验,中间一点点大...

922
来自专栏人工智能头条

算法优化之道:避开鞍点

1663
来自专栏PPV课数据科学社区

TensorFlow和深度学习入门教程

关键词:Python,tensorflow,深度学习,卷积神经网络 正文如下: 前言 上月导师在组会上交我们用tensorflow写深度学习和卷积神经网络,并把...

3826
来自专栏人工智能

用TensorFlow生成抽象图案艺术

QQ图片20180204220437.jpg

7305

跨语言嵌入模型的调查

注意:如果您正在查找调查报告,此博客文章也可作为arXiv上的一篇文章。

29510
来自专栏人工智能LeadAI

图像学习-验证码识别

这是去年博主心血来潮实现的一个小模型,现在把它总结一下。由于楼主比较懒,网上许多方法都需要切割图片,但是楼主思索了一下感觉让模型有多个输出就可以了呀,没必要一定...

4394
来自专栏技术墨客

MNIST 机器学习入门(TensorFlow)

本文是为既没有机器学习基础也没了解过TensorFlow的码农、序媛们准备的。如果已经了解什么是MNIST和softmax回归本文也可以再次帮助你提升理解。在阅...

592
来自专栏CVer

TensorFlow和深度学习入门教程

英文原文:https://codelabs.developers.google.com/codelabs/cloud-tensorflow-mnist/#0 C...

3618

扫码关注云+社区