分类问题样本不均衡常见的解决方法

分类时,由于训练集合中各样本数量不均衡,导致模型训偏在测试集合上的泛化性不好。解决样本不均衡的方法主要包括两类:(1)数据层面,修改各类别的分布;(2)分类器层面,修改训练算法或目标函数进行改进。还有方法是将上述两类进行融合。

数据层面

1. 过采样

(1) 基础版本的过采样:随机过采样训练样本中数量比较少的数据;缺点,容易过拟合;

(2) 改进版本的过采样:SMOTE,通过插值的方式加入近邻的数据点;

(3) 基于聚类的过采样:先对数据进行聚类,然后对聚类后的数据分别进行过采样。这种方法能够降低类间和类内的不平衡。

(4) 神经网络中的过采样:SGD训练时,保证每个batch内部样本均衡。

2. 欠采样

与过采样方法相对立的是欠采样方法,主要是移除数据量较多类别中的部分数据。这个方法的问题在于,丢失数据带来的信息缺失。为克服这一缺点,可以丢掉一些类别边界部分的数据。

分类器层面

1. Thresholding

Thresholding的方法又称为post scaling的方法,即根据测试数据的不同类别样本的分布情况选取合适的阈值判断类别,也可以根据贝叶斯公式重新调整分类器输出概率值。一般的基础做法如下:

假设对于某个类别class在训练数据中占比为x,在测试数据中的占比为x’。分类器输出的概率值需要做scaling,概率转换公式为:

当然这种加权的方式亦可在模型训练过程中进行添加,即对于二分类问题目标函数可以转换为如下公式:

2. Cost sensitive learning

根据样本中不同类别的误分类样本数量,重新定义损失函数。Threshold moving和post scaling是常见的在测试过程进行cost调整的方法。这种方法在训练过程计算损失函数时亦可添加,具体参见上一部分。

另外一种cost sensitive的方法是动态调节学习率,认为容易误分的样本在更新模型参数时的权重更大一些。

3. One-class分类

区别于作类别判决,One-class分类只需要从大量样本中检测出该类别即可,对于每个类别均是一个独立的detect model。这种方法能很好地样本极度不均衡的问题。

4. 集成的方法

主要是使用多种以上的方法。例如SMOTEBoost方法是将Boosting和SMOTE 过采样进行结合。

CNN分类处理方法

CNN神经网络有效地应用于图像分类、文本分类。目前成功解决数据不均衡的问题的一种方法是two-phrase training,即分两阶段训练。首先,在均衡的数据集上进行训练,然后在不均衡的原始数据集合上fine tune最后的output layer。

参考文献

[1] https://www.kaggle.com/c/quora-question-pairs/discussion/31179

[2] Buda M, MakiA, Mazurowski M A, et al. A systematic study of the class imbalance problem inconvolutional neural networks[J]. Neural Networks, 2018: 249-259.

原文发布于微信公众号 - CodeInHand(CodeInHand)

原文发表时间:2018-10-11

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏计算机视觉战队

详聊CNN的精髓

现在的深度学习发展速度已经超出每个人的想象,很大一部分人只是觉得我用他人的框架去实现自己的目的,并且效果很好就可以了,这也是现在一大部分的一个瓶颈。曾经有一个老...

3655
来自专栏AI研习社

用Kaggle经典案例教你用CNN做图像分类!

前言 在上一篇专栏《利用卷积自编码器对图片进行降噪》中,我们利用卷积自编码器对 MNIST 数据进行了实验,这周我们来看一个 Kaggle 上比较经典的一...

4056
来自专栏深度学习与计算机视觉

学习SVM(四) 理解SVM中的支持向量(Support Vector)

学习SVM(一) SVM模型训练与分类的OpenCV实现 学习SVM(二) 如何理解支持向量机的最大分类间隔 学习SVM(三)理解SVM中的对偶问题 ...

2388
来自专栏技术随笔

[Detection] CNN 之 "物体检测" 篇IndexRCNNFast RCNNFaster RCNNR-FCNYOLOSSDNMS

44910
来自专栏贾志刚-OpenCV学堂

理解CNN卷积层与池化层计算

深度学习中CNN网络是核心,对CNN网络来说卷积层与池化层的计算至关重要,不同的步长、填充方式、卷积核大小、池化层策略等都会对最终输出模型与参数、计算复杂度产生...

2231
来自专栏深度学习思考者

深度学习目标检测算法——Faster-Rcnn

Faster-Rcnn代码下载地址:https://github.com/ShaoqingRen/faster_rcnn 一 前言   Faster rcnn是...

3225
来自专栏闪电gogogo的专栏

《统计学习方法》笔记七(1) 支持向量机——线性可分支持向量机

应用拉格朗日对偶性,通过求解对偶问题得到原始问题的最优解,一是因为对偶问题往往更容易求解,二是自然引入核函数,进而推广到非线性分类的问题。

892
来自专栏AI深度学习求索

深度学习基础学习 | 为什么要进行特征提取

在计算机中,图片以有序的多维矩阵进行存储,按颜色分为灰度图片用二维数组存储图片的像素值,和彩色图片用三维数组存储图片的三个通道颜色的像素值。

1922
来自专栏null的专栏

简单易学的机器学习算法——EM算法

一、机器学习中的参数估计问题 image.png 二、EM算法简介     在上述存在隐变量的问题中,不能直接通过极大似然估计求出模型中的参数,EM算法是一种解...

1.7K5
来自专栏机器学习、深度学习

人脸对齐--Dense Face Alignment

Dense Face Alignment ICCVW2017 http://cvlab.cse.msu.edu/project-pifa.html ...

5177

扫码关注云+社区

领取腾讯云代金券