前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >C+实现神经网络之六—实战手写数字识别

C+实现神经网络之六—实战手写数字识别

作者头像
企鹅号小编
发布2018-01-05 17:28:16
7950
发布2018-01-05 17:28:16
举报
文章被收录于专栏:人工智能

之前的五篇博客讲述的内容应该覆盖了如何编写神经网络的大部分内容,在经过之前的一系列努力之后,终于可以开始实战了。试试写出来的神经网络怎么样吧。

数据准备

有人说MNIST手写数字识别是机器学习领域的Hello World,所以我这一次也是从手写字体识别开始。我是从Kaggle找的手写数字识别的数据集。数据已经被保存为csv格式,相对比较方便读取。

数据集包含了数字0-9是个数字的灰度图。但是这个灰度图是展开过的。展开之前都是28x28的图像,展开后成为1x784的一行。csv文件中,每一行有785个元素,第一个元素是数字标签,后面的784个元素分别排列着展开后的184个像素。看起来像下面这样:

也许你已经看到了第一列0-9的标签,但是会疑惑为啥像素值全是0,那是因为这里能显示出来的,甚至不足28x28图像的一行。而数字一般应该在图像中心位置,所以边缘位置当然是啥也没有,往后滑动就能看到非零像素值了。像下面这样:

这里需要注意到的是,像素值的范围是0-255。一般在数据预处理阶段都会归一化,全部除以255,把值转换到0-1之间。

csv文件中包含42000个样本,这么多样本,对于我七年前买的4000元级别的破笔记本来说,单单是读取一次都得半天,更不要提拿这么多样本去迭代训练了,简直是噩梦(兼论一个苦逼的学生几年能挣到换电脑的钱!)。所以我只是提取了前1000个样本,然后把归一化后的样本和标签都保存到一个xml文件中。在前面的一篇博客中已经提到了输入输出的组织形式,偷懒直接复制了:

既然说到了输出的组织方式,那就顺便也提一句输入的组织方式。生成神经网络的时候,每一层都是用一个单列矩阵来表示的。显然第一层输入层就是一个单列矩阵。所以在对数据进行预处理的过程中,我就是把输入样本和标签一列一列地排列起来,作为矩阵存储。标签矩阵的第一列即是第一列样本的标签。以此类推。

把输出层设置为一个单列十行的矩阵,标签是几就把第几行的元素设置为1,其余都设为0。由于编程中一般都是从0开始作为第一位的,所以位置与0-9的数字正好一一对应。我们到时候只需要找到输出最大值所在的位置,也就知道了输出是几。”

这里只是重复一下,这一部分的代码在中:

这是我最近用ReLU的时候的代码,标签是几就把第几位设为几,其他为全设为0。最后都是找到最大值的位置即可。

在代码中的作用是,检验下转换后的矩阵和标签是否对应正确这里是把col(3),也就是第四个样本从一行重新变成28x28的图像,看上面的第一张图的第一列可以看到,第四个样本的标签是4。那么它转换回来的图像时什么样呢?是下面这样:

这里也证明了为啥第一张图看起来像素全是0。边缘全黑能不是0吗?

然后在使用的时候用前面提到过的get_input_label()获取一定数目的样本和标签。

实战数字识别

没想到前面数据处理说了那么多。。。。

废话少说,直接说训练的过程:

给定每层的神经元数目,初始化神经网络和权值矩阵

从inputlabel1000.xml文件中取前800个样本作为训练样本,后200作为测试样本。

这是神经网络的一些参数:训练时候的终止条件,学习率,激活函数类型

前800样本训练神经网络,直到满足loss小于阈值loss_threshold,停止。

后200样本测试神经网络,输出正确率。

保存训练得到的模型。

以sigmoid为激活函数的训练代码如下:

对比前面说的六个过程,代码应该是很清晰的了。参数output_interval是间隔几次迭代输出一次,这设置为迭代两次输出一次。

如果按照上面的参数来训练,正确率是0.855:

在只有800个样本的情况下,这个正确率我认为还是可以接受的。

如果要直接使用训练好的样本,那就更加简单了:

如果激活函数是tanh函数,由于tanh函数的值域是[-1,1],所以在训练的时候要把标签矩阵稍作改动,需要改动的地方如下:

这里不光改了标签,还有几个参数也是需要改以下的,学习率比sigmoid的时候要小一个量级,效果会比较好。这样训练出来的正确率大概在0.88左右,也是可以接受的。

源码链接

你们不要看到送书就忽略代码了,一边读书一边调代码,理论与实践并重才是王道啊。

结语

至此,神经网络系列告一段落。我真诚希望有人能自己从头写一个,或者对我这个做出优化和扩充。

现在既然可以进行手写字体识别,那么显然识别其他东西或者应用到其他地方也没有问题。之后我会整理个总结出来,那时候没有送书吸引大家注意力也许大家就能好好看看内容了。。。

本文来自企鹅号 - CVPy媒体

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文来自企鹅号 - CVPy媒体

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
图像处理
图像处理基于腾讯云深度学习等人工智能技术,提供综合性的图像优化处理服务,包括图像质量评估、图像清晰度增强、图像智能裁剪等。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档