混淆矩阵是用于评估图像分类器性能的一种常用工具,它可以展示分类器在不同类别上的预测结果。对于CIFAR-10数据集,它包含10个不同的类别,分别是飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。
要打印图像分类器的混淆矩阵,可以按照以下步骤进行:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
# 加载测试集数据和预测结果
test_labels = np.load('test_labels.npy')
predicted_labels = np.load('predicted_labels.npy')
# 计算混淆矩阵
cm = confusion_matrix(test_labels, predicted_labels)
# 可视化混淆矩阵
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix - CIFAR-10')
plt.colorbar()
tick_marks = np.arange(10)
plt.xticks(tick_marks, ['飞机', '汽车', '鸟类', '猫', '鹿', '狗', '青蛙', '马', '船', '卡车'], rotation=45)
plt.yticks(tick_marks, ['飞机', '汽车', '鸟类', '猫', '鹿', '狗', '青蛙', '马', '船', '卡车'])
plt.tight_layout()
plt.ylabel('真实标签')
plt.xlabel('预测标签')
plt.show()
以上代码中,test_labels
是测试集的真实标签,predicted_labels
是分类器的预测标签。通过调用confusion_matrix
函数计算混淆矩阵,然后使用imshow
函数将混淆矩阵可视化出来。plt.xticks
和plt.yticks
用于设置x轴和y轴的刻度标签,plt.title
设置图表标题,plt.colorbar
添加颜色条。
这样,就可以打印出CIFAR-10图像分类器的混淆矩阵,并通过可视化直观地了解分类器在不同类别上的预测结果。
领取专属 10元无门槛券
手把手带您无忧上云