机器学习决策树的分裂到底是什么?这篇文章讲明白了!

作者 | Prashant Gupta

译者 | AI100(rgznai100)

在实际生活中,树的类比如影随形。事实证明,树形结构对于机器学习领域同样有着广泛的影响,特别是对分类和回归两大任务来说。

在决策分析中,决策树可以非常清晰地呈现决策的过程和结果。“树”如其名,决策树所用的正是一个树形的决策模型。数据挖掘领域经常会用决策树来搜寻给定问题的解决策略,机器学习领域同样会广泛用到这一方法。这将会是这篇博客的主题。

算法如何能被表示成树形?

对于这一点,我们来看一个基本的例子:用泰坦尼克号的数据集每位乘客的信息来预测他能否幸存下来。接下来这个模型,选取了数据集内的3项特征(features)/ 属性(attributes)/ 列(columns):性别、年龄、sibsp(Sibling及Spouse:表示随行的亲属数量)。

决策树自上而下进行绘制,顶部节点被称为根节点。如上图所示:黑色加粗文字代表了条件(condition) / 内部节点(internal node),分支(branches)/ 边界(edges)基于这些节点分裂而出。不再分裂的节点就是计算出来的决策 / 叶节点。在本例中,乘客的幸存或遇难,分别对应于绿色和红色的叶节点。

尽管真实的数据集将会有着更多的特征,前面的树形图仅能算作某颗大树的部分分支,但你无法否认这一算法的简洁性。特征的重要性一目了然,数据间的关系也一眼就能看清。这里的方法通常被称为从数据中构建决策树(learning decision tree from data),上面的树形图也被称作分类树(Classification tree),因为算法在这里的目标是把乘客分为幸存或遇难这两类。回归树(Regression trees)也是同样的方式进行呈现,只是它们所预测的是连续的数值,比如房价。通常来说,决策树算法指的是CART,即分类回归树(Classification And Regression Tree)。

那么,决策树背后机制到底是什么?在形成决策树的过程中,分裂涉及到的问题是选择哪个特征和分裂的条件是什么,同时还要知道何时终止分裂。由于树的生成相对比较武断,你需要对其进行修剪,才能让它看起来更好。让我们先来看一个比较常用的分裂方法。

递归二元分裂

在递归二元分裂(Recursive Binary Splitting)中,所有的特征都会被考虑,各种不同的分裂方式都将被尝试,并使用成本函数(cost function)来评估每种分裂方式的优劣。成本最优(最低)的方法将被选用来进行分裂。

以前面泰坦尼克号数据集的分类树为例:第一次分裂或在根节点时,所有的属性/特征都会纳入进来考虑,训练数据基于这一点被分成不同的组。我们共有3个特征,因此会有3个待定的分裂。而后,我们用一个函数来计算每种分裂所消耗的成本。算法自动选择损失最小的那一个,具体到本例中就是乘客的性别。这个算法本质上是递归的,因为每次分裂所形成的数据组都能以同样的方式再次进行划分。由于这一步骤,该算法也被称为是贪心算法,因为我们极度渴望降低损失。这也使得根节点成为最好的预测器 / 分类点。

分裂成本

让我们进一步讨论用于分类和回归的成本函数。在这两种情况下,成本函数都在试图寻找分裂后结构相似程度最高的那种方式。其中的意义是,我们可以更加确信测试数据将会跟随哪个路径。

Regression : sum(y — prediction)²

回归

比如预测房价:决策树开始分裂时需要考虑训练数据的所有特征;对于训练数据的特定分组,其输入响应的均值会被作为该组的预测值。上述函数会被用在所有的数据点,用以计算所有可能分裂的成本。损失最低的分裂方式将被筛选出来。另一种成本函数涉及到约化和标准差,更多信息可参考这里:http://www.saedsayad.com/decision_tree_reg.htm。

Classification : G = sum(pk * (1 — pk))

分类

为评估某个分裂方式的优劣,我们用Gini分数来衡量训练数据分裂后的混乱程度。其中,pk表示特定分组中相同输入类别所占的比例。当某一数据组的所有输入都来自同一类别时,我们就得到了一个完美分类,此时的pk值不是1就是0,而G必定为0。但如果某个数据组中两个类别的数据各占一半,这就发生了最坏的情况,此时二元分类的pk=0.5, G=0.5。

何时停止分裂?

接下来你可能会问,决策树何时停止分裂?由于一个问题通常有着大量的相关特征,进而生成大量的分裂,形成一个巨大的树形图。如此复杂的树,就容易出现过拟合。因此,我们有必要知道何时来停止分裂。

一种方法是在每个叶节点上设置训练输入量的最小阈值。比如,我们可以把每个叶节点设置成最少要有10位乘客的数据,忽略掉那些乘客数量少于10的叶节点。另一种方法是设定出模型的最大深度。决策树最大深度指的是从根节点到叶节点的最长路径所对应的分裂长度。

修剪

决策树的性能可通过修剪来进一步提升,这就涉及到移除那些特征并不重要的分支。通过这种方式,我们降低了决策树的复杂性,也就是通过降低过拟合程度来提升它的预测能力。

修剪既可从根节点开始,又可从叶节点开始。最简单的办法是从叶节点开始,并移除所有使用该叶节点主分类的节点,如果该操作不会削弱决策树的准确度,这一修剪就可被执行。这一方法被称为减少错误修剪(reduced error pruning)。你还能使用其它更为成熟的修剪方法,比如成本复杂修剪(cost complexity pruning),它用一个学习参数来衡量当前节点的子树大小,并以此来决定是否保留它。这一方法也被称作最弱连接修剪(weakest link pruning)。

分类回归树优点

  • 易于理解、阐释,易于可视化。
  • 决策树潜在进行的是变量筛选(variable screening)或特征选取(feature selection)。
  • 能够处理数值与标注这两类数据,并能处理多输出问题。
  • 对用户而言,决策树仅需相对较少的数据预处理
  • 参数间的非线性关系不会影响决策树性能。

分类回归树缺点

  • 决策树容易创造出过于复杂的树,致使数据泛化不够。这就是所谓的过拟合(overfitting)。
  • 决策树不够稳定,因为数据的微小变化可能会生成一个完全不同的树形图。这被称为变异(variance),需要采取办法进行优化。
  • 贪心算法无法保证所生成的决策树全局最优。这可通过训练多颗树来加以缓解,它们的特征和样本可通过重置随机取样来获得。
  • 如果某些类别的权重过大,决策树就会生成偏差树(biased trees)。因此,在用数据生成决策树前,要注意平衡数据集的特征。

关于决策树的这些概念都非常基础。目前,实现该算法的一个非常流行的库是Scikit-learn。它拥有非常好的API,只需要几行的Python代码,你就能很方便地构建出模型。

原文发布于微信公众号 - AI科技大本营(rgznai100)

原文发表时间:2017-06-14

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏AI研习社

一文详解 Word2vec 之 Skip-Gram 模型(训练篇)

第一部分我们了解 skip-gram 的输入层、隐层、输出层。在第二部分,会继续深入讲如何在 skip-gram 模型上进行高效的训练。 在第一部分讲解完成后,...

7465
来自专栏SIGAI学习与实践平台

深入浅出聚类算法

原创声明:本文为 SIGAI 原创文章,仅供个人学习使用,未经允许,不能用于商业目的。

2120
来自专栏技术随笔

[译] End-to-end people detection in crowded scenes

3726
来自专栏AI科技评论

开发 | Google 软件工程师解读:深度学习的activation function哪家强?

AI科技评论按:本文作者夏飞,清华大学计算机软件学士,卡内基梅隆大学人工智能硕士。现为谷歌软件工程师。本文首发于知乎,AI科技评论获授权转载。 ? TLDR (...

4094
来自专栏iOSDevLog

机器学习术语表机器学习术语表

3487
来自专栏leland的专栏

机器学习算法简介

本文是对机器学习算法的一个概览,以及个人的学习小结。通过阅读本文,可以快速地对机器学习算法有一个比较清晰的了解。

1.5K1
来自专栏数据派THU

独家 | 一文为你解析神经网络(附实例、公式)

原文标题:Introduction To Neural Networks 作者:Ben Gorman 翻译:申利彬 校对:和中华 本文长度为4000字,建议阅读...

2455
来自专栏AI科技评论

卷积神经网络新手指南之二

卷积神经网络新手指南之二 ? 引言 本文将进一步探讨有关卷积神经网络的更多细节,注:以下文章中部分内容较为复杂,为了保证其简明性,部分内容详细解释的研究文献会标...

3727
来自专栏人工智能LeadAI

神经网络中 BP 算法的原理与 Python 实现源码解析

最近这段时间系统性的学习了BP算法后写下了这篇学习笔记,因为能力有限,若有明显错误,还请指出。 ? 目录 1、什么是梯度下降和链式求导法则 2、神经网络的结构 ...

6676
来自专栏SnailTyan

Inception-V3论文翻译——中英文对照

Rethinking the Inception Architecture for Computer Vision Abstract Convolutional...

4450

扫码关注云+社区

领取腾讯云代金券