我对SHAP很陌生,并试图在我的RandomForestClassifier上使用它。下面是我已经运行clf.fit(train_x, train_y)
之后的代码片段
explainer = shap.Explainer(clf)
shap_values = explainer(train_x.to_numpy()[0:5, :])
shap.summary_plot(shap_values, plot_type='bar')
由此产生的情节如下:
这有两个问题。其一,即使我设置了plot_type
参数,它也不是条形图。另一个原因是,我似乎以某种方式丢失了我的功能名称(是的,在调用clf.fit()
时,它们确实存在于数据流中)。
我试着把最后一行替换成:
shap.summary_plot(shap_values, train_x.to_numpy()[0:5, :], plot_type='bar')
但这没什么改变。我还试图用以下内容替换它,以确定是否至少可以恢复我的功能名称:
shap.summary_plot(shap_values, train_x.to_numpy()[0:5, :], feature_names=list(train_x.columns.values), plot_type='bar')
但这造成了一个错误:
Traceback (most recent call last):
File "sklearn_model_runs.py", line 41, in <module>
main()
File "sklearn_model_runs.py", line 38, in main
shap.summary_plot(shap_values, train_x.to_numpy()[0:5, :], feature_names=list(train_x.columns.values), plot_type='bar')
File "C:\Users\kapoo\anaconda3\envs\sci\lib\site-packages\shap\plots\_beeswarm.py", line 554, in summary_legacy
feature_names=feature_names[sort_inds],
TypeError: only integer scalar arrays can be converted to a scalar index
我现在有点不知所措。我只是尝试了5行的训练集,但想要使用整个东西,一旦我通过了这个绊脚石。如果有帮助的话,分类器有5个标签,我的SHAP版本是0.40.0。
发布于 2021-11-23 23:55:54
好吧,问题出在这里。代之以:
shap_values = explainer(train_x.to_numpy()[0:5, :])
在这方面:
shap_values = explainer.shap_values(train_x) # Use whole thing as dataframe
然后,您可以在绘图过程中使用此方法:
feature_names=list(train_x.columns.values)
文档这里确实应该更新..。
https://stackoverflow.com/questions/70076314
复制相似问题