前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Focal Loss 论文理解及公式推导

Focal Loss 论文理解及公式推导

作者头像
AIHGF
修改2020-06-12 16:30:55
5K0
修改2020-06-12 16:30:55
举报
文章被收录于专栏:AIUAIAIUAI

原文:Focal Loss 论文理解及公式推导 - AIUAI

题目: Focal Loss for Dense Object Detection - ICCV2017 作者: Tsung-Yi, Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollar 团队: FAIR

精度最高的目标检测器往往基于 RCNN 的 two-stage 方法,对候选目标位置再采用分类器处理. 而,one-stage 目标检测器是对所有可能的目标位置进行规则的(regular)、密集采样,更快速简单,但是精度还在追赶 two-stage 检测器. <论文所关注的问题于此.>

论文发现,密集检测器训练过程中,所遇到的极端前景背景类别不均衡(extreme foreground-background class imbalance)是核心原因.

对此,提出了 Focal Loss,通过修改标准的交叉熵损失函数,降低对能够很好分类样本的权重(down-weights the loss assigned to well-classified examples),解决类别不均衡问题.

Focal Loss 关注于在 hard samples 的稀疏子集进行训练,并避免在训练过程中大量的简单负样本淹没检测器.

Focal Loss 是动态缩放的交叉熵损失函数,随着对正确分类的置信增加,缩放因子(scaling factor) 衰退到 0. 如图:

Focal Loss 的缩放因子能够动态的调整训练过程中简单样本的权重,并让模型快速关注于困难样本(hard samples).

基于 Focal Loss 的 RetinaNet 的目标检测器表现.

1. Focal Loss

1.2 Focal Loss 定义

1.3. Focal Loss 例示

1.4. Focal Loss 求导

2. SoftmaxFocalLoss 求导

Focal Loss 损失函数:

3. Pytorch 实现

FocalLoss-PyTorch

代码语言:javascript
复制
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = torch.Tensor([gamma])
        self.size_average = size_average
        if isinstance(alpha, (float, int, long)):
            if self.alpha > 1:
                raise ValueError('Not supported value, alpha should be small than 1.0')
            else:
                self.alpha = torch.Tensor([alpha, 1.0 - alpha])
        if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)
        self.alpha /= torch.sum(self.alpha)

    def forward(self, input, target):
        if input.dim() > 2:
            input = input.view(input.size(0), input.size(1), -1)  # [N,C,H,W]->[N,C,H*W] ([N,C,D,H,W]->[N,C,D*H*W])
        # target
        # [N,1,D,H,W] ->[N*D*H*W,1]
        if self.alpha.device != input.device:
            self.alpha = torch.tensor(self.alpha, device=input.device)
        target = target.view(-1, 1)
        logpt = torch.log(input + 1e-10)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1, 1)
        pt = torch.exp(logpt)
        alpha = self.alpha.gather(0, target.view(-1))

        gamma = self.gamma

        if not self.gamma.device == input.device:
            gamma = torch.tensor(self.gamma, device=input.device)

        loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
        if self.size_average:
            loss = loss.mean()
        else:
            loss = loss.sum()
        return loss

4. Keras 实现

keras-focal-loss

基于 Keras 和 TensorFlow 后端实现的 Binary Focal Loss 和 Categorical/Multiclass Focal Loss.

主要设计两个参数:alphagamma.

用法

代码语言:javascript
复制
model.compile(optimizer='adam', loss=categorical_focal_loss(gamma=2.0, alpha=0.25), metrics=['accuracy'])

实现

代码语言:javascript
复制
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Oct 19 08:20:58 2018

@OS: Ubuntu 18.04
@IDE: Spyder3
@author: Aldi Faizal Dimara (Steam ID: phenomos)
"""

import keras.backend as K
import tensorflow as tf

def categorical_focal_loss(gamma=2.0, alpha=0.25):
    """
    Implementation of Focal Loss from the paper in multiclass classification
    Formula:
        loss = -alpha*((1-p)^gamma)*log(p)
    Parameters:
        alpha -- the same as wighting factor in balanced cross entropy
        gamma -- focusing parameter for modulating factor (1-p)
    Default value:
        gamma -- 2.0 as mentioned in the paper
        alpha -- 0.25 as mentioned in the paper
    """
    def focal_loss(y_true, y_pred):
        # Define epsilon so that the backpropagation will not result in NaN
        # for 0 divisor case
        epsilon = K.epsilon()
        # Add the epsilon to prediction value
        #y_pred = y_pred + epsilon
        # Clip the prediction value
        y_pred = K.clip(y_pred, epsilon, 1.0-epsilon)
        # Calculate cross entropy
        cross_entropy = -y_true*K.log(y_pred)
        # Calculate weight that consists of  modulating factor and weighting factor
        weight = alpha * y_true * K.pow((1-y_pred), gamma)
        # Calculate focal loss
        loss = weight * cross_entropy
        # Sum the losses in mini_batch
        loss = K.sum(loss, axis=1)
        return loss
    
    return focal_loss

def binary_focal_loss(gamma=2.0, alpha=0.25):
    """
    Implementation of Focal Loss from the paper in multiclass classification
    Formula:
        loss = -alpha_t*((1-p_t)^gamma)*log(p_t)
        
        p_t = y_pred, if y_true = 1
        p_t = 1-y_pred, otherwise
        
        alpha_t = alpha, if y_true=1
        alpha_t = 1-alpha, otherwise
        
        cross_entropy = -log(p_t)
    Parameters:
        alpha -- the same as wighting factor in balanced cross entropy
        gamma -- focusing parameter for modulating factor (1-p)
    Default value:
        gamma -- 2.0 as mentioned in the paper
        alpha -- 0.25 as mentioned in the paper
    """
    def focal_loss(y_true, y_pred):
        # Define epsilon so that the backpropagation will not result in NaN
        # for 0 divisor case
        epsilon = K.epsilon()
        # Add the epsilon to prediction value
        #y_pred = y_pred + epsilon
        # Clip the prediciton value
        y_pred = K.clip(y_pred, epsilon, 1.0-epsilon)
        # Calculate p_t
        p_t = tf.where(K.equal(y_true, 1), y_pred, 1-y_pred)
        # Calculate alpha_t
        alpha_factor = K.ones_like(y_true)*alpha
        alpha_t = tf.where(K.equal(y_true, 1), alpha_factor, 1-alpha_factor)
        # Calculate cross entropy
        cross_entropy = -K.log(p_t)
        weight = alpha_t * K.pow((1-p_t), gamma)
        # Calculate focal loss
        loss = weight * cross_entropy
        # Sum the losses in mini_batch
        loss = K.sum(loss, axis=1)
        return loss
    
    return focal_loss    

Related

[1] - Focal Loss 的前向与后向公式推导

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018年10月31日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. Focal Loss
    • 1.2 Focal Loss 定义
      • 1.3. Focal Loss 例示
        • 1.4. Focal Loss 求导
        • 2. SoftmaxFocalLoss 求导
        • 3. Pytorch 实现
        • 4. Keras 实现
        • Related
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档