Loading [MathJax]/jax/output/CommonHTML/config.js
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >问答首页 >Tensorflow度量混淆:准确性和损失都很高,但混淆矩阵表示预测错误

Tensorflow度量混淆:准确性和损失都很高,但混淆矩阵表示预测错误
EN

Stack Overflow用户
提问于 2022-03-31 02:41:31
回答 2查看 534关注 0票数 0

经过多年的阅读,终于是我提出第一个问题的时候了:

使用jupyter笔记本中的tensorflow和keras,我在20k声谱图(我自己的数据集)上训练了一个VGG16模型,并使用数据生成器进行了一些数据增强,以进行4级多类分类。下面,我的代码:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import tensorflow as tf
from tensorflow.keras.applications.vgg16 import VGG16

model = VGG16(include_top=True,
              weights=None,
              input_tensor=None,
              pooling=None,
              classes=len(labels),
              classifier_activation="softmax")


from keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import optimizers

# Rescale by 1/255, add data augmentation:
train_datagen = ImageDataGenerator(
      rescale=1./255,
      width_shift_range=0.2,
      brightness_range=[0.8,1.2],
      fill_mode='nearest')

# Note that the validation data should not be augmented!
test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
        # This is the target directory
        train_dir,
        # All images will be resized to 224x224
        target_size=(224, 224),
        batch_size=20,
        # one hot label for multiclass
        class_mode='categorical')

validation_generator = test_datagen.flow_from_directory(
        validation_dir,
        target_size=(224, 224),
        batch_size=20,
        class_mode='categorical')

model.compile(loss='categorical_crossentropy',
              optimizer=optimizers.RMSprop(learning_rate=2e-5),
              metrics=[tf.keras.metrics.CategoricalAccuracy(), 
                       tf.keras.metrics.Precision(), 
                       tf.keras.metrics.Recall()])

# Train the model:
history = model.fit(
      train_generator,
      steps_per_epoch=100,
      epochs=100,
      validation_data=validation_generator,
      validation_steps=50,
      verbose=2)

为了评估训练过程,我绘制了acc,损失,精确,回忆和F1分数。他们看起来都很好,这表明训练进行得很顺利。

当我在我的测试集上使用modell.evaluate时,我得到了91%的acc。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
test_generator = test_datagen.flow_from_directory(
        test_dir,
        target_size=(224, 224),
        batch_size=20,
        class_mode='categorical')

test_loss, test_acc, test_precison, test_recall = model.evaluate(test_generator, steps=50)
print('test_acc:' + str(test_acc))

发现4724幅图像,分属于4类。50/50 ============================== - 2s 49 2s/步进损失: 0.2739 - categorical_accuracy: 0.9120 -精度: 0.9244 -召回: 0.9050 test_acc:0.9120000004768372

但是,当我试图用下面的方式绘制一个混淆矩阵时,它看起来很可怕,当我从创建混淆矩阵的数据中手工计算acc时,我得到了25%的acc。这意味着我的模型完全没有学到任何…

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import numpy as np
import sklearn.metrics

# Print confuision matrix for test set:

test_pred_raw = model.predict(test_generator)
print('raw preditcitons:')
print(test_pred_raw)

test_pred = np.argmax(test_pred_raw, axis=1)
print('prediction:')
print(test_pred)

test_labels = test_generator.classes
print('labels')
print(test_labels)

# Calculate accuracy manualy:
my_test_acc = sum(test_pred == test_labels) / len(test_labels)
print('My_acc:')
print(my_test_acc)

# Calculate the confusion matrix using sklearn.metrics
cm = sklearn.metrics.confusion_matrix(test_labels, test_pred)    
figure = plot_confusion_matrix(cm, class_names=labels)

原始预处理:

[2.9204198e-12 2.8631955e-09 1.0000000e+00 7.3386294e-16

0.0000000e+00 1.0000000e+00 0.0000000e+00 0.0000000e+00

..。

2.2919695e-03 3.8061540e-07 9.9770677e-01 8.1024604e-07

0.0000000e+00 1.0000000e+00 4.0776377e-37 2.6318860e-38]

预测:

2 2 1.2 1

标签

0 0 0.3 3

My_acc:

0.2491532599491956\

我现在的问题是,我可以信任哪些指标,另一个指标有什么问题?

EN

回答 2

Stack Overflow用户

发布于 2022-03-31 05:29:35

好吧。我想我明白了!

shuffle = False中设置test_datagen.flow_from_directory()似乎可以解决问题。现在,混淆矩阵看起来要好得多,my_acc = 89%看起来也不错。

当调用两次时,数据生成器似乎会产生不同的批。首先,通过model.predict(test_generator),然后再通过test_generator.classes,基本使标签和预测不匹配,因为它们是针对不同的批次。

有人能确认我说得对吗?

票数 1
EN

Stack Overflow用户

发布于 2022-03-31 03:06:31

问题可能是:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
my_test_acc = sum(test_pred == test_labels) / len(test_labels)

也许您应该添加一个舍入步骤,以确保预测值实际上是1.0,而不是0.99。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71691097

复制
相关文章
pandas如何获取Excel文件下所有的sheet名称
一定要加sheet_name=None,才能读取出所有的sheet,否则默认读取第一个sheet,且获取到的keys是第一行的值
py3study
2020/12/22
13.9K0
SAS获取某目录下所有指定类型的文件名称
今天看到一个群友提的一个问题:SAS中如何简单地获取某一目录下所有指定类型的文件名称并赋值为宏变量?用常规的方法可能要20多行代码,如果用FILENAME PIPE只需要9行代码就可以轻松解决,语法如下:
专业余码农
2020/07/16
4.7K1
使用VBA在工作表中列出所有定义的名称
有时候,工作簿中可能有大量的命名区域。然而,如果名称太多,虽然有名称管理器,可能名称的命名也有清晰的含义,但查阅起来仍然不是很方便,特别是想要知道名称引用的区域时,如果经常要打开名称管理器查找命名区域,会非常麻烦,也浪费时间。
fanjy
2022/11/16
6.6K0
js获取url链接中的域名部分
因为一个正确的url必定是由http://或者是https://、domain、路径/参数组成,所以可以用split以/进行分割成数组,取第3部分就是域名了。
全栈程序员站长
2022/07/08
9.2K0
PHP 获取指定 URL 页面中的所有链接
以下代码可以获取到指定 URL 页面中的所有链接,即所有 a 标签的 href 属性:
Z4
2020/04/22
7.6K0
根据 PID 获取 K8S Pod名称 - 反之 POD名称 获取 PID
随着 Kubernetes 越来越火爆,运维人员排查问题难度越来越大。比如我们收到监控报警,某台 Kubernetes Node 节点负载高。通过 top 或者 pidstat 命令获取 Pid,问题来了,这个 Pid 对应那个 Kubernetes Pod 呢?
YP小站
2020/07/21
3.4K0
Qt中获取当前应用程序全路径
在Qt中获取应用程序全路径非常简单,直接使用`QCoreApplication::applicationDirPath()`这个静态函数就OK了,
ccf19881030
2021/05/30
3.4K0
Qt中获取当前应用程序全路径
如何获取流式应用程序中checkpoint的最新offset
对于流式应用程序,保证应用7*24小时的稳定运行,是非常必要的。因此对于计算引擎,要求必须能够适应与应用程序逻辑本身无关的问题(比如driver应用失败重启、网络问题、服务器问题、JVM崩溃等),具有自动容错恢复的功能。
大数据学习与分享
2020/08/10
1.3K0
Android中获取应用程序(包)的信息-----PackageManager的使用
Android系统为我们提供了很多服务管理的类,包括ActivityManager、PowerManager(电源管理)、AudioManager(音频管理)
forrestlin
2022/04/02
2.2K0
Android中获取应用程序(包)的信息-----PackageManager的使用
arcengine开发如何遍历MapControl和PageLaoutControl中的图层,获取图层名称
一般的GIS开发者都知道arcengine开发中如何遍历MapControl中的图层,代码如下:
acoolgiser
2019/01/17
2.3K0
js中,如何获取批量传入文件的大小,名称,进行循环展示。
<div class="handle"> <div class="handle-box" id="drop_area" v-on:drop="dropClick"> <div class="handle-btn"> <img class="btn-icon" src="./images/compress/new-btn-icon.png" alt="">
用户4344670
2022/09/02
10K0
根据 PID 获取容器所在的 Pod 名称
在管理 Kubernetes 集群的过程中,我们经常会遇到这样一种情况:在某台节点上发现某个进程资源占用量很高,却又不知道是哪个容器里的进程。有没有办法可以根据 PID 快速找到 Pod 名称呢?
米开朗基杨
2020/07/17
6.9K0
Python获取网卡信息(名称、MAC、
    “人生苦短,我用Python”。Python的高效有一部分是跟它丰富的模块分不开的。Python有很多第三方模块可以帮助我们完成一些事情,减少开发时间。
py3study
2020/01/03
4.7K0
Python获取网卡信息(名称、MAC、
在不是Thread类的子类中,如何获取线程对象的名称呢?
我想要获取main方法所在的线程对象的名称,该怎么办呢?   遇到这种情况,Thread类就提供了一个很好玩的方法:     public static Thread currentThread()
黑泽君
2018/10/11
4.9K0
2021最新微博爬虫——根据话题名称获取所有相关微博与评论
-首先确定抓取微博内容、评论数、点赞数、发布时间、发布者名称等主要字段。选择weibo.com作为主要数据来源。(就是因为搜索功能好使)
MinChess
2022/12/27
4.4K1
2021最新微博爬虫——根据话题名称获取所有相关微博与评论
java准确的获取操作系统的名称
在我们日常开发中,经常需要判断操作系统的版本或者系统的名字等等。这就需要我们用到jdk默认带的一些属性了。这里我对各个版本的系统都做了区分,分别能判断mac,linux,window等大众的操作系统名称。直接看代码(OSUtil.java):
业余草
2019/01/21
4.3K0
java准确的获取操作系统的名称
Netty中的线程名称
创建的第一个步骤就是创建线程执行器ThreadPerTaskExecutor, 这个线程执行器就是用来创建Netty底层的线程的. 在学习Java的Thread时候,线程默认名称类似thread-0,thread-1,thread-2...以此类推. 而线程的名称对于我们排查问题的时候也是起到很大作用的, 因此我们在设计线程池, 也会根据一定的规则给线程池中的线程命名, 这也是一个好的习惯.
书唐瑞
2022/06/02
1.1K0
Netty中的线程名称
JavaScript获取url网址中域名后面的部分
lastIndexOf() 方法返回调用 String 对象的指定值最后一次出现的索引,在一个字符串中的指定位置 fromIndex 处从后向前搜索。如果没找到这个特定值则返回-1 。
德顺
2021/01/18
7.2K0
获取两个list中相互不包含的部分
代码如下:提供了几种方法(自个写的) import java.util.*; import java.util.stream.Collectors; import java.util.stream.Stream; class Scratch { public static void main(String[] args) { List<Integer> list = new ArrayList<>(); list.add(1); list.a
阿超
2022/08/21
1.6K0
获取两个list中相互不包含的部分
iPhone应用程序名称本地化
iPhone的应用程序名称也可以本地化,可以按照以下步骤来实施: 1. 修改项目目录下的'-info.plist'文件名</h2> 将'-info.plist' 修改为 Info.plist ## 2. 将Info.plist本地化 在Info.plist上右键点选Get Info,在General标签下,点击Make File Localizable按钮。 里面会有一个默认的英文版本,点击Add Localization... 按钮,添加你需要的本地化语言。 如简体中文"zh-hans",然后点击添加
EltonZheng
2021/01/26
5800

相似问题

查找所有并删除Regedit的所有函数

20

在regedit中创建名称中带有“/”的密钥

211

获取具有字符串部分名称的所有类

110

获取顶层部分的名称?

10

按属性名称部分获取所有元素的最佳方法

18
添加站长 进交流群

领取专属 10元无门槛券

AI混元助手 在线答疑

扫码加入开发者社群
关注 腾讯云开发者公众号

洞察 腾讯核心技术

剖析业界实践案例

扫码关注腾讯云开发者公众号
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
查看详情【社区公告】 技术创作特训营有奖征文