专栏首页YZh学习记录深度学习trick--labelsmooth
原创

深度学习trick--labelsmooth

labelsmooth 分类问题中错误标注的一种解决方法

1. 应用背景

Label smoothing其全称是 Label Smoothing Regularization(LSR),即标签平滑正则化

其应用场景必须具备以下几个要素:

  1. 标签是one-hot向量
  2. 损失函数是交叉熵损失函数

其作用对象是真实标签.

在神经网络训练中,真实标签主要用于两个方面:1)计算loss; 2)计算accuracy。 计算accuracy时只拿真实标签值和预测索引值做比较,如果二者相等,则说明预测准确,此时真实标签并不参与计算。 计算loss时,用到的交叉熵损失函数如下:

可以看出损失函数的计算只与预测值和真实值有关.所以此处真实值的形式和数值大小对损失函数的计算影响非常大

常用的真实标签的形式是one-hot向量,其值非0即1,

计算Loss时如果是1,Loss中该项的log(Yinference)就得到保留;

如果是0,该项的log(Yinference)就彻底抛弃。

即:

这样就使得模型过分相信标签的标注,只要是标签为1的项就保留,标签为0的项就统统抛弃,那万一标签标注错了岂不是错杀好人?标签是人为标注的,如果人在标注的时候一个不留神标错了,而模型的判决又“充分”信任人为标注的标签,它作为模型判断保留或丢弃某项数据的标准,这就会使错误标签在模型训练中产生较大的影响。

2. 原理介绍

在多分类训练任务中,输入图片经过神级网络的计算,会得到当前输入图片对应于各个类别的置信度分数,这些分数会被softmax进行归一化处理,最终得到当前输入图片属于每个类别的概率。

之后在使用交叉熵函数来计算损失值:

最终在训练网络时,最小化预测概率和标签真实概率的交叉熵,从而得到最优的预测概率分布。在此过程中,为了达到最好的拟合效果,最优的预测概率分布为:

也就是说,网络会驱使自身往正确标签和错误标签差值大的方向学习,在训练数据不足以表征所以的样本特征的情况下,这就会导致网络过拟合。这会导致模型对正确分类的情况奖励最大,错误分类惩罚最大。如果训练数据能覆盖所有情况,或者是完全正确,那么这种方式没有问题。但事实上,这不可能。所以这种方式可能会带来泛化能力差的问题,即过拟合。

3. labelsmooth

标签平滑做的工作就是使真实标签不那么极端化,给予标签一定的容错概率。思想如下:

4. pytorch 实现

class LSR(nn.Module):

    def __init__(self, e=0.01,reduction='mean'):
        super().__init__()

        self.log_softmax = nn.LogSoftmax(dim=1)
        self.e = e
        self.reduction = reduction

    def _one_hot(self, labels, classes, value=1):
        """
            Convert labels to one hot vectors

        Args:
            labels: torch tensor in format [label1, label2, label3, ...]
            classes: int, number of classes
            value: label value in one hot vector, default to 1

        Returns:
            return one hot format labels in shape [batchsize, classes]
        """
        #print("classes", classes)
        one_hot = t.zeros(labels.size(0), classes)

        # labels and value_added  size must match
        labels = labels.view(labels.size(0), -1)
        value_added = t.Tensor(labels.size(0), 1).fill_(value)

        value_added = value_added.to(labels.device)
        one_hot = one_hot.to(labels.device)

        one_hot.scatter_add_(1, labels, value_added)

        return one_hot

    def _smooth_label(self, target, length, smooth_factor):
        """convert targets to one-hot format, and smooth
        them.

        Args:
            target: target in form with [label1, label2, label_batchsize]
            length: length of one-hot format(number of classes)
            smooth_factor: smooth factor for label smooth

        Returns:
            smoothed labels in one hot format
        """
        #print("length", length)
        #print("smooth_fact", smooth_factor)
        one_hot = self._one_hot(target, length, value=1 - smooth_factor)
        one_hot += smooth_factor / length

        return one_hot.to(target.device)

    def forward(self, x, target):

        if x.size(0) != target.size(0):
            raise ValueError('Expected input batchsize ({}) to match target batch_size({})'
                             .format(x.size(0), target.size(0)))

        if x.dim() < 2:
            raise ValueError('Expected input tensor to have least 2 dimensions(got {})'
                             .format(x.size(0)))

        if x.dim() != 2:
            raise ValueError('Only 2 dimension tensor are implemented, (got {})'
                             .format(x.size()))
        #print("x: ", x)
        #print("target", target)

        smoothed_target = self._smooth_label(target, x.size(1), self.e)
        x = self.log_softmax(x)
        loss = t.sum(- x * smoothed_target, dim=1)
        if self.reduction == 'none':
            return loss

        elif self.reduction == 'sum':
            return t.sum(loss)

        elif self.reduction == 'mean':
            return t.mean(loss)

        else:
            raise ValueError('unrecognized option, expect reduction to be one of none, mean, sum')

将损失函数替换成LSR即可

试验后,结果并没有提升.

总结原因可能是使用的小数据集,在样本标注过程并没有出现错误.因此LSR并没有发挥作用.

参考链接:

https://www.jianshu.com/p/6a5ea4ddbf32 https://blog.csdn.net/qiu931110/article/details/86684241

https://blog.csdn.net/z2536083458/article/details/99215855

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • valgrind测试报告分析

    valgrind输出结果会报告5种内存泄露,"definitely lost", "indirectly lost", "possibly lost", "st...

    yzh
  • CMAKE学习记录(二)

    # 找到boost特定版本,以及需要的子模块,1.46.1为最低版本,REQUIRED强调必须找到该模块,否则编译失败;

    yzh
  • MobileNetV2 论文阅读

    MobileNetv2架构是基于倒置残差结构(inverted residual structure),原本的残差结构的主分支是有三个卷积,两个逐点卷积通道数较...

    yzh
  • leetcode: 77. Combinations

    JNingWei
  • python_面向对象编程

    py3study
  • Python3网络爬虫快速入门实战解析

    强烈建议:请在电脑的陪同下,阅读本文。本文以实战为主,阅读过程如稍有不适,还望多加练习。

    圆方圆PYTHON学院
  • Python3网络爬虫快速入门实战解析

    Python版本: Python3.x 运行平台: Windows IDE: Sublime text3 一 前言 强烈建议:请在电脑的陪同下,阅读...

    Jack_Cui
  • Python 面向对象介绍

    py3study
  • JsonPath实践(一)

    最近团队开发了一个平台,功能界面类似postman,用例都还是单接口的用例,就是可以绑定一个用户的登录状态和一些常量。验证功能主要分为两类:1、系统验证(包括请...

    FunTester
  • 省 市 区(县) 街道(乡) plist 文件

    developerbfl

扫码关注云+社区

领取腾讯云代金券