# 如何处理机器学习中类的不平衡问题

## Balance Scale数据集

```import pandas as pd
import numpy as np

names=['balance', 'var1', 'var2', 'var3', 'var4'])

# Display example observations

• 它有一个目标变量，我们把它标记为balance。
• 它有四个输入特性，我们通过var4把它标记为var1。

• 右重的R，即当var3 * var4 > var1 * var2时
• 左重的L，即当var3 * var4 < var1 * var2时
• 平衡的B，即当var3 * var4 = var1 * var2时
```df['balance'].value_counts()
# R    288
# L    288
# B     49
# Name: balance, dtype: int64```

```# Transform into binary classification
df['balance'] = [1 if b=='B' else 0 for b in df.balance]

df['balance'].value_counts()
# 0    576
# 1     49
# Name: balance, dtype: int64

## 不平衡类的危害

```from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score```

```# Separate input features (X) and target variable (y)
y = df.balance
X = df.drop('balance', axis=1)

# Train model
clf_0 = LogisticRegression().fit(X, y)

# Predict on training set
pred_y_0 = clf_0.predict(X)```

```# How's the accuracy?
print( accuracy_score(pred_y_0, y) )
# 0.9216```

```# Should we be excited?
print( np.unique( pred_y_0 ) )
# [0]```

## 1．上采样少数类

`from sklearn.utils import resample`

1. 首先，我们将把每个类的观察分离到不同的DataFrames。
2. 接下来，我们将用替换来对少数类进行重新取样，并设置与多数类相匹配的样本数量。
3. 最后，我们将把上采样的少数类DataFrame与原始的多数类DataFrame合并在一起。

```# Separate majority and minority classes
df_majority = df[df.balance==0]
df_minority = df[df.balance==1]

# Upsample minority class
df_minority_upsampled = resample(df_minority,
replace=True,     # sample with replacement
n_samples=576,    # to match majority class
random_state=123) # reproducible results

# Combine majority class with upsampled minority class
df_upsampled = pd.concat([df_majority, df_minority_upsampled])

# Display new class counts
df_upsampled.balance.value_counts()
# 1    576
# 0    576
# Name: balance, dtype: int64```

```# Separate input features (X) and target variable (y)
y = df_upsampled.balance
X = df_upsampled.drop('balance', axis=1)

# Train model
clf_1 = LogisticRegression().fit(X, y)

# Predict on training set
pred_y_1 = clf_1.predict(X)

# Is our model still predicting just one class?
print( np.unique( pred_y_1 ) )
# [0 1]

# How's our accuracy?
print( accuracy_score(y, pred_y_1) )
# 0.513888888889```

## 2．下采样多数类

1. 首先，我们将把每个类的观察分离到不同的DataFrames。
2. 接下来，我们将在没有替换的情况下对多数类进行重新取样，并设置与少数类相匹配的样本数量。
3. 最后，我们将把下采样的多数类DataFrame与原始的少数类DataFrame合并在一起。

```# Separate majority and minority classes
df_majority = df[df.balance==0]
df_minority = df[df.balance==1]

# Downsample majority class
df_majority_downsampled = resample(df_majority,
replace=False,    # sample without replacement
n_samples=49,     # to match minority class
random_state=123) # reproducible results

# Combine minority class with downsampled majority class
df_downsampled = pd.concat([df_majority_downsampled, df_minority])

# Display new class counts
df_downsampled.balance.value_counts()
# 1    49
# 0    49
# Name: balance, dtype: int64```

```# Separate input features (X) and target variable (y)
y = df_downsampled.balance
X = df_downsampled.drop('balance', axis=1)

# Train model
clf_2 = LogisticRegression().fit(X, y)

# Predict on training set
pred_y_2 = clf_2.predict(X)

# Is our model still predicting just one class?
print( np.unique( pred_y_2 ) )
# [0 1]

# How's our accuracy?
print( accuracy_score(y, pred_y_2) )
# 0.581632653061```

## 3．改变你的性能指标

• 我们不会在本教程中详细介绍它的细节，但你可以在这里阅读更多相关内容。
• 直观地说，AUROC代表了你的模型将观察与两个类区分开来的可能性。
• 换句话说，如果你随机地从每个类中选择一个观察，你的模型能够正确地排列它们的概率是多少?

`from sklearn.metrics import roc_auc_score`

```# Predict class probabilities
prob_y_2 = clf_2.predict_proba(X)

# Keep only the positive class
prob_y_2 = [p[1] for p in prob_y_2]

prob_y_2[:5] # Example
# [0.45419197226479618,
#  0.48205962213283882,
#  0.46862327066392456,
#  0.47868378832689096,
#  0.58143856820159667]```

```print( roc_auc_score(y, prob_y_2) )
# 0.568096626406```

```prob_y_0 = clf_0.predict_proba(X)
prob_y_0 = [p[1] for p in prob_y_0]

print( roc_auc_score(y, prob_y_0) )
# 0.530718537415```

## 4．惩罚算法(代价敏感训练)

`from sklearn.svm import SVC`

```# Separate input features (X) and target variable (y)
y = df.balance
X = df.drop('balance', axis=1)

# Train model
clf_3 = SVC(kernel='linear',
class_weight='balanced', # penalize
probability=True)

clf_3.fit(X, y)

# Predict on training set
pred_y_3 = clf_3.predict(X)

# Is our model still predicting just one class?
print( np.unique( pred_y_3 ) )
# [0 1]

# How's our accuracy?
print( accuracy_score(y, pred_y_3) )
# 0.688

prob_y_3 = clf_3.predict_proba(X)
prob_y_3 = [p[1] for p in prob_y_3]
print( roc_auc_score(y, prob_y_3) )
# 0.5305236678```

## 5．使用树型结构算法

`from sklearn.ensemble import RandomForestClassifier`

```# Separate input features (X) and target variable (y)
y = df.balance
X = df.drop('balance', axis=1)

# Train model
clf_4 = RandomForestClassifier()
clf_4.fit(X, y)

# Predict on training set
pred_y_4 = clf_4.predict(X)

# Is our model still predicting just one class?
print( np.unique( pred_y_4 ) )
# [0 1]

# How's our accuracy?
print( accuracy_score(y, pred_y_4) )
# 0.9744

prob_y_4 = clf_4.predict_proba(X)
prob_y_4 = [p[1] for p in prob_y_4]
print( roc_auc_score(y, prob_y_4) )
# 0.999078798186```

## 结论与展望

0 条评论

• ### 帮助数据科学家理解数据的23个pandas常用代码

返回给定轴缺失的标签对象，并在那里删除所有缺失数据（’any’：如果存在任何NA值，则删除该行或列。）。

• ### 机器学习项目：建立一个酒店推荐引擎

所有在线旅行社都在争先恐后地满足亚马逊和网飞（Netflix）设定的AI驱动的个性化标准。此外，在线旅游已经成为一个竞争激烈的领域，品牌试图通过推荐，对比，匹配...

• ### 优步正在开发AI来识别醉酒乘客

该专利描述了一种测量用户在手机上的行为与其通常行为的方法。该系统依靠一种算法来衡量多种因素，包括错别字，用户点击链接和按钮的准确度，用户行走的速度以及请求乘车所...

• ### 基于 CentOS 搭建 FTP 文件服务

此时，访问 ftp://192.168.1.170 可浏览机器上的 /var/ftp目录了。

概述 Stetho 是 Facebook 开源的一个 Android 调试工具。是一个 Chrome Developer Tools 的扩展，可用来检测应用的网...

• ### Python+KNN算法判断单词相似度小案例

本文代码用于判断待测单词与哪个候选单词最接近，判断标准为字母出现频次（直方图）最接近，只考虑了不小心的拼写错误，而没有考虑故意的拼写错误，例如故意把god写成d...

• ### 高薪编程，品牌公司——人往高处走，作为程序员的你够格吗

暑假马上就要结束了，暑假没回过家，在学校留校学习一个半月，每天键盘敲击声不断，很充实，每天都在不同程度的进步。且不说学了多少东西，头发反正是稀疏了不少，无奈之...

• ### 利用OpenCV的人脸检测给头像带上圣诞帽

我们来看下效果 原图： ? 效果： ? 原理其实很简单： 采用一张圣诞帽的png图像作为素材， ? 利用png图像背景是透明的，贴在背景图片上就是戴...

• ### LeetCode-192. 统计词频

写一个 bash 脚本以统计一个文本文件 words.txt 中每个单词出现的频率。

• ### 用Python实现WGS84、火星坐标系、百度坐标系、web墨卡托四种坐标相互转换

主流被使用的地理坐标系并不统一，常用的有WGS84、GCJ02（火星坐标系）、BD09（百度坐标系）以及百度地图中保存矢量信息的web墨卡托，本文利用Pyt...