如何使用PYTORCH从RESNET获取概率?

内容来源于 Stack Overflow,并遵循CC BY-SA 3.0许可协议进行翻译与使用

  • 回答 (1)
  • 关注 (0)
  • 查看 (429)

我正在整理我的数据集中的RESNET,它有多个标签。

我想将分类层的“分数”转换为概率,并使用这些概率来计算训练中的损失。

你能给出一个示例代码吗?我能这样用吗?

       P = net.forward(x)
       p = torch.nn.functional.softmax(P, dim=1)
       loss = torch.nn.functional.cross_entropy(P, y)
提问于
用户回答回答于

所以,你在modelpytorch 中训练一个带有交叉熵的resnet。你的损失计算将如下所示。

logit = model(x)
loss = torch.nn.functional.cross_entropy(logits=logit, target=y)

在这种情况下,你可以通过执行来计算所有类的概率,

logit = model(x)
p = torch.nn.functional.softmax(logit, dim=1)
# to calculate loss using probabilities you can do below 
loss = torch.nn.functional.nll_loss(torch.log(p), y)

扫码关注云+社区

领取腾讯云代金券