我有3个类别的类Tree, Stump, Ground。我为这些类别做了一个列表:
CATEGORIES = ["Tree", "Stump", "Ground"]当我打印我的预测时,它给我的输出是
[[0. 1. 0.]]我读过有关numpy的Argmax的资料,但我不完全确定在这种情况下如何使用它。
我试过用
print(np.argmax(prediction))但这给了我1的输出。这很好,但是我想找出1的索引是什么,然后打印出类别,而不是最高值。
import cv2
import tensorflow as tf
import numpy as np
CATEGORIES = ["Tree", "Stump", "Ground"]
def prepare(filepath):
IMG_SIZE = 150 # This value must be the same as the value in Part1
img_array = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
new_array = cv2.resize(img_array, (IMG_SIZE, IMG_SIZE))
return new_array.reshape(-1, IMG_SIZE, IMG_SIZE, 1)
# Able to load a .model, .h3, .chibai and even .dog
model = tf.keras.models.load_model("models/test.model")
prediction = model.predict([prepare('image.jpg')])
print("Predictions:")
print(prediction)
print(np.argmax(prediction))我希望我的预测能告诉我:
Predictions:
[[0. 1. 0.]]
Stump感谢阅读:)非常感谢大家的帮助。
发布于 2019-01-13 20:08:14
您只需使用np.argmax的结果对类别进行索引
pred_name = CATEGORIES[np.argmax(prediction)]
print(pred_name)https://stackoverflow.com/questions/54167910
复制相似问题