训练数据来源于sklearn的datasets
sklearn是通用的机器学习数据库,能实现分类和回归的目的。这篇文章主要聚焦实现手写数字识别。datasets的数据集有若干个手写数字图片,大小完全一致,每张图片取8*8=64个像素点信息,相当的粗糙。
先上结果
上面四幅图是我们要训练的数据和他对应的标签,都是对的。
下面四幅图是训练之后,给定数字对应的预测的结果,有一定模糊性
2. metrics.classification_report(expected, predicted))一个打分报告。精度都还可以大于90%,最后一列代表共899次的预测里,88次预测了0,91次预测了1,86次预测了2,我用投票代表我的态度。
3.混淆矩阵metrics.confusion_matrix(expected, predicted)
第一行代表对数字‘0’的预测中(87次right,1次wrong(正确为‘4’))
第二行代表对数字'1'的预测中(88次right,3次wrong(正确分别为2,8,9))~~~
下文附完整代码:
import matplotlib.pyplot as plt #画图用
from sklearn import datasets,svm,metrics #数据集,分类器,评价器
digits=datasets.load_digits()
images_and_labels=list(zip(digits.images,digits.target)) #组合自变量和因变量
#enumerate方法,把一个itarable和index结合到一起
#把一个大图分成2*4个小图,在第index+1的图片上画图,标签叫----
for index,(image,label) in enumerate(images_and_labels[:4]):
plt.subplot(2,4,index+1)
plt.axis('off')
plt.title('Training:%i'%label)
n_samples=len(digits.images) #看看有多少个数据
classfier=svm.SVC(gamma=0.001) #分类机用支持向量机里的SVC
classfier.fit(data[:n_samples//2],digits.target[:n_samples//2]) #train
predicted=classifier.predict(data[n_samples//2:])#test,predict结果
expected=digits.target[n_samples//2:] #标签数据,预期结果
#最后一步啦,要对最后test的结果评价一下,模型到低做的怎么样
#第一种打分metrics.classification_report(expected,predicted)
print("Classification report for classifier %s:\n%s\n"
%(classifier,metrics.classification_report(expected,predicted)))
#第二种打分,混淆矩阵
print("Confusion matrix:\n%s" % metrics.confusion_matrix(expected, predicted))
#画出预测的图
images_and_predictions=list(zip(digits,images[n_samples//2:],predicted))
for index,(image,prediction) in enumerate(images_and_predictions[:4]):
plt.subplot(2,4,index+5)
plt.axis('off')
plt.title('Prediction:%i'%prediction)
plt.show()
领取专属 10元无门槛券
私享最新 技术干货