前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >学习| 如何处理不平衡数据集

学习| 如何处理不平衡数据集

作者头像
陆勤_数据人网
发布2019-05-16 10:06:31
2K0
发布2019-05-16 10:06:31
举报

编者按:数据集的目标变量分布不平衡问题是一个常见问题,它对特征集的相关性和模型的质量与性能都有影响。因此,在做有监督学习的时候,处理类别不平衡数据集问题是必要的。

分类是机器学习中最常见的问题之一。处理任何分类问题的最佳方法是从分析和探索数据集开始,我们称之为探索性数据分析(EDA)。唯一目的是生成尽可能多的关于数据的见解和信息。它还用于查找数据集中可能存在的任何问题。在用于分类的数据集中发现的一个常见问题是不平衡类问题。

什么是数据不平衡?

数据不平衡通常反映数据集中类的不均匀分布。例如,在信用卡欺诈检测数据集中,大多数信用卡交易都不是欺诈,只有很少的类是欺诈交易。这使得我们在欺诈类和非欺诈类之间的比例约为50:1。在本文中,我将使用Kaggle中的信用卡欺诈交易数据集,可以从这里下载。

首先,让我们绘制类分布以查看不平衡。

如您所见,非欺诈交易远远超过欺诈交易。如果我们训练一个二分类模型而不解决这个问题,这个模型将会完全有偏差。它还影响特性之间的相关性,稍后我将向您介绍How 和 Why。

现在,让我们讨论一些解决类不平衡问题的技术。在这里可以找到一个完整代码的笔记本。

1-重采样(过采样和欠采样)

这听起来很直观。欠采样是一个过程,在这个过程中,您从多数类中随机删除一些观察结果,以便与少数类中的数字匹配。一种简单的方法如下面的代码所示:

代码语言:javascript
复制
# Shuffle the Dataset.
shuffled_df = credit_df.sample(frac=1,random_state=4)

# Put all the fraud class in a separate dataset.
fraud_df = shuffled_df.loc[shuffled_df['Class'] == 1]

#Randomly select 492 observations from the non-fraud (majority class)
non_fraud_df = shuffled_df.loc[shuffled_df['Class'] == 0].sample(n=492,random_state=42)

# Concatenate both dataframes again
normalized_df = pd.concat([fraud_df, non_fraud_df])

#plot the dataset after the undersampling
plt.figure(figsize=(8, 8))
sns.countplot('Class', data=normalized_df)
plt.title('Balanced Classes')
plt.show()

对多数类欠采样

在对数据集进行欠采样后,我再次绘制它,它显示了相同数量的类:

第二种重采样技术称为过采样。这个过程比欠采样要复杂一些。它是生成综合数据的过程,试图从少数类的观察中随机生成属性的样本。对于典型的分类问题,有许多方法用于对数据集进行过采样。最常用的技术是SMOTE(Synthetic Minority Over-sampling Technique)。简单地说,它查看少数类数据点的特征空间,并考虑它的k个最近邻。

来源:https://imbalanced-learn.readthedocs.io/en/stable/over_sampling.html

要用python编写这段代码,我使用了一个名为imbalanced-learn或imblearn的库。下面的代码展示了如何实现SMOTE。

代码语言:javascript
复制
from imblearn.over_sampling import SMOTE

# Resample the minority class. You can change the strategy to 'auto' if you are not sure.
sm = SMOTE(sampling_strategy='minority', random_state=7)

# Fit the model to generate the data.
oversampled_trainX, oversampled_trainY = sm.fit_sample(credit_df.drop('Class', axis=1), credit_df['Class'])
oversampled_train = pd.concat([pd.DataFrame(oversampled_trainY), pd.DataFrame(oversampled_trainX)], axis=1)
oversampled_train.columns = normalized_df.columns

还记得我说过不平衡的数据会如何影响特征相关性吗?让我给你们展示一下处理不平衡类前后的相关关系。

重采样之前

下面的代码绘制了所有特征之间的相关矩阵。

代码语言:javascript
复制
# Sample figsize in inches
fig, ax = plt.subplots(figsize=(20,10))         
# Imbalanced DataFrame Correlation
corr = credit_df.corr()
sns.heatmap(corr, cmap='YlGnBu', annot_kws={'size':30}, ax=ax)
ax.set_title("Imbalanced Correlation Matrix", fontsize=14)
plt.show()
重采样之后

请注意,现在特征相关性更加明显。在处理不平衡问题之前,大多数特征都没有显示出任何相关性,这肯定会影响模型的性能。由于特征相关性对整个模型的性能非常重要,因此修复这种不平衡非常重要,因为它还会影响ML模型的性能。

2-集成方法(采样器集成)

在机器学习中,集成方法使用多种学习算法和技术来获得比单独使用任何一种组成学习算法更好的性能。(是的,就像民主党的投票制度一样)。当使用集成分类器时,bagging方法变得流行起来,它通过在不同随机选择的数据子集上构建多个估计器来工作。在scikit-learn库中,有一个名为baggingclassifier的集成分类器。但是,这个分类器不会平衡数据的每个子集。因此,当对不平衡数据集进行训练时,该分类器将有利于大多数类,并创建一个有偏差的模型。

为了解决这个问题,我们可以使用imblearn库中的BalancedBaggingClassifier。它允许在训练集合的每个估计器之前对数据集的每个子集进行重新采样。因此,BalancedBaggingClassifier除了使用sampling_strategy和replace这两个控制随机采样器行为的参数外,还使用了与scikit-learn baggingclassifierwith相同的参数。下面是一些代码,展示了如何做到这一点:

代码语言:javascript
复制
from imblearn.ensemble import BalancedBaggingClassifier
from sklearn.tree import DecisionTreeClassifier

#Create an object of the classifier.
bbc = BalancedBaggingClassifier(base_estimator=DecisionTreeClassifier(),
                                sampling_strategy='auto',
                                replacement=False,
                                random_state=0)

y_train = credit_df['Class']
X_train = credit_df.drop(['Class'], axis=1, inplace=False)

#Train the classifier.
bbc.fit(X_train, y_train)
preds = bbc.predict(X_train)

这样,您就可以训练一个分类器来处理这种不平衡,而不必在训练之前手工欠采样和过采样。 总之,每个人都应该知道,建立在不平衡数据集上的ML模型的总体性能将受到其预测罕见点和少数点的能力的限制。识别和解决这些点之间的不平衡对生成模型的质量和性能至关重要。

原文链接: https://www.kdnuggets.com/2019/05/fix-unbalanced-dataset.html

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-05-13,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 数据科学与人工智能 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 什么是数据不平衡?
  • 1-重采样(过采样和欠采样)
    • 重采样之前
      • 重采样之后
      • 2-集成方法(采样器集成)
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档