前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >机器学习入门 10-8 多分类问题中的混淆矩阵

机器学习入门 10-8 多分类问题中的混淆矩阵

作者头像
触摸壹缕阳光
发布2020-05-25 23:20:13
4.8K0
发布2020-05-25 23:20:13
举报

人生值得欣慰之处,便是每一天都有结束的时候。今天也不例外。

——卡罗琳·帕克丝特《星星上的人》

前言

本系列是《玩转机器学习教程》一个整理的视频笔记。本小节主要介绍如何求解多分类问题中的指标,着重介绍多分类问题中的混淆矩阵,将混淆矩阵进行处理转换为error_matrix矩阵,并通过可视化的方式直观的观察分类算法错误分类的地方。

a

求解多分类问题中的指标

前几个小节针对极度有偏的二分类问题介绍了很多新的分类指标。本小节来简单看一下如何利用前几个小节学习的指标来处理多分类问题,在前几个小节的二分类问题中介绍了一个非常重要的小工具混淆矩阵,由混淆矩阵推导出了很多重要的指标,因此最后着重介绍多分类问题中的混淆矩阵。

这一小节依然使用手写数字识别的数据集,不过由于本小节主要介绍多分类问题中的混淆矩阵,所以不再构造针对某一个类别的二分类数据集,使用手写数字识别中0-9的10个类别。使用train_test_split将数据集划分为训练集和测试集。

接下来使用逻辑回归算法对多分类手写数字识别数据集进行训练。在第八章介绍逻辑回归算法的时候提到,如果为sklearn提供的逻辑回归算法传入多分类数据集,逻辑回归算法默认使用OVR的方式来训练多分类数据集。在fit训练逻辑回归算法之后,调用score函数计算训练好的模型在测试样本上的accuracy分类准确度为93%。

使用predict函数对所有测试样本进行预测,预测每个测试样本的所属类别。

二分类问题中计算precision精准率可以直接使用sklearn中的precision_score函数。下面来看看对于多分类问题,能不能也直接使用precision_score函数求出多分类问题的精准率呢?

抛出ValueError的错误,重点看最后打印的错误信息"Target is multiclass but average='binary'. please choose another average setting.",默认precision_score函数只能求出二分类问题的精准率(average参数默认为binary),不过如果要想能够求出多分类问题的精准率可以选择合理average参数的值。

查阅sklearn官方文档,precision_score函数除了我们一直传入的真实值y_true和预测值y_predict还有很多其它的参数。

▲sklearn文档

主要看average这个参数,默认情况下average = 'binary'。如果传入其它的参数值相应的就可以解决多分类问题。

▲average参数值

不同的average参数值是如何计算多分类问题指标的可以查阅上面的文档,这里主要来介绍一下'micro'参数值,micro是将多分类中的所有类别放到一起计算。比如计算多分类问题的precision精准率,micro方式是将所有类别的TP值相加,再除以所有类别的TP和FN的和,因此micro方法下的precision、recall和F1 Score的值都相等,并且都等于accuracy。

使用micro方式计算多分类问题的精准率,只需要调用precision_score函数的同时传入average参数,并将average参数值设置为micor。在recall_score和f1_score函数中有同样的参数。

由于使用micro的计算方式,所以此时多分类的precision精准率、recall召回率以及F1 Score的值都是相等的并且都等于精确率,这和前面介绍micro的计算方式所描述的结果是一致的,不同的参数值对应不同的计算方式。

b

多分类问题中的混淆矩阵

这一小节的重点是介绍多分类问题中的混淆矩阵,不同于sklearn中的precision_score、recall_score和f1_score,sklearn中的混淆矩阵天然支持多分类问题,所以我们可以直接在confusion_matrix函数中传入多分类的真实值y_test和预测值y_predict。

对于十分类的问题得到了一个(10 x 10)的混淆矩阵。对于这个(10 x 10)矩阵的解读方式和二分类问题中的(2 x 2)矩阵的解读方式是一模一样的。

在这个(10 x 10)的矩阵中,行依然表示样本的真实值,列依然表示算法预测样本的预测值,因此对于第 i 行第 j 列的位置表示的是样本的真实值为 i 预测值为 j 的样本数量。由于使用逻辑回归算法训练手写数字识别的十分类精确率达到了93%(十分类的手写数字识别各个类别的样本比较均衡,所以可以使用精确率作为分类指标),换句话说整个分类算法预测错误的情况相对来说比较少,因此样本真实值为 i 预测值为 i 的样本数量相对来说比较多,而第 i 行第 i 列的位置就是混淆矩阵中的对角线。

通过观察混淆矩阵依然会发现算法会犯一些错误,比如矩阵第4行的第0列为2(从第0行第0列开始,下同),表示的是有两个样本真实值为数字4但是算法错误预测为数字0。接下来简单介绍一种能够直观的看到犯错误比较多的位置的方法,也就是将整个混淆矩阵映射成灰度图像。

绘制整个混淆矩阵的具体步骤:

  • 为了方便将得到的混淆矩阵保存在一个名为cfm的变量中;
  • 调用plt中的matshow函数,matshow全称为matrix show,也就是绘制一个矩阵,matshow函数需要传入两个参数:
    • 第一个参数传入需要绘制的矩阵cfm;
    • 第二个参数是cmap,cmap全称为color map,也就是颜色映射,将矩阵中的每一个元素与对应的颜色映射起来。这里将混淆矩阵映射成灰度图像,因此传入plt.cm.gray;
  • 调用plt.show()绘制混淆矩阵映射的灰度图像;

通过matplotlib将混淆矩阵映射成了灰度图像,在灰度图像上越亮的地方代表数值越大,越暗的地方代表数值越小,很显然在上面的灰度图像中对角线位置的方块是最亮的,这是因为逻辑回归算法在十分类的手写数字识别任务上准确率达到了93%,因此绝大多数的样本算法都能正确预测,这和之前的分析相吻合。

不过关注算法预测正确的部分是没有意义的,我们真正需要找到的是那些算法预测错误的部分,所以我们需要适当的处理一下混淆矩阵。

具体的处理方式:

  • 使用np.sum计算出混淆矩阵中每一个行的样本总数(将axis参数设置为1表示按照列的方向求和 ),将结果向量命名为row_sums;
  • 设计一个名为error_matrix的新矩阵,error_matrix矩阵的值等于混淆矩阵中的每一个元素值除以混淆矩阵每一个行的和,即cfm / row_sums;
  • 我们并不关注那些完全预测正确的结果,所以使用fill_diagonal函数将error_matrix中对角线位置的值设置为0。

现在的error_matrix矩阵中预测正确的位置被我们忽略(对角线的值全为0),并且error_matrix矩阵中的每一个元素表示错误样本占真实类别样本的比重。比如对于error_matrix矩阵的第1行第9列元素值(从第0行第0列开始),表示真实为数字1但是算法错误预测为数字9的所有样本数占所有真实为数字1的样本总数的比重。

经过处理后的error_matrix矩阵中有很多元素都是小数,并且对角线的位置全为0。通过矩阵中的元素值很难直观的找到算法犯错误最多的地方,因此依然是将error_matrix矩阵映射成灰度图像。

error_matrix矩阵中的每一个元素表示错误样本占真实类别样本的比重,因此越亮的地方表示算法预测错误越多的地方。比如在右上角非常亮的白色方格的位置表示的是真实的样本为数字1但是算法预测的结果却为数字9,此时这个位置比较亮表示算法将很多真实值为数字1的样本错误预测成了数字9。同理,在左下角同样有一个比较亮的白色方格的位置表示的是真实的样本为数字8但是算法预测结果却为数字1,此时这个位置比较亮表示算法将很多真实值为数字8的样本错误预测成了数字1。

通过error_matrix矩阵映射的灰度图像,可以非常直观的看到算法把那些类别的样本分类错误,并且可以根据灰度图像的亮度来确定算法在那些类别的分类上犯错误最多,最重要的是能够直观的看到算法具体犯了什么错误。比如根据error_matrix矩阵绘制的灰度图可以看到算法将很多真实为数字1的样本预测成了数字9,将很多真实为数字8的样本预测成了数字1。有了这些提示,我们可以进一步的改进我们的算法,可以把这些容易混淆的类别规约成二分类问题,换句话说现在这个算法比较容易混淆数字1和数字9也比较容易混淆数字8和数字1,相应的可以微调在数字1和数字9以及数字8和数字1这两个二分类问题中的threshold阈值,进而提高整个分类算法的准确率。

这里需要强调的是,由于这个课程主要介绍的是机器学习算法,所以在遇到问题的时候,通常是寻找在算法层面上可以采用什么策略来进行改进,但是对于机器学习领域来说,很多时候问题的关键可能不在算法层面上,有可能出现在样本数据的层面上,所以在真实的环境中,当面对error_matrix矩阵的时候,老师建议将这些类别为数字1、类别为数字9以及类别为数字8的样本取出来看一看。通过感性的理解一下,为什么算法会出现那么多将数字1错误分类成数字9以及将数字8错误分类成数字1的情况。很有可能样本数据有问题,即使样本数据没有问题,通过有针对性的观察样本,很有可能总结出更多新的特征。

不过观察样本这些对数据的处理都属于特征工程的范畴,特征工程是一个专门的学科,所以这里不过多介绍。我们应该时刻注意在具体解决机器学习任务的时候,当不能通过算法来很好解决问题的时候,要回到样本数据本身,看看样本数据有没有问题,能不能更好的整理数据、清理数据或从数据中挖掘更有用的特征等,这些都是机器学习领域中非常重要的事情。

这一章介绍了很多能够更好评价分类算法的指标,并对每一个指标都进行了详细的解释,通过这一章的学习会对如何评价分类算法有了更加全面的认识。下一章将会介绍一个大名鼎鼎的分类算法:支持向量机SVM。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-05-19,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档