我有一个学校的项目,有深入学习的面孔识别。我需要倒数矩阵来衡量性能指标,如准确性,精确性。我试过下面的代码。但是,y_test参数会给出一个错误。我怎么才能解决这个问题?
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(img_array, img_labels,
shuffle=True, stratify=img_labels,
test_size=0.1, random_state=42)
print('Eğitim için eleman sayısı, yükseklik/genişlik ve kanal sayısı: ', x_train.shape)
print('Test için eleman sayısı, yükseklik/genişlik ve kanal sayısı: : ',x_test.shape)
print('Eğitimdeki örnek ve sınıf sayısı :', y_train.shape)
print('Testteki örnek ve sınıf sayısı : ',y_test.shape)
我的代码
cm = confusion_matrix(y_test, y_pred)
print(cm)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Input In [55], in <cell line: 1>()
----> 1 cm = confusion_matrix(y_test, y_pred)
2 print(cm)
File ~\anaconda3\lib\site-packages\sklearn\metrics\_classification.py:307, in confusion_matrix(y_true, y_pred, labels, sample_weight, normalize)
222 def confusion_matrix(
223 y_true, y_pred, *, labels=None, sample_weight=None, normalize=None
224 ):
225 """Compute confusion matrix to evaluate the accuracy of a classification.
226
227 By definition a confusion matrix :math:`C` is such that :math:`C_{i, j}`
(...)
305 (0, 2, 1, 1)
306 """
--> 307 y_type, y_true, y_pred = _check_targets(y_true, y_pred)
308 if y_type not in ("binary", "multiclass"):
309 raise ValueError("%s is not supported" % y_type)
File ~\anaconda3\lib\site-packages\sklearn\metrics\_classification.py:93, in _check_targets(y_true, y_pred)
90 y_type = {"multiclass"}
92 if len(y_type) > 1:
---> 93 raise ValueError(
94 "Classification metrics can't handle a mix of {0} and {1} targets".format(
95 type_true, type_pred
96 )
97 )
99 # We can't have more than one value on y_type => The set is no more needed
100 y_type = y_type.pop()
ValueError: Classification metrics can't handle a mix of multilabel-indicator and continuous-multioutput targets
发布于 2022-05-12 08:55:17
我知道我不应该在回答中提供这一点,但我现在不能补充评论。分类报告期望y_pred和y_test都是一个一维数组,类标签作为整数。TensorFlow模型的预测主要是二维数组,每个条目都是具有给定行类概率的一维数组。因此,您需要对y_pred进行一些预处理。几周前我遇到了类似的情况,我要分享几行可能有用的代码。
res = np.array(res)
res = res.flatten()
res = np.round(res)
请注意,以上代码用于二进制分类。对于多标签分类,可以使用np.argmax
。
https://stackoverflow.com/questions/72212393
复制相似问题