深度学习笔记13:Tensorflow实战之手写mnist手写数字识别

上一讲笔者和大家一起学习了如何使用构建一个卷积神经网络模型。本节我们将继续利用的便捷性完成 mnist 手写数字数据集的识别实战。mnist 数据集是 Yann Lecun 大佬基于美国国家标准技术研究所构建的一个研究深度学习的手写数字的数据集。mnist 由 70000 张不同人手写的 0-9 10个数字的灰度图组成。本节笔者就和大家一起研究如何利用搭建一个 CNN 模型来识别这些手写的数字。

数据导入

mnist 作为标准深度学习数据集,在各大深度学习开源框架中都默认有进行封装。所以我们直接从中导入相关的模块即可:

快速搭建起一个简易神经网络模型

数据导入之后即可按照的范式创建相应的变量然后创建会话:

定义前向传播过程和损失函数:

进行模型训练:

使用训练好的模型对测试集进行预测:

预测准确率为 0.9,虽然说也是一个很高的准确率了,但对于 mnist 这种标准数据集来说,这样的结果还有很大的提升空间。所以我们继续优化模型结构,为模型添加卷积结构。

搭建卷积神经网络模型

定义初始化模型权重函数:

定义卷积和池化函数:

搭建第一层卷积:

搭建第二层卷积:

搭建全连接层:

设置防止过拟合:

对输出层定义:

训练模型并进行预测:

部分迭代过程和预测结果如下:

经过添加两层卷积之后我们的模型预测准确率达到了 0.9931,模型训练的算是比较好了。

注:本深度学习笔记系作者学习 Andrew NG 的 deeplearningai 五门课程所记笔记,其中代码为每门课的课后assignments作业整理而成。

参考资料:

https://www.coursera.org/learn/machine-learning

https://www.deeplearning.ai/

http://www.tensorfly.cn/tfdoc/tutorials/mnist_pros.html

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20180715B0FHIX00?refer=cp_1026
  • 腾讯「云+社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。

扫码关注云+社区

领取腾讯云代金券