前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【强基固本】5分钟理解Focal Loss与GHM——解决样本不平衡利器

【强基固本】5分钟理解Focal Loss与GHM——解决样本不平衡利器

作者头像
马上科普尚尚
发布2021-07-30 16:40:26
7720
发布2021-07-30 16:40:26
举报

“强基固本,行稳致远”,科学研究离不开理论基础,人工智能学科更是需要数学、物理、神经科学等基础学科提供有力支撑,为了紧扣时代脉搏,我们推出“强基固本”专栏,讲解AI领域的基础知识,为你的科研学习提供助力,夯实理论基础,提升原始创新能力,敬请关注。

来源:知乎—中国移不动

地址:https://zhuanlan.zhihu.com/p/80594704

Focal Loss for Dense Object Detection 是ICCV2017的Best student paper,文章思路很简单但非常具有开拓性意义,效果也非常令人称赞。

GHM(gradient harmonizing mechanism) 发表于 “Gradient Harmonized Single-stage Detector",AAAI2019,是基于Focal loss的改进,也是个人推荐的一篇深度学习必读文章。

01

第一部分 Focal Loss

Focal Loss的实现:

代码语言:javascript
复制

def py_sigmoid_focal_loss(pred,
                          target,
                          weight=None,
                          gamma=2.0,
                          alpha=0.25,
                          reduction='mean',
                          avg_factor=None):
    pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss

02

第二部分 GHM

那么,Focal Loss存在什么问题呢?

首先,让模型过多关注那些特别难分的样本肯定是存在问题的,样本中有离群点(outliers),可能模型已经收敛了但是这些离群点还是会被判断错误,让模型去关注这样的样本,怎么可能是最好的呢?

代码语言:javascript
复制
class GHMC(nn.Module):
    def __init__(self, bins=10, ......):
        self.bins = bins
        edges = torch.arange(bins + 1).float() / bins
......

>>> edges = tensor([0.0000, 0.1000, 0.2000, 0.3000, 0.4000, 
                  0.5000, 0.6000, 0.7000, 0.8000,0.9000, 1.0000])

edges是每个区域的边界,有了边界就很容易计算出梯度模长落入哪个区间内。

然后根据网络输出pred和ground true计算loss:

注意,不管是Focal Loss还是GHM其实都是对不同样本赋予不同的权重,所以该代码前面计算的都是样本权重,最后计算GHM Loss就是调用了Pytorch自带的binary_cross_entropy_with_logits,将样本权重填进去。

代码语言:javascript
复制
代码语言:javascript
复制
# 计算梯度模长
g = torch.abs(pred.sigmoid().detach() - target)
# n 用来统计有效的区间数。
# 假如某个区间没有落入任何梯度模长,密度为0,需要额外考虑,不然取个倒数就无穷了。
n = 0  # n valid bins
# 通过循环计算落入10个bins的梯度模长数量
for i in range(self.bins):
    inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
    num_in_bin = inds.sum().item()
    if num_in_bin > 0:
        # 重点,所谓的梯度密度就是1/num_in_bin
        weights[inds] = num_labels / num_in_bin 
        n += 1
if n > 0:
    weights = weights / n
# 把上面计算的weights填到binary_cross_entropy_with_logits里就行了
loss = torch.nn.functional.binary_cross_entropy_with_logits(
    pred, target, weights, reduction='sum') / num_labels

看看抑制的效果吧,也就是文章开头的这张图片:

End~

因为本文着重论文的理解,很多细节没有写出,大家还是要去看一下原文的。

如果文中有错误还请批评指出!

参考资料:

https://openaccess.thecvf.com/content_ICCV_2017/papers/Lin_Focal_Loss_for_ICCV_2017_paper.pdf

https://www.aaai.org/ojs/index.php/AAAI/article/download/4877/4750

https://blog.csdn.net/watermelon1123/article/details/89362220

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-07-28,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 人工智能前沿讲习 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Focal Loss的实现:
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档