前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >用dtreeviz实现决策树可视化

用dtreeviz实现决策树可视化

作者头像
磐创AI
发布2021-04-21 10:22:00
2.1K0
发布2021-04-21 10:22:00
举报

磐创AI分享

作者 | Eryk Lewinson 编译 | VK 来源 | Towards Data Science

决策树是一类非常重要的机器学习模型,也是许多更高级算法的组成部分,如随机林或著名的XGBoost。这些树也是基线模型的良好起点,我们随后尝试使用更复杂的算法对其进行改进。

决策树的最大优点之一是它的可解释性——在拟合模型之后,它是一组有效的规则,可以用来预测目标变量。这也是为什么很容易绘制规则并将其展示给涉众,这样他们就可以很容易地理解模型的底层逻辑。当然,只要树不太深。

使用scikitlearn和matplotlib的组合,可视化决策树非常简单。然而,有一个很好的名为dtreeviz的库,它带来了更多内容,可以创建了不仅更漂亮而且能传达更多决策过程信息的可视化效果。

在本文中,我将首先展示绘制决策树的“旧方法”,然后介绍使用dtreeviz的改进方法。

安装程序

一如既往,我们需要从导入所需的库开始。

代码语言:javascript
复制
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris, load_boston
from sklearn import tree

from dtreeviz.trees import *

然后,我们从scikit learn加载Iris数据集。我们还将讨论一个回归示例,但稍后将为此加载波士顿住房数据集。

代码语言:javascript
复制
# 加载数据集
iris = load_iris()
boston = load_boston()

“老办法”

下一步包括创建训练/测试集,并将决策树分类器与iris数据集相匹配。在本文中,我们只关注可视化决策树。因此,我们不注意拟合模型或寻找一组好的超参数(关于这些主题的文章很多)。我们唯一要“调整”的是树的最大深度—我们将其限制为3,这样树仍然可以适应图像并保持可读性。

代码语言:javascript
复制
# 准备数据
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 拟合
clf = tree.DecisionTreeClassifier(max_depth=3, random_state=42)
clf.fit(X_train, y_train)

现在我们有了一个合适的决策树模型,我们可以继续可视化的树。我们从最简单的方法开始-使用scikit learn中的plot_tree函数。

代码语言:javascript
复制
tree.plot_tree(clf);

好吧,这也不错。但是它的可读性不强,例如,没有特征名称(只有它们的列索引)或类标签。我们可以通过运行以下代码片段轻松地改进这一点。

代码语言:javascript
复制
tree.plot_tree(clf,
               feature_names = iris.feature_names, 
               class_names=iris.target_names,
               rounded=True, 
               filled = True);

好多了!现在,我们可以很容易地解释决策树。也可以使用graphviz库来可视化决策树,但是,结果非常相似,具有与上图相同的元素集。这就是为什么我们将在这里跳过它。

dtreeviz

在了解了绘制决策树的老方法之后,让我们直接进入dtreeviz方法。

代码语言:javascript
复制
viz = dtreeviz(clf, 
               x_data=X_train,
               y_data=y_train,
               target_name='class',
               feature_names=iris.feature_names, 
               class_names=list(iris.target_names), 
               title="Decision Tree - Iris data set")
viz

代码片段几乎是不言自明的,因此我们可以继续讨论结果。首先,让我们花一点时间来确认它有多大的改进,特别是考虑到函数调用非常相似。

让我们一步一步地看图表。在每个节点上,我们都可以看到用于分割观测值的特征的堆叠直方图,并按类别着色。

通过这种方式,我们可以看到类是如何通过来分割的。x轴的小三角形是拆分点。在第一个柱状图中,我们可以清楚地看到,所有观察到的刚毛类的花瓣长度都小于2.45厘米。

树的右分支表示选择大于或等于拆分值的值,而左分支表示选择小于拆分值的值。叶节点用饼图表示,饼图显示叶中的观察值属于哪个类。这样,我们就可以很容易地看到哪个类是最主要的,所以也可以看到模型的预测。

在这张图上,我们没有看到的是每个节点的基尼系数。在我看来,柱状图提供了更多关于分割的直观信息,在向利益相关者呈现的情况下,基尼的值可能没有那么重要。

注意:我们也可以为测试集创建一个类似的可视化,我们只需要在调用函数时替换x_data和y_data参数。

如果你不喜欢直方图并且希望简化绘图,可以指定fancy=False来接收以下简化绘图。

dtreeviz的另一个方便的功能是提高模型的可解释性,即在绘图上突出显示特定观测值的路径。通过这种方式,我们可以清楚地看到哪些特征有助于类预测。

使用下面的代码片段,我们突出显示测试集的第一个样本的路径。

代码语言:javascript
复制
viz = dtreeviz(clf, 
               x_data=X_train,
               y_data=y_train,
               target_name='class',
               feature_names=iris.feature_names, 
               class_names=list(iris.target_names),
               title="Decision Tree - Iris data set",
               #orientation="LR", 
               X=X_test[0])  
viz

这张图与前一张非常相似,然而,橙色突出显示清楚地显示了样本所遵循的路径。此外,我们可以在每个直方图上看到橙色三角形。它表示给定特征的观察值。最后,我们看到了这个样本的所有特征的值,用于决策的特征用橙色突出显示。在这种情况下,只有两个特征被用来预测观察属于花色类。

提示:我们还可以通过设置orientation=“LR”从上到下再从左到右更改绘图的方向。在本文中我们不展示它,因为对于屏幕较窄的设备,图表的缩放效果不会很好。

最后,我们可以用通俗易懂的英语打印这个观察预测所用的决定。为此,我们运行以下命令。

代码语言:javascript
复制
print(explain_prediction_path(clf, X_test[0], 
                              feature_names=iris.feature_names, 
                              explanation_type="plain_english"))

# 2.45 <= petal length (cm)  < 4.75 
# petal width (cm) < 1.65

这样,我们就可以清楚地看到这个观察所满足的条件。

回归示例

我们已经介绍了一个分类示例,它显示了库的大多数有趣的功能。但为了完整性起见,我们还讨论了一个回归问题的例子,来说明曲线图是如何不同的。我们使用另一个流行的数据集——波士顿住房数据集。我们使用一组不同的地区来预测波士顿某些地区的房价中值。

代码语言:javascript
复制
# 准备数据
X = boston.data
y = boston.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 拟合
reg = tree.DecisionTreeRegressor(max_depth=2, random_state=42)
reg.fit(X_train, y_train)

# 绘图
viz = dtreeviz(reg,
               x_data=X_train,
               y_data=y_train,
               target_name='price',
               feature_names=boston.feature_names,
               title="Decision Tree - Boston housing",
               show_node_labels = True)
viz

代码已经让人感觉很相似了。唯一的变化是我们添加了show_node_labels = True。对于较大的决策树,它尤其方便。

让我们深入研究分类树和回归树之间的区别。这一次,我们不看直方图,而是检查用于分割和目标的特征散点图。在这些散点图上,我们看到一些虚线。其解释如下:

  • 水平线是决策节点中左右边的目标平均值。
  • 垂直线是分割点。它与黑色三角形表示的信息完全相同。

在叶节点中,虚线表示叶内目标的平均值,这也是模型的预测。

我们已经展示了我们可以突出某个观察的决策路径。我们可以更进一步,只绘制用于预测的节点。为此,我们指定show_just_path=True。下图仅显示上面树中选定的节点。

结论

在本文中,我演示了如何使用dtreeviz库来创建决策树的优雅而有见地的可视化。玩了一段时间之后,我肯定会继续使用它作为可视化决策树的工具。我相信使用这个库创建的图对于那些不经常使用ML的人来说更容易理解,并且可以帮助向涉众传达模型的逻辑。

还值得一提的是,dtreeviz支持XGBoost和Spark MLlib树的一些可视化。

你可以在我的GitHub上找到本文使用的代码:https://github.com/erykml/medium_articles/blob/master/Machine%20Learning/decision_tree_visualization.ipynb

如果你喜欢这篇文章,你可能还对以下内容之一感兴趣:

https://towardsdatascience.com/improve-the-train-test-split-with-the-hashing-function-f38f32b721fb

https://towardsdatascience.com/lazy-predict-fit-and-evaluate-all-the-models-from-scikit-learn-with-a-single-line-of-code-7fe510c7281

https://towardsdatascience.com/explaining-feature-importance-by-example-of-a-random-forest-d9166011959e

参考引用

https://github.com/parrt/dtreeviz

https://explained.ai/decision-tree-viz/index.html

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-04-05,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 磐创AI 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 参考引用
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档