首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何在形状上绘制错误分类的样本?

如何在形状上绘制错误分类的样本?
EN

Stack Overflow用户
提问于 2020-10-09 15:42:40
回答 2查看 1.2K关注 0票数 1

我有一个基因数据集,得分在引起疾病的可能性的0到1之间(基因得分为1是已知的致病基因,而如果基因得分为0.74,则很可能导致疾病)。我试图建立一个机器学习模型来预测新基因在回归分类中的疾病评分。

我想看一看那些已知的疾病基因,但得分很低的基因的形状决定图(例如,得分为1的基因,但我的模型得分小于0.8)。我正挣扎着把这些基因组合在一起来作图。

我的数据看起来是:

代码语言:javascript
运行
复制
X:
Index   Feature1  Feature2   ... FeatureN
Gene1     1           0.2          10
Gene2     1           0.1          7
Gene3     0           0.3          10
#index is actually the index and not a column

Y:
Score
1
0.6
0.4

我运行了一个带有嵌套交叉验证的xgboost回归器,并查看了MSE,预测了r2,并绘制了观察到的和预期的值。我可以从观察到的和预期的图中看到,Y中得分1的基因有很多被模型预测的低分数,我想知道为什么这个模型用shap来做这个。不幸的是,我不能给出示例数据。

我正在尝试调整用于标签分类的示例shap代码:

代码语言:javascript
运行
复制
import shap

xgbr = xgboost.XGBRegressor()
xgbr.fit(X_train, Y_train)

select = range(8) #I have 8 features after feature selection with BorutaShap
features = X.iloc[select]
features_display = X.loc[features.index]

explainer = shap.TreeExplainer(xgbr)
expected_value = explainer.expected_value

#Example code from https://slundberg.github.io/shap/notebooks/plots/decision_plot.html: 

y_pred = xgbr.predict(X) 
y_pred = (shap_values.sum(1) + expected_value) > 0
misclassified = y_pred != y_test[select]
shap.decision_plot(expected_value, shap_values, features_display, link='logit', highlight=misclassified)

我如何选择y_pred,以便预测/基因应该是1,但实际上低于0.8 (或任何低的数字)?

编辑:作为对给出的答案的回应,我尝试过:

代码语言:javascript
运行
复制
explainer = shap.TreeExplainer(xgbr)
shap_values = explainer.shap_values(X_test)

y_pred = xgbr.predict(X_test)
m = (y_pred <= 0.5) & (Y_test == 1)

shap.initjs()
shap.decision_plot(explainer.expected_value, shap_values,  X_test[m],  return_objects=True)

这是运行的,但是m的长度是171 (我的Y_test数据中的全部行数),然后绘制出它所显示的171个数据--从数据来看,我知道应该只有一个基因,即<= 0.5,但实际上是1。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2020-10-14 15:08:47

首先,你提到预测新基因的疾病评分在回归分类中是什么意思?输出似乎是二进制的,01,因此这是一个二进制分类问题。您应该使用xgboost的分类器。Update:根据注释,让我们假设一个回归问题来模拟您的情况。尽管对于下面的示例,我们应该设置'objective':'multi:softmax'来输出实际的标签。

根据你的问题,你想要做的似乎是在那些没有正确预测的样本上对测试集进行索引,并分析误导性的特征,这是合理的。

让我们用一些示例数据集重现您的问题:

代码语言:javascript
运行
复制
from sklearn.datasets import load_iris

from sklearn.model_selection import train_test_split
import shap
import xgboost

X,y = shap.datasets.iris()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

model = xgboost.train(params={"learning_rate": 0.01}, 
                      dtrain=xgboost.DMatrix(X_train, label=y_train), 
                      num_boost_round =100)

使用整个测试集的SHAP图,是直接向前的。例如,对于一个force_plot

代码语言:javascript
运行
复制
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)

shap.initjs()
shap.force_plot(explainer.expected_value, shap_values, X_test)

现在,如果我们想对错误分类的样本进行同样的处理,我们需要看看输出概率。由于虹膜数据集有几个类,假设我们希望对本应被归类为force_plot的样本可视化2,但是我们的输出值低于1.7

代码语言:javascript
运行
复制
y_pred = model.predict(xgboost.DMatrix(X_test))
m = (y_pred <= 1.7) & (y_test == 2)

现在,我们可以使用掩码对X_test集执行布尔索引,并更新shap_values

代码语言:javascript
运行
复制
shap.initjs()
c= explainer.shap_values(X_test[m])
shap.force_plot(explainer.expected_value, shap_values, X_test[m])

这告诉我们,花瓣的长度和宽度主要推动回归到一个更高的值。因此,它们大概是在错误分类中起主要作用的变量。

类似地,对于decision_plot

代码语言:javascript
运行
复制
shap.decision_plot(explainer.expected_value, shap_values, 
                   X_test[m], feature_order='hclust', 
                   return_objects=True)

票数 2
EN

Stack Overflow用户

发布于 2020-10-14 14:10:29

因为我没有您的数据集,所以我无法检查代码,但是这里有一些想法,可以告诉您方向。

似乎你没有训练你的回归者。应该是线,就像

代码语言:javascript
运行
复制
xgbr = xgboost.XGBRegressor()
xgbr.train(X, Y)

现在您可以使用xgbr.predict(X) ;)

你还需要训练解释者:

代码语言:javascript
运行
复制
explainer = shap.TreeExplainer(xgbr)
with warnings.catch_warnings():
     warnings.simplefilter("ignore")
     sh = explainer.shap_values(X)

现在您可以选择值:

代码语言:javascript
运行
复制
misclassified = (y_pred <= 0.7) & (Y == 1)
shap.decision_plot(expected_value, sh, features_display, link='logit', highlight=misclassified)

在使用shap之前,我建议您检查回归器是否适合您的数据。因此,为此,我建议您将部分数据用于测试,而不是在培训中使用。然后,通过计算和比较测试集和训练集上的MSE,来评估拟合优度。差异越大,预测器的性能就越差。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64283186

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档