前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >kaggle实战-基于机器学习肾脏病预测

kaggle实战-基于机器学习肾脏病预测

作者头像
皮大大
发布2023-08-25 11:57:46
3660
发布2023-08-25 11:57:46
举报

kaggle实战:机器学习建模预测肾脏疾病

本文是针对kaggle上面一份肾脏疾病数据的建模

原数据集地址:

https://www.kaggle.com/datasets/mansoordaku/ckdisease?datasetId=1111&sortBy=voteCount

结果

先看看最终的结果对比:

  • KNN是分数最低的;LGBM第一。一般在kaggle,分类问题LGBM高频使用,且效果一般都比较好
  • 树模型中,以决策树为基础,效果都有所提升。

导入库

笔记1📒:一般在建模中,导入库包含:

  • 数据处理pandas为主
  • 可视化库:笔者一般用的Plotly结合seaborn;偶尔用原生的matplotlib和pyecharts
  • 各种回归和分类模型 + 评价指标
  • 其他:切分数据、降维、采样、标准化等
代码语言:javascript
复制
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import missingno as ms
import plotly.express as px
import plotly.graph_objs as go
import plotly.figure_factory as ff
from plotly.subplots import make_subplots
import plotly.offline as pyo
pyo.init_notebook_mode()
sns.set_style('darkgrid')

plt.style.use('fivethirtyeight')
%matplotlib inline

from sklearn.decomposition import PCA
from imblearn.over_sampling import SMOTE
from sklearn.model_selection import train_test_split,cross_val_score

from sklearn.ensemble import RandomForestClassifier,AdaBoostClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import f1_score as f1
from sklearn.metrics import confusion_matrix

import eli5
from eli5.sklearn import PermutationImportance
import shap
from pdpbox import pdp, info_plots

plt.rc('figure',figsize=(18,9))

import warnings
warnings.filterwarnings("ignore")

pd.set_option('display.max_columns', 26)

数据基本信息

很明显:上面数据中id字段是对建模无用的,直接drop函数删除:

In [3]:

代码语言:javascript
复制
df.drop("id",axis=1,inplace=True)

查看数据量大小:行数和字段属性数量

In [4]:

代码语言:javascript
复制
df.shape

Out[4]:

代码语言:javascript
复制
(400, 25)

总共是400条数据,25个字段

不同的字段类型统计:

In [5]:

代码语言:javascript
复制
df.dtypes

Out[5]:

代码语言:javascript
复制
age               float64
bp                float64
sg                float64
al                float64
su                float64
rbc                object
pc                 object
pcc                object
ba                 object
bgr               float64
bu                float64
sc                float64
sod               float64
pot               float64
hemo              float64
pcv                object
wc                 object
rc                 object
htn                object
dm                 object
cad                object
appet              object
pe                 object
ane                object
classification     object
dtype: object

In [6]:

代码语言:javascript
复制
pd.value_counts(df.dtypes)

Out[6]:

只包含两个类型的字段

代码语言:javascript
复制
object     14
float64    11
dtype: int64

查看缺失值情况:

In [7]:

代码语言:javascript
复制
df.isnull().sum().sort_values(ascending=False)

Out[7]:

代码语言:javascript
复制
rbc               152
rc                130
wc                105
pot                88
sod                87
pcv                70
pc                 65
hemo               52
su                 49
sg                 47
al                 46
bgr                44
bu                 19
sc                 17
bp                 12
age                 9
ba                  4
pcc                 4
htn                 2
dm                  2
cad                 2
appet               1
pe                  1
ane                 1
classification      0
dtype: int64

数值型字段的描述统计信息,通常是查看这些字段的统计值信息:总统计量、最值、四分位数等:

In [8]:

代码语言:javascript
复制
df.describe().style.background_gradient(cmap="ocean_r")  # 描述统计信息

数据基本信息:

In [9]:

代码语言:javascript
复制
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 400 entries, 0 to 399
Data columns (total 25 columns):
 #   Column          Non-Null Count  Dtype
---  ------          --------------  -----
 0   age             391 non-null    float64
 1   bp              388 non-null    float64
 2   sg              353 non-null    float64
 3   al              354 non-null    float64
 4   su              351 non-null    float64
 5   rbc             248 non-null    object
 6   pc              335 non-null    object
 7   pcc             396 non-null    object
 8   ba              396 non-null    object
 9   bgr             356 non-null    float64
 10  bu              381 non-null    float64
 11  sc              383 non-null    float64
 12  sod             313 non-null    float64
 13  pot             312 non-null    float64
 14  hemo            348 non-null    float64
 15  pcv             330 non-null    object
 16  wc              295 non-null    object
 17  rc              270 non-null    object
 18  htn             398 non-null    object
 19  dm              398 non-null    object
 20  cad             398 non-null    object
 21  appet           399 non-null    object
 22  pe              399 non-null    object
 23  ane             399 non-null    object
 24  classification  400 non-null    object
dtypes: float64(11), object(14)
memory usage: 78.2+ KB

字段解释

针对每个字段的中文含义解释:

In [10]:

代码语言:javascript
复制
columns = df.columns
columns

Out[10]:

代码语言:javascript
复制
Index(['age', 'bp', 'sg', 'al', 'su', 'rbc', 'pc', 'pcc', 'ba', 'bgr', 'bu',
       'sc', 'sod', 'pot', 'hemo', 'pcv', 'wc', 'rc', 'htn', 'dm', 'cad',
       'appet', 'pe', 'ane', 'classification'],
      dtype='object')
  • age:年龄
  • bp:blood_pressure,血压
  • sg:specific_gravity,比重值;肾脏疾病通常是检测尿比重
  • al:albumin,白蛋白
  • su:sugar,葡萄糖
  • rbc:red_blood_cells,【红血细胞】是否正常?
  • pc:pus_cell,【脓细胞】含量是否正常?
  • pcc:pus_cell_clumps,【脓细胞群】是否正常
  • ba:bacteria,是否【细菌】感染?
  • bgr:blood_glucose_random,随机血糖量
  • bu:blood_urea,血尿素
  • sc:serum_creatinine,血清肌酐
  • sod:sodium,钠
  • pot:potassium,钾
  • hemo:haemoglobin,血红蛋白
  • pcv:packed_cell_volume(PCV),血细胞压积,红细胞在血液中所占容积比
  • wc:white_blood_cell_count,白血细胞计数
  • rc:red_blood_cell_count,红血细胞计数
  • htn:hypertension,是否有【高血压】?
  • dm:diabetes_mellitus,是否有【糖尿病】?
  • cad:coronary_artery_disease,是否有【冠状动脉疾病】?
  • appet:appetite,是否有【食欲】?
  • pe:peda_edema,足部是否【水肿】?
  • ane:aanemia,是否【贫血】?
  • classification:分类结果,是否患病

字段预处理

下面我们对部分字段进行处理

字段classification

最终分类结果的处理

In [11]:

代码语言:javascript
复制
df["classification"].value_counts()  # 修改前

Out[11]:

代码语言:javascript
复制
ckd       248
notckd    150
ckd\t       2
Name: classification, dtype: int64

可以看到有2个记录是异常的,这种情况就是属于数据异常,需要手动定位发现统一改成ckd:

In [12]:

代码语言:javascript
复制
df["classification"] = df["classification"].apply(lambda x: x if x == "notckd" else "ckd")

In [13]:

代码语言:javascript
复制
df["classification"].value_counts()  # 修改后

Out[13]:

代码语言:javascript
复制
ckd       250
notckd    150
Name: classification, dtype: int64

年龄age

In [14]:

代码语言:javascript
复制
px.violin(df,y="age",color="classification")

pcv:packed_cell_volume(PCV)

PCV-血细胞压积,红细胞在血液中所占容积比

In [15]:

代码语言:javascript
复制
df["pcv"].value_counts()  # 修改前

可以看到这个字段存在不规范的记录,也需要处理:

In [16]:

代码语言:javascript
复制
df["pcv"] = pd.to_numeric(df["pcv"], errors="coerce")

In [17]:

代码语言:javascript
复制
df["pcv"].value_counts()  # 修改后

wc:white_blood_cell_count

白血细胞计数

In [18]:

代码语言:javascript
复制
df["wc"].value_counts()  # 修改后

Out[18]:

代码语言:javascript
复制
9800     11
6700     10
9200      9
9600      9
7200      9
         ..
19100     1
\t?       1
12300     1
14900     1
12700     1
Name: wc, Length: 92, dtype: int64

In [19]:

代码语言:javascript
复制
df["wc"] = pd.to_numeric(df["wc"], errors="coerce")

rc:red_blood_cell_count

红血细胞计数

In [20]:

代码语言:javascript
复制
df["rc"].value_counts()  # 修改前

也需要进行转化:

In [21]:

代码语言:javascript
复制
df["rc"] = pd.to_numeric(df["rc"], errors="coerce")

In [22]:

代码语言:javascript
复制
# 不同字段类型统计

pd.value_counts(df.dtypes)

Out[22]:

代码语言:javascript
复制
float64    14
object     11
dtype: int64

dm:diabetes_mellitus

是否有【糖尿病】?

In [23]:

代码语言:javascript
复制
df["dm"].value_counts()

Out[23]:

代码语言:javascript
复制
no       258
yes      134
\tno       3
\tyes      2
 yes       1
Name: dm, dtype: int64

dm字段存在异常,一般是空格和换行符引起的;我们将取值统一成no和yes

In [24]:

代码语言:javascript
复制
df["dm"] = df["dm"].str.strip()  # 去除空格

In [25]:

代码语言:javascript
复制
df["dm"].value_counts()

Out[25]:

代码语言:javascript
复制
no     261
yes    137
Name: dm, dtype: int64

cad:coronary_artery_disease

是否有【冠状动脉疾病】?

In [26]:

代码语言:javascript
复制
df["cad"].value_counts()

Out[26]:

代码语言:javascript
复制
no      362
yes      34
\tno      2
Name: cad, dtype: int64

In [27]:

代码语言:javascript
复制
df["cad"] = df["cad"].str.strip()  # 去除空格

查看处理后df的信息:

In [28]:

代码语言:javascript
复制
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 400 entries, 0 to 399
Data columns (total 25 columns):
 #   Column          Non-Null Count  Dtype
---  ------          --------------  -----
 0   age             391 non-null    float64
 1   bp              388 non-null    float64
 2   sg              353 non-null    float64
 3   al              354 non-null    float64
 4   su              351 non-null    float64
 5   rbc             248 non-null    object
 6   pc              335 non-null    object
 7   pcc             396 non-null    object
 8   ba              396 non-null    object
 9   bgr             356 non-null    float64
 10  bu              381 non-null    float64
 11  sc              383 non-null    float64
 12  sod             313 non-null    float64
 13  pot             312 non-null    float64
 14  hemo            348 non-null    float64
 15  pcv             329 non-null    float64
 16  wc              294 non-null    float64
 17  rc              269 non-null    float64
 18  htn             398 non-null    object
 19  dm              398 non-null    object
 20  cad             398 non-null    object
 21  appet           399 non-null    object
 22  pe              399 non-null    object
 23  ane             399 non-null    object
 24  classification  400 non-null    object
dtypes: float64(14), object(11)
memory usage: 78.2+ KB

不同特征分布

In [29]:

代码语言:javascript
复制
# 分类型
cat_cols = [col for col in df.columns if df[col].dtype == "object"]
#  连续型
num_cols = [col for col in df.columns if df[col].dtype != "object"]

分类型变量取值

下面查看分类型变量的不同取值情况:

In [30]:

代码语言:javascript
复制
for col in cat_cols:
    print("变量:", col)
    print(df[col].value_counts())
    print("-"* 10)
变量: rbc
normal      201
abnormal     47
Name: rbc, dtype: int64
----------
变量: pc
normal      259
abnormal     76
Name: pc, dtype: int64
----------
变量: pcc
notpresent    354
present        42
Name: pcc, dtype: int64
----------
变量: ba
notpresent    374
present        22
Name: ba, dtype: int64
----------
变量: htn
no     251
yes    147
Name: htn, dtype: int64
----------
变量: dm
no     261
yes    137
Name: dm, dtype: int64
----------
变量: cad
no     364
yes     34
Name: cad, dtype: int64
----------
变量: appet
good    317
poor     82
Name: appet, dtype: int64
----------
变量: pe
no     323
yes     76
Name: pe, dtype: int64
----------
变量: ane
no     339
yes     60
Name: ane, dtype: int64
----------
变量: classification
ckd       250
notckd    150
Name: classification, dtype: int64
----------

In [31]:

代码语言:javascript
复制
# 分类型变量统计

len(cat_cols)

Out[31]:

代码语言:javascript
复制
11

In [32]:

代码语言:javascript
复制
plt.figure(figsize = (20, 16))
sub_plotnumber = 1

for col in cat_cols:
    if  sub_plotnumber <= 12:
        ax = plt.subplot(4, 3, sub_plotnumber)  # 4行3列的画布;plotnumber表示子图位置
        sns.countplot(df[col], palette = "gist_earth") # 绘图
        plt.xlabel(col)  # x轴标签

    sub_plotnumber += 1  # 自加1

plt.tight_layout()
plt.show()

连续型变量分布

In [33]:

代码语言:javascript
复制
len(num_cols)  # 总共是14个连续型数值变量

Out[33]:

代码语言:javascript
复制
14

In [34]:

代码语言:javascript
复制
plt.figure(figsize = (20, 16))
sub_plotnumber = 1

for col in num_cols:
    if  sub_plotnumber <= 14:
        ax = plt.subplot(4, 4, sub_plotnumber)  # 3行5列的画布;plotnumber表示子图位置
        sns.distplot(df[col])  # 绘图
        plt.xlabel(col)  # x轴标签

    sub_plotnumber += 1  # 自加1

plt.tight_layout()
plt.show()

小结:可以看到多个特征的分布存在一定的偏度skewness(更多的是左偏)

不同特征分布

定义3个不同绘图函数

In [35]:

代码语言:javascript
复制
# 1、小提琴图:查看数据分布情况
def violin(col):
    fig = px.violin(df,
                    y=col,
                    x="classification",
                    color="classification",
                    box=True,
                    template = 'plotly_dark')

    return fig.show()

# 2、kde密度图:是否有正态分布
def kde(col):
    grid = sns.FacetGrid(df,
                         hue="classification",
                         height = 6,
                         aspect=2)
    grid.map(sns.kdeplot, col)
    grid.add_legend()

# 3、散点图:两个变量之间的关系
def scatter(col1, col2):
    fig = px.scatter(df,
                     x=col1,
                     y=col2,
                     color="classification",
                     template = 'plotly_dark')
    return fig.show()

In [36]:

代码语言:javascript
复制
violin("rc")
代码语言:javascript
复制
kde("rc")

两两变量关系

两个连续型变量之间的关系

缺失值处理

整体缺失情况

In [56]:

代码语言:javascript
复制
# 全部字段的缺失值情况

df.isnull().sum().sort_values(ascending = False)

Out[56]:

代码语言:javascript
复制
rbc               152
rc                131
wc                106
pot                88
sod                87
pcv                71
pc                 65
hemo               52
su                 49
sg                 47
al                 46
bgr                44
bu                 19
sc                 17
bp                 12
age                 9
ba                  4
pcc                 4
htn                 2
dm                  2
cad                 2
appet               1
pe                  1
ane                 1
classification      0
dtype: int64

In [57]:

代码语言:javascript
复制
# 连续型变量缺失

df[num_cols].isnull().sum()

Out[57]:

代码语言:javascript
复制
age       9
bp       12
sg       47
al       46
su       49
bgr      44
bu       19
sc       17
sod      87
pot      88
hemo     52
pcv      71
wc      106
rc      131
dtype: int64

In [58]:

代码语言:javascript
复制
# 分类型变量缺失

df[cat_cols].isnull().sum()

Out[58]:

代码语言:javascript
复制
rbc               152
pc                 65
pcc                 4
ba                  4
htn                 2
dm                  2
cad                 2
appet               1
pe                  1
ane                 1
classification      0
dtype: int64

两种填充方式

  1. 随机采样填充:在字段现有值的数据中随机采样进行填充,针对的缺失值较多的字段
  2. 均值或众数填充:针对缺失值较少的字段,用该字段现有数据的均值或者众数填充

In [59]:

代码语言:javascript
复制
df["rbc"].isna().sum() # 表示某个字段的缺失量

Out[59]:

代码语言:javascript
复制
152

In [60]:

代码语言:javascript
复制
df["dm"].mode()[0]   # 某个字段的众数

Out[60]:

代码语言:javascript
复制
'no'

In [61]:

代码语言:javascript
复制
def random_value_imputate(col):
    """
    函数:随机填充方法(缺失值较多的字段)
    """

    # 1、确定填充的数量;在取出缺失值随机选择缺失值数量的样本
    random_sample = df[col].dropna().sample(df[col].isna().sum())
    # 2、索引号就是原缺失值记录的索引号
    random_sample.index = df[df[col].isnull()].index
    # 3、通过loc函数定位填充
    df.loc[df[col].isnull(), col] = random_sample


def mode_impute(col):
    """
    函数:众数填充缺失值
    """
    # 1、确定众数
    mode = df[col].mode()[0]
    # 2、fillna函数填充众数
    df[col] = df[col].fillna(mode)

1、连续型变量使用随机填充方法:

In [62]:

代码语言:javascript
复制
for col in num_cols:
    random_value_imputate(col)

2、分类型变量,针对字段不同方法不同:

In [63]:

代码语言:javascript
复制
# 随机填充
random_value_imputate('rbc')
random_value_imputate('pc')

In [64]:

代码语言:javascript
复制
# 其他字段是众数填充

for col in cat_cols:
    mode_impute(col)

填充完成后数据就没有缺失值:

In [65]:

代码语言:javascript
复制
df.isnull().sum()

Out[65]:

代码语言:javascript
复制
age               0
bp                0
sg                0
al                0
su                0
rbc               0
pc                0
pcc               0
ba                0
bgr               0
bu                0
sc                0
sod               0
pot               0
hemo              0
pcv               0
wc                0
rc                0
htn               0
dm                0
cad               0
appet             0
pe                0
ane               0
classification    0
dtype: int64

相关性分析

In [67]:

代码语言:javascript
复制
df["classification"] = df["classification"].map({"ckd":0, "notckd":1})

corr = df.corr()  # 仅针对数值连续变量

plt.figure(figsize = (15, 8))

sns.heatmap(corr,
            annot = True,
            linewidths = 2,
            linecolor = 'lightgrey')

plt.show()

可以看到和classification强相关的特征主要是:sg(尿比重)、hemo(血红蛋白)、pcv(血细胞压积,红细胞在血液中所占容积比)、rc(红细胞数量)

特征编码

针对分类型变量编码:

In [68]:

代码语言:javascript
复制
for col in cat_cols:
    print(f"Categories of {col}: {df[col].nunique()} ")
Categories of rbc: 2
Categories of pc: 2
Categories of pcc: 2
Categories of ba: 2
Categories of htn: 2
Categories of dm: 2
Categories of cad: 2
Categories of appet: 2
Categories of pe: 2
Categories of ane: 2
Categories of classification: 2

所有的分类型变量都是两种取值情况,我们直接使用类型编码,变成0-1即可:

In [69]:

代码语言:javascript
复制
from sklearn.preprocessing import LabelEncoder

led = LabelEncoder()

for col in cat_cols:
    df[col] = led.fit_transform(df[col])

为了分析的方便,也对classification字段进行编码:

In [70]:

代码语言:javascript
复制
df["classification"].value_counts()

Out[70]:

代码语言:javascript
复制
0    250
1    150
Name: classification, dtype: int64

建模

特征和目标

In [71]:

代码语言:javascript
复制
X = df.drop("classification",axis=1)
y = df["classification"]

训练集和测试集

In [72]:

代码语言:javascript
复制
# 随机打乱数据

from sklearn.utils import shuffle
df = shuffle(df)

In [73]:

代码语言:javascript
复制
# from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.20, random_state = 0)

定义建模函数

In [75]:

代码语言:javascript
复制
def create_model(model):
    # 模型训练
    model.fit(X_train, y_train)
    # 模型预测
    y_pred = model.predict(X_test)
    # 准确率acc
    acc = accuracy_score(y_test, y_pred)
    # 混淆矩阵
    cm = confusion_matrix(y_test, y_pred)
    # 分类报告
    cr = classification_report(y_test,y_pred)

    print(f"Test Accuracy of {model} : {acc}")
    print(f"Confusion Matrix of {model}: \n{cm}")
    print(f"Classification Report of {model} : \n {cr}")

8种模型

KNN

In [76]:

代码语言:javascript
复制
from sklearn.neighbors import KNeighborsClassifier

knn = KNeighborsClassifier()
create_model(knn)
Test Accuracy of KNeighborsClassifier() : 0.7375
Confusion Matrix of KNeighborsClassifier():
[[35 17]
 [ 8 20]]
Classification Report of KNeighborsClassifier() :
               precision    recall  f1-score   support

           0       0.81      0.67      0.74        52
           1       0.54      0.71      0.62        28

    accuracy                           0.69        80
   macro avg       0.68      0.69      0.68        80
weighted avg       0.72      0.69      0.69        80

决策树

In [77]:

代码语言:javascript
复制
from sklearn.tree import DecisionTreeClassifier

dt = DecisionTreeClassifier()
create_model(dt)
Test Accuracy of DecisionTreeClassifier() : 0.9375
Confusion Matrix of DecisionTreeClassifier():
[[50  2]
 [ 3 25]]
Classification Report of DecisionTreeClassifier() :
               precision    recall  f1-score   support

           0       0.94      0.96      0.95        52
           1       0.93      0.89      0.91        28

    accuracy                           0.94        80
   macro avg       0.93      0.93      0.93        80
weighted avg       0.94      0.94      0.94        80

随机森林Random Forest Classifier

In [78]:

代码语言:javascript
复制
from sklearn.ensemble import RandomForestClassifier

rd_clf = RandomForestClassifier(criterion = 'entropy',
                                max_depth = 11,
                                max_features = 'auto',
                                min_samples_leaf = 2, min_samples_split = 3, n_estimators = 130)
create_model(rd_clf)
Test Accuracy of RandomForestClassifier(criterion='entropy', max_depth=11, min_samples_leaf=2,
                       min_samples_split=3, n_estimators=130) : 0.95
Confusion Matrix of RandomForestClassifier(criterion='entropy', max_depth=11, min_samples_leaf=2,
                       min_samples_split=3, n_estimators=130):
[[52  0]
 [ 4 24]]
Classification Report of RandomForestClassifier(criterion='entropy', max_depth=11, min_samples_leaf=2,
                       min_samples_split=3, n_estimators=130) :
               precision    recall  f1-score   support

           0       0.93      1.00      0.96        52
           1       1.00      0.86      0.92        28

    accuracy                           0.95        80
   macro avg       0.96      0.93      0.94        80
weighted avg       0.95      0.95      0.95        80

Ada Boost Classifier

In [80]:

代码语言:javascript
复制
from sklearn.ensemble import AdaBoostClassifier

ada = AdaBoostClassifier(base_estimator = dt)
create_model(ada)
Test Accuracy of AdaBoostClassifier(base_estimator=DecisionTreeClassifier()) : 0.95
Confusion Matrix of AdaBoostClassifier(base_estimator=DecisionTreeClassifier()):
[[51  1]
 [ 3 25]]
Classification Report of AdaBoostClassifier(base_estimator=DecisionTreeClassifier()) :
               precision    recall  f1-score   support

           0       0.94      0.98      0.96        52
           1       0.96      0.89      0.93        28

    accuracy                           0.95        80
   macro avg       0.95      0.94      0.94        80
weighted avg       0.95      0.95      0.95        80

Gradient Boosting Classifier

In [81]:

代码语言:javascript
复制
from sklearn.ensemble import GradientBoostingClassifier

gb = GradientBoostingClassifier()
create_model(gb)
Test Accuracy of GradientBoostingClassifier() : 0.9625
Confusion Matrix of GradientBoostingClassifier():
[[51  1]
 [ 3 25]]
Classification Report of GradientBoostingClassifier() :
               precision    recall  f1-score   support

           0       0.94      0.98      0.96        52
           1       0.96      0.89      0.93        28

    accuracy                           0.95        80
   macro avg       0.95      0.94      0.94        80
weighted avg       0.95      0.95      0.95        80

XgBoost

In [82]:

代码语言:javascript
复制
from xgboost import XGBClassifier

xgb = XGBClassifier(objective = 'binary:logistic',
                    learning_rate = 0.5,
                    max_depth = 5,
                    n_estimators = 150)

create_model(xgb)
Test Accuracy of XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
              colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1,
              importance_type='gain', interaction_constraints='',
              learning_rate=0.5, max_delta_step=0, max_depth=5,
              min_child_weight=1, missing=nan, monotone_constraints='()',
              n_estimators=150, n_jobs=0, num_parallel_tree=1, random_state=0,
              reg_alpha=0, reg_lambda=1, scale_pos_weight=1, subsample=1,
              tree_method='exact', validate_parameters=1, verbosity=None) : 0.9625
Confusion Matrix of XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
              colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1,
              importance_type='gain', interaction_constraints='',
              learning_rate=0.5, max_delta_step=0, max_depth=5,
              min_child_weight=1, missing=nan, monotone_constraints='()',
              n_estimators=150, n_jobs=0, num_parallel_tree=1, random_state=0,
              reg_alpha=0, reg_lambda=1, scale_pos_weight=1, subsample=1,
              tree_method='exact', validate_parameters=1, verbosity=None):
[[52  0]
 [ 3 25]]
Classification Report of XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
              colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1,
              importance_type='gain', interaction_constraints='',
              learning_rate=0.5, max_delta_step=0, max_depth=5,
              min_child_weight=1, missing=nan, monotone_constraints='()',
              n_estimators=150, n_jobs=0, num_parallel_tree=1, random_state=0,
              reg_alpha=0, reg_lambda=1, scale_pos_weight=1, subsample=1,
              tree_method='exact', validate_parameters=1, verbosity=None) :
               precision    recall  f1-score   support

           0       0.95      1.00      0.97        52
           1       1.00      0.89      0.94        28

    accuracy                           0.96        80
   macro avg       0.97      0.95      0.96        80
weighted avg       0.96      0.96      0.96        80

Cat Boost Classifier

In [83]:

代码语言:javascript
复制
from catboost import CatBoostClassifier

cab = CatBoostClassifier(iterations=10)
create_model(cab)
Learning rate set to 0.432149
0:	learn: 0.2531464	total: 62.5ms	remaining: 562ms
1:	learn: 0.1524287	total: 63.9ms	remaining: 256ms
2:	learn: 0.0906595	total: 65.3ms	remaining: 152ms
3:	learn: 0.0578563	total: 66.8ms	remaining: 100ms
4:	learn: 0.0460263	total: 68.2ms	remaining: 68.2ms
5:	learn: 0.0356541	total: 69.6ms	remaining: 46.4ms
6:	learn: 0.0268575	total: 70.9ms	remaining: 30.4ms
7:	learn: 0.0206936	total: 72.2ms	remaining: 18ms
8:	learn: 0.0186242	total: 73.6ms	remaining: 8.17ms
9:	learn: 0.0162996	total: 75ms	remaining: 0us
Test Accuracy of <catboost.core.CatBoostClassifier object at 0x1296bfe50> : 0.9750
Confusion Matrix of <catboost.core.CatBoostClassifier object at 0x1296bfe50>:
[[51  1]
 [ 3 25]]
Classification Report of <catboost.core.CatBoostClassifier object at 0x1296bfe50> :
               precision    recall  f1-score   support

           0       0.94      0.98      0.96        52
           1       0.96      0.89      0.93        28

    accuracy                           0.95        80
   macro avg       0.95      0.94      0.94        80
weighted avg       0.95      0.95      0.95        80

Extra Trees Classifier

In [84]:

代码语言:javascript
复制
from sklearn.ensemble import ExtraTreesClassifier

etc = ExtraTreesClassifier()
create_model(etc)
Test Accuracy of ExtraTreesClassifier() : 0.9625
Confusion Matrix of ExtraTreesClassifier():
[[52  0]
 [ 3 25]]
Classification Report of ExtraTreesClassifier() :
               precision    recall  f1-score   support

           0       0.95      1.00      0.97        52
           1       1.00      0.89      0.94        28

    accuracy                           0.96        80
   macro avg       0.97      0.95      0.96        80
weighted avg       0.96      0.96      0.96        80

LGBM

In [85]:

代码语言:javascript
复制
from lightgbm import LGBMClassifier

lgbm = LGBMClassifier(learning_rate = 0.1)
create_model(lgbm)
Test Accuracy of LGBMClassifier() : 0.9625
Confusion Matrix of LGBMClassifier():
[[51  1]
 [ 2 26]]
Classification Report of LGBMClassifier() :
               precision    recall  f1-score   support

           0       0.96      0.98      0.97        52
           1       0.96      0.93      0.95        28

    accuracy                           0.96        80
   macro avg       0.96      0.95      0.96        80
weighted avg       0.96      0.96      0.96        80

模型对比

In [86]:

代码语言:javascript
复制
models = pd.DataFrame({"model":["KNN","Decision Tree","Random Forest","Ada Boost ",
                                "Gradient Boosting","Xgboost","Cat Boost","Extra Trees","LGBM"],
                      "acc":[0.6875,0.9375,0.95,0.95,0.95,0.9625,0.95,0.9625,0.9625]})

models

Out[86]:

In [87]:

代码语言:javascript
复制
models = models.sort_values("acc",ascending=True)  # 升序排列
models

Out[87]:

In [88]:

代码语言:javascript
复制
px.bar(models,
       x="acc",
       y="model",
       text="acc",
       color = 'acc',
       template = 'plotly_dark',
       title = 'Nine Models Comparison')

模型可解释性

我们在这里选择随机森林模型(rd_clf)同时使用shap库来进行解释

shap值计算

In [89]:

代码语言:javascript
复制
explainer = shap.TreeExplainer(rd_clf)
# 在explainer中传入特征值的数据,计算shap值
shap_values = explainer.shap_values(X_test)
shap_values

Out[89]:

代码语言:javascript
复制
[array([[ 0.00082722, -0.00174422,  0.08114011, ..., -0.00147523,
          0.01219736,  0.00090329],
        [ 0.00241676, -0.00674329, -0.10784976, ..., -0.00708956,
         -0.00392391, -0.00032504],
        [ 0.00211145, -0.00874862, -0.12467772, ..., -0.00752903,
         -0.00413827, -0.00035173],
        ...,
        [-0.00775482, -0.00930411, -0.12662254, ..., -0.00807709,
         -0.00414294, -0.00035052],
        [-0.00719232, -0.00715492, -0.11854169, ..., -0.00821187,
         -0.00448367, -0.00035973],
        [-0.00499459, -0.00897095, -0.12441733, ..., -0.0081471 ,
         -0.00434574, -0.00035173]]),
 array([[-0.00082722,  0.00174422, -0.08114011, ...,  0.00147523,
         -0.01219736, -0.00090329],
        [-0.00241676,  0.00674329,  0.10784976, ...,  0.00708956,
          0.00392391,  0.00032504],
        [-0.00211145,  0.00874862,  0.12467772, ...,  0.00752903,
          0.00413827,  0.00035173],
        ...,
        [ 0.00775482,  0.00930411,  0.12662254, ...,  0.00807709,
          0.00414294,  0.00035052],
        [ 0.00719232,  0.00715492,  0.11854169, ...,  0.00821187,
          0.00448367,  0.00035973],
        [ 0.00499459,  0.00897095,  0.12441733, ...,  0.0081471 ,
          0.00434574,  0.00035173]])]

Feature Importance

In [90]:

代码语言:javascript
复制
shap.summary_plot(shap_values[1], X_test, plot_type="bar")

从结果来看,sg(尿比重)、sc(血清肌酐)、hemo(血红蛋白)是重点影响特征。

代码语言:javascript
复制
shap.summary_plot(shap_values[1], X_test)

summary plot 为每个样本绘制其每个特征的SHAP值;一个点代表一个样本,颜色表示特征值的高低(红色高,蓝色低)

个体差异

查看单个病人的不同特征属性对其结果的影响:

从选择3个病人的结果来看,即使同样是患病者shap值的个体差异仍然很大。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2022-10-20,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • kaggle实战:机器学习建模预测肾脏疾病
  • 结果
  • 导入库
  • 数据基本信息
  • 字段解释
  • 字段预处理
    • 字段classification
      • 年龄age
        • pcv:packed_cell_volume(PCV)
          • wc:white_blood_cell_count
            • rc:red_blood_cell_count
              • dm:diabetes_mellitus
                • cad:coronary_artery_disease
                • 不同特征分布
                  • 分类型变量取值
                    • 连续型变量分布
                      • 不同特征分布
                        • 两两变量关系
                          • 整体缺失情况
                            • 两种填充方式
                            • 相关性分析
                            • 特征编码
                            • 建模
                              • 特征和目标
                                • 训练集和测试集
                                  • 定义建模函数
                                  • 8种模型
                                    • KNN
                                      • 决策树
                                        • 随机森林Random Forest Classifier
                                          • Ada Boost Classifier
                                            • Gradient Boosting Classifier
                                              • XgBoost
                                                • Cat Boost Classifier
                                                  • Extra Trees Classifier
                                                    • LGBM
                                                    • 模型对比
                                                    • 模型可解释性
                                                      • shap值计算
                                                        • Feature Importance
                                                          • 个体差异
                                                          领券
                                                          问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档