A trap of parameter 'size_average' in pytorch

pytorch的入门教程中有这样的一句:

loss = torch.nn.MSELoss(size_average=True)
input = Variable(torch.ones(2,2), requires_grad=True)
target = Variable(torch.Tensor([[3,3],[3,3]]))
output = loss(input, target)
output.backward()

上面的程序很简单,设定一个loss函数,然后设定一个input和target进行loss计算,然后再backward。

这里我们不看反向过程,只看output的值是多少:

Variable containing:
 4
[torch.FloatTensor of size 1]

这个4是output的值,可以看到这个值是一个scalar而并不是一个向量。

我们再来看一段程序:

loss = torch.nn.MSELoss(size_average=False)
input = Variable(torch.ones(2,2), requires_grad=True)
target = Variable(torch.Tensor([[3,3],[3,3]]))
output = loss(input, target)
output.backward()
output

Out[11]: 
Variable containing:
 16
[torch.FloatTensor of size 1]

同样一段程序,怎么就变成16了,原因是在设置loss函数中我们把参数size_average设为了False。

当size_average为Ture的时候,我们得到4,反之我们得到16,结合size_average参数的官方解释:

size_average (bool, optional): By default, the losses are averaged over
    observations for each minibatch. However, if the field size_average
    is set to False, the losses are instead summed for each minibatch.

也就是说True的loss除以了loss的数量(此处n=4),False的时候loss为summed的也就是全部loss的总和。

size_average参数默认为True,也就是说我们平时计算的时候得到的loss是平均loss。

这样会出现怎么问题呢:

这个参数默认是True,我们平时的训练集都是规整的,也就是说训练集中的每个类数量几乎差不多的训练集,然后我们算出了损失,比如这个损失是个[10]的Tensor,这里的10就是你之前说的N 也就是分类的数量,这个参数设为True则在得出的所有loss中除以N 如果为Flase则不处理直接把所有类的loss直接加起来给你。这个平时加起来和除以N在实际计算中影响不大,但是如果是别的任务 比如RNN 这时候分类loss就需要权重了。每个分类的loss的权重系数应该不一样,或者说训练集中某些类的数据没有另一类多,得出的损失有可能也有缺陷也需要参数,这时在pytorch中我们就无能为力了,因为其只能输出平均或者总和的loss 不能输出a batch of losses 也就是一个类一个损失,这种输出后期我们可以自由处理 。所以说这个是pytorch的一个缺陷 ,在github中的issue中有人提出来了。pytorch的编写人员回复说正在修改底层。

更新:在pytorch0.3中部分损失函数已经得到修改!

另外几个问题:

如何自定义不同权重loss:https://github.com/pytorch/pytorch/issues/563

loss中奇怪的问题:https://discuss.pytorch.org/t/the-default-value-of-size-average-true-in-loss-function-is-a-trap/4251

How to combine multiple criterions to a loss function? Get loss for individual samples losses per-batch-element:https://github.com/pytorch/pytorch/issues/264

此文由腾讯云爬虫爬取,文章来源于Oldpan博客

欢迎关注Oldpan博客公众号,持续酝酿深度学习质量文:

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

编辑于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏小鹏的专栏

机器学习进阶系列

本文为博主原创文章,未经博主允许不得转载。有问题可以加微信:lp9628(注明CSDN)。

2727
来自专栏AI研习社

都说 AllenNLP 好用,我们跑一遍看看究竟多好用

良好学习过程的关键原则之一,就是让学习的内容略高于当前的理解。如果该主题与你已知的内容太过于相似,那么你就不会有很大的进步。另一方面,如果这个主题太难的话,你就...

1952
来自专栏AI研习社

自定义损失函数Gradient Boosting

互联网上有很多关于梯度提升的很好的解释(我们在参考资料中分享了一些选择的链接),但是我们注意到很少有人提起自定义损失函数的信息:为什么要自定义损失函数,何时需要...

7613
来自专栏机器学习算法与Python学习

教程 | 一步一步,看图理解长短期记忆网络与门控循环网络

大家好,欢迎来到 LSTM 和 GRU 的图解指南。在本文中,Michael 将从 LSTM 和 GRU 的背后的原理开始,然后解释令 LSTM 和 GRU 具...

933
来自专栏集智的专栏

使用腾讯云 GPU 学习深度学习系列之四:深度学习的特征工程

本系列文章主要介绍如何使用 腾讯云GPU服务器 进行深度学习运算,前面主要介绍原理部分,后期则以实践为主。

3.2K4
来自专栏杂七杂八

xgboost初识

XGBoost使用 原始数据 数据介绍 鸢尾花数据集是由杰出的统计学家R.A.Fisher在20世纪30年代中期创建的,它被公认为用于数据挖掘的最著名的数据集。...

3114
来自专栏Python中文社区

Python机器学习方法智能识别亚马逊验证码

亚马逊网站验证码全部由英文字母组成,每个字母的形式也是多样的,通过Tesseract-OCR技术识别效率还是比较低,非常不理想。这里采用向量空间技术进行训练识别...

772
来自专栏大数据文摘

有这5小段代码在手,轻松实现数据可视化(Python+Matplotlib)

2446
来自专栏机器之心

教程 | 一步一步,看图理解长短期记忆网络与门控循环网络

循环神经网络(RNN)很大程度上会受到短期记忆的影响,如果序列足够长,它们将很难将信息从早期时间步传递到靠后的时间步。因此,如果你试图处理一段文字来做预测,RN...

973
来自专栏Coding迪斯尼

深度学习项目实践,使用神经网络分析电影评论的正能量与负能量

1151

扫码关注云+社区