首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >PyTorch CNN教程的混淆矩阵和测试精度

PyTorch CNN教程的混淆矩阵和测试精度
EN

Stack Overflow用户
提问于 2020-04-07 16:28:22
回答 1查看 412关注 0票数 0

我只对报告训练和测试的准确性以及混淆矩阵感兴趣(比如使用sklearn混淆矩阵)。我该怎么做呢?目前的教程只报告了train/val的准确性,我很难弄清楚如何在那里整合sklearn confusionmatrix代码。链接到原始教程:https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/4%20-%20Convolutional%20Sentiment%20Analysis.ipynb

EN

回答 1

Stack Overflow用户

发布于 2020-04-07 17:05:35

与本教程中定义的binary_accuracy函数非常相似,您可以实现所需的任何指标。您所需要的只是一组模型预测(本例中为preds)和真实目标(y)。

例如,对于混淆矩阵,您可以执行以下操作:

代码语言:javascript
运行
复制
from sklearn.metrics import confusion_matrix

def compute_confusion_matrix(preds, y):
    #round predictions to the closest integer
    rounded_preds = torch.round(torch.sigmoid(preds))
    return confusion_matrix(y, rounded_preds)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61075810

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档