机器学习三人行(系列四)----手写数字识别实战(附代码)

今天我们根据mnist手写数字的数据集来对0-9共10个数字进行分类,旨在通过这个分类项目的学习,一起学习分类的相关知识和技巧。

由于篇幅有限,代码未全部贴出,文末附关键字,回复该关键字即可下载本系列实战代码。

言归正传,步入正题!首先我们需要进行数据集的下载,也可以按照系列二中介绍的方法下载数据集,因为该数据集比较大,代码中下载比较费时,所以我给大家下载好了,直接公众号回复“mnist”,即可网盘下载。在进行分类之前,我们第一步是需要了解数据集,一起看一下数据集中都有些什么吧。

1. MNIST数据集

首先我们通过scipy中的loadmat方式加载数据集,如下:

可以看出,该数据集中主要有两部分内容,data和label,通过shape查看data可知,该数据集中有70000条数据,每条数据有784个特征。其实每一条数据都是一副28x28的图片,784个特征就是784个像素灰度值(0-255),我们可以通过matplotlib中的imshow()函数来查看其中的样本:如下图,给大家一个感官的感受:

从上图可以看出,手写的数字还是五花八门的,各种姿势都有。很明显,我们的任务是将每个数字分到正确的类中,那么这里有10个类,所以这是一个多分类的问题,万尺高楼平地起,多分类问题是由多个二分类组成的,比方说,我们可以做10个二分类分类器:0和非0,1和非1等等。这样比如判断5这个数字的时候10个分类器的结果理想情况下就会是{非0,非1,非2,非3,非4,5,非6,非7,非8,非9},这样就能够把5给分出来。那么问题又来了,二分类分类器怎么构建呢?在构建二分类之前,如系列二所述我们需要对数据集进行分组,这里选取60000个样本作为训练集,10000个样本作为测试集。

2. 二分类分类器

2.1. 训练分类器

比如我们以5和非5为例,来说明二分类分类器的构建及评估。首先第一步我们需要对5和非5进行一个label的重置,5的label为true,非5的label为false,如下:

处理好数据集后,我们就可以建立二分类分类器了,这里以SGD(Stochastic Gradient Descent)随机梯度下降为例来说明如何实现二分类问题,从sklearn中直接调用SGD并对其中的某一样本进行分类,代码如下:

这里我们可以看出我们的分类器对某一样本的分类结果是True,那就说明该分类器认为该样本是5,事实如此。一般选择一种分类器之后,我们希望看一下分类器的性能如何,然后再决定是否将该方法用在测试集中进行测试。

2.2. 分类器评估

a. 交叉验证

最常用的方法就是交叉验证,关于交叉验证的解释详见系列三。如下调用,从下面5折交叉验证结果上来看,平均准确率在96%左右,貌似还不错。

聪明的你可能已经发现问题了,有些人可能会说在10个数字中,随机给出一个数字判断是不是为5,只要我一直猜非5也会有90%的概率猜准确的,事实如此。也就是说现在还不能评判这个SGD分类器到底怎么样,分类器总得有个好坏吧,那么该怎样评判呢?

b. 混淆矩阵

一种常见的评判方法是查看该分类器对分类结果的混淆矩阵,二分类问题的混淆矩阵是一个2x2的矩阵表示,对于该5和非5的问题,如下图:

其中TN表示分正确负样本的数量

FN表示将正样本错误的分为负样本的数量

TP表示分正确的正样本的数量

FP表示将负样本错误的分为正样本的数量

来,我们一起看一下SGD分类器的混淆矩阵如何:

从上面混淆矩阵结果可以看出,有1492个“5”分成了“非5”,有672个“非5”分成了“5”。混淆矩阵是从量上面进行的一个直观的观察,一般情况下,我们喜欢从一个比率的角度来评判一个东西。

c. precision,recall和F1-score

对于混淆矩阵,我们一般从s三个角度进行一个评判:分别是precision(准确率),recall(召回率),F1-score(F1得分)

precision = TP/(TP+FP)

recall = TP/(TP+FN)

precision的意思是在预测为正类的样本中,有多少真正的正样本的比例。而recall表示在真正的正样本中又有多少被预测正确的比例。我们先来看一下该分类器的pecision和recall分别是多少,如下:

也就是说在5421个数字“5”中,有将近72.5%的“5”被正确分类,在被分为“5”的4601个样本中,有85.4%的样本是分正确的。我们都希望两个比例都比较高。其实我们可以根据每个样本的得分来把precision和recall的曲线画出来,如下图:

可以看出两个标准是互斥的,其实我们从他们的公式中也可以推断出来。这里留一个思考题(为什么precision曲线到后面会有波动?欢迎文末留言交流!)两个指标互斥的,我们可以自由选择,比方说我们选择precision高的时候就必须得接受recall的低。我们可以手动选择上面precision和recall的阈值,来获得不同的precision和recall,如下所示:

当然,我们也可以从一个指标上来进行综合评判,那就是F1-score,F1-score其实是precision和recall的调和平均数,如下:

如上可知,对于该分类器的F1得分为0.7841,这样我们就可以从F1得分上来进行一个分类器的综合比较。此外,还有一个可视化的评判标准。

d. roc曲线

对于评判分类器的评判,我们更喜欢可视化的评判,那么一个传说中的大神级的评判标准ROC曲线,ROC曲线是根据TPR(true positive rate)和FPR(false positive rate)得到的,如下图:

上图ROC曲线中的对角线上的虚线为纯粹随机的分类结果,蓝色曲线为上面SGD方法的分类结果,评判一个ROC曲线的好坏一般以ROC曲线下方的面积AUC(area under the curve)为标准,该值越大,说明性能越好。如上图的的AUC值如下,其中y_scores是每一个样本的得分(详见文末代码):

至此,我们已经介绍了四种常见的评判方法,那么我们来做一个比较,比较SGD和另一种分类器随机森林两者的分类性能,哪个更胜一筹。下面我们先构建随机森林的分类器,如下:

将随机森林的ROC曲线和SGD的ROC曲线进行比较,如下图:

从上图可以看出随机森林的AUC值比SGD的AUC要高,也就是说从ROC曲线上来看,随机森林的性能在这个二分类问题上要比SGD的性能好。

从这个AUC的值也可以看出,随机森林的性能要比SGD的性能更好一些,如果让在两个分类器中进行选择的话,我们会选择随机森林分类器来进行分类。

至此,我们已经从二分类分类器的训练,评估以及性能比较综合了解二分类的相关知识。但是,这只是一个二分类的问题,我们的目标是把每一种数字都分到自己的类中,这是一个多分类的问题。征途的路上,我们一起前行……

3. 多分类分类器

a. OvA多分类策略

如上面所说,对于我们的任务,可以设置10个分类器(0和非0,1和非1等等),然后对于一个新样本来说,将该样本依次通过这10个分类器,看哪个分类器的得分最高,那么这个样本就属于哪个分类器分出来的类。这种多分类分类器的策略称为OvA(one versus All),简单的说就是一个分类器说了算。用例子说话,如下代码:

注意,此时的训练样本的标签不再是5和非5的标签,而是0-9的10类标签,对于某一数字的10个分类器的得分如上图,从上图中可以看出第五个分数是最高的,那么多分类分类器对于这个数字的预测结果就是5。

b. OvO多分类策略

上面说的分类器是基于OvA策略的分类器,还有一种分类器是这样构造的:对于我们的数字分类来说,需要构建一系列二分类前,其中二分类分别区分0和1,0和2,0和3等等总共有N*(N-1)/2个分类器。比方说要对5这个数字进行分类,依次通过这45个分类器后,会发现45个分类器中有最多9个分类器都认为这个数字是5,类似于基于投票机制的分类器。那么对这个样本的最终分类就是5。如下代码:

有一些分类器只能用OvO策略,比方说SVM,有一些分类器没有所谓的OvO或者OvA策略,比方说随机森林,因为随机森林本身能够直接分为多类,可以直接根据类别得分进行确定是分为哪一类,如下图:

从上图可以看出随机森林分类器的结果中5数字的得分是最高的,那么分类器就将该样本分为5。

c. 多分类分类器的评估

多分类分类器的评估,和二分类一样,也可以用交叉验证进行评估,如下:

从上面的结果来看有84%左右的准确率,还是不错的,比随机分类的10%的准确率要高出很多,但是如果我们利用系列三的方法进行一个数据预处理(尺度归一化)的话,可以再提高分类的准确率,如下:

d. 误差分析

有上面的知识我们知道,对分类器的误差分析一般从混淆矩阵下手,那么我们输出一下SGD分类器对mnist数据集的混淆矩阵,如下:

从混淆矩阵的对角线上可以看出,大部分还是分类正确的,关于混淆矩阵的其他元素的说明这里不再赘述,详情参照上面的解释。如果将上面的混淆矩阵转换成图像的形式的话,可以很明显的看出上面的现象。

上面的图显示虽然直观,但是对于误差项的分析仍然不够明朗,这里我们把误差项从误判的数量转换成比例,如下图:

就可以很直观的看出哪些数字容易发生误判,比方说,3和5交叉点比较亮,那么就说明3和5容易误判,我们就可以通过一些措施,例如增加样本量等方法来提高该数字的分类性能等等。

4. 小结

至此,我们已经从查看mnist数据集,二分类分类器的构建,混淆矩阵,ROC曲线等评估方法进行分类器性能的评估,OvA, OvO多分类分类器的构建策略以及实战,多分类分类器的性能评判误差分析等方面进行系统学习。由于小编水平有限,难免会有一些理解有偏差的地方,欢迎大家多多指正,欢迎留言,多多交流。

原文发布于微信公众号 - 智能算法(AI_Algorithm)

原文发表时间:2017-12-18

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏AI研习社

如何用 Caffe 生成对抗样本?这篇文章告诉你一个更高效的算法

Fast Gradient Sign方法 先回顾一下《杂谈CNN:如何通过优化求解输入图像》中通过加噪音生成对抗样本的方法,出自Christian Szeg...

3133
来自专栏有趣的Python

15- OpenCV+TensorFlow 入门人工智能图像处理- Adaboost与Hog特征

1572
来自专栏机器学习算法工程师

Batchnorm原理详解

作者:刘威威 小编:赵一帆 前言:Batchnorm是深度网络中经常用到的加速神经网络训练,加速收敛速度及稳定性的算法,可以说是目前深度网络必不可少的一部分。...

8386
来自专栏智能算法

机器学习三人行(系列七)----支持向量机实践指南(附代码)

其实逻辑回归算法和今天要讲的支持向量机有些类似,他们都是从感知机发展而来,支持向量机是一个非常强大而且应用面很广的机器学习算法,能够胜任线性分类器,非线性分类器...

35812
来自专栏专知

【专知-PyTorch手把手深度学习教程02】CNN快速理解与PyTorch实现: 图文+代码

【导读】主题链路知识是我们专知的核心功能之一,为用户提供AI领域系统性的知识学习服务,一站式学习人工智能的知识,包含人工智能( 机器学习、自然语言处理、计算机视...

4.4K11
来自专栏鸿的学习笔记

LSTMs

由于使用权重矩阵的方式,会对典型RNN可以学习的模式类型存在一些显着的限制。因此,对于称为长短期存储器网络(Long Short-Term Memory net...

1031
来自专栏深度学习

图像分类 | 深度学习PK传统机器学习

图像分类,顾名思义,是一个输入图像,输出对该图像内容分类的描述的问题。它是计算机视觉的核心,实际应用广泛。 图像分类的传统方法是特征描述及检测,这类传统方法可能...

4709
来自专栏AI科技评论

开发 | Keras版faster-rcnn算法详解(RPN计算)

AI科技评论按:本文首发于知乎专栏Learning Machine,作者张潇捷, AI科技评论获其授权转载。 前段时间学完Udacity的机器学习和深度学习的课...

67611
来自专栏算法channel

机器学习逻辑回归:算法兑现为python代码

0 回顾 昨天推送了逻辑回归的基本原理:从逻辑回归的目标任务,到二分类模型的构建,再到如何用梯度下降求出二分类模型的权重参数。今天,我们将对这个算法兑现为代码...

3575
来自专栏云时之间

深度学习与TensorFlow:FCN论文翻译(二)

Each layer of data in a convnet is a three-dimensional array of size h × w × d, ...

2202

扫码关注云+社区

领取腾讯云代金券