专栏首页相约机器人如何针对数据不平衡做处理?

如何针对数据不平衡做处理?

背景

数据和特征决定了机器学习的上限,模型和算法只是不断逼近这个上限。

无论是做比赛还是做项目,都会遇到一个问题:类别不平衡。这与 数据分布不一致所带来的影响不太一样,前者会导致你的模型在训练过程中无法拟合所有类别的数据,也就是会弄混,后者则更倾向于导致模型泛华能力减弱。

举个例子,让你从一千张狗的图中找到放进去的一只猫,你看了一遍,由于狗的特征你观察的太多了,所以很难会及时分辨出哪只是猫(请忽略人的先验知识)。

下面给出两种解决办法:

1. 数据扩充

数据不平衡,某个类别的数据量太少,那就新增一些呗,简单直接。

但是,怎么增加?如果是实际项目且能够与数据源直接或方便接触的时候,就可以直接去采集新数据。如果是比赛,那就行不通了,最好的办法就是对数据做有效增强后进行扩充。

数据增强的手段:

  • 水平 / 竖直翻转
  • 90°,180°,270° 旋转
  • 翻转 + 旋转
  • 亮度,饱和度,对比度的随机变化
  • 随机裁剪(Random Crop)
  • 随机缩放(Random Resize)
  • 加模糊(Blurring)
  • 加高斯噪声(Gaussian Noise)

以上是我在实际过程中常用一些增强手段,但是除了前三种以外,其他的要慎重考虑。因为不同的任务场景下数据特征依赖不同,比如高斯噪声,在天池铝材缺陷检测竞赛中,如果高斯噪声增加不当,有些图片原本在采集的时候相机就对焦不准,导致工件难以看清,倘若再增加高斯模糊属性,基本就废了。

以前在做处理的时候,也是瞎凑一块,暴力堆数据,但是这样很容易导致噪声过大,从而影响模型效果。后来从 刘思聪大佬的竞赛分享中得到了启发(原文链接:Kaggle 求生:亚马逊热带雨林篇),以下是一些转移理解:

以下图为例

我们做数据增强一定要保证有效性,即不能跟原始数据特征差别太大也不能直接复制,旋转和翻转其实是保证了数据特征的旋转不变性能被模型学习到。就下面一张图而言,结合旋转和翻转,做了八次增强,效果如下:

即使我做了这么多次的旋转工作,模型能从第一张图中识别出雨林和河流,那理所当然从其他角度也能识别出。

在做旋转的时候,也有一个疑问,不做 90° 倍数的旋转不行吗?做 30° 倍数的旋转,最后得到的数据岂不是更多?

个人理解是这样的:一方面考虑存储和模型训练周期的影响,增益比太小,划不来;另一方面,我让模型从这八个角度去看一张图片理论来说已经把图片的旋转特征看了一遍了,这对深度学习模型而言已经足够了。

附上做旋转的代码:

  • from PIL import ImageEnhance
  • from PIL import Image
  • #原图
  • raw_image = Image.open("./raw_images/amazon.jpg")
  • #旋转90°倍数
  • rotate_90 = raw_image.rotate(90)
  • rotate_180 = raw_image.rotate(180)
  • rotate_270 = raw_image.rotate(270)
  • #旋转结合翻转
  • flip_vertical_raw = raw_image.transpose(Image.FLIP_TOP_BOTTOM)
  • flip_vertical_90 = rotate_90.transpose(Image.FLIP_TOP_BOTTOM)
  • flip_vertical_180 = rotate_180.transpose(Image.FLIP_TOP_BOTTOM)
  • flip_vertical_270 = rotate_270.transpose(Image.FLIP_TOP_BOTTOM)
  • #存储
  • flip_vertical_raw.save("./processed_images/flip_vertical_raw.jpg")
  • flip_vertical_90.save("./processed_images/flip_vertical_90.jpg")
  • flip_vertical_180.save("./processed_images/flip_vertical_180.jpg")
  • flip_vertical_270.save("./processed_images/flip_vertical_270.jpg")
  • raw_image.save("./processed_images/amazon.jpg")
  • rotate_90.save("./processed_images/rotate_90.jpg")
  • rotate_180.save("./processed_images/rotate_180.jpg")
  • rotate_270.save("./processed_images/rotate_270.jpg")

2. sampler

2.1 采样

如果说类别之间的差距过大,有效的数据增强方式肯定不能弥补这种严重的不平衡,这个时候就需要在模型训练过程中对采样过程进行处理了。常见的采样方式分为两种:过采样和欠采样,效果图如下 (图片来源见参考文献 2):

原理就是 “删图片” 和 “增加图片”,从而保证在训练过程中类别之间的数据量大致相同。所带来的影响如下

过采样:重复正比例数据,实际上没有为模型引入更多数据,过分强调正比例数据,会放大正比例噪音对模型的影响。

欠采样:丢弃大量数据,和过采样一样会存在过拟合的问题。

但总的来肯定是利大于弊

2.2 pytorch 权重采样

pytorch 在 DataLoader () 的时候可以传入 sampler ,这里只说一下加权采样

torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)

源码:

  • class WeightedRandomSampler(Sampler):
  • r"""Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
  • Arguments:
  • weights (sequence) : a sequence of weights, not necessary summing up to one
  • num_samples (int): number of samples to draw
  • replacement (bool): if ``True``, samples are drawn with replacement.
  • If not, they are drawn without replacement, which means that when a
  • sample index is drawn for a row, it cannot be drawn again for that row.
  • """
  • def __init__(self, weights, num_samples, replacement=True):
  • if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
  • num_samples <= 0:
  • raise ValueError("num_samples should be a positive integeral "
  • "value, but got num_samples={}".format(num_samples))
  • if not isinstance(replacement, bool):
  • raise ValueError("replacement should be a boolean value, but got "
  • "replacement={}".format(replacement))
  • self.weights = torch.tensor(weights, dtype=torch.double)
  • self.num_samples = num_samples
  • self.replacement = replacement
  • def __iter__(self):
  • return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
  • def __len__(self):
  • return self.num_samples

使用方法:

  • import torch
  • from torch.utils.data import DataLoader,WeightedRandomSampler
  • from dataset import train_dataset
  • weights = torch.FloatTensor([1,2,2,4,4,1])
  • train_sampler = WeightedRandomSampler(weights,len(train_dataset),replacement=True)
  • train_sampler = DataLoader(train_dataset,sampler=sampler)

解释:

  • weights:指每一个类别在采样过程中得到权重大小(不要求综合为 1),权重越大的样本被选中的概率越大;
  • num_samples: 共选取的样本总数,待选取的样本数目一般小于全部的样本数目;
  • replacement :指定是否可以重复选取某一个样本,默认为 True,即允许在一个 epoch 中重复采样某一个数据。如果设为 False,则当某一类的样本被全部选取完,但其样本数目仍未达到 num_samples 时,sampler 将不会再从该类中选择数据,此时可能导致 weights 参数失效。

3. 损失函数加权

还有一种方法是在计算损失函数过程中,对每个类别的损失做加权,具体的方式如下

weights = torch.FloatTensor([1,1,8,8,4]) 
criterion = nn.BCEWithLogitsLoss(pos_weight=weights).cuda()

4. 其他方法

暂时没用到,如果有大佬有更好的办法,欢迎评论或联系我。

参考文献

[1] Kaggle 求生:亚马逊热带雨林篇

https://zhuanlan.zhihu.com/p/28084438

[2] Resampling strategies for imbalanced datasets

https://www.kaggle.com/rafjaa/resampling-strategies-for-imbalanced-datasets

[3] pytorch sampler 对数据进行采样

https://blog.csdn.net/TH_NUM/article/details/80877772

本文分享自微信公众号 - 相约机器人(xiangyuejiqiren)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2020-04-26

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 关于如何使用以下技术微调机器和深度学习模型的简介:随机搜索,自动超参数调整和人工神经网络调整

    模型参数定义了如何使用输入数据来获得所需的输出,并在训练时进行学习。相反,超参数首先确定了模型的结构。

    代码医生工作室
  • 推特800赞:图网络论文实现大合集| 已过1k星

    自从科学家发现,图神经网络 (GNN) 能处理不规则数据、攻克从前难解的问题,后每每出现图网络的资源,便广受人类的喜爱。

    代码医生工作室
  • 人工智能在网络安全领域的应用与提高

    网络安全领域中的加密流量的检测是一个老生常谈的话题,随着人工智能的发展,给同样的问题,带来了不同的解决思路。

    代码医生工作室
  • AI in WAF | 腾讯云网站管家 WAF AI 引擎实践

    腾讯云安全
  • java开发C语言编译器:消除冗余语句和把ifelse控制语句编译成字节码

    望月从良
  • 前端day09-JS学习笔记

    .注意点 : if-else if -else结构中必须以if开头,中间的else if可以是多个,末尾的esle可以省略(一般都不会省略)

    帅的一麻皮
  • 大数据新机遇,教育系统将建设完整安全体系

    随着网络规模的扩大,Web应用承载的业务系统越来越复杂,Web系统也受到越来越多的攻击和威胁。大数据时代,网络安全也直接影响到每一个用户的个人信息安全,但是大数...

    安恒信息
  • Blade 模板引擎进阶篇

    除了基本的数据渲染及控制结构指令之外,Blade 还提供了模板继承和组件引入功能,从而允许视图模板之间继承、覆盖及引入。

    学院君
  • 漫谈C变量——天然原子性是怎么回事?

    在20世纪初叶,人们曾经一度认为原子是物质的最小组成单位,原子不可再分。虽然很快人们就发现这是一个谬误——原子不仅可以再分,由质子、中字、电子组成,事实上这些微...

    GorgonMeducer 傻孩子
  • 提高工作效率的GitHub Chrome插件

    以悬浮小框的形式展示作者,仓库,Issues,Pull requests的概述信息

    Java识堂

扫码关注云+社区

领取腾讯云代金券