前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >如何针对数据不平衡做处理?

如何针对数据不平衡做处理?

作者头像
代码医生工作室
发布2020-05-09 10:33:12
1.3K0
发布2020-05-09 10:33:12
举报
文章被收录于专栏:相约机器人

背景

数据和特征决定了机器学习的上限,模型和算法只是不断逼近这个上限。

无论是做比赛还是做项目,都会遇到一个问题:类别不平衡。这与 数据分布不一致所带来的影响不太一样,前者会导致你的模型在训练过程中无法拟合所有类别的数据,也就是会弄混,后者则更倾向于导致模型泛华能力减弱。

举个例子,让你从一千张狗的图中找到放进去的一只猫,你看了一遍,由于狗的特征你观察的太多了,所以很难会及时分辨出哪只是猫(请忽略人的先验知识)。

下面给出两种解决办法:

1. 数据扩充

数据不平衡,某个类别的数据量太少,那就新增一些呗,简单直接。

但是,怎么增加?如果是实际项目且能够与数据源直接或方便接触的时候,就可以直接去采集新数据。如果是比赛,那就行不通了,最好的办法就是对数据做有效增强后进行扩充。

数据增强的手段:

  • 水平 / 竖直翻转
  • 90°,180°,270° 旋转
  • 翻转 + 旋转
  • 亮度,饱和度,对比度的随机变化
  • 随机裁剪(Random Crop)
  • 随机缩放(Random Resize)
  • 加模糊(Blurring)
  • 加高斯噪声(Gaussian Noise)

以上是我在实际过程中常用一些增强手段,但是除了前三种以外,其他的要慎重考虑。因为不同的任务场景下数据特征依赖不同,比如高斯噪声,在天池铝材缺陷检测竞赛中,如果高斯噪声增加不当,有些图片原本在采集的时候相机就对焦不准,导致工件难以看清,倘若再增加高斯模糊属性,基本就废了。

以前在做处理的时候,也是瞎凑一块,暴力堆数据,但是这样很容易导致噪声过大,从而影响模型效果。后来从 刘思聪大佬的竞赛分享中得到了启发(原文链接:Kaggle 求生:亚马逊热带雨林篇),以下是一些转移理解:

以下图为例

我们做数据增强一定要保证有效性,即不能跟原始数据特征差别太大也不能直接复制,旋转和翻转其实是保证了数据特征的旋转不变性能被模型学习到。就下面一张图而言,结合旋转和翻转,做了八次增强,效果如下:

即使我做了这么多次的旋转工作,模型能从第一张图中识别出雨林和河流,那理所当然从其他角度也能识别出。

在做旋转的时候,也有一个疑问,不做 90° 倍数的旋转不行吗?做 30° 倍数的旋转,最后得到的数据岂不是更多?

个人理解是这样的:一方面考虑存储和模型训练周期的影响,增益比太小,划不来;另一方面,我让模型从这八个角度去看一张图片理论来说已经把图片的旋转特征看了一遍了,这对深度学习模型而言已经足够了。

附上做旋转的代码:

代码语言:javascript
复制
  • from PIL import ImageEnhance
  • from PIL import Image
  • #原图
  • raw_image = Image.open("./raw_images/amazon.jpg")
  • #旋转90°倍数
  • rotate_90 = raw_image.rotate(90)
  • rotate_180 = raw_image.rotate(180)
  • rotate_270 = raw_image.rotate(270)
  • #旋转结合翻转
  • flip_vertical_raw = raw_image.transpose(Image.FLIP_TOP_BOTTOM)
  • flip_vertical_90 = rotate_90.transpose(Image.FLIP_TOP_BOTTOM)
  • flip_vertical_180 = rotate_180.transpose(Image.FLIP_TOP_BOTTOM)
  • flip_vertical_270 = rotate_270.transpose(Image.FLIP_TOP_BOTTOM)
  • #存储
  • flip_vertical_raw.save("./processed_images/flip_vertical_raw.jpg")
  • flip_vertical_90.save("./processed_images/flip_vertical_90.jpg")
  • flip_vertical_180.save("./processed_images/flip_vertical_180.jpg")
  • flip_vertical_270.save("./processed_images/flip_vertical_270.jpg")
  • raw_image.save("./processed_images/amazon.jpg")
  • rotate_90.save("./processed_images/rotate_90.jpg")
  • rotate_180.save("./processed_images/rotate_180.jpg")
  • rotate_270.save("./processed_images/rotate_270.jpg")

2. sampler

2.1 采样

如果说类别之间的差距过大,有效的数据增强方式肯定不能弥补这种严重的不平衡,这个时候就需要在模型训练过程中对采样过程进行处理了。常见的采样方式分为两种:过采样和欠采样,效果图如下 (图片来源见参考文献 2):

原理就是 “删图片” 和 “增加图片”,从而保证在训练过程中类别之间的数据量大致相同。所带来的影响如下

过采样:重复正比例数据,实际上没有为模型引入更多数据,过分强调正比例数据,会放大正比例噪音对模型的影响。

欠采样:丢弃大量数据,和过采样一样会存在过拟合的问题。

但总的来肯定是利大于弊

2.2 pytorch 权重采样

pytorch 在 DataLoader () 的时候可以传入 sampler ,这里只说一下加权采样

torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)

源码:

代码语言:javascript
复制
  • class WeightedRandomSampler(Sampler):
  • r"""Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
  • Arguments:
  • weights (sequence) : a sequence of weights, not necessary summing up to one
  • num_samples (int): number of samples to draw
  • replacement (bool): if ``True``, samples are drawn with replacement.
  • If not, they are drawn without replacement, which means that when a
  • sample index is drawn for a row, it cannot be drawn again for that row.
  • """
  • def __init__(self, weights, num_samples, replacement=True):
  • if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
  • num_samples <= 0:
  • raise ValueError("num_samples should be a positive integeral "
  • "value, but got num_samples={}".format(num_samples))
  • if not isinstance(replacement, bool):
  • raise ValueError("replacement should be a boolean value, but got "
  • "replacement={}".format(replacement))
  • self.weights = torch.tensor(weights, dtype=torch.double)
  • self.num_samples = num_samples
  • self.replacement = replacement
  • def __iter__(self):
  • return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
  • def __len__(self):
  • return self.num_samples

使用方法:

代码语言:javascript
复制
  • import torch
  • from torch.utils.data import DataLoader,WeightedRandomSampler
  • from dataset import train_dataset
  • weights = torch.FloatTensor([1,2,2,4,4,1])
  • train_sampler = WeightedRandomSampler(weights,len(train_dataset),replacement=True)
  • train_sampler = DataLoader(train_dataset,sampler=sampler)

解释:

  • weights:指每一个类别在采样过程中得到权重大小(不要求综合为 1),权重越大的样本被选中的概率越大;
  • num_samples: 共选取的样本总数,待选取的样本数目一般小于全部的样本数目;
  • replacement :指定是否可以重复选取某一个样本,默认为 True,即允许在一个 epoch 中重复采样某一个数据。如果设为 False,则当某一类的样本被全部选取完,但其样本数目仍未达到 num_samples 时,sampler 将不会再从该类中选择数据,此时可能导致 weights 参数失效。

3. 损失函数加权

还有一种方法是在计算损失函数过程中,对每个类别的损失做加权,具体的方式如下

代码语言:javascript
复制
weights = torch.FloatTensor([1,1,8,8,4]) 
criterion = nn.BCEWithLogitsLoss(pos_weight=weights).cuda()

4. 其他方法

暂时没用到,如果有大佬有更好的办法,欢迎评论或联系我。

参考文献

[1] Kaggle 求生:亚马逊热带雨林篇

https://zhuanlan.zhihu.com/p/28084438

[2] Resampling strategies for imbalanced datasets

https://www.kaggle.com/rafjaa/resampling-strategies-for-imbalanced-datasets

[3] pytorch sampler 对数据进行采样

https://blog.csdn.net/TH_NUM/article/details/80877772

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-04-26,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 相约机器人 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. 数据扩充
    • 在做旋转的时候,也有一个疑问,不做 90° 倍数的旋转不行吗?做 30° 倍数的旋转,最后得到的数据岂不是更多?
    • 2. sampler
      • 2.1 采样
        • 2.2 pytorch 权重采样
        • 3. 损失函数加权
        • 4. 其他方法
        • 参考文献
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档