C+实现神经网络之六—实战手写数字识别

之前的五篇博客讲述的内容应该覆盖了如何编写神经网络的大部分内容,在经过之前的一系列努力之后,终于可以开始实战了。试试写出来的神经网络怎么样吧。

数据准备

有人说MNIST手写数字识别是机器学习领域的Hello World,所以我这一次也是从手写字体识别开始。我是从Kaggle找的手写数字识别的数据集。数据已经被保存为csv格式,相对比较方便读取。

数据集包含了数字0-9是个数字的灰度图。但是这个灰度图是展开过的。展开之前都是28x28的图像,展开后成为1x784的一行。csv文件中,每一行有785个元素,第一个元素是数字标签,后面的784个元素分别排列着展开后的184个像素。看起来像下面这样:

也许你已经看到了第一列0-9的标签,但是会疑惑为啥像素值全是0,那是因为这里能显示出来的,甚至不足28x28图像的一行。而数字一般应该在图像中心位置,所以边缘位置当然是啥也没有,往后滑动就能看到非零像素值了。像下面这样:

这里需要注意到的是,像素值的范围是0-255。一般在数据预处理阶段都会归一化,全部除以255,把值转换到0-1之间。

csv文件中包含42000个样本,这么多样本,对于我七年前买的4000元级别的破笔记本来说,单单是读取一次都得半天,更不要提拿这么多样本去迭代训练了,简直是噩梦(兼论一个苦逼的学生几年能挣到换电脑的钱!)。所以我只是提取了前1000个样本,然后把归一化后的样本和标签都保存到一个xml文件中。在前面的一篇博客中已经提到了输入输出的组织形式,偷懒直接复制了:

既然说到了输出的组织方式,那就顺便也提一句输入的组织方式。生成神经网络的时候,每一层都是用一个单列矩阵来表示的。显然第一层输入层就是一个单列矩阵。所以在对数据进行预处理的过程中,我就是把输入样本和标签一列一列地排列起来,作为矩阵存储。标签矩阵的第一列即是第一列样本的标签。以此类推。

把输出层设置为一个单列十行的矩阵,标签是几就把第几行的元素设置为1,其余都设为0。由于编程中一般都是从0开始作为第一位的,所以位置与0-9的数字正好一一对应。我们到时候只需要找到输出最大值所在的位置,也就知道了输出是几。”

这里只是重复一下,这一部分的代码在中:

这是我最近用ReLU的时候的代码,标签是几就把第几位设为几,其他为全设为0。最后都是找到最大值的位置即可。

在代码中的作用是,检验下转换后的矩阵和标签是否对应正确这里是把col(3),也就是第四个样本从一行重新变成28x28的图像,看上面的第一张图的第一列可以看到,第四个样本的标签是4。那么它转换回来的图像时什么样呢?是下面这样:

这里也证明了为啥第一张图看起来像素全是0。边缘全黑能不是0吗?

然后在使用的时候用前面提到过的get_input_label()获取一定数目的样本和标签。

实战数字识别

没想到前面数据处理说了那么多。。。。

废话少说,直接说训练的过程:

给定每层的神经元数目,初始化神经网络和权值矩阵

从inputlabel1000.xml文件中取前800个样本作为训练样本,后200作为测试样本。

这是神经网络的一些参数:训练时候的终止条件,学习率,激活函数类型

前800样本训练神经网络,直到满足loss小于阈值loss_threshold,停止。

后200样本测试神经网络,输出正确率。

保存训练得到的模型。

以sigmoid为激活函数的训练代码如下:

对比前面说的六个过程,代码应该是很清晰的了。参数output_interval是间隔几次迭代输出一次,这设置为迭代两次输出一次。

如果按照上面的参数来训练,正确率是0.855:

在只有800个样本的情况下,这个正确率我认为还是可以接受的。

如果要直接使用训练好的样本,那就更加简单了:

如果激活函数是tanh函数,由于tanh函数的值域是[-1,1],所以在训练的时候要把标签矩阵稍作改动,需要改动的地方如下:

这里不光改了标签,还有几个参数也是需要改以下的,学习率比sigmoid的时候要小一个量级,效果会比较好。这样训练出来的正确率大概在0.88左右,也是可以接受的。

源码链接

你们不要看到送书就忽略代码了,一边读书一边调代码,理论与实践并重才是王道啊。

结语

至此,神经网络系列告一段落。我真诚希望有人能自己从头写一个,或者对我这个做出优化和扩充。

现在既然可以进行手写字体识别,那么显然识别其他东西或者应用到其他地方也没有问题。之后我会整理个总结出来,那时候没有送书吸引大家注意力也许大家就能好好看看内容了。。。

本文来自企鹅号 - CVPy媒体

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏Python中文社区

ALI的Tensorflow炼成与GAN科普

專 欄 ❈那只猫,Python中文社区专栏作者,Python中文社区新Logo设计人,纯种非CS科班数据分析人,沉迷Keras。在Cambridge做了点小事...

245100
来自专栏机器之心

谷歌终于开源BERT代码:3 亿参数量,机器之心全面解读

今日,谷歌终于放出官方代码和预训练模型,包括 BERT 模型的 TensorFlow 实现、BERT-Base 和 BERT-Large 预训练模型和论文中重要...

23520
来自专栏机器之心

教程 | 用TensorFlow Estimator实现文本分类

86540
来自专栏IT派

教程 | 用TensorFlow Estimator实现文本分类

本文选自介绍 TensorFlow 的 Datasets 和 Estimators 模块系列博文的第四部分。读者无需阅读所有之前的内容,如果想重温某些概念,可以...

18430
来自专栏机器学习算法工程师

朴素贝叶斯实战篇之新浪新闻分类

Python版本: Python3.x 作者:崔家华 运行平台: Windows 编辑:黄俊嘉...

75160
来自专栏FreeBuf

基于机器学习的web异常检测

Web防火墙是信息安全的第一道防线。随着网络技术的快速更新,新的黑客技术也层出不穷,为传统规则防火墙带来了挑战。传统web入侵检测技术通过维护规则集对入侵访问进...

82650
来自专栏人工智能头条

LSTM实现详解

24730
来自专栏IT派

教程 | 用TensorFlow Estimator实现文本分类

本文选自介绍 TensorFlow 的 Datasets 和 Estimators 模块系列博文的第四部分。读者无需阅读所有之前的内容,如果想重温某些概念,可以...

29930

TensorFlow中生成手写笔迹的Demo

这项操作现在在github上已经可以使用了。

56770
来自专栏WeaponZhi

编写你人生中第一个机器学习代码吧!

用 Python 实现第一段机器学习代码,跟我一起来吧! 我们先要学习的机器学习算法是监督学习,那么,何为监督学习呢?要了解监督学习,我们得先回顾下我们平时的编...

38790

扫码关注云+社区

领取腾讯云代金券