深入机器学习系列7-Random Forest

1 Bagging

  采用自助采样法()采样数据。给定包含个样本的数据集,我们先随机取出一个样本放入采样集中,再把该样本放回初始数据集,使得下次采样时,样本仍可能被选中, 这样,经过次随机采样操作,我们得到包含个样本的采样集。

  按照此方式,我们可以采样出个含个训练样本的采样集,然后基于每个采样集训练出一个基本学习器,再将这些基本学习器进行结合。这就是的一般流程。在对预测输出进行结合时,通常使用简单投票法, 对回归问题使用简单平均法。若分类预测时,出现两个类收到同样票数的情形,则最简单的做法是随机选择一个,也可以进一步考察学习器投票的置信度来确定最终胜者。

  的算法描述如下图所示。

2随机森林

  随机森林是的一个扩展变体。随机森林在以决策树为基学习器构建集成的基础上,进一步在决策树的训练过程中引入了随机属性选择。具体来讲,传统决策树在选择划分属性时, 在当前节点的属性集合(假设有个属性)中选择一个最优属性;而在随机森林中,对基决策树的每个节点,先从该节点的属性集合中随机选择一个包含个属性的子集,然后再从这个子集中选择一个最优属性用于划分。 这里的参数控制了随机性的引入程度。若令,则基决策树的构建与传统决策树相同;若令,则是随机选择一个属性用于划分。在中,有两种选择用于分类,即、; 一种选择用于回归,即。在源码分析中会详细介绍。

  可以看出,随机森林对只做了小改动,但是与中基学习器的“多样性”仅仅通过样本扰动(通过对初始训练集采样)而来不同,随机森林中基学习器的多样性不仅来自样本扰动,还来自属性扰动。 这使得最终集成的泛化性能可通过个体学习器之间差异度的增加而进一步提升。

3 随机森林在分布式环境下的优化策略

  随机森林算法在单机环境下很容易实现,但在分布式环境下特别是在平台上,传统单机形式的迭代方式必须要进行相应改进才能适用于分布式环境 ,这是因为在分布式环境下,数据也是分布式的,算法设计不得当会生成大量的操作,例如频繁的网络数据传输,从而影响算法效率。 因此,在上进行随机森林算法的实现,需要进行一定的优化,中的随机森林算法主要实现了三个优化策略:

1).切分点抽样统计,如下图所示。在单机环境下的决策树对连续变量进行切分点选择时,一般是通过对特征点进行排序,然后取相邻两个数之间的点作为切分点,这在单机环境下是可行的,但如果在分布式环境下如此操作的话, 会带来大量的网络传输操作,特别是当数据量达到级时,算法效率将极为低下。为避免该问题,中的随机森林在构建决策树时,会对各分区采用一定的子特征策略进行抽样,然后生成各个分区的统计数据,并最终得到切分点。 (从源代码里面看,是先对样本进行抽样,然后根据抽样样本值出现的次数进行排序,然后再进行切分)。

2).特征装箱(),如下图所示。决策树的构建过程就是对特征的取值不断进行划分的过程,对于离散的特征,如果有个值,最多有个划分。如果值是有序的,那么就最多个划分。 比如年龄特征,有老,中,少3个值,如果无序有个划分,即。;如果是有序的,即按老,中,少的序,那么只有个,即2种划分,。 对于连续的特征,其实就是进行范围划分,而划分的点就是(切分点),划分出的区间就是。对于连续特征,理论上是无数的,在分布环境下不可能取出所有的值,因此它采用的是切点抽样统计方法。

3).逐层训练(),如下图所示。单机版本的决策树生成过程是通过递归调用(本质上是深度优先)的方式构造树,在构造树的同时,需要移动数据,将同一个子节点的数据移动到一起。 此方法在分布式数据结构上无法有效的执行,而且也无法执行,因为数据太大,无法放在一起,所以在分布式环境下采用的策略是逐层构建树节点(本质上是广度优先),这样遍历所有数据的次数等于所有树中的最大层数。 每次遍历时,只需要计算每个节点所有切分点统计参数,遍历完后,根据节点的特征划分,决定是否切分,以及如何切分。

4 使用实例

下面的例子用于分类。

(提示:代码块部分可以左右滑动屏幕完整查看哦)

  下面的例子用于回归。

5 源码分析5.1 训练分析

训练过程简单可以分为两步,第一步是初始化,第二步是迭代构建随机森林。这两大步还分为若干小步,下面会分别介绍这些内容。

5.1.1 初始化

初始化的第一步就是决策树元数据信息的构建。它的代码如下所示。

初始化的第二步就是找到切分点(`splits`)及箱子信息(`Bins`)。这时,调用了`DecisionTree.findSplitsBins`方法,进入该方法了解详细信息。

我们进入findSplitsBinsBySorting方法了解Sort分裂策略的实现。

计算连续特征的所有切分位置需要调用方法`findSplitsForContinuousFeature`方法。

5.1.2 迭代构建随机森林

这里有两点需要重点介绍,第一点是取得每个树所有需要切分的节点,通过RandomForest.selectNodesToSplit方法实现;第二点是找出最优的切分,通过DecisionTree.findBestSplits方法实现。下面分别介绍这两点。

取得每个树所有需要切分的节点

选中最优切分

5.2 预测分析

在利用随机森林进行预测时,调用的predict方法扩展自TreeEnsembleModel,它是树结构组合模型的表示,其核心代码如下所示:

参考文献

【1】机器学习.周志华

【2】Spark 随机森林算法原理、源码分析及案例实战

【3】Scalable Distributed Decision Trees in Spark MLlib

本文来自企鹅号 - 智子AI媒体

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏人工智能LeadAI

什么!卷积要旋转180度?!

一看这个标题就会想,这有什么大惊小怪的,可能好多人觉得这是个脑残话题,但我确实误解了两三年……

711
来自专栏编程

关于反向传播在Python中应用的入门教程

我来这里的目的是为了测试我对于Karpathy的博客《骇客的神经网络指导》以及Python的理解,也是为了掌握最近精读的Derek Banas的文章《令人惊奇的...

1887
来自专栏AILearning

卷积神经网络

注意:本教程面向TensorFlow 的高级用户,并承担机器学习方面的专业知识和经验。 概观 CIFAR-10分类是机器学习中常见的基准问题。问题是将R...

19110
来自专栏利炳根的专栏

学习笔记CB010:递归神经网络、LSTM、自动抓取字幕

递归神经网络(RNN),时间递归神经网络(recurrent neural network),结构递归神经网络(recursive neural network...

5694
来自专栏人工智能

使用Keras在训练深度学习模型时监控性能指标

Keras库提供了一套供深度学习模型训练时的用于监控和汇总的标准性能指标并且开放了接口给开发者使用。

1.3K10
来自专栏鸿的学习笔记

写给开发者的机器学习指南(四)

查全率是定义由给定查询和数据语料库的算法检索的相关性的大小。因此,给定一组文档和应该返回这些文档的子集的查询,查全率的值表示实际返回了多少相关文档。 此值计算如...

531
来自专栏刁寿钧的专栏

使用 Tensorflow 构建 CNN 进行情感分析实践

本次实验参照的是 Kim Yoon 的论文,代码放在我的 github 上,可直接使用。

2.5K1
来自专栏瓜大三哥

基于FPGA的均值滤波(二)

基于FPGA的均值滤波(二) 之一维求和模块 均值滤波按照整体设计可以分为以下几个子模块: (1)一维求和模块,这里记为sum_1D; (2)二维求和模块,这里...

2519
来自专栏目标检测和深度学习

如何从零开发一个复杂深度学习模型

深度学习框架中涉及很多参数,如果一些基本的参数如果不了解,那么你去看任何一个深度学习框架是都会觉得很困难,下面介绍几个新手常问的几个参数。 batch 深度学习...

4607
来自专栏机器之心

NIPS 2018 | 将RNN内存占用缩小90%:多伦多大学提出可逆循环神经网络

循环神经网络(RNN)在语音识别 [1]、语言建模 [2,3] 和机器翻译 [4,5] 等多种任务上都取得了极优的性能。然而,训练 RNN 需要大量的内存。标准...

704

扫码关注云+社区