前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >CNN中的混淆矩阵 | PyTorch系列(二十三)

CNN中的混淆矩阵 | PyTorch系列(二十三)

作者头像
AI算法与图像处理
发布2020-07-01 16:29:43
5K0
发布2020-07-01 16:29:43
举报

文 |AI_study

原标题:CNN Confusion Matrix With PyTorch - Neural Network Programming

在这节课中,我们将建立一些函数,让我们能够得到训练集中每个样本的预测张量。然后,我们会看到如何使用这个预测张量,以及每个样本的标签,来创建一个混淆矩阵。这个混淆矩阵将允许我们查看我们的网络中哪些类别相互混淆。

  • 准备数据
  • 建立模型
  • 训练模型
  • 分析模型的结果
    • 构建、绘制和解释一个混淆矩阵

有关所有代码设置细节,请参阅本课程的前一节。

混淆矩阵要求

要为整个数据集创建一个混淆矩阵,我们需要一个与训练集长度相同的一维预测张量。

代码语言:javascript
复制
> len(train_set)
60000

这个预测张量将包含我们训练集中每个样本的10个预测(每个服装类别一个)。在我们得到这个张量之后,我们可以使用标签张量来生成一个混淆矩阵。

代码语言:javascript
复制
> len(train_set.targets)
60000

一个混淆矩阵将告诉我们模型在哪里被混淆了。更具体地说,混淆矩阵将显示模型正确预测的类别和模型不正确预测的类别。对于不正确的预测,我们将能够看到模型预测的类别,这将告诉我们哪些类别使模型混乱。

获取整个训练集的预测

为了得到所有训练集样本的预测,我们需要通过网络传递所有样本。为此,可以创建一个batch_size=1的DataLoader。这将一次性向网络传递一批数据,并为所有训练集样本提供所需的预测张量。

然而,根据计算资源和训练集的大小,如果我们在不同的数据集上训练,我们需要一种方法来预测更小的批量并收集结果。为了收集结果,我们将使用torch.cat()函数将输出张量连接在一起,以获得单个预测张量。我们来建立一个函数。

建立一个函数来获得所有样本的预测

我们将创建一个名为get_all_preds()的函数,并传递一个模型和一个数据加载器。该模型将用于获取预测,而数据加载器将用于提供来自训练集的批次。

所有函数需要做的就是遍历数据加载器,将批处理传递给模型,并将每个批处理的结果连接到一个预测张量,该张量将返回给调用者。

代码语言:javascript
复制
@torch.no_grad()
def get_all_preds(model, loader):
    all_preds = torch.tensor([])
    for batch in loader:
        images, labels = batch

        preds = model(images)
        all_preds = torch.cat(
            (all_preds, preds)
            ,dim=0
        )
    return all_preds

此函数的植入会创建一个空张量,all_preds来保存输出预测。然后,迭代来自数据加载器的批处理,并将输出预测与all_preds张量连接在一起。最后,所有预测all_preds将返回给调用方。

请注意,在顶部,我们已使用@ torch.no_grad() PyTorch装饰对函数进行了注释。这是因为我们希望该函数执行忽略梯度跟踪。

这是因为梯度跟踪占用内存,并且在推理(在不训练的情况下获得预测)期间,无需跟踪计算图。装饰器是在执行特定功能时局部关闭梯度跟踪功能的一种方法。

本地禁用PyTorch梯度跟踪

我们现在准备调用以获取训练集的预测。我们需要做的就是创建一个具有合理批处理大小的数据加载器,并将模型和数据加载器传递给get_all_preds() 函数。

在上一节中,我们了解了在不需要时如何使用PyTorch的梯度跟踪功能,并在开始训练过程时将其重新打开。

每当我们要使用Backward()函数计算梯度时,我们特别需要梯度计算功能。否则,将其关闭是一个好主意,因为将其关闭会减少计算的内存消耗,例如 当我们使用网络进行预测(推理)时。

代码语言:javascript
复制
with torch.no_grad():
    prediction_loader = torch.utils.data.DataLoader(train_set, batch_size=10000)
    train_preds = get_all_preds(network, prediction_loader)

这两个选项均有效。让我们保留所有这些并获得我们的预测。

使用预测张量

现在,有了预测张量,我们可以将其传递给我们在上一节中创建的get_num_correct()函数以及训练集标签,以获取正确预测的总数。

代码语言:javascript
复制
> preds_correct = get_num_correct(train_preds, train_set.targets)

> print('total correct:', preds_correct)
> print('accuracy:', preds_correct / len(train_set))
total correct: 53578
accuracy: 0.8929666666666667

我们可以看到正确预测的总数,并通过除以训练集中的样本数来打印准确性。

建立混淆矩阵

我们构建混淆矩阵的任务是将预测值的数量与真实值(目标)进行比较。

这将创建一个充当热图的矩阵,告诉我们预测值相对于真实值的下降位置。

为此,我们需要具有目标张量和train_preds张量中的预测标签。

代码语言:javascript
复制
> train_set.targets
tensor([9, 0, 0,  ..., 3, 0, 5])

> train_preds.argmax(dim=1)
tensor([9, 0, 0,  ..., 3, 0, 5])

现在,如果我们逐元素比较两个张量,我们可以看到预测的标签是否与目标匹配。此外,如果我们要计算预测标签与目标标签的数量,则两个张量内的值将作为矩阵的坐标。让我们沿着第二维堆叠这两个张量,以便我们可以有60,000个有序对。

代码语言:javascript
复制
> stacked = torch.stack(
    (
        train_set.targets
        ,train_preds.argmax(dim=1)
    )
    ,dim=1
)


> stacked.shape
torch.Size([60000, 2])


> stacked
tensor([
    [9, 9],
    [0, 0],
    [0, 0],
    ...,
    [3, 3],
    [0, 0],
    [5, 5]
])

> stacked[0].tolist()
[9, 9]

现在,我们可以遍历这些对,并计算矩阵中每个位置的出现次数。让我们创建矩阵。由于我们有十个预测类别,因此将有一个十乘十的矩阵。检查此处以了解stack()函数。

https://deeplizard.com/learn/video/kF2AlpykJGY

代码语言:javascript
复制
> cmt = torch.zeros(10,10, dtype=torch.int64)
> cmt
tensor([
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
])

现在,我们将遍历预测目标对,并在每次发生特定位置时向矩阵内的值添加一个。

代码语言:javascript
复制
for p in stacked:
    tl, pl = p.tolist()
    cmt[tl, pl] = cmt[tl, pl] + 1

这为我们提供了以下混淆矩阵张量。

代码语言:javascript
复制
> cmt
tensor([
    [5637,    3,   96,   75,   20,   10,   86,    0,   73,    0],
    [  40, 5843,    3,   75,   16,    8,    5,    0,   10,    0],
    [  87,    4, 4500,   70, 1069,    8,  156,    0,  106,    0],
    [ 339,   61,   19, 5269,  203,   10,   72,    2,   25,    0],
    [  23,    9,  263,  209, 5217,    2,  238,    0,   39,    0],
    [   0,    0,    0,    1,    0, 5604,    0,  333,   13,   49],
    [1827,    7,  716,  104,  792,    3, 2370,    0,  181,    0],
    [   0,    0,    0,    0,    0,   22,    0, 5867,    4,  107],
    [  32,    1,   13,   15,   19,    5,   17,   11, 5887,    0],
    [   0,    0,    0,    0,    0,   28,    0,  234,    6, 5732]
])

请注意,下面的示例将具有不同的值,因为这两个示例是在不同的时间创建的。

绘制混淆矩阵

为了将实际的混淆矩阵生成为numpy.ndarray,我们使用sklearn.metrics库中的confusion_matrix()函数。让我们将其与其他需要的导入一起导入。

代码语言:javascript
复制
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix
from resources.plotcm import plot_confusion_matrix

对于最后一次导入,请注意plotcm是一个文件plotcm.py,位于当前目录中的资源文件夹中。在plotcm.py文件中,有一个称为plot_confusion_matrix()的函数,我们将调用该函数。您将需要在系统上实现此功能。我们将在稍后讨论如何执行此操作。首先,让我们生成混淆矩阵。

我们可以像这样生成混淆矩阵:

代码语言:javascript
复制
> cm = confusion_matrix(train_set.targets, train_preds.argmax(dim=1))
> print(type(cm))
> cm

<class 'numpy.ndarray'>
Out[74]:
array([[5431,   14,   88,  145,   26,    7,  241,    0,   48,    0],
        [   4, 5896,    6,   75,    8,    0,    8,    0,    3,    0],
        [  92,    6, 5002,   76,  565,    1,  232,    1,   25,    0],
        [ 191,   49,   23, 5504,  162,    1,   61,    0,    7,    2],
        [  15,   12,  267,  213, 5305,    1,  168,    0,   19,    0],
        [   0,    0,    0,    0,    0, 5847,    0,  112,    3,   38],
        [1159,   16,  523,  189,  676,    0, 3396,    0,   41,    0],
        [   0,    0,    0,    0,    0,   99,    0, 5540,    0,  361],
        [  28,    6,   29,   15,   32,   23,   26,   14, 5827,    0],
        [   0,    0,    0,    0,    1,   61,    0,  107,    1, 5830]],
        dtype=int64)

PyTorch张量是类似于数组的Python对象,因此我们可以将它们直接传递给confusion_matrix()函数。我们相对于train_preds张量的第一维传递训练集标签张量(targets)和argmax,这为我们提供了混淆矩阵数据结构。

要实际绘制混淆矩阵,我们需要一些自定义代码,这些代码已放入名为plotcm的本地文件中。该函数称为plot_confusion_matrix()。plotcm.py文件需要包含以下内容,并且位于当前目录的resources文件夹中。

请注意,您也可以只将此代码复制到笔记本中,或避免导入的任何内容。

plotcm.py:

代码语言:javascript
复制
import itertools
import numpy as np
import matplotlib.pyplot as plt

def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

来源-scikit-learn.org

对于导入,我们这样做:

代码语言:javascript
复制
from plotcm import plot_confusion_matrix

我们已经准备好绘制混淆矩阵,但是首先我们需要创建一个预测类名称列表,以传递给plot_confusion_matrix()函数。下表给出了我们的预测类及其相应的索引:

这使我们可以调用以绘制矩阵:

代码语言:javascript
复制
> names = (
    'T-shirt/top'
    ,'Trouser'
    ,'Pullover'
    ,'Dress'
    ,'Coat'
    ,'Sandal'
    ,'Shirt'
    ,'Sneaker'
    ,'Bag'
    ,'Ankle boot'
)
> plt.figure(figsize=(10,10))
> plot_confusion_matrix(cm, names)

Confusion matrix, without normalization
[[5431   14   88  145   26    7  241    0   48    0]
[   4 5896    6   75    8    0    8    0    3    0]
[  92    6 5002   76  565    1  232    1   25    0]
[ 191   49   23 5504  162    1   61    0    7    2]
[  15   12  267  213 5305    1  168    0   19    0]
[   0    0    0    0    0 5847    0  112    3   38]
[1159   16  523  189  676    0 3396    0   41    0]
[   0    0    0    0    0   99    0 5540    0  361]
[  28    6   29   15   32   23   26   14 5827    0]
[   0    0    0    0    1   61    0  107    1 5830]]

解释混淆矩阵

混淆矩阵具有三个轴:

  1. 预测标签(类)
  2. 真实标签
  3. 热图值(彩色)

预测标签和真实标签向我们显示了我们正在处理的预测类。矩阵对角线表示矩阵中预测和真值相同的位置,因此我们希望此处的热图更暗。

任何不在对角线上的值都是不正确的预测,因为预测和真实标签不匹配。要读取该图,我们可以使用以下步骤:

  1. 在水平轴上选择一个预测标签。
  2. 检查此标签的对角线位置以查看正确的总数。
  3. 检查其他非对角线位置以查看网络混乱之处。

例如,网络正在将T恤/上衣与衬衫混淆,但并未将T恤/上衣与以下物质混淆:

  • Ankle boot
  • Sneaker
  • Sandal

如果我们考虑一下,这很有意义。随着我们模型的学习,我们将看到对角线之外的数字越来越小。

在本系列的这一点上,我们已经完成了许多在PyTorch中构建和训练CNN的工作。恭喜!

文章中内容都是经过仔细研究的,本人水平有限,翻译无法做到完美,但是真的是费了很大功夫,希望小伙伴能动动你性感的小手,分享朋友圈或点个“在看”,支持一下我 ^_^

英文原文链接是:

<https://deeplizard.com/learn/video/0LhiS6yu2qQ>

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-06-28,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI算法与图像处理 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 混淆矩阵要求
    • 这个预测张量将包含我们训练集中每个样本的10个预测(每个服装类别一个)。在我们得到这个张量之后,我们可以使用标签张量来生成一个混淆矩阵。
      • 获取整个训练集的预测
        • 建立一个函数来获得所有样本的预测
          • 本地禁用PyTorch梯度跟踪
            • 使用预测张量
              • 建立混淆矩阵
                • 绘制混淆矩阵
                  • 解释混淆矩阵
                  领券
                  问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档