用TensorFlow进行手写数字识别

作者 : 陈龙,腾讯即通产品部Android开发工程师,负责Android QQ的开发与维护。热衷于机器学习的研究与分享。

对于人类来说,识别手写的数字是一件非常容易的事情。我们甚至不用思考,就可以看出下面的数字分别是5,0,4,1。

但是想让机器识别这些数字,则要困难得多。

如果让你用传统的编程语言(如Java)写一个程序去识别这些形态各异的数字,你会怎么写?写很多方法去检测横、竖、圆这些基本形状,然后计算它们的相对位置?我想你很快就会陷入绝望之中。即使你花很多时间写出来了程序,精确度也一定不高。当你在传统编程方法的小黑屋里挣扎时,机器学习这种更高阶的方法却为我们打开了一扇窗。

为了找到识别手写数字的方法,机器学习界的大师Yann LeCun利用NIST(National Institute of Standards and Technology 美国国家标准技术研究所)的手写数字库构建了一个便于机器学习研究的子集MNIST。MNIST由70000张手写数字(0~9)图片(灰度图)组成,由很多不同的人写成,其中60000张是训练集,另外10000张是测试集,每张图片的大小是28 x 28像素,数字的大小是20 x 20,位于图片的中心。更详细的信息可以参考Yann LeCun的网站:http://yann.lecun.com/exdb/mnist/

已经有很多研究人员利用该数据集进行了手写数字识别的研究,也提出了很多方法,比如KNN、SVM、神经网络等,精度已经达到了人类的水平。

抛开这些研究成果,我们从头开始,想想怎样用机器学习的方法来识别这些手写数字。因为数字只包含0~9,对于任意一张图片,我们需要确定它是0~9中的哪个数字,所以这是一个分类问题。对于原始图片,我们可以将它看作一个28 x 28的矩阵,或者更简单地将它看作一个长度为784的一维数组。将图片看作一维数组将不能理解图片里的二维结构,我们暂且先这么做,看能够达到什么样的精度。这样一分析,我们很自然地就想到可以用Softmax回归来解决这个问题。关于Softmax Regression可以参考下面的文章:

http://ufldl.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92

我们的模型如下:

对于一张图片,我们需要算出它分别属于0~9的概率,哪个概率最大,我们即认为这张图片上是那个数字。我们给每个像素10个权重,对应于0~9,这样我们的权重矩阵的大小就是784 x 10。将上图的模型用公式表示,即为:

写成向量的形式:

softmax函数将n个非负的值归一化为0~1之间的值,形成一个概率分布。

模型有了,我们的代价函数是什么呢?怎样评估模型输出的结果和真实值之间的差距呢?我们可以将数字表示成一个10维的向量,数字是几则将第几个元素置为1,其它都为0,如下图所示:

比如1可表示为:[0, 1, 0, 0, 0, 0, 0, 0, 0, 0]。这样我们就可以用交叉熵来衡量模型输出结果和真实值之间的差距了,我们将其定义为:

其中,y是模型输出的概率分布,y一撇是真实值,也可以看作概率分布,只不过只有一个值为1,其它都为0。将训练集里每个样本与真实值的差距累加起来,就得到了成本函数。这个函数可以通过梯度下降法求解其最小值。

关于交叉熵(cross-entropy)可以参考下面这篇文章:

http://colah.github.io/posts/2015-09-Visual-Information/

模型和成本函数都有了,接下来我们用TensorFlow来实现它,代码如下:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# Import data
mnist = input_data.read_data_sets('input_data/', one_hot=True)

# Create the model
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, W) + b
y_ = tf.placeholder(tf.float32, [None, 10])

# Define loss and optimizer

# The raw formulation of cross-entropy,
#
#   tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),
#                                 reduction_indices=[1]))
#
# can be numerically unstable.
#
# So here we use tf.nn.softmax_cross_entropy_with_logits on the raw
# outputs of 'y', and then average across the batch.

cross_entropy = tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

sess = tf.InteractiveSession()

tf.global_variables_initializer().run()

# Train
for _ in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

# Test trained model
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: mnist.test.images,
                                    y_: mnist.test.labels}))

下面我来解释一些比较重要的代码:

 mnist = input_data.read_data_sets('input_data/', one_hot=True)

这行代码用来下载(如果没有下载)和读取MNIST的训练集、测试集和验证集(验证集暂时可以先不用管)。TensorFlow的安装包里就有了input_data这个module,所以我们直接import进来就好了。将数据读到内存后,我们就可以直接通过mnist.test.imagesmnist.test.labels来获得测试集的图片和对应的标签了。TensorFlow提供的方法从训练集里取了5000个样本作为验证集,所以训练集、测试集、验证集的大小分别为:55000、10000、5000。

'input_data/'是你存放下载的数据集的文件夹名。

W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

这里简单的将参数都初始化为0。在复杂的模型中,初始化参数有很多的技巧。

y = tf.matmul(x, W) + b

这一行是我们建立的模型,很简单的一个模型。tf.matmul表示矩阵相乘。

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))

这一行等价于:

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)), reduction_indices=[1]))

y_ * tf.log(tf.nn.softmax(y)做的事情就是计算每个样本的cross-entropy,因为这样算数值不稳定,所以换成了tf.nn.softmax_cross_entropy_with_logits这个方法。

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

这一行就是用TensorFlow提供的梯度下降法在成本函数最小化的过程中调整参数,学习率设置为0.5。

for _ in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

这3行代码是进行训练,因为训练集有55000个样本,太大了,所以采用了Stochastic Gradient Descent(随机梯度下降),这样做大大降低了计算量,同时又能有效的训练参数,使其收敛。mnist.train.next_batch方法就是从训练集里随机取100个样本来训练。迭代的次数为1000。

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

这3行是在模型参数训练完之后,在测试集上检测其准确度。

运行我们的代码,其结果如下:

可以看到,这样一个简单的模型就可以达到92%的准确度。Amazing~

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

编辑于

longchen的专栏

1 篇文章1 人订阅

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏PPV课数据科学社区

自创数据集,使用TensorFlow预测股票入门

机器之心编译 参与:蒋思源、李亚洲、刘晓坤 STATWORX 团队近日从 Google Finance API 中精选出了 S&P 500 数据,该数据集包含 ...

3417
来自专栏机器之心

学界 | 谷歌论文新突破:通过辅助损失提升RNN学习长期依赖关系的能力

选自arXiv 机器之心编译 参与:李诗萌、黄小天 本文提出了一种简单的方法,通过在原始函数中加入辅助损失改善 RNN 捕捉长期依赖关系的能力,并在各种设置下评...

3565
来自专栏机器学习算法与Python学习

基于TensorFlow实现自编码器(附源码)

关键字全网搜索最新排名 【机器学习算法】:排名第一 【机器学习】:排名第二 【Python】:排名第三 【算法】:排名第四 AE简介 传统的机器学习很大程度上依...

9379
来自专栏人工智能LeadAI

机器学习实战 | 数据探索(缺失值处理)

点击“阅读原文”直接打开【北京站 | GPU CUDA 进阶课程】报名链接 接着上一篇:《机器学习实战-数据探索》介绍,机器学习更多内容可以关注github项目...

3546
来自专栏新智元

机器翻译新突破!“普适注意力”模型:概念简单参数少,性能大增

目前,最先进的机器翻译系统基于编码器-解码器架构,首先对输入序列进行编码,然后根据输入编码生成输出序列。两者都与注意机制接口有关,该机制基于解码器状态,对源令牌...

924
来自专栏文武兼修ing——机器学习与IC设计

CapsNet学习笔记理论学习代码阅读(PyTorch)参考资料

理论学习 胶囊结构 胶囊可以看成一种向量化的神经元。对于单个神经元而言,目前的深度网络中流动的数据均为标量。例如多层感知机的某一个神经元,其输入为若干个标量,...

3649
来自专栏小樱的经验随笔

最小二乘法多项式曲线拟合原理与实现

概念 最小二乘法多项式曲线拟合,根据给定的m个点,并不要求这条曲线精确地经过这些点,而是曲线y=f(x)的近似曲线y= φ(x)。 原理 [原理部分由个人根据互...

4495
来自专栏有趣的Python

2- 深度学习之神经网络核心原理与算法-提高神经网络学习效率

上一章我们介绍了基本的前馈神经网络的实现。 本节我们来介绍一些可以提高神经网络学习效率的方法。 并行计算 加快神经网络训练最直接的方式。我们需要得到的是一个网络...

56613
来自专栏ATYUN订阅号

伯克利人工智能研究项目:为图像自动添加准确的说明

人类可以很容易地推断出给定图像中最突出的物体,并能描述出场景内容,如物体所处于的环境或是物体特征。而且,重要的是,物体与物体之间如何在同一个场景中互动。视觉描述...

3415
来自专栏数据派THU

自创数据集,用TensorFlow预测股票教程 !(附代码)

来源:机器之心 本文长度为4498字,建议阅读8分钟 本文非常适合初学者了解如何使用TensorFlow构建基本的神经网络。 STATWORX 团队近日从 Go...

4617

扫码关注云+社区