用TensorFlow进行手写数字识别

更多腾讯海量技术文章,请关注云加社区:https://cloud.tencent.com/developer

作者:陈龙

对于人类来说,识别手写的数字是一件非常容易的事情。我们甚至不用思考,就可以看出下面的数字分别是5,0,4,1。

但是想让机器识别这些数字,则要困难得多。

如果让你用传统的编程语言(如Java)写一个程序去识别这些形态各异的数字,你会怎么写?写很多方法去检测横、竖、圆这些基本形状,然后计算它们的相对位置?我想你很快就会陷入绝望之中。即使你花很多时间写出来了程序,精确度也一定不高。当你在传统编程方法的小黑屋里挣扎时,机器学习这种更高阶的方法却为我们打开了一扇窗。

为了找到识别手写数字的方法,机器学习界的大师Yann LeCun利用NIST(National Institute of Standards and Technology 美国国家标准技术研究所)的手写数字库构建了一个便于机器学习研究的子集MNIST。MNIST由70000张手写数字(0~9)图片(灰度图)组成,由很多不同的人写成,其中60000张是训练集,另外10000张是测试集,每张图片的大小是28 x 28像素,数字的大小是20 x 20,位于图片的中心。更详细的信息可以参考Yann LeCun的网站:http://yann.lecun.com/exdb/mnist/

已经有很多研究人员利用该数据集进行了手写数字识别的研究,也提出了很多方法,比如KNN、SVM、神经网络等,精度已经达到了人类的水平。

抛开这些研究成果,我们从头开始,想想怎样用机器学习的方法来识别这些手写数字。因为数字只包含0~9,对于任意一张图片,我们需要确定它是0~9中的哪个数字,所以这是一个分类问题。对于原始图片,我们可以将它看作一个28 x 28的矩阵,或者更简单地将它看作一个长度为784的一维数组。将图片看作一维数组将不能理解图片里的二维结构,我们暂且先这么做,看能够达到什么样的精度。这样一分析,我们很自然地就想到可以用Softmax回归来解决这个问题。关于Softmax Regression可以参考下面的文章:

http://ufldl.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92

我们的模型如下:

对于一张图片,我们需要算出它分别属于0~9的概率,哪个概率最大,我们即认为这张图片上是那个数字。我们给每个像素10个权重,对应于0~9,这样我们的权重矩阵的大小就是784 x 10。将上图的模型用公式表示,即为:

写成向量的形式:

softmax函数将n个非负的值归一化为0~1之间的值,形成一个概率分布。

模型有了,我们的代价函数是什么呢?怎样评估模型输出的结果和真实值之间的差距呢?我们可以将数字表示成一个10维的向量,数字是几则将第几个元素置为1,其它都为0,如下图所示:

比如1可表示为:[0, 1, 0, 0, 0, 0, 0, 0, 0, 0]。这样我们就可以用交叉熵来衡量模型输出结果和真实值之间的差距了,我们将其定义为:

其中,y是模型输出的概率分布,y一撇是真实值,也可以看作概率分布,只不过只有一个值为1,其它都为0。将训练集里每个样本与真实值的差距累加起来,就得到了成本函数。这个函数可以通过梯度下降法求解其最小值。

关于交叉熵(cross-entropy)可以参考下面这篇文章:

http://colah.github.io/posts/2015-09-Visual-Information/

模型和成本函数都有了,接下来我们用TensorFlow来实现它,代码如下:

下面我来解释一些比较重要的代码:

mnist = input_data.read_data_sets('input_data/', one_hot=True)

'input_data/'是你存放下载的数据集的文件夹名。

W = tf.Variable(tf.zeros([784, 10]))b = tf.Variable(tf.zeros([10]))

这里简单的将参数都初始化为0。在复杂的模型中,初始化参数有很多的技巧。

y = tf.matmul(x, W) + b

这一行是我们建立的模型,很简单的一个模型。

tf.matmul

表示矩阵相乘。

这一行等价于:

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

这一行就是用TensorFlow提供的梯度下降法在成本函数最小化的过程中调整参数,学习率设置为0.5。

for _ in range(1000):

sess.run(train_step, feed_dict=)

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

print(sess.run(accuracy, feed_dict=))

这3行是在模型参数训练完之后,在测试集上检测其准确度。

运行我们的代码,其结果如下:

可以看到,这样一个简单的模型就可以达到92%的准确度。Amazing~

  • 发表于:
  • 原文链接:http://kuaibao.qq.com/s/20180102A0D4XI00?refer=cp_1026

同媒体快讯

相关快讯

扫码关注云+社区