如何获得每个类的Accuracy、Detection_Rate、False_Positive_Rate和False_Negative_Rate?
例如,class_1、class_2、class_3等中的所有这些指标。
发布于 2021-06-24 15:57:48
下面是我用来计算VGG分类器准确性的代码的一个例子,它是用Py火炬完成的:
# VGGClassifier is my Model being trained
# testloader is my dataset for testing
print('Now testing...')
res = 0 # This variable counts the number of good answer
for n, (Xtest, Ytest) in enumerate(testloader): # Batchsize of testloader should probably be 1
Xtest, Ytest = Xtest.to(device), Ytest.to(device) # Put the tensor on CPU/GPU
Y_pred = VGGClassifier(Xtest) # Compute the output of the model
if torch.argmax(Y_pred) == Ytest:
res += 1 # If predicted output = groundtruth output, then we add 1 to the counter
acc = res / (len(testloader)) # divide the number of good answer by the total number
acctab.append(acc) # This list stores the accuracy values during training
print("acc : ", acc)这段代码并不难,所以花点时间去理解它,然后您可以调整它来计算Detection_Rate、False_Positive_Rate和False_Negative_Rate。
例如,为了计算Detection_Rate of class1,您使用相同的代码作为精度,但是您只检查Ytest =1(它们属于第一类)的值,而不是检查每个值的输出。
不确定我的解释是否足够清楚,如果您难以理解一些指标,我可以给出更多的代码片段。
我很确定TF和Pytorch都有已经编码的基本指标,所以一定要查看它们的度量库:
https://datascience.stackexchange.com/questions/97034
复制相似问题