首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >使用PyTorch的交叉熵损失函数是否需要One-Hot编码?

使用PyTorch的交叉熵损失函数是否需要One-Hot编码?
EN

Stack Overflow用户
提问于 2020-06-19 02:04:11
回答 2查看 7.7K关注 0票数 10

例如,如果我想解决MNIST分类问题,我们有10个输出类。在PyTorch中,我想使用torch.nn.CrossEntropyLoss函数。我是否必须格式化目标以便它们是一次性编码的,或者我可以简单地使用数据集附带的它们的类标签?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2020-06-19 02:18:13

nn.CrossEntropyLoss需要整数标签。它在内部所做的是,它根本不会对类标签进行一次性编码,而是使用标签索引到输出概率向量中,以计算损失,如果您决定使用这个类作为最终标签的话。这个小但重要的细节使得计算损失变得更容易,并且等同于执行one-hot编码,测量每个输出神经元的输出损失,因为输出层中的每个值都将为零,但在目标类上索引的神经元除外。因此,如果您已经提供了标签,则不需要对数据进行一次性编码。

文档对此有更多的见解:https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html。在文档中,您将看到targets,它是输入参数的一部分。这些是您的标签,它们被描述为:

这清楚地显示了输入应该如何形成以及预期的内容。如果您实际上想要对数据进行一次性编码,则需要使用torch.nn.functional.one_hot。为了最好地复制交叉熵损失在幕后所做的事情,您还需要nn.functional.log_softmax作为最终输出,并且您还必须编写自己的损失层,因为没有一个PyTorch层使用log softmax输入和one-hot编码目标。但是,nn.CrossEntropyLoss将这两种操作组合在一起,如果您的输出只是类标签,那么最好使用它,这样就不需要进行转换。

票数 13
EN

Stack Overflow用户

发布于 2020-06-19 02:16:01

如果您正在加载ImageLoader以从文件夹本身加载数据集,则PyTorch将自动为您标记它们。您所要做的就是像这样构造文件夹:

代码语言:javascript
运行
复制
|
|__train
|   |
|   |__1
|   |_ 2
|   |_ 3
|   .
|   .
|   .
|   |_10
|
|__test
    |
    |__1
    |_ 2
    |_ 3
    .
    .
    .
    |_10

每个类都应该有一个单独的文件夹。如果您从DataFrame加载数据,您可以使用以下代码对其进行编码:

代码语言:javascript
运行
复制
one_hot = torch.nn.functional.one_hot(target)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62456558

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档