前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >机器学习入门 12-5 CART与决策树中的超参数

机器学习入门 12-5 CART与决策树中的超参数

作者头像
触摸壹缕阳光
发布2021-01-28 15:46:08
1.8K0
发布2021-01-28 15:46:08
举报

CART

前面介绍的决策树通常还有一个名字,叫做 CART(读音与cut相近)。CART 是 Classification And Regression Tree 的首字母缩写,通过 Classification And Regression Tree 的字面意思可以看出,CART 这种决策树既能够解决分类问题(Classification)也能够解决回归问题(Regression)。每个节点根据某种衡量系统不确定性的指标(信息熵或基尼系数)来找到某个合适的维度 d 以及维度 d 上的阈值 v,根据 d 和 v 对当前节点中的数据进行二分,通过这种方式得到的决策树一定是一颗二叉树,这也是 CART 这种决策树的特点。

在 sklearn 中的决策树都是 CART。在文献或者资料中看到的 ID3、C4.5 和 C5.0 都是构建决策树的不同方法。

CART的复杂度

决策树的预测时间复杂度为 O(logm),训练的时间复杂度为 O(n*m*logm)「需要注意我们介绍的都是 CART 这种决策树。」

  • 决策树预测的时间复杂度为 O(logm)

通过之前对决策树上的各个节点上的数据进行划分的模拟,对于构建好的一棵决策树进行预测,平均而言的时间复杂度为 O(logm),其中 m 为样本个数。每次在一个节点上都是对当前的数据进行对半划分,因此最终这棵决策树的高度大概为 logm,当构建好了这棵高度为 logm 的决策树之后,来了一个新样本,这个新样本就需要从这棵决策树的根节点开始一步一步进行决策判断,最终走到叶子节点。根据对应叶子节点上的数据来决定新样本的标签(分类问题)和目标(回归问题)。

  • 决策树训练的时间复杂度为 O(n*m*logm)

由于决策树算法属于无参数学习,因此构造决策树的过程就是决策树的训练过程,决策树训练的时间复杂度为 O(n*m*logm)。在模拟构建决策树的过程中,需要对样本中的每一个维度 n 和每一个样本 m 进行遍历,最终找到在哪个维度上的哪个阈值上进行划分的最佳划分点,训练的时间复杂度相对来说还是比较高的。

CART的剪枝

决策树和 kNN 算法一样都属于非参数学习算法,所有的非参数学习算法都非常容易发生过拟合。因此对于决策树来说,不仅训练的时间复杂度比较高,而且非常容易发生过拟合。基于决策树的高时间复杂度以及容易产生过拟合的问题,实际在构建决策树的时候必须要对决策树进行剪枝的操作,剪枝操作有两个目的:

  • 降低时间复杂度;
  • 解决过拟合;

其实前面构建决策树的例子中一直在使用剪枝操作。在使用 sklearn 创建决策树的对象时传入的 max_depth 参数,指定 max_depth 参数一直是 2,即构建决策树的最大深度,这其实就是剪枝的一种手段。换句话说,所谓的剪枝,其实就是在创建决策树对象时传入参数的平衡。

除了前面一直使用的 max_depth 最大深度这个参数外,还有很多参数可以用于剪枝,这些参数既可以降低决策树训练过程的时间复杂度,同时也可以减轻过拟合的问题。

首先导入相应的库函数。

代码语言:javascript
复制
 In[1]: import numpy as np
        import matplotlib.pyplot as plt

本小节使用的是由 sklearn.datasets 中的 make_moons 函数生成的噪声为 0.25 的非线性虚拟数据集。使用非线性数据集是为了能够更好的看出决策树发生过拟合的样子,以及使用超参数解决过拟合后的结果。

代码语言:javascript
复制
 In[2]: from sklearn import datasets
        X, y = datasets.make_moons(noise = 0.25, random_state = 666)

通过散点图绘制数据集的分布。

代码语言:javascript
复制
 In[3]: plt.scatter(X[y == 0, 0], X[y == 0, 1])
        plt.scatter(X[y == 1, 0], X[y == 1, 1])
        plt.show()
  • 默认参数的决策树
代码语言:javascript
复制
 In[4]: from sklearn.tree import DecisionTreeClassifier
        
        dt_clf = DecisionTreeClassifier()
        dt_clf.fit(X, y)
        
Out[4]: DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best')

使用 sklearn 中的决策树不传入任何参数,默认节点上的划分标准使用基尼系数。默认 max_depth = None,对决策树的深度不做限定,决策树会一直向下划分,直到划分后的节点的基尼系数都为 0 为止。

训练好了默认参数的决策树,接下来使用前面一直使用的 plot_decision_boundary 函数绘制决策边界。

代码语言:javascript
复制
 In[5]: def plot_decision_boundary(model, axis):
            x0, x1 = np.meshgrid(
          np.linspace(axis[0], axis[1], int((axis[1]-axis[0])*100)).reshape(-1, 1),
          np.linspace(axis[2], axis[3], int((axis[3]-axis[2])*100)).reshape(-1, 1),
      )
            X_new = np.c_[x0.ravel(), x1.ravel()]
            
            y_predict = model.predict(X_new)
            zz = y_predict.reshape(x0.shape)
            from matplotlib.colors import ListedColormap
            custom_cmap = ListedColormap(['#EF9A9A','#FFF59D','#90CAF9'])
            
            plt.contourf(x0, x1, zz, linewidth=5, cmap=custom_cmap)

选择合适的绘制范围,在绘制决策边界的同时将数据集以散点图的形式也绘制出来。

代码语言:javascript
复制
 In[6]: plot_decision_boundary(dt_clf, axis=[-1.5, 2.5, -1.0, 1.5])
        plt.scatter(X[y==0,0], X[y==0,1])
        plt.scatter(X[y==1,0], X[y==1,1])
        plt.show()

此时绘制出来的决策边界的形状相对不规则,显然默认参数的决策树模型发生了过拟合。

  • 指定 max_depth 参数为 2,限制整个决策树的最大深度为 2
代码语言:javascript
复制
 In[7]: dt_clf2 = DecisionTreeClassifier(max_depth=2)
        dt_clf2.fit(X, y)
        
        plot_decision_boundary(dt_clf2, axis=[-1.5, 2.5, -1.0, 1.5])
        plt.scatter(X[y==0,0], X[y==0,1])
        plt.scatter(X[y==1,0], X[y==1,1])
        plt.show()

此时绘制出来的决策边界非常清晰,没有那种特别不规则的边界。换句话说,决策树模型不会针对某几个特别的样本点进行特殊的变化。相对于前面默认参数的决策树,显然指定参数max_depth = 2 的决策树模型的过拟合程度降低。当然此时的模型又可能欠拟合,所以对于这些参数,我们需要进行比较精细的调整,让决策树模型处在一个既不过拟合又不欠拟合的位置上。

  • 指定 min_samples_split 参数为 10,节点再划分所需要的最小样本数为 10
代码语言:javascript
复制
 In[8]: dt_clf3 = DecisionTreeClassifier(min_samples_split=10)
        dt_clf3.fit(X, y)
        
        plot_decision_boundary(dt_clf3, axis=[-1.5, 2.5, -1.0, 1.5])
        plt.scatter(X[y==0,0], X[y==0,1])
        plt.scatter(X[y==1,0], X[y==1,1])
        plt.show()

此时绘制出来的决策边界非常清晰,没有那种特别不规则的边界。换句话说,决策树模型不会针对某几个特别的样本点进行特殊的变化。相对于前面默认参数的决策树,显然指定参数min_samples_split = 10 的决策树模型的过拟合程度降低。

「将 min_samples_split 参数的值设置的越低,决策树模型越容易发生过拟合。」 考虑极端情况下,如果将 min_samples_split 设置的值大于等于样本总数,此时仅有的根节点不需要进行划分,显然此时仅有一个根节点的决策树模型欠拟合。

  • 指定 min_samples_leaf 参数为 6,叶子节点最少样本数为 6
代码语言:javascript
复制
 In[9]: dt_clf4 = DecisionTreeClassifier(min_samples_leaf=6)
        dt_clf4.fit(X, y)
        
        plot_decision_boundary(dt_clf4, axis=[-1.5, 2.5, -1.0, 1.5])
        plt.scatter(X[y==0,0], X[y==0,1])
        plt.scatter(X[y==1,0], X[y==1,1])
        plt.show()

此时绘制出来的决策边界非常清晰,没有那种特别不规则的边界。换句话说,决策树模型不会针对某几个特别的样本点进行特殊的变化。相对于前面默认参数的决策树,显然指定参数min_samples_leaf = 6 的决策树模型的过拟合程度降低。

「将 min_samples_leaf 参数的值设置的越低,决策树模型越容易发生过拟合。」 考虑极端情况下,如果将 min_samples_leaf 设置为 1,对于只有一个样本点的叶子节点,在具体预测的时候,测试样本点需要根据所到达的叶子节点上的样本点来决定预测的类别(或目标),而如果此时叶子节点仅有一个样本点,那么测试样本点非常容易受到这一个样本点的影响,测试样本点的预测类别(或目标)会变得非常敏感。

  • 指定 max_leaf_nodes 参数为 4,最大叶子节点数为 4
代码语言:javascript
复制
 In[10]: dt_clf5 = DecisionTreeClassifier(max_leaf_nodes=4)
         dt_clf5.fit(X, y)

         plot_decision_boundary(dt_clf5, axis=[-1.5, 2.5, -1.0, 1.5])
         plt.scatter(X[y==0,0], X[y==0,1])
         plt.scatter(X[y==1,0], X[y==1,1])
         plt.show()

此时绘制出来的决策边界非常清晰,没有那种特别不规则的边界。换句话说,决策树模型不会针对某几个特别的样本点进行特殊的变化。相对于前面默认参数的决策树,显然指定参数max_leaf_nodes = 4 的决策树模型的过拟合程度降低。

「将 max_leaf_nodes 参数的值设置的越高,决策树模型越容易发生过拟合。」

小结

「决策树这种非参数学习很容易过拟合,所以在实际使用这些参数的时候,要注意避免决策树模型被过渡调节参数,从而导致决策树模型欠拟合。同时这些参数并不是相互独立的,它们之间可以相互组合,所以可以使用网格搜索的方式寻找最优的参数组合。」

对于决策树算法来说,可以调节的参数还有很多,「但是需要注意,无论将决策树用于分类问题,还是用于回归问题,可能无论怎样调节这些参数都不能得到特别好的效果,这就是决策树的局限性。」 尽管如此,决策树依然非常重要,因为机器学习中非常重要的算法随机森林,使用了决策树的思想,所以本小节介绍的这些参数,也能够用于随机森林的调参。

References:

  1. Python3入门机器学习 经典算法与应用: https://coding.imooc.com/class/chapter/169.html#Anchor
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-01-21,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • CART
  • CART的复杂度
  • CART的剪枝
  • 小结
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档