首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >如何用Tensorflow完成手写数字识别?

如何用Tensorflow完成手写数字识别?

作者头像
HuangWeiAI
发布2019-08-01 15:16:28
6550
发布2019-08-01 15:16:28
举报
文章被收录于专栏:浊酒清味浊酒清味

深度学习最经典的任务问题就是分类。通过分类,我们可以将照片中的数字,人脸,动植物等等分到它属于的那一类当中,完成识别。接下来,我就带着大家一起完成一个简单的程序,来实现经典问题手写数字识别。

数据集

我们第一步需要收集一堆手写数据,并且将每个手写数字都标号类别,用来做成数据集。对于深度学习而言,一般的数据集大小至少上万起。所以收集数据这个工作还是比较繁琐的。不过呢,有人已经帮我们弄好了数据集,这就是鼎鼎有名的MNIST数据集。

MNIST数据集是一个标准的手写数据集,如上图所示,数据集里面有六万个手写数字且都标记完全。其中有五万个手写数字作为训练集,另外一万作为测试集。

这里有一份传送门:

http://yann.lecun.com/exdb/mnist/

我们并不需要事先下载MNIST数据集,Tensorflow几行代码就可以搞定

搭建网络

准备好了数据集之后,我们开始用Tensorflow搭建神经网络模型:

1.输入输出

tf.placeholder是占位符的意思,先把坑填好,之后会有数据填充进去。其中y_是输入对应的正确的数字标签,x就是手写数字照片。

2.网络主体

我们建立了一个四层全连接网络,每一层的网络宽度都是400。因为MNIST数据集的数字照片都是28*28的,所以第一层网络的权重的形状是[784,400],注意到我们使用了Dropout技术,所以代码中有tf.nn.dropout。对于最后一层我们用softmax技术,将对0-9数字的预测归一化,变成一个概率。

3.损失函数和优化器

对于损失函数,我们选择了平方差函数,其实就是线性规划。而优化器我们选择了Adam,是目前主流的优化器。

训练网络

1.初始化

我们在这里做了两件事情,一个是初始化网络中变量,第二个建立一个存储器,用来存储训练过程的一些变量。

2.训练

第一行的循环是控制循环的次数,我们使用了随机梯度训练,就是每次更新参数的时候并不是一次性把五万张照片一起塞进去,而是从中随机选出来作为一个batch来训练,这样的做的好处是可以大大减轻计算量。我们需要在每一步都在训练集上面训练来更新网络的参数,接着我们一定步骤后在测试集上面看看我们的训练效果。

3.执行程序

才开始训练集和测试集上的准确率是在10%附近,这是因为在网络的参数没有更新的时候,所有参数都是随机的,相当于我们在瞎猜。一共有十个数字,所以猜对的概率是十分之一。之后,随着训练的进行,训练集和测试集上的准确率都在增加。我们同时观察训练集和测试集上的准确率,是防止网络过拟合把我们欺骗了。

训练到一定步时,我们发现训练集的准确率已经接近百分之百了,测试集上的准确率也达到了百分之九十七以上。简简单单的四层就能做到如此之高的准确率,可见神经网络之神奇!

代码下载

需要代码的同学请阅读原文访问我的github页面下载代码fnn-mnist.py。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-07-31,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 浊酒清味 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档