前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >​常用损失函数Loss和Python代码

​常用损失函数Loss和Python代码

作者头像
Minerva
修改2020-05-23 10:34:27
7.2K0
修改2020-05-23 10:34:27
举报

常用损失函数Loss和Python代码

1、损失函数

在机器学习和深度学习中,损失函数 Loss function 是用来估量训练过程中模型的预测值Prediction与真实值Target的偏差,损失函数越小,预测值和真实值越接近,模型的泛化性能越好,通过不断调整模型参数使得损失函数越来越小,从而指导模型的学习。

2、交叉熵损失

2.1 Softmax Loss

交叉熵损失一般配合softmax使用的,通过softmax操作得到每个类别的概率值,所有概率值和为1,然后计算loss,softmax函数为:

交叉熵损失函数为:

其中,M表示类别数, y_c 是一个one-hot向量,元素只有0和1两种取值,至于 p_c 表示预测样本属于的概率。 当类别数等于2的时候,这个损失就是二元交叉熵。 Python代码如下:

代码语言:javascript
复制
import torch
import torch.nn as nn
import torch.nn.functional as F

input = torch.randn(3, requires_grad=True)  # 从标准正态分布(均值为0,方差为1,即高斯白噪声)中抽取的3个随机数
target = torch.empty(3).random_(2)  # 生成3个值,值为0 或者 1

#二值交叉熵,这里输入要经过sigmoid处理
out = F.sigmoid(input)
loss = nn.BCELoss(out, target)

#多分类交叉熵, 用这个 loss 前面不需要加 Softmax 层
out = activation fuction(input)#activation fuction是激活函数
loss= nn.CrossEntropyLoss(input, target)

loss.backward()

交叉熵Loss可以用在大多数语义分割场景中,但当某一类的数量占比远远小于其他类时,损失函数中这一类就会越来越不被重视,其他类的成分就会占据主导,导致效果不好。交叉熵损失还有一个特点,优化类间的距离非常棒,但是优化类内距离时比较弱,因此有很多研究人员对其进行改进优化。

2.2 带权交叉熵损失 Weighted Softmax Loss

在任务中当某一类数量占比很少,但这一类又很重要时,带权重的交叉熵损失就发挥用处了,函数如下:

W_c 就是加入的权重参数,该权重可以自行设定,也可以采用如下公式自动设定

N_c 为类别为c的像素个数,NN为所有类别的像素个数,在样本数量不均衡的情况下可以获得更好的效果。

Python代码如下:

代码语言:javascript
复制
nn.CrossEntropyLoss(input, target, weight=class_weight)

2.3 Soft Softmax loss

公式如下:

当T=1时,就是Softmax的定义,当T>1,就称之为Soft Softmax,T越大,因为 Z_k 产生的概率差异就会越小。该函数是为了迁移学习,生成软标签,然后将软标签和硬标签同时用于新网络的学习。

3、MSELoss

计算均方误差 Mean Squared Error (squared L2 Norm)。 公式如下,其中是真实值,是预测值:

Python代码如下:

代码语言:javascript
复制
import torch
Loss = nn.MSELoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
loss = loss(input, target)
loss.backward()

4、Dice Loss

是用来度量集合相似度的度量函数,通常用于计算两个样本之间的像素相似度,公式如下:

Dice是公式后面部分,是两个样本A和B的相似度度量。分子是矩阵A和B逐个元素相乘(点乘),再求和,即相交。分母是矩阵分别求和(矩阵内所有元素加起来),再相加。对于二分类问题,Target分割图是只有 , 两个值,因此 可以有效忽视背景像素,只关注要检测的目标部分,预测结果Prediction和真实标签Target越相似,Dice 系数越高,Dice Loss越小。 Dice Loss适用于目标样本极度不均衡的情况,但目标很小时,使用Dice Loss会因为分割不好导致Loss很大,对反向传播有不利的影响,使得训练不稳定。 Python代码如下:

代码语言:javascript
复制
import torch.nn as nn
import torch.nn.functional as F

class SoftDiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(SoftDiceLoss, self).__init__()
    def forward(self, logits, targets):
        num = targets.size(0)
        smooth = 1
        probs = F.sigmoid(logits)
        m1 = probs.view(num, -1)
        m2 = targets.view(num, -1)
        intersection = (m1 * m2)
        score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
        score = 1 - score.sum() / num
        return score

5、Focal Loss

何凯明团队在RetinaNet论文中引入了Focal Loss来解决难易样本数量不平衡,在one-stage目标检测中正负样本比例严重失衡,该损失函数降低了大量简单负样本在训练中所占的权重,也可理解为一种困难样本挖掘。 二元交叉熵函数如下:

其对于正样本而言,输出概率越大损失越小。对于负样本而言,输出概率越小则损失越小。为了解决正负样本数量不平衡的问题,我们经常在二元交叉熵损失前面加一个参数,成为带权重的交叉熵损失,这样虽然平衡了正负样本的数量,但实际上,目标检测中大量的候选目标都是易分样本(置信度高)。这些样本的损失很低,对模型的提升效果非常小,模型应该主要关注那些难分样本 。所以Focal Loss横空出世了。一个简单的想法就是只要我们将高置信度样本的损失降低 也即是下面的公式:

以GammaGamma等于3,易分样本置信度p=0.9为例,原来的损失为1-0.9=0.1,该方法的损失为(1-0.9)^3 = 0.0001,损失降低了1000倍,使得模型更加重视难分样本。这解决了难易样本的不平衡,而引入权重解决了正负样本的不平衡,Focal Loss同时解决正负难易两个问题,最终Focal Loss的形式如下:

当Gamma = 2, alpha = 0.5时,损失函数训练的过程中关注的样本优先级就是正难>负难>正易>负易。 Python 代码如下:

代码语言:javascript
复制
import torch
import torch.nn as nn

#二分类
class FocalLoss(nn.Module):

    def __init__(self, gamma=2,alpha=0.25):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha=alpha
    def forward(self, input, target):
        # input:size is M*2. M is the batch number
        # target:size is M.
        pt=torch.softmax(input,dim=1)
        p=pt[:,1]
        loss = -self.alpha*(1-p)**self.gamma*(target*torch.log(p))-\
               (1-self.alpha)*p**self.gamma*((1-target)*torch.log(1-p))
        return loss.mean()

6、多种损失函数结合

6.1 BCE + Dice Loss

在数据较为均衡的情况下有所改善,但是在数据极度不均衡的情况下交叉熵Loss会在迭代几个Epoch之后远远小于Dice Loss,这个组合Loss会退化为Dice Loss。

6.2 Focal Loss + Dice Loss

即将Focal Loss和Dice Loss相加,以来处理小器官的分割问题。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-03-09,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Python编程和深度学习 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 常用损失函数Loss和Python代码
  • 1、损失函数
  • 2、交叉熵损失
    • 2.1 Softmax Loss
      • 2.2 带权交叉熵损失 Weighted Softmax Loss
        • 2.3 Soft Softmax loss
        • 3、MSELoss
        • 4、Dice Loss
        • 5、Focal Loss
        • 6、多种损失函数结合
          • 6.1 BCE + Dice Loss
            • 6.2 Focal Loss + Dice Loss
            领券
            问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档