本文是针对kaggle上面一份肾脏疾病数据的建模
原数据集地址:
https://www.kaggle.com/datasets/mansoordaku/ckdisease?datasetId=1111&sortBy=voteCount
先看看最终的结果对比:
笔记1📒:一般在建模中,导入库包含:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import missingno as ms
import plotly.express as px
import plotly.graph_objs as go
import plotly.figure_factory as ff
from plotly.subplots import make_subplots
import plotly.offline as pyo
pyo.init_notebook_mode()
sns.set_style('darkgrid')
plt.style.use('fivethirtyeight')
%matplotlib inline
from sklearn.decomposition import PCA
from imblearn.over_sampling import SMOTE
from sklearn.model_selection import train_test_split,cross_val_score
from sklearn.ensemble import RandomForestClassifier,AdaBoostClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import f1_score as f1
from sklearn.metrics import confusion_matrix
import eli5
from eli5.sklearn import PermutationImportance
import shap
from pdpbox import pdp, info_plots
plt.rc('figure',figsize=(18,9))
import warnings
warnings.filterwarnings("ignore")
pd.set_option('display.max_columns', 26)
很明显:上面数据中id字段是对建模无用的,直接drop函数删除:
In [3]:
df.drop("id",axis=1,inplace=True)
查看数据量大小:行数和字段属性数量
In [4]:
df.shape
Out[4]:
(400, 25)
总共是400条数据,25个字段
不同的字段类型统计:
In [5]:
df.dtypes
Out[5]:
age float64
bp float64
sg float64
al float64
su float64
rbc object
pc object
pcc object
ba object
bgr float64
bu float64
sc float64
sod float64
pot float64
hemo float64
pcv object
wc object
rc object
htn object
dm object
cad object
appet object
pe object
ane object
classification object
dtype: object
In [6]:
pd.value_counts(df.dtypes)
Out[6]:
只包含两个类型的字段
object 14
float64 11
dtype: int64
查看缺失值情况:
In [7]:
df.isnull().sum().sort_values(ascending=False)
Out[7]:
rbc 152
rc 130
wc 105
pot 88
sod 87
pcv 70
pc 65
hemo 52
su 49
sg 47
al 46
bgr 44
bu 19
sc 17
bp 12
age 9
ba 4
pcc 4
htn 2
dm 2
cad 2
appet 1
pe 1
ane 1
classification 0
dtype: int64
数值型字段的描述统计信息,通常是查看这些字段的统计值信息:总统计量、最值、四分位数等:
In [8]:
df.describe().style.background_gradient(cmap="ocean_r") # 描述统计信息
数据基本信息:
In [9]:
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 400 entries, 0 to 399
Data columns (total 25 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 age 391 non-null float64
1 bp 388 non-null float64
2 sg 353 non-null float64
3 al 354 non-null float64
4 su 351 non-null float64
5 rbc 248 non-null object
6 pc 335 non-null object
7 pcc 396 non-null object
8 ba 396 non-null object
9 bgr 356 non-null float64
10 bu 381 non-null float64
11 sc 383 non-null float64
12 sod 313 non-null float64
13 pot 312 non-null float64
14 hemo 348 non-null float64
15 pcv 330 non-null object
16 wc 295 non-null object
17 rc 270 non-null object
18 htn 398 non-null object
19 dm 398 non-null object
20 cad 398 non-null object
21 appet 399 non-null object
22 pe 399 non-null object
23 ane 399 non-null object
24 classification 400 non-null object
dtypes: float64(11), object(14)
memory usage: 78.2+ KB
针对每个字段的中文含义解释:
In [10]:
columns = df.columns
columns
Out[10]:
Index(['age', 'bp', 'sg', 'al', 'su', 'rbc', 'pc', 'pcc', 'ba', 'bgr', 'bu',
'sc', 'sod', 'pot', 'hemo', 'pcv', 'wc', 'rc', 'htn', 'dm', 'cad',
'appet', 'pe', 'ane', 'classification'],
dtype='object')
下面我们对部分字段进行处理
最终分类结果的处理
In [11]:
df["classification"].value_counts() # 修改前
Out[11]:
ckd 248
notckd 150
ckd\t 2
Name: classification, dtype: int64
可以看到有2个记录是异常的,这种情况就是属于数据异常,需要手动定位发现统一改成ckd:
In [12]:
df["classification"] = df["classification"].apply(lambda x: x if x == "notckd" else "ckd")
In [13]:
df["classification"].value_counts() # 修改后
Out[13]:
ckd 250
notckd 150
Name: classification, dtype: int64
In [14]:
px.violin(df,y="age",color="classification")
PCV-血细胞压积,红细胞在血液中所占容积比
In [15]:
df["pcv"].value_counts() # 修改前
可以看到这个字段存在不规范的记录,也需要处理:
In [16]:
df["pcv"] = pd.to_numeric(df["pcv"], errors="coerce")
In [17]:
df["pcv"].value_counts() # 修改后
白血细胞计数
In [18]:
df["wc"].value_counts() # 修改后
Out[18]:
9800 11
6700 10
9200 9
9600 9
7200 9
..
19100 1
\t? 1
12300 1
14900 1
12700 1
Name: wc, Length: 92, dtype: int64
In [19]:
df["wc"] = pd.to_numeric(df["wc"], errors="coerce")
红血细胞计数
In [20]:
df["rc"].value_counts() # 修改前
也需要进行转化:
In [21]:
df["rc"] = pd.to_numeric(df["rc"], errors="coerce")
In [22]:
# 不同字段类型统计
pd.value_counts(df.dtypes)
Out[22]:
float64 14
object 11
dtype: int64
是否有【糖尿病】?
In [23]:
df["dm"].value_counts()
Out[23]:
no 258
yes 134
\tno 3
\tyes 2
yes 1
Name: dm, dtype: int64
dm字段存在异常,一般是空格和换行符引起的;我们将取值统一成no和yes
In [24]:
df["dm"] = df["dm"].str.strip() # 去除空格
In [25]:
df["dm"].value_counts()
Out[25]:
no 261
yes 137
Name: dm, dtype: int64
是否有【冠状动脉疾病】?
In [26]:
df["cad"].value_counts()
Out[26]:
no 362
yes 34
\tno 2
Name: cad, dtype: int64
In [27]:
df["cad"] = df["cad"].str.strip() # 去除空格
查看处理后df的信息:
In [28]:
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 400 entries, 0 to 399
Data columns (total 25 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 age 391 non-null float64
1 bp 388 non-null float64
2 sg 353 non-null float64
3 al 354 non-null float64
4 su 351 non-null float64
5 rbc 248 non-null object
6 pc 335 non-null object
7 pcc 396 non-null object
8 ba 396 non-null object
9 bgr 356 non-null float64
10 bu 381 non-null float64
11 sc 383 non-null float64
12 sod 313 non-null float64
13 pot 312 non-null float64
14 hemo 348 non-null float64
15 pcv 329 non-null float64
16 wc 294 non-null float64
17 rc 269 non-null float64
18 htn 398 non-null object
19 dm 398 non-null object
20 cad 398 non-null object
21 appet 399 non-null object
22 pe 399 non-null object
23 ane 399 non-null object
24 classification 400 non-null object
dtypes: float64(14), object(11)
memory usage: 78.2+ KB
In [29]:
# 分类型
cat_cols = [col for col in df.columns if df[col].dtype == "object"]
# 连续型
num_cols = [col for col in df.columns if df[col].dtype != "object"]
下面查看分类型变量的不同取值情况:
In [30]:
for col in cat_cols:
print("变量:", col)
print(df[col].value_counts())
print("-"* 10)
变量: rbc
normal 201
abnormal 47
Name: rbc, dtype: int64
----------
变量: pc
normal 259
abnormal 76
Name: pc, dtype: int64
----------
变量: pcc
notpresent 354
present 42
Name: pcc, dtype: int64
----------
变量: ba
notpresent 374
present 22
Name: ba, dtype: int64
----------
变量: htn
no 251
yes 147
Name: htn, dtype: int64
----------
变量: dm
no 261
yes 137
Name: dm, dtype: int64
----------
变量: cad
no 364
yes 34
Name: cad, dtype: int64
----------
变量: appet
good 317
poor 82
Name: appet, dtype: int64
----------
变量: pe
no 323
yes 76
Name: pe, dtype: int64
----------
变量: ane
no 339
yes 60
Name: ane, dtype: int64
----------
变量: classification
ckd 250
notckd 150
Name: classification, dtype: int64
----------
In [31]:
# 分类型变量统计
len(cat_cols)
Out[31]:
11
In [32]:
plt.figure(figsize = (20, 16))
sub_plotnumber = 1
for col in cat_cols:
if sub_plotnumber <= 12:
ax = plt.subplot(4, 3, sub_plotnumber) # 4行3列的画布;plotnumber表示子图位置
sns.countplot(df[col], palette = "gist_earth") # 绘图
plt.xlabel(col) # x轴标签
sub_plotnumber += 1 # 自加1
plt.tight_layout()
plt.show()
In [33]:
len(num_cols) # 总共是14个连续型数值变量
Out[33]:
14
In [34]:
plt.figure(figsize = (20, 16))
sub_plotnumber = 1
for col in num_cols:
if sub_plotnumber <= 14:
ax = plt.subplot(4, 4, sub_plotnumber) # 3行5列的画布;plotnumber表示子图位置
sns.distplot(df[col]) # 绘图
plt.xlabel(col) # x轴标签
sub_plotnumber += 1 # 自加1
plt.tight_layout()
plt.show()
小结:可以看到多个特征的分布存在一定的偏度skewness(更多的是左偏)
定义3个不同绘图函数
In [35]:
# 1、小提琴图:查看数据分布情况
def violin(col):
fig = px.violin(df,
y=col,
x="classification",
color="classification",
box=True,
template = 'plotly_dark')
return fig.show()
# 2、kde密度图:是否有正态分布
def kde(col):
grid = sns.FacetGrid(df,
hue="classification",
height = 6,
aspect=2)
grid.map(sns.kdeplot, col)
grid.add_legend()
# 3、散点图:两个变量之间的关系
def scatter(col1, col2):
fig = px.scatter(df,
x=col1,
y=col2,
color="classification",
template = 'plotly_dark')
return fig.show()
In [36]:
violin("rc")
kde("rc")
两个连续型变量之间的关系
缺失值处理
In [56]:
# 全部字段的缺失值情况
df.isnull().sum().sort_values(ascending = False)
Out[56]:
rbc 152
rc 131
wc 106
pot 88
sod 87
pcv 71
pc 65
hemo 52
su 49
sg 47
al 46
bgr 44
bu 19
sc 17
bp 12
age 9
ba 4
pcc 4
htn 2
dm 2
cad 2
appet 1
pe 1
ane 1
classification 0
dtype: int64
In [57]:
# 连续型变量缺失
df[num_cols].isnull().sum()
Out[57]:
age 9
bp 12
sg 47
al 46
su 49
bgr 44
bu 19
sc 17
sod 87
pot 88
hemo 52
pcv 71
wc 106
rc 131
dtype: int64
In [58]:
# 分类型变量缺失
df[cat_cols].isnull().sum()
Out[58]:
rbc 152
pc 65
pcc 4
ba 4
htn 2
dm 2
cad 2
appet 1
pe 1
ane 1
classification 0
dtype: int64
In [59]:
df["rbc"].isna().sum() # 表示某个字段的缺失量
Out[59]:
152
In [60]:
df["dm"].mode()[0] # 某个字段的众数
Out[60]:
'no'
In [61]:
def random_value_imputate(col):
"""
函数:随机填充方法(缺失值较多的字段)
"""
# 1、确定填充的数量;在取出缺失值随机选择缺失值数量的样本
random_sample = df[col].dropna().sample(df[col].isna().sum())
# 2、索引号就是原缺失值记录的索引号
random_sample.index = df[df[col].isnull()].index
# 3、通过loc函数定位填充
df.loc[df[col].isnull(), col] = random_sample
def mode_impute(col):
"""
函数:众数填充缺失值
"""
# 1、确定众数
mode = df[col].mode()[0]
# 2、fillna函数填充众数
df[col] = df[col].fillna(mode)
1、连续型变量使用随机填充方法:
In [62]:
for col in num_cols:
random_value_imputate(col)
2、分类型变量,针对字段不同方法不同:
In [63]:
# 随机填充
random_value_imputate('rbc')
random_value_imputate('pc')
In [64]:
# 其他字段是众数填充
for col in cat_cols:
mode_impute(col)
填充完成后数据就没有缺失值:
In [65]:
df.isnull().sum()
Out[65]:
age 0
bp 0
sg 0
al 0
su 0
rbc 0
pc 0
pcc 0
ba 0
bgr 0
bu 0
sc 0
sod 0
pot 0
hemo 0
pcv 0
wc 0
rc 0
htn 0
dm 0
cad 0
appet 0
pe 0
ane 0
classification 0
dtype: int64
In [67]:
df["classification"] = df["classification"].map({"ckd":0, "notckd":1})
corr = df.corr() # 仅针对数值连续变量
plt.figure(figsize = (15, 8))
sns.heatmap(corr,
annot = True,
linewidths = 2,
linecolor = 'lightgrey')
plt.show()
可以看到和classification强相关的特征主要是:sg(尿比重)、hemo(血红蛋白)、pcv(血细胞压积,红细胞在血液中所占容积比)、rc(红细胞数量)
针对分类型变量编码:
In [68]:
for col in cat_cols:
print(f"Categories of {col}: {df[col].nunique()} ")
Categories of rbc: 2
Categories of pc: 2
Categories of pcc: 2
Categories of ba: 2
Categories of htn: 2
Categories of dm: 2
Categories of cad: 2
Categories of appet: 2
Categories of pe: 2
Categories of ane: 2
Categories of classification: 2
所有的分类型变量都是两种取值情况,我们直接使用类型编码,变成0-1即可:
In [69]:
from sklearn.preprocessing import LabelEncoder
led = LabelEncoder()
for col in cat_cols:
df[col] = led.fit_transform(df[col])
为了分析的方便,也对classification字段进行编码:
In [70]:
df["classification"].value_counts()
Out[70]:
0 250
1 150
Name: classification, dtype: int64
In [71]:
X = df.drop("classification",axis=1)
y = df["classification"]
In [72]:
# 随机打乱数据
from sklearn.utils import shuffle
df = shuffle(df)
In [73]:
# from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.20, random_state = 0)
In [75]:
def create_model(model):
# 模型训练
model.fit(X_train, y_train)
# 模型预测
y_pred = model.predict(X_test)
# 准确率acc
acc = accuracy_score(y_test, y_pred)
# 混淆矩阵
cm = confusion_matrix(y_test, y_pred)
# 分类报告
cr = classification_report(y_test,y_pred)
print(f"Test Accuracy of {model} : {acc}")
print(f"Confusion Matrix of {model}: \n{cm}")
print(f"Classification Report of {model} : \n {cr}")
In [76]:
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier()
create_model(knn)
Test Accuracy of KNeighborsClassifier() : 0.7375
Confusion Matrix of KNeighborsClassifier():
[[35 17]
[ 8 20]]
Classification Report of KNeighborsClassifier() :
precision recall f1-score support
0 0.81 0.67 0.74 52
1 0.54 0.71 0.62 28
accuracy 0.69 80
macro avg 0.68 0.69 0.68 80
weighted avg 0.72 0.69 0.69 80
In [77]:
from sklearn.tree import DecisionTreeClassifier
dt = DecisionTreeClassifier()
create_model(dt)
Test Accuracy of DecisionTreeClassifier() : 0.9375
Confusion Matrix of DecisionTreeClassifier():
[[50 2]
[ 3 25]]
Classification Report of DecisionTreeClassifier() :
precision recall f1-score support
0 0.94 0.96 0.95 52
1 0.93 0.89 0.91 28
accuracy 0.94 80
macro avg 0.93 0.93 0.93 80
weighted avg 0.94 0.94 0.94 80
In [78]:
from sklearn.ensemble import RandomForestClassifier
rd_clf = RandomForestClassifier(criterion = 'entropy',
max_depth = 11,
max_features = 'auto',
min_samples_leaf = 2, min_samples_split = 3, n_estimators = 130)
create_model(rd_clf)
Test Accuracy of RandomForestClassifier(criterion='entropy', max_depth=11, min_samples_leaf=2,
min_samples_split=3, n_estimators=130) : 0.95
Confusion Matrix of RandomForestClassifier(criterion='entropy', max_depth=11, min_samples_leaf=2,
min_samples_split=3, n_estimators=130):
[[52 0]
[ 4 24]]
Classification Report of RandomForestClassifier(criterion='entropy', max_depth=11, min_samples_leaf=2,
min_samples_split=3, n_estimators=130) :
precision recall f1-score support
0 0.93 1.00 0.96 52
1 1.00 0.86 0.92 28
accuracy 0.95 80
macro avg 0.96 0.93 0.94 80
weighted avg 0.95 0.95 0.95 80
In [80]:
from sklearn.ensemble import AdaBoostClassifier
ada = AdaBoostClassifier(base_estimator = dt)
create_model(ada)
Test Accuracy of AdaBoostClassifier(base_estimator=DecisionTreeClassifier()) : 0.95
Confusion Matrix of AdaBoostClassifier(base_estimator=DecisionTreeClassifier()):
[[51 1]
[ 3 25]]
Classification Report of AdaBoostClassifier(base_estimator=DecisionTreeClassifier()) :
precision recall f1-score support
0 0.94 0.98 0.96 52
1 0.96 0.89 0.93 28
accuracy 0.95 80
macro avg 0.95 0.94 0.94 80
weighted avg 0.95 0.95 0.95 80
In [81]:
from sklearn.ensemble import GradientBoostingClassifier
gb = GradientBoostingClassifier()
create_model(gb)
Test Accuracy of GradientBoostingClassifier() : 0.9625
Confusion Matrix of GradientBoostingClassifier():
[[51 1]
[ 3 25]]
Classification Report of GradientBoostingClassifier() :
precision recall f1-score support
0 0.94 0.98 0.96 52
1 0.96 0.89 0.93 28
accuracy 0.95 80
macro avg 0.95 0.94 0.94 80
weighted avg 0.95 0.95 0.95 80
In [82]:
from xgboost import XGBClassifier
xgb = XGBClassifier(objective = 'binary:logistic',
learning_rate = 0.5,
max_depth = 5,
n_estimators = 150)
create_model(xgb)
Test Accuracy of XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1,
importance_type='gain', interaction_constraints='',
learning_rate=0.5, max_delta_step=0, max_depth=5,
min_child_weight=1, missing=nan, monotone_constraints='()',
n_estimators=150, n_jobs=0, num_parallel_tree=1, random_state=0,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, subsample=1,
tree_method='exact', validate_parameters=1, verbosity=None) : 0.9625
Confusion Matrix of XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1,
importance_type='gain', interaction_constraints='',
learning_rate=0.5, max_delta_step=0, max_depth=5,
min_child_weight=1, missing=nan, monotone_constraints='()',
n_estimators=150, n_jobs=0, num_parallel_tree=1, random_state=0,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, subsample=1,
tree_method='exact', validate_parameters=1, verbosity=None):
[[52 0]
[ 3 25]]
Classification Report of XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1,
importance_type='gain', interaction_constraints='',
learning_rate=0.5, max_delta_step=0, max_depth=5,
min_child_weight=1, missing=nan, monotone_constraints='()',
n_estimators=150, n_jobs=0, num_parallel_tree=1, random_state=0,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, subsample=1,
tree_method='exact', validate_parameters=1, verbosity=None) :
precision recall f1-score support
0 0.95 1.00 0.97 52
1 1.00 0.89 0.94 28
accuracy 0.96 80
macro avg 0.97 0.95 0.96 80
weighted avg 0.96 0.96 0.96 80
In [83]:
from catboost import CatBoostClassifier
cab = CatBoostClassifier(iterations=10)
create_model(cab)
Learning rate set to 0.432149
0: learn: 0.2531464 total: 62.5ms remaining: 562ms
1: learn: 0.1524287 total: 63.9ms remaining: 256ms
2: learn: 0.0906595 total: 65.3ms remaining: 152ms
3: learn: 0.0578563 total: 66.8ms remaining: 100ms
4: learn: 0.0460263 total: 68.2ms remaining: 68.2ms
5: learn: 0.0356541 total: 69.6ms remaining: 46.4ms
6: learn: 0.0268575 total: 70.9ms remaining: 30.4ms
7: learn: 0.0206936 total: 72.2ms remaining: 18ms
8: learn: 0.0186242 total: 73.6ms remaining: 8.17ms
9: learn: 0.0162996 total: 75ms remaining: 0us
Test Accuracy of <catboost.core.CatBoostClassifier object at 0x1296bfe50> : 0.9750
Confusion Matrix of <catboost.core.CatBoostClassifier object at 0x1296bfe50>:
[[51 1]
[ 3 25]]
Classification Report of <catboost.core.CatBoostClassifier object at 0x1296bfe50> :
precision recall f1-score support
0 0.94 0.98 0.96 52
1 0.96 0.89 0.93 28
accuracy 0.95 80
macro avg 0.95 0.94 0.94 80
weighted avg 0.95 0.95 0.95 80
In [84]:
from sklearn.ensemble import ExtraTreesClassifier
etc = ExtraTreesClassifier()
create_model(etc)
Test Accuracy of ExtraTreesClassifier() : 0.9625
Confusion Matrix of ExtraTreesClassifier():
[[52 0]
[ 3 25]]
Classification Report of ExtraTreesClassifier() :
precision recall f1-score support
0 0.95 1.00 0.97 52
1 1.00 0.89 0.94 28
accuracy 0.96 80
macro avg 0.97 0.95 0.96 80
weighted avg 0.96 0.96 0.96 80
In [85]:
from lightgbm import LGBMClassifier
lgbm = LGBMClassifier(learning_rate = 0.1)
create_model(lgbm)
Test Accuracy of LGBMClassifier() : 0.9625
Confusion Matrix of LGBMClassifier():
[[51 1]
[ 2 26]]
Classification Report of LGBMClassifier() :
precision recall f1-score support
0 0.96 0.98 0.97 52
1 0.96 0.93 0.95 28
accuracy 0.96 80
macro avg 0.96 0.95 0.96 80
weighted avg 0.96 0.96 0.96 80
In [86]:
models = pd.DataFrame({"model":["KNN","Decision Tree","Random Forest","Ada Boost ",
"Gradient Boosting","Xgboost","Cat Boost","Extra Trees","LGBM"],
"acc":[0.6875,0.9375,0.95,0.95,0.95,0.9625,0.95,0.9625,0.9625]})
models
Out[86]:
In [87]:
models = models.sort_values("acc",ascending=True) # 升序排列
models
Out[87]:
In [88]:
px.bar(models,
x="acc",
y="model",
text="acc",
color = 'acc',
template = 'plotly_dark',
title = 'Nine Models Comparison')
我们在这里选择随机森林模型(rd_clf)同时使用shap库来进行解释
In [89]:
explainer = shap.TreeExplainer(rd_clf)
# 在explainer中传入特征值的数据,计算shap值
shap_values = explainer.shap_values(X_test)
shap_values
Out[89]:
[array([[ 0.00082722, -0.00174422, 0.08114011, ..., -0.00147523,
0.01219736, 0.00090329],
[ 0.00241676, -0.00674329, -0.10784976, ..., -0.00708956,
-0.00392391, -0.00032504],
[ 0.00211145, -0.00874862, -0.12467772, ..., -0.00752903,
-0.00413827, -0.00035173],
...,
[-0.00775482, -0.00930411, -0.12662254, ..., -0.00807709,
-0.00414294, -0.00035052],
[-0.00719232, -0.00715492, -0.11854169, ..., -0.00821187,
-0.00448367, -0.00035973],
[-0.00499459, -0.00897095, -0.12441733, ..., -0.0081471 ,
-0.00434574, -0.00035173]]),
array([[-0.00082722, 0.00174422, -0.08114011, ..., 0.00147523,
-0.01219736, -0.00090329],
[-0.00241676, 0.00674329, 0.10784976, ..., 0.00708956,
0.00392391, 0.00032504],
[-0.00211145, 0.00874862, 0.12467772, ..., 0.00752903,
0.00413827, 0.00035173],
...,
[ 0.00775482, 0.00930411, 0.12662254, ..., 0.00807709,
0.00414294, 0.00035052],
[ 0.00719232, 0.00715492, 0.11854169, ..., 0.00821187,
0.00448367, 0.00035973],
[ 0.00499459, 0.00897095, 0.12441733, ..., 0.0081471 ,
0.00434574, 0.00035173]])]
In [90]:
shap.summary_plot(shap_values[1], X_test, plot_type="bar")
从结果来看,sg(尿比重)、sc(血清肌酐)、hemo(血红蛋白)是重点影响特征。
shap.summary_plot(shap_values[1], X_test)
summary plot
为每个样本绘制其每个特征的SHAP值;一个点代表一个样本,颜色表示特征值的高低(红色高,蓝色低)
查看单个病人的不同特征属性对其结果的影响:
从选择3个病人的结果来看,即使同样是患病者shap值的个体差异仍然很大。