机器学习之识别图片中的数字

今天我们使用逻辑回归(Logistic Regression)做图片机器学习,使用到了digits和MNIST两种数据集,这两种数据集分别代表小数据集和大数据集。

一、Digits数据集

1.1导入数据

digits数据集是scikit-learn中不需要从网络上下载,直接自带的数据集。

现在我们看看digits数据集统计性信息

运行

1.2 打印照片和其标签

因为数据的维度是1797条,一共有64个维度。那么每一条数据是一个列表。但是我们知道图片是二维结构,而且我们知道digits数据集的图片是方形,所以我们要将图片原始数据重构(reshape)为(8,8)的数组。

为了让大家对于数据集有一个更直观的印象,我们在这里打印digits数据集的前5张照片。

运行

在notebook中显示matplotlib的图片

png

1.3 将数据分为训练集合测试集

为了减弱模型对数据的过拟合的可能性,增强模型的泛化能力。保证我们训练的模型可以对新数据进行预测,我们需要将digits数据集分为训练集和测试集。

1.4 训练、预测、准确率

在本文中,我们使用LogisticRegression。由于digits数据集合较小,我们就是用默认的solver即可

对新数据进行预测,注意如果只是对一个数据(一维数组)进行预测,一定要把该一维数组转化为矩阵形式。

data.reshape(n_rows, n_columns)

将data转化为维度为(n_rows, n_columns)的矩阵。注意,如果我们不知道要转化的矩阵的某一个维度的尺寸,可以将该值设为-1.

运行

对多个数据进行预测

运行结果

哇,还是很准的啊

1.5 混淆矩阵

一般评价预测准确率经常会用到混淆矩阵(Confusion Matrix),这里我们使用seaborn和matplotlib绘制混淆矩阵。

png

二、MNIST数据集

digits数据集特别的小,刚刚的训练和预测都只需几秒就可以搞定。但是如果数据集很大时,我们对于训练的速度的要求就变得紧迫起来,模型的参数调优就显得很有必要。所以,我们拿MNIST这个大数据集试试手。我从网上将mnist下载下来,整理为csv文件。其中第一列为标签,之后的列为图片像素点的值。共785列。MNIST数据集的图片是28*28组成的。

运行结果

2.1 打印MNIST图片和标签

png

2.2 训练、预测、准确率

之前digits数据集才1797个,而且每个图片的尺寸是(8,8)。但是MNIST数据集高达70000,每张图片的尺寸是(28,28)。所以如果不考虑参数合理选择,训练的速度会很慢。

运行结果

经过测试发现,在我的macbook air2015默认

solver='liblinear'训练时间3840秒。

solver='lbfgs'训练时间65秒。

solver从liblinear变为lbfgs,只牺牲了0.0003的准确率,速度却能提高了将近60倍。 在机器学习训练中,算法参数不同,训练速度差异很大,看看下面这个图。

2.3 打印预测错误的图片

digits数据集使用的混淆矩阵查看准确率,但不够直观。这里我们打印预测错误的图片

将错误图片打印出来

png

数据及代码获取

往期文章

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20180601G03SCL00?refer=cp_1026
  • 腾讯「云+社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 yunjia_community@tencent.com 删除。

扫码关注云+社区

领取腾讯云代金券