前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Pytorch里的CrossEntropyLoss详解

Pytorch里的CrossEntropyLoss详解

作者头像
marsggbo
发布2020-06-12 09:41:13
2.8K0
发布2020-06-12 09:41:13
举报
文章被收录于专栏:AutoML(自动机器学习)

在使用Pytorch时经常碰见这些函数cross_entropy,CrossEntropyLoss, log_softmax, softmax。看得我头大,所以整理本文以备日后查阅。

首先要知道上面提到的这些函数一部分是来自于torch.nn,而另一部分则来自于torch.nn.functional(常缩写为F)。二者函数的区别可参见 知乎:torch.nn和funtional函数区别是什么?

下面是对与cross entropy有关的函数做的总结:

torch.nn

torch.nn.functional (F)

CrossEntropyLoss

cross_entropy

LogSoftmax

log_softmax

NLLLoss

nll_loss

下面将主要介绍torch.nn.functional中的函数为主,torch.nn中对应的函数其实就是对F里的函数进行包装以便管理变量等操作。

在介绍cross_entropy之前先介绍两个基本函数:

log_softmax

这个很好理解,其实就是logsoftmax合并在一起执行。

nll_loss

该函数的全程是negative log likelihood loss,函数表达式为

\[f(x,class)=-x[class] \]

例如假设

x=[1,2,3], class=2

,那额

f(x,class)=-x[2]=-3

cross_entropy

交叉熵的计算公式为:

\[cross\_entropy=-\sum_{k=1}^{N}\left(p_{k} * \log q_{k}\right) \]

其中

p

表示真实值,在这个公式中是one-hot形式;

q

是预测值,在这里假设已经是经过softmax后的结果了。

代码示例

代码语言:javascript
复制
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.randint(5, (3,), dtype=torch.int64)
>>> loss = F.cross_entropy(input, target)
>>> loss.backward()
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2019-02-19 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • log_softmax
  • nll_loss
  • cross_entropy
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档