今天我们使用逻辑回归(Logistic Regression)做图片机器学习,使用到了digits和MNIST两种数据集,这两种数据集分别代表小数据集和大数据集。
一、Digits数据集
1.1导入数据
digits数据集是scikit-learn中不需要从网络上下载,直接自带的数据集。
现在我们看看digits数据集统计性信息
运行
1.2 打印照片和其标签
因为数据的维度是1797条,一共有64个维度。那么每一条数据是一个列表。但是我们知道图片是二维结构,而且我们知道digits数据集的图片是方形,所以我们要将图片原始数据重构(reshape)为(8,8)的数组。
为了让大家对于数据集有一个更直观的印象,我们在这里打印digits数据集的前5张照片。
运行
在notebook中显示matplotlib的图片
png
1.3 将数据分为训练集合测试集
为了减弱模型对数据的过拟合的可能性,增强模型的泛化能力。保证我们训练的模型可以对新数据进行预测,我们需要将digits数据集分为训练集和测试集。
1.4 训练、预测、准确率
在本文中,我们使用LogisticRegression。由于digits数据集合较小,我们就是用默认的solver即可
对新数据进行预测,注意如果只是对一个数据(一维数组)进行预测,一定要把该一维数组转化为矩阵形式。
data.reshape(n_rows, n_columns)
将data转化为维度为(n_rows, n_columns)的矩阵。注意,如果我们不知道要转化的矩阵的某一个维度的尺寸,可以将该值设为-1.
运行
对多个数据进行预测
运行结果
哇,还是很准的啊
1.5 混淆矩阵
一般评价预测准确率经常会用到混淆矩阵(Confusion Matrix),这里我们使用seaborn和matplotlib绘制混淆矩阵。
png
二、MNIST数据集
digits数据集特别的小,刚刚的训练和预测都只需几秒就可以搞定。但是如果数据集很大时,我们对于训练的速度的要求就变得紧迫起来,模型的参数调优就显得很有必要。所以,我们拿MNIST这个大数据集试试手。我从网上将mnist下载下来,整理为csv文件。其中第一列为标签,之后的列为图片像素点的值。共785列。MNIST数据集的图片是28*28组成的。
运行结果
2.1 打印MNIST图片和标签
png
2.2 训练、预测、准确率
之前digits数据集才1797个,而且每个图片的尺寸是(8,8)。但是MNIST数据集高达70000,每张图片的尺寸是(28,28)。所以如果不考虑参数合理选择,训练的速度会很慢。
运行结果
经过测试发现,在我的macbook air2015默认
solver='liblinear'训练时间3840秒。
solver='lbfgs'训练时间65秒。
solver从liblinear变为lbfgs,只牺牲了0.0003的准确率,速度却能提高了将近60倍。 在机器学习训练中,算法参数不同,训练速度差异很大,看看下面这个图。
2.3 打印预测错误的图片
digits数据集使用的混淆矩阵查看准确率,但不够直观。这里我们打印预测错误的图片
将错误图片打印出来
png
数据及代码获取
往期文章
领取专属 10元无门槛券
私享最新 技术干货