深度学习入门笔记系列 ( 五 )

基于 tensorflow 的手写数字的识别(简单版本)

本系列将分为 8 篇 。本次为第 5 篇 ,结合上一篇的应用实例 ,将前边学到一些基础知识用到手写数字的识别分类上 。

1.关于 MNIST 数据集

首先 ,我们得了解 MNIST 数据集 。这是一个手写数字数据集 ,在深度学习入门学习中极具代表性 。可以从官网下载该数据集 ,但事实上 TensorFlow 中提供了一个类来处理 MNIST 数据 ,这个类会自动下载并转化格式 ,将数据从原始的数据包中解析成训练和测试神经网络时使用的格式 ,具体相关函数在接下来代码中介绍 。

MNIST 数据集被分为训练数据集(60000张手写数字图片)和测试数据集(10000张手写数字图片)。

每一张图片包含 28*28 个像素 ,图片里的某个像素的强度值介于0-1之间。例如 ,数字 1 对应一个 28*28 像素图片 ,其像素强度如下 :

我们把这一个数组展开成一个向量 ,长度是 28*28=784 。因此在MNIST训练数据集中 mnist.train.images 是一个形状为 [60000, 784] 的张量,第一个维度数字用来索引图片,第二个维度数字用来索引每张图片中的像素点。

2.one-hot 向量和 Softmax 函数

MNIST 数据集标签为 0-9 十个数字 ,我们用 one-hot 向量来表示 。以MNIST 数据集为例 ,one-hot 向量指的是除了某一位数字为 1 ,其他维度都为 0 ,比如数字 1 对应 [0,1,0,0,0,0,0,0,0,0] 。

那么我们就可以得到数据集中对应的标签(labels)是若干个 one-hot 向量组成的矩阵 。以训练集为例 ,是一个 [60000,10] 的数字矩阵 。

另一个重要的知识就是 Softmax 函数 。如果是二分类问题 ,我们可以考虑用 sigmoid 或 tanh 等进行分类 ,即分为是或否 。这里是多分类问题 ,softmax 就很合适了 。这里小詹不知道怎么描述容易让大家理解 ,借鉴一个博客链接给出一段较为生动的描述 。

我们知道 max ,假如说我有两个数 ,a 和 b ,并且 a > b ,如果取 max ,那么就直接取 a ,没有第二种可能 。但有的时候我希望分值大的那一项(a) 经常取到 ,分值小的那一项 (b) 也偶尔可以取到 ,那么我用 softmax 就可以了 。(尊重原创 ,附上这段话链接:https://blog.csdn.net/supercally/article/details/54234115)

3.MNIST 数据集识别实战

以上已经对基本的知识进行了介绍 ,这里进行实战讲解 。我们首先要设计一个网络结构 ,然后根据第四讲中的 “三步走” 步骤进行实现 。这里简单版本先设计一个简单到不能更简单的网络实现手写数字的识别分类 。

训练过程 ,每一张图片输入的可以看成一个长度为 784 的向量 ,输出为 0-9 中的一个 ,即有 10 种可能 ,或者说这就是一个 10 分类问题 。所以我们采取输入层 784 个神经元 ,全连接到输出层 10 个神经元 。( 哪个帅哥写的字 ?这么丑 !哈哈)

首先 ,需要读取 MNIST 数据集 ,利用 TF 框架自带类进行下载读取 。

接下来就是根据之前的 “三步走” 进行实践 。实现上述的最简单的网络结构 ,并依旧选择二次代价函数和梯度下降法 。

再在会话 Session 中执行 。代码如下 :

这里小詹讲一下下面这两行代码如何求出了 accuracy 。

因为这里无论是数据集中的 labels ,还是预测值 prediction 都是以 one-hot 向量形式存在 。tf.argmax 返回一维张量中最大值所在位置 ,若某一张图片数据的 label 和对该图片的预测 最大值在同一个位置(例如数字 3 ,预测结果和 label 对应的 one-hot 向量都为[0,0,0,1,0,0,0,0,0,0]),此时 tf.equal 则返回值为 1 ,反之为 0 。即预测正确为 1 ,错误为 0 。

之后利用 tf.reduce_mean() 函数将所有的correct_prediction 求平均值 ,比如测试 10 张图片 ,上述有 9 张正确(9个1),1 张错误(1个0)。则平均值为 0.9,就是预测精度了 。

那么利用以上网络和代码得到的结果是怎样的呢 ?下面给出结果 。

原文发布于微信公众号 - 小詹学Python(xiaoxiaozhantongxue)

原文发表时间:2018-08-04

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏Coding迪斯尼

深度学习:将新闻报道按照不同话题性质进行分类

1092
来自专栏谭正中的专栏

TensorFlow入门(1):求N元一次方程

今年以来,人工智能成为一个时代热点,同时 TensorFlow 1.0 的发布后,我也想蹭蹭时代的热点,初步学习一下神经网络和机器学习,在这里把成果以初学者的方...

3.7K1
来自专栏自学笔记

聚类算法

p=2时就说平时计算的几何距离,当p趋向于正无穷的时候,其实求的就不是x,y的距离了,而是求x y中最长的一个了。因为如果x大于y,在指数增长下x回远大于y,所...

2902
来自专栏ATYUN订阅号

使用Apache MXNet分类交通标志图像

有许多深度学习的框架,例如TensorFlow、Keras、Torch和Caffe,Apache MXNet由于其在多个GPU上的可伸缩性而受到欢迎。在这篇博文...

6986
来自专栏IT派

实战|TensorFlow 实践之手写体数字识别!

本文的主要目的是教会大家运用google开源的深度学习框架tensorflow来实现手写体数字识别,给出两种模型,一种是利用机器学习中的softmax regr...

1240
来自专栏机器学习之旅

Python:SMOTE算法

17.11.28更新一下:最近把这个算法集成到了数据预处理的python工程代码中了,不想看原理想直接用的,有简易版的python开发:特征工程代码模版 ,进...

2224
来自专栏游戏开发那些事

【Unity3d游戏开发】游戏中的贝塞尔曲线以及其在Unity中的实现

  RT,马三最近在参与一款足球游戏的开发,其中涉及到足球的各种运动轨迹和路径,比如射门的轨迹,高吊球,香蕉球的轨迹。最早的版本中马三是使用物理引擎加力的方式实...

5071
来自专栏AI科技大本营的专栏

如何在Python中用LSTM网络进行时间序列预测

Matt MacGillivray 拍摄,保留部分权利 翻译 | AI科技大本营(rgznai100) 长短记忆型递归神经网络拥有学习长观察值序列的潜力。它似...

9744
来自专栏机器之心

资源 | Python 环境下的自动化机器学习超参数调优

由于机器学习算法的性能高度依赖于超参数的选择,对机器学习超参数进行调优是一项繁琐但至关重要的任务。手动调优占用了机器学习算法流程中一些关键步骤(如特征工程和结果...

1944
来自专栏程序生活

理解LSTM网络(整合)Recurrent Neural Networks长期依赖(Long-Term Dependencies)问题LSTM 网络GRU - Gated Recurrent Unit

1712

扫码关注云+社区

领取腾讯云代金券