首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >问答首页 >如何计算测试数据的混淆矩阵?

如何计算测试数据的混淆矩阵?
EN

Stack Overflow用户
提问于 2021-02-13 16:52:18
回答 3查看 1.9K关注 0票数 0

我想在验证数据上绘制一个混淆矩阵。

具体来说,我想对验证数据计算模型输出的混淆矩阵。

我在网上什么都试过了,但搞不懂。

这是我的模型:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import tensorflow as tf
from tensorflow.keras import datasets, layers, models

(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0

model = models.Sequential()
# layers here

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

history = model.fit(train_images, train_labels, epochs=1, 
                    validation_data=(test_images, test_labels))
EN

回答 3

Stack Overflow用户

回答已采纳

发布于 2021-02-13 17:00:13

下面是一个虚拟的例子。

DataSet

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# train set / data 
x_train = x_train.reshape(-1, 28*28)
x_train = x_train.astype('float32') / 255

# train set / target 
num_of_classess = 10 
y_train = tf.keras.utils.to_categorical(y_train , num_classes=num_of_classess )

模型

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
model = Sequential()
model.add(Dense(800, input_dim=784, activation="relu"))
model.add(Dense(num_of_classess , activation="softmax"))
model.compile(loss="categorical_crossentropy", optimizer="SGD", metrics=["accuracy"])
history = model.fit(x_train, y_train, 
                    batch_size=200, 
                    epochs=20,  
                    verbose=1)

混淆矩阵

你的兴趣主要集中在这里。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# get predictions
y_pred = model.predict(x_train, verbose=2)

# compute confusion matrix with `tf` 
confusion = tf.math.confusion_matrix(
              labels = np.argmax(y_train, axis=1),      # get trule labels 
              predictions = np.argmax(y_pred, axis=1),  # get predicted labels 
              num_classes=num_of_classess)              # no. of classifier 

print(confusion)
<tf.Tensor: shape=(10, 10), dtype=int32, numpy=
array([[5750,    0,   16,   13,    9,   25,   40,    9,   54,    7],
       [   2, 6570,   28,   34,    8,   26,    6,   16,   45,    7],
       [  35,   44, 5425,   82,   93,   12,   69,   79,  100,   19],
       [  15,   24,  105, 5628,    4,  136,   26,   60,   82,   51],
       [   9,   29,   33,    6, 5483,    2,   60,   10,   33,  177],
       [  58,   32,   26,  159,   51, 4864,  101,   19,   67,   44],
       [  32,   18,   28,    3,   43,   60, 5697,    2,   33,    2],
       [  26,   46,   74,   19,   62,   10,    3, 5895,   15,  115],
       [  27,  101,   46,  142,   25,   71,   52,   15, 5304,   68],
       [  34,   30,   20,   94,  173,   21,    4,  162,   32, 5379]],
      dtype=int32)>

Visualization

让我们想象一下。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import seaborn as sns 
import pandas as pd 

cm = pd.DataFrame(confusion.numpy(), # use .numpy(), because now confusion is tensor
               range(num_of_classess),range(num_of_classess))

plt.figure(figsize = (10,10))
sns.heatmap(cm, annot=True, annot_kws={"size": 12}) # font size
plt.show()

更新

基于对话,如果你要用

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

然后,不要像我前面展示的那样转换整数标签(即y_train = tf.keras.utils.to_categorical(y_train, num_classes=10))。但如下所示

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# train set / data 
x_train = x_train.astype('float32') / 255

print(x_train.shape, y_train.shape) 
# (50000, 32, 32, 3) (50000, 1)

model ...
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
....
)

在预测时间时,不要使用np.argmax() on ,地面真理,,因为它们现在已经是一个整数,因为我们这次没有使用tf.keras.utils.to_categorical

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
print(np.argmax(y_pred, axis=1).shape, y_train.reshape(-1).shape)
# (50000,) (50000,)

y_pred = model.predict(x_train, verbose=2) # take prediction 
confusion = tf.math.confusion_matrix(
              labels = y_train.reshape(-1),             # get trule labels 
              predictions = np.argmax(y_pred, axis=1),  # get predicted labels
              )    

现在剩下的东西可以用了。

票数 3
EN

Stack Overflow用户

发布于 2021-02-13 16:59:46

应该相当直截了当。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
test_labels = np.array([0,0,1,1,2,2,3,3,3]) #actual labels
test_pred = np.array([0,1,1,1,1,2,3,3,0])   #predicted labels

cf = tf.math.confusion_matrix(test_labels, test_pred)

pd.DataFrame(cf.numpy(), columns=[0,1,2,3], index=[0,1,2,3])
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
   0  1  2  3
0  1  1  0  0
1  0  2  0  0
2  0  1  1  0
3  1  0  0  2

确保您正在应用np.argmax over axis=1 on test_pred,以确保它的1D带有标签,而不是像这样的2D逻辑

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
test_pred = np.argmax(model.predict(test_images), axis=1)
票数 0
EN

Stack Overflow用户

发布于 2021-02-13 17:05:21

另外,一旦您将混淆矩阵作为一个numpy数组,您就可以很容易地用sklearn's ConfusionMatrixDisplay绘制它。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
from sklearn.metrics import ConfusionMatrixDisplay

def plot_cm(cm):
    ConfusionMatrixDisplay(cm).plot()

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66191450

复制
相关文章
pandas如何获取Excel文件下所有的sheet名称
一定要加sheet_name=None,才能读取出所有的sheet,否则默认读取第一个sheet,且获取到的keys是第一行的值
py3study
2020/12/22
13.9K0
SAS获取某目录下所有指定类型的文件名称
今天看到一个群友提的一个问题:SAS中如何简单地获取某一目录下所有指定类型的文件名称并赋值为宏变量?用常规的方法可能要20多行代码,如果用FILENAME PIPE只需要9行代码就可以轻松解决,语法如下:
专业余码农
2020/07/16
4.7K1
使用VBA在工作表中列出所有定义的名称
有时候,工作簿中可能有大量的命名区域。然而,如果名称太多,虽然有名称管理器,可能名称的命名也有清晰的含义,但查阅起来仍然不是很方便,特别是想要知道名称引用的区域时,如果经常要打开名称管理器查找命名区域,会非常麻烦,也浪费时间。
fanjy
2022/11/16
6.6K0
js获取url链接中的域名部分
因为一个正确的url必定是由http://或者是https://、domain、路径/参数组成,所以可以用split以/进行分割成数组,取第3部分就是域名了。
全栈程序员站长
2022/07/08
9.2K0
PHP 获取指定 URL 页面中的所有链接
以下代码可以获取到指定 URL 页面中的所有链接,即所有 a 标签的 href 属性:
Z4
2020/04/22
7.6K0
根据 PID 获取 K8S Pod名称 - 反之 POD名称 获取 PID
随着 Kubernetes 越来越火爆,运维人员排查问题难度越来越大。比如我们收到监控报警,某台 Kubernetes Node 节点负载高。通过 top 或者 pidstat 命令获取 Pid,问题来了,这个 Pid 对应那个 Kubernetes Pod 呢?
YP小站
2020/07/21
3.4K0
Qt中获取当前应用程序全路径
在Qt中获取应用程序全路径非常简单,直接使用`QCoreApplication::applicationDirPath()`这个静态函数就OK了,
ccf19881030
2021/05/30
3.4K0
Qt中获取当前应用程序全路径
如何获取流式应用程序中checkpoint的最新offset
对于流式应用程序,保证应用7*24小时的稳定运行,是非常必要的。因此对于计算引擎,要求必须能够适应与应用程序逻辑本身无关的问题(比如driver应用失败重启、网络问题、服务器问题、JVM崩溃等),具有自动容错恢复的功能。
大数据学习与分享
2020/08/10
1.3K0
Android中获取应用程序(包)的信息-----PackageManager的使用
Android系统为我们提供了很多服务管理的类,包括ActivityManager、PowerManager(电源管理)、AudioManager(音频管理)
forrestlin
2022/04/02
2.2K0
Android中获取应用程序(包)的信息-----PackageManager的使用
arcengine开发如何遍历MapControl和PageLaoutControl中的图层,获取图层名称
一般的GIS开发者都知道arcengine开发中如何遍历MapControl中的图层,代码如下:
acoolgiser
2019/01/17
2.3K0
js中,如何获取批量传入文件的大小,名称,进行循环展示。
<div class="handle"> <div class="handle-box" id="drop_area" v-on:drop="dropClick"> <div class="handle-btn"> <img class="btn-icon" src="./images/compress/new-btn-icon.png" alt="">
用户4344670
2022/09/02
10K0
根据 PID 获取容器所在的 Pod 名称
在管理 Kubernetes 集群的过程中,我们经常会遇到这样一种情况:在某台节点上发现某个进程资源占用量很高,却又不知道是哪个容器里的进程。有没有办法可以根据 PID 快速找到 Pod 名称呢?
米开朗基杨
2020/07/17
6.9K0
Python获取网卡信息(名称、MAC、
    “人生苦短,我用Python”。Python的高效有一部分是跟它丰富的模块分不开的。Python有很多第三方模块可以帮助我们完成一些事情,减少开发时间。
py3study
2020/01/03
4.7K0
Python获取网卡信息(名称、MAC、
在不是Thread类的子类中,如何获取线程对象的名称呢?
我想要获取main方法所在的线程对象的名称,该怎么办呢?   遇到这种情况,Thread类就提供了一个很好玩的方法:     public static Thread currentThread()
黑泽君
2018/10/11
4.9K0
2021最新微博爬虫——根据话题名称获取所有相关微博与评论
-首先确定抓取微博内容、评论数、点赞数、发布时间、发布者名称等主要字段。选择weibo.com作为主要数据来源。(就是因为搜索功能好使)
MinChess
2022/12/27
4.4K1
2021最新微博爬虫——根据话题名称获取所有相关微博与评论
java准确的获取操作系统的名称
在我们日常开发中,经常需要判断操作系统的版本或者系统的名字等等。这就需要我们用到jdk默认带的一些属性了。这里我对各个版本的系统都做了区分,分别能判断mac,linux,window等大众的操作系统名称。直接看代码(OSUtil.java):
业余草
2019/01/21
4.3K0
java准确的获取操作系统的名称
Netty中的线程名称
创建的第一个步骤就是创建线程执行器ThreadPerTaskExecutor, 这个线程执行器就是用来创建Netty底层的线程的. 在学习Java的Thread时候,线程默认名称类似thread-0,thread-1,thread-2...以此类推. 而线程的名称对于我们排查问题的时候也是起到很大作用的, 因此我们在设计线程池, 也会根据一定的规则给线程池中的线程命名, 这也是一个好的习惯.
书唐瑞
2022/06/02
1.1K0
Netty中的线程名称
JavaScript获取url网址中域名后面的部分
lastIndexOf() 方法返回调用 String 对象的指定值最后一次出现的索引,在一个字符串中的指定位置 fromIndex 处从后向前搜索。如果没找到这个特定值则返回-1 。
德顺
2021/01/18
7.2K0
获取两个list中相互不包含的部分
代码如下:提供了几种方法(自个写的) import java.util.*; import java.util.stream.Collectors; import java.util.stream.Stream; class Scratch { public static void main(String[] args) { List<Integer> list = new ArrayList<>(); list.add(1); list.a
阿超
2022/08/21
1.6K0
获取两个list中相互不包含的部分
iPhone应用程序名称本地化
iPhone的应用程序名称也可以本地化,可以按照以下步骤来实施: 1. 修改项目目录下的'-info.plist'文件名</h2> 将'-info.plist' 修改为 Info.plist ## 2. 将Info.plist本地化 在Info.plist上右键点选Get Info,在General标签下,点击Make File Localizable按钮。 里面会有一个默认的英文版本,点击Add Localization... 按钮,添加你需要的本地化语言。 如简体中文"zh-hans",然后点击添加
EltonZheng
2021/01/26
5800

相似问题

查找所有并删除Regedit的所有函数

20

在regedit中创建名称中带有“/”的密钥

211

获取具有字符串部分名称的所有类

110

获取顶层部分的名称?

10

按属性名称部分获取所有元素的最佳方法

18
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
查看详情【社区公告】 技术创作特训营有奖征文