前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Focal Loss --- 从直觉到实现

Focal Loss --- 从直觉到实现

作者头像
NewBeeNLP
发布2021-08-20 17:11:46
6490
发布2021-08-20 17:11:46
举报
文章被收录于专栏:NewBeeNLPNewBeeNLP

作者 | 虎哥@知乎 整理 | NewBeeNLP

做机器学习分类问题,难免遇到Biased-Data-Problem, 例如

  • CV的目标检测问题: 绝大多数检测框里都是 backgroud
  • NLP的异常文本检测: 绝大多数文本都是 normal

对此,以下套路可以缓解:

  • 升/降采样, 或者调整样本权重
  • 换个更鲁棒的loss函数 ,或者加正则
  • 集成模型: Bagging, RandomForest ...
  • 利于外部先验知识: 预训练+微调
  • 多任务联合学习
  • ...

今天要聊的就是一种针对该问题精心设计的loss函数——「Focal Loss」

现状

先来回顾一下常用的 BinaryCrossEntropyLoss 公式如下

不难看出,CE是个“笨学生”。

考前复习的时候,「他不会划重点,对所有知识点 “一视同仁”」

如果教科书上有100道例题,包括: 90道加减乘除 + 10道 三角函数。CE同学就会吭哧吭哧的“平均用力”反复练习这100道例题,结果可想而知——他会精通那90道个位数加减乘除题目,然后其他题目基本靠蒙。那10道他不会的题,往往还是分值高的压轴题。

嗯,大概就是这么个症状。

解决办法

机智如你,想必已经有办法了 —— 「给他指个方向,别再“平均用力”就好了」

方法一、分科复习

每个【科目】的难度是不同的;你要花 30%的精力在四则运算,70%的精力在三角函数。--- 老师告诉CE同学 第一个技巧

对应到公式中,就是针对每个类别赋予不同的权重,即下述

\alpha_t

这是个简单粗暴有效的办法。

方法二、刷题战术

每道【题目】的难度是不同的;你要根据以往刷类似题时候的正确率来合理分配精力。 --- 老师告诉CE同学 第二个技巧

观察CE中的

p_t

,它反映了模型对这个样本的识别能力(即 “这个知识点掌握得有多好”);显然,对于

p_t

越大的样本,我们越要打压它对loss的贡献。

FL是这么定义的:

这里有个超参数

\gamma

; 直观来看,

\gamma

越大 打压越重。如下图所示:

  • 横轴是
p_t

, 纵轴是

\text{FL}(p_t)
  • 总体来说,所有曲线都是单调下降的,即 “掌握越好的知识点越省力”
\gamma=0

时,FL退化成CE,即蓝色线条

\gamma

很大时,线条逐步压低到绿色位置,即各样本对于总loss的贡献受到打压;中间靠右区段承压尤其明显

方法三、综合上述两者

代码

基于Keras实现

代码语言:javascript
复制
from keras import backend as K

def focal_loss(alpha=0.75, gamma=2.0):
    """ 参考 https://blog.csdn.net/u011583927/article/details/90716942 """
    def focal_loss_fixed(y_true, y_pred):
        # y_true 是个一阶向量, 下式按照加号分为左右两部分
        # 注意到 y_true的取值只能是 0或者1 (假设二分类问题),可以视为“掩码”
        # 加号左边的 y_true*alpha 表示将 y_true中等于1的槽位置为标量 alpha
        # 加号右边的 (ones-y_true)*(1-alpha) 则是将等于0的槽位置为 1-alpha
        ones = K.ones_like(y_true)
        alpha_t = y_true*alpha + (ones-y_true)*(1-alpha)

        # 类似上面,y_true仍然视为 0/1 掩码
        # 第1部分 `y_true*y_pred` 表示 将 y_true中为1的槽位置为 y_pred对应槽位的值
        # 第2部分 `(ones-y_true)*(ones-y_pred)` 表示 将 y_true中为0的槽位置为 (1-y_pred)对应槽位的值
        # 第3部分 K.epsilon() 避免后面 log(0) 溢出
        p_t = y_true*y_pred + (ones-y_true)*(ones-y_pred) + K.epsilon()

        # 就是公式的字面意思
        focal_loss = -alpha_t * K.pow((ones-p_t),gamma) * K.log(p_t)
    return focal_loss_fixed

model = ...
model.compile(..., loss=focal_loss(gamma=3, alpha=0.5))

调参经验

\alpha \in (0,1)

反映了“方法一、分科复习”时,各科目的难度比率

  • 二分类场景下,类似于正例的sample_weight概念,可以按照样本占比,适度加权
  • e.g. 设有5条正例、95条负例,则建议
\alpha = 0.95
\alpha=0.5

相当于关掉该功能

\gamma \in [0,+\infty)

反映了 “方法二、刷题战术”时,对于难度的区分程度

\gamma=0

相当于关掉该功能; 即不考虑难度区别,一视同仁

\gamma

越大,则越重视难度,即专注于比较困难的样本。建议在

(0.5, 10.0)

范围尝试

总结

  • 机器学习分类问题中,各类别样本数差距悬殊是很常见的情况;这会干扰模型效果
  • 通过将CrossEntropyLoss替换为综合版的FocalLoss,可以有效缓解上述问题
  • 具体思路就是引入两个额外的变量来区分对待每个样本
    \alpha_t

    根据类别加权

    (1-p_t)^\gamma

    根据难度加权

  • 代码实现很简单、调参也不复杂,详见上文

参考文献

  • Focal Loss for Dense Object Detection
  • Demystifying Focal Loss I: A More Focused Cross Entropy Loss
  • Focal loss论文详解
  • Focal loss 原理及keras实战
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-07-28,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 NewBeeNLP 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 作者 | 虎哥@知乎 整理 | NewBeeNLP
  • 现状
  • 解决办法
    • 方法一、分科复习
      • 方法二、刷题战术
        • 方法三、综合上述两者
        • 代码
        • 调参经验
        • 总结
        • 参考文献
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档