机器学习三人行-手写数字识别实战

前面三个系列我们分别从机器学习入门,洞悉数据,已经数据预处理,回归建模等方面进行了系统的学习。

今天我们根据mnist手写数字的数据集来对0-9共10个数字进行分类,旨在通过这个分类项目的学习,一起学习分类的相关知识和技巧。

由于篇幅有限,代码未全部贴出,文末附关键字,回复该关键字即可下载本系列实战代码。

1. MNIST数据集

首先我们通过scipy中的loadmat方式加载数据集,如下:

可以看出,该数据集中主要有两部分内容,data和label,通过shape查看data可知,该数据集中有70000条数据,每条数据有784个特征。其实每一条数据都是一副28x28的图片,784个特征就是784个像素灰度值(0-255),我们可以通过matplotlib中的imshow()函数来查看其中的样本:如下图,给大家一个感官的感受:

从上图可以看出,手写的数字还是五花八门的,各种姿势都有。很明显,我们的任务是将每个数字分到正确的类中,那么这里有10个类,所以这是一个多分类的问题,万尺高楼平地起,多分类问题是由多个二分类组成的,比方说,我们可以做10个二分类分类器:0和非0,1和非1等等。这样比如判断5这个数字的时候10个分类器的结果理想情况下就会是,这样就能够把5给分出来。那么问题又来了,二分类分类器怎么构建呢?在构建二分类之前,如系列二所述我们需要对数据集进行分组,这里选取60000个样本作为训练集,10000个样本作为测试集。

2. 二分类分类器

2.1. 训练分类器

比如我们以5和非5为例,来说明二分类分类器的构建及评估。首先第一步我们需要对5和非5进行一个label的重置,5的label为true,非5的label为false,如下:

处理好数据集后,我们就可以建立二分类分类器了,这里以SGD(Stochastic Gradient Descent)随机梯度下降为例来说明如何实现二分类问题,从sklearn中直接调用SGD并对其中的某一样本进行分类,代码如下:

这里我们可以看出我们的分类器对某一样本的分类结果是True,那就说明该分类器认为该样本是5,事实如此。一般选择一种分类器之后,我们希望看一下分类器的性能如何,然后再决定是否将该方法用在测试集中进行测试。

2.2. 分类器评估

a. 交叉验证

最常用的方法就是交叉验证,关于交叉验证的解释详见系列三。如下调用,从下面5折交叉验证结果上来看,平均准确率在96%左右,貌似还不错。

聪明的你可能已经发现问题了,有些人可能会说在10个数字中,随机给出一个数字判断是不是为5,只要我一直猜非5也会有90%的概率猜准确的,事实如此。也就是说现在还不能评判这个SGD分类器到底怎么样,分类器总得有个好坏吧,那么该怎样评判呢?

b. 混淆矩阵

一种常见的评判方法是查看该分类器对分类结果的混淆矩阵,二分类问题的混淆矩阵是一个2x2的矩阵表示,对于该5和非5的问题,如下图:

其中TN表示分正确负样本的数量

FN表示将正样本错误的分为负样本的数量

TP表示分正确的正样本的数量

FP表示将负样本错误的分为正样本的数量

来,我们一起看一下SGD分类器的混淆矩阵如何:

从上面混淆矩阵结果可以看出,有1492个“5”分成了“非5”,有672个“非5”分成了“5”。混淆矩阵是从量上面进行的一个直观的观察,一般情况下,我们喜欢从一个比率的角度来评判一个东西。

c. precision,recall和F1-score

对于混淆矩阵,我们一般从s三个角度进行一个评判:分别是precision(准确率),recall(召回率),F1-score(F1得分)

precision = TP/(TP+FP)

recall = TP/(TP+FN)

precision的意思是在预测为正类的样本中,有多少真正的正样本的比例。而recall表示在真正的正样本中又有多少被预测正确的比例。我们先来看一下该分类器的pecision和recall分别是多少,如下:

也就是说在5421个数字“5”中,有将近72.5%的“5”被正确分类,在被分为“5”的4601个样本中,有85.4%的样本是分正确的。我们都希望两个比例都比较高。其实我们可以根据每个样本的得分来把precision和recall的曲线画出来,如下图:

可以看出两个标准是互斥的,其实我们从他们的公式中也可以推断出来。这里留一个思考题(为什么precision曲线到后面会有波动?欢迎文末留言交流!)两个指标互斥的,我们可以自由选择,比方说我们选择precision高的时候就必须得接受recall的低。我们可以手动选择上面precision和recall的阈值,来获得不同的precision和recall,如下所示:

当然,我们也可以从一个指标上来进行综合评判,那就是F1-score,F1-score其实是precision和recall的调和平均数,如下:

如上可知,对于该分类器的F1得分为0.7841,这样我们就可以从F1得分上来进行一个分类器的综合比较。此外,还有一个可视化的评判标准。

d. roc曲线

对于评判分类器的评判,我们更喜欢可视化的评判,那么一个传说中的大神级的评判标准ROC曲线,ROC曲线是根据TPR(true positive rate)和FPR(false positive rate)得到的,如下图:

上图ROC曲线中的对角线上的虚线为纯粹随机的分类结果,蓝色曲线为上面SGD方法的分类结果,评判一个ROC曲线的好坏一般以ROC曲线下方的面积AUC(area under the curve)为标准,该值越大,说明性能越好。如上图的的AUC值如下,其中y_scores是每一个样本的得分(详见文末代码):

至此,我们已经介绍了四种常见的评判方法,那么我们来做一个比较,比较SGD和另一种分类器随机森林两者的分类性能,哪个更胜一筹。下面我们先构建随机森林的分类器,如下:

将随机森林的ROC曲线和SGD的ROC曲线进行比较,如下图:

从上图可以看出随机森林的AUC值比SGD的AUC要高,也就是说从ROC曲线上来看,随机森林的性能在这个二分类问题上要比SGD的性能好。

从这个AUC的值也可以看出,随机森林的性能要比SGD的性能更好一些,如果让在两个分类器中进行选择的话,我们会选择随机森林分类器来进行分类。

至此,我们已经从二分类分类器的训练,评估以及性能比较综合了解二分类的相关知识。但是,这只是一个二分类的问题,我们的目标是把每一种数字都分到自己的类中,这是一个多分类的问题。征途的路上,我们一起前行……

3. 多分类分类器

a. OvA多分类策略

如上面所说,对于我们的任务,可以设置10个分类器(0和非0,1和非1等等),然后对于一个新样本来说,将该样本依次通过这10个分类器,看哪个分类器的得分最高,那么这个样本就属于哪个分类器分出来的类。这种多分类分类器的策略称为OvA(one versus All),简单的说就是一个分类器说了算。用例子说话,如下代码:

注意,此时的训练样本的标签不再是5和非5的标签,而是0-9的10类标签,对于某一数字的10个分类器的得分如上图,从上图中可以看出第五个分数是最高的,那么多分类分类器对于这个数字的预测结果就是5。

b. OvO多分类策略

上面说的分类器是基于OvA策略的分类器,还有一种分类器是这样构造的:对于我们的数字分类来说,需要构建一系列二分类前,其中二分类分别区分0和1,0和2,0和3等等总共有N*(N-1)/2个分类器。比方说要对5这个数字进行分类,依次通过这45个分类器后,会发现45个分类器中有最多9个分类器都认为这个数字是5,类似于基于投票机制的分类器。那么对这个样本的最终分类就是5。如下代码:

有一些分类器只能用OvO策略,比方说SVM,有一些分类器没有所谓的OvO或者OvA策略,比方说随机森林,因为随机森林本身能够直接分为多类,可以直接根据类别得分进行确定是分为哪一类,如下图:

从上图可以看出随机森林分类器的结果中5数字的得分是最高的,那么分类器就将该样本分为5。

c. 多分类分类器的评估

多分类分类器的评估,和二分类一样,也可以用交叉验证进行评估,如下:

从上面的结果来看有84%左右的准确率,还是不错的,比随机分类的10%的准确率要高出很多,但是如果我们利用系列三的方法进行一个数据预处理(尺度归一化)的话,可以再提高分类的准确率,如下:

d. 误差分析

有上面的知识我们知道,对分类器的误差分析一般从混淆矩阵下手,那么我们输出一下SGD分类器对mnist数据集的混淆矩阵,如下:

从混淆矩阵的对角线上可以看出,大部分还是分类正确的,关于混淆矩阵的其他元素的说明这里不再赘述,详情参照上面的解释。如果将上面的混淆矩阵转换成图像的形式的话,可以很明显的看出上面的现象。

上面的图显示虽然直观,但是对于误差项的分析仍然不够明朗,这里我们把误差项从误判的数量转换成比例,如下图:

就可以很直观的看出哪些数字容易发生误判,比方说,3和5交叉点比较亮,那么就说明3和5容易误判,我们就可以通过一些措施,例如增加样本量等方法来提高该数字的分类性能等等。

4. 小结

至此,我们已经从查看mnist数据集,二分类分类器的构建,混淆矩阵,ROC曲线等评估方法进行分类器性能的评估,OvA, OvO多分类分类器的构建策略以及实战,多分类分类器的性能评判误差分析等方面进行系统学习。由于小编水平有限,难免会有一些理解有偏差的地方,欢迎大家多多指正,欢迎留言,多多交流。

(如需更好的了解相关知识,欢迎加入智能算法社区,在“智能算法”公众号发送“社区”,即可加入算法微信群和QQ群)

本文来自企鹅号 - 智能算法媒体

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏磐创AI技术团队的专栏

新手入门机器学习十大算法

【磐创AI导读】:对于想要了解机器学习的新手,本文为大家总结了数据科学家最经常使用的十大机器学习算法来帮助大家快速入门。如果喜欢我们的文章,欢迎点赞、评论、转发...

8810
来自专栏机器学习算法工程师

全面直观认识深度神经网络

作者:石文华 编辑:赵一帆 01 深度学习的精准定义 一类通过多层非线性变换对高复杂性数据建模算法的集合。它的两个非常重要的特征是多层性和非线性。俗称多层非线...

35980
来自专栏CreateAMind

RCNN学习笔记(7):Faster R-CNN 英文论文翻译笔记

reference link:http://blog.csdn.net/liumaolincycle/article/details/48804687

34030
来自专栏云时之间

对交叉验证的一些补充(转)

交叉验证是一种用来评价一个统计分析的结果是否可以推广到一个独立的数据集上的技术。主要用于预测,即,想要估计一个预测模型的实际应用中的准确度。它是一种统计学上将数...

42890
来自专栏AI科技评论

裴健团队KDD新作:革命性的新方法,准确、一致地解释深度神经网络

AI 科技评论按:你有没有想过,深度神经网络是依据什么来准确识别有猫的图片的?随着深度神经网络在金融、医疗及自动驾驶等领域的广泛应用,深度神经网络无法明确解释...

14630
来自专栏企鹅号快讯

机器学习之白话adaboost元算法

阅读本文大约需要5分钟 作为(曾)被认为两大最好的监督分类算法之一的adaboost元算法(另一个为前几节介绍过的SVM算法),该算法以其简单的思想解决复杂的分...

20560
来自专栏深度学习计算机视觉

计算机视觉中的物体检测方法

本文适合刚入门物体检测的人群学习,不涉及公式推理。 目录 *摘要 *相关物体检测数据集介绍 *现有的主流物体检测算法 *物体检测的难点与挑战 *相关术语介绍 *...

30540
来自专栏AI研习社

英伟达教你用深度学习做图像修复,确定不试一下?

深度学习是人工智能中发展最快的领域之一,被广泛的应用在很多领域,尤其是实时目标检测、图像识别和视频分类。深度学习通常采用卷积神经网络、深度信念网络、循环神经网络...

15820
来自专栏技术小站

吴恩达深度学习笔记 course4 week 4 特殊应用:人脸识别与神经风格转换

output:如果图片是对应的K人中的一人,则输出此人ID,否则验证不通过    ,人脸识别比人脸验证更难一些,如果一个人脸验证系统的正确率为99%,即错误率为...

19620
来自专栏hadoop学习笔记

非局部神经网络,打造未来神经网络基本组件

将非局部计算作为获取长时记忆的通用模块,提高神经网络性能在深度神经网络中,获取长时记忆(long-range dependency)至关重要。对于序列数据(例如...

14500

扫码关注云+社区

领取腾讯云代金券