首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >当我将培训数据包含在expected_value中时,为什么会得到不同的TreeExplainer?

当我将培训数据包含在expected_value中时,为什么会得到不同的TreeExplainer?
EN

Stack Overflow用户
提问于 2020-11-13 13:30:05
回答 2查看 1.5K关注 0票数 5

包括SHAP中的训练数据,TreeExplainer给出了不同的expected_value在scikit学习GBT回归。

可复制的示例(运行在Google中):

代码语言:javascript
复制
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingRegressor
import numpy as np
import shap

shap.__version__
# 0.37.0

X, y = make_regression(n_samples=1000, n_features=10, random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

gbt = GradientBoostingRegressor(random_state=0)
gbt.fit(X_train, y_train)

# mean prediction:
mean_pred_gbt = np.mean(gbt.predict(X_train))
mean_pred_gbt
# -11.534353657511172

# explainer without data
gbt_explainer = shap.TreeExplainer(gbt)
gbt_explainer.expected_value
# array([-11.53435366])

np.isclose(mean_pred_gbt, gbt_explainer.expected_value)
# array([ True])

# explainer with training data
gbt_data_explainer = shap.TreeExplainer(model=gbt, data=X_train) # specifying feature_perturbation does not change the result
gbt_data_explainer.expected_value
# -23.564797322079635

因此,当包含训练数据时,gbt_data_explainer.expected_value的期望值与不提供数据的预期值(gbt_explainer.expected_value)相差很大。

当与(明显不同的)各自的shap_values一起使用时,这两种方法是相加的,并且是一致的。

代码语言:javascript
复制
np.abs(gbt_explainer.expected_value + gbt_explainer.shap_values(X_train).sum(1) - gbt.predict(X_train)).max() < 1e-4
# True

np.abs(gbt_data_explainer.expected_value + gbt_data_explainer.shap_values(X_train).sum(1) - gbt.predict(X_train)).max() < 1e-4
# True

但我想知道为什么它们不提供相同的expected_value,为什么gbt_data_explainer.expected_value与预测的平均值如此不同。

我在这里错过了什么?

EN

Stack Overflow用户

发布于 2020-11-21 13:46:07

虽然@Ben在挖掘data如何通过Independent掩码方面做了很大的工作,但他的回答并没有准确地显示(1)基值是如何计算的,从哪里得到不同的基值,以及(2)如何选择/降低max_samples参数。

,其中不同的值来自

掩蔽对象具有一个data属性,该属性在掩蔽处理后保存数据。获取在gbt_explainer.expected_value中显示的值

代码语言:javascript
复制
from shap.maskers import Independent
gbt = GradientBoostingRegressor(random_state=0)

# mean prediction:
mean_pred_gbt = np.mean(gbt.predict(X_train))
mean_pred_gbt
# -11.534353657511172

# explainer without data
gbt_explainer = shap.TreeExplainer(gbt)
gbt_explainer.expected_value
# array([-11.53435366])

gbt_explainer = shap.TreeExplainer(gbt, Independent(X_train,100))
gbt_explainer.expected_value
# -23.56479732207963

我们需要这样做:

代码语言:javascript
复制
masker = Independent(X_train,100)
gbt.predict(masker.data).mean()
# -23.56479732207963

max_samples**?** 如何选择

max_samples设置为原始数据集长度似乎也适用于其他解释程序:

代码语言:javascript
复制
import sklearn
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
import shap
from shap.maskers import Independent
from scipy.special import logit, expit

corpus,y = shap.datasets.imdb()
corpus_train, corpus_test, y_train, y_test = train_test_split(corpus, y, test_size=0.2, random_state=7)

vectorizer = TfidfVectorizer(min_df=10)
X_train = vectorizer.fit_transform(corpus_train)

model = sklearn.linear_model.LogisticRegression(penalty="l2", C=0.1)
model.fit(X_train, y_train)

explainer = shap.Explainer(model
                           ,masker = Independent(X_train,100)
                           ,feature_names=vectorizer.get_feature_names()
                          )
explainer.expected_value
# -0.18417413671991964

此值来自:

代码语言:javascript
复制
masker=Independent(X_train,100)
logit(model.predict_proba(masker.data.mean(0).reshape(1,-1))[...,1])
# array([-0.18417414])

max_samples=100似乎有点不适合使用true base_value (仅仅是提供功能):

代码语言:javascript
复制
logit(model.predict_proba(X_train.mean(0).reshape(1,-1))[:,1])
array([-0.02938039])

通过增加max_samples,可以合理地接近true基线,同时保持较低的样本数:

代码语言:javascript
复制
masker = Independent(X_train,1000)
logit(model.predict_proba(masker.data.mean(0).reshape(1,-1))[:,1])
# -0.05957302658674238

因此,要获得感兴趣的解释者的基值,(1)通过您的模型传递explainer.data (或masker.data),(2)选择max_samples,以便采样数据上的base_value足够接近真正的基值。您也可以尝试观察形状输入的值和顺序是否一致。

有些人可能会注意到,为了得到基本值,有时我们平均特征输入(LogisticRegression),有时输出(GBT)。

票数 2
EN
查看全部 2 条回答
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64821829

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档