首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >Keras -如何使用argmax进行预测

Keras -如何使用argmax进行预测
EN

Stack Overflow用户
提问于 2019-01-13 18:20:00
回答 1查看 17.2K关注 0票数 4

我有3个类别的类Tree, Stump, Ground。我为这些类别做了一个列表:

代码语言:javascript
复制
CATEGORIES = ["Tree", "Stump", "Ground"]

当我打印我的预测时,它给我的输出是

代码语言:javascript
复制
[[0. 1. 0.]]

我读过有关numpy的Argmax的资料,但我不完全确定在这种情况下如何使用它。

我试过用

代码语言:javascript
复制
print(np.argmax(prediction))

但这给了我1的输出。这很好,但是我想找出1的索引是什么,然后打印出类别,而不是最高值。

代码语言:javascript
复制
import cv2
import tensorflow as tf
import numpy as np

CATEGORIES = ["Tree", "Stump", "Ground"]


def prepare(filepath):
    IMG_SIZE = 150 # This value must be the same as the value in Part1
    img_array = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
    new_array = cv2.resize(img_array, (IMG_SIZE, IMG_SIZE))
    return new_array.reshape(-1, IMG_SIZE, IMG_SIZE, 1)

# Able to load a .model, .h3, .chibai and even .dog
model = tf.keras.models.load_model("models/test.model")

prediction = model.predict([prepare('image.jpg')])
print("Predictions:")
print(prediction)
print(np.argmax(prediction))

我希望我的预测能告诉我:

代码语言:javascript
复制
Predictions:
[[0. 1. 0.]]
Stump

感谢阅读:)非常感谢大家的帮助。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-01-13 20:08:14

您只需使用np.argmax的结果对类别进行索引

代码语言:javascript
复制
pred_name = CATEGORIES[np.argmax(prediction)]
print(pred_name)
票数 8
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/54167910

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档