专栏首页人工智能MLlib中的随机森林和提升方法

MLlib中的随机森林和提升方法

本帖是与来自于Origami Logic 的Manish Amd共同撰写的。

Apache Spark 1.2将随机森林梯度提升树(GBT)引入到MLlib中。这两个算法适用于分类和回归,是最成功的且被广泛部署的机器学习方法之一。随机森林和GBT是两类集成学习算法,它们结合了多个决策树,以生成更强大的模型。在这篇文章中,我们将描述这些模型和它们在MLlib中的分布式实现。我们还展示了一些简单的例子,并提供了一些我们该如何开始学习的建议。

集成方法

简而言之,集成学习算法通过组合不同的模型,是建立在其他机器学习方法之上的算法。这种组合可以比任意的单个模型更加强大且准确。

在MLlib 1.2中,我们使用决策树作为基础模型。我们提供了两种集成方法:随机森林梯度提升树(GBT)。这两种算法的主要区别在于集成模型中每个树部件的训练顺序。

随机森林使用数据的随机样本独立地训练每棵树。这种随机性有助于使模型比单个决策树更健壮,而且不太可能会在训练数据上过拟合。

GBT(梯度提升树)每次只训练一棵树,每棵新树帮助纠正先前训练过的树所产生的错误。随着每一棵新树的加入,模型变得更加具有表现力。

最后,这两种方法都会产生一个决策树的加权集合。集成模型通过结合所有单个树的结果进行预测。下图显示了一个采用三棵树进行集成的简单例子。

在上面的集成回归的例子中,每棵树都预测了一个实值。然后将这三个预测结合起来获得集成模型的最终预测。在这里,我们使用均值来将结合不同的预测值(但具体的算法设计时,需要根据预测任务的特点来使用不同的技术)。

分布式集成学习

在MLlib中,随机森林和GBT(梯度提升树)通过实例(行)来对数据进行划分。该实现建立在最初的决策树代码之上,该代码实现了单个决策树的学习(在较早的博客文章中进行了描述)。我们的许多优化都基于Google的PLANET项目,这是发表过的、在分布式环境下进行决策树集成学习的主要作品之一。

随机森林:由于随机森林中的每棵树都是独立训练的,所以可以并行地训练多棵树(作为并行化训练单颗树的补充)。MLlib正是这样做的:并行地训练可变数目的子树,这里的子树的数目根据内存约束在每次迭代中都进行优化。

GBT:由于GBT(梯度提升树)必须一次训练一棵树,所以训练只在单颗树的水平上进行并行化。

我们想强调在MLlib中使用的两个关键优化:

  • 内存:随机森林使用不同的数据子样本来训练每棵树。我们不使用显式复制数据,而是使用TreePoint结构来保存内存信息,该结构存储每个子样本中每个实例的副本数量。
  • 通信:在决策树中的每个决策节点,决策树通常是通过从所有特征中选择部分特征来进行训练的,随机森林经常在每个节点将特征的选择限制在某个随机子集上。MLlib的实现利用了这种二次采样的优点来减少通信开销:例如,如果在每个节点只使用1/3的特征,那么我们可以将通信减少到原来的1/3。

更多的详细信息,请参见“MLlib编程指南”中的“集成”部分

使用MLlib集成

我们演示如何使用MLlib来学习集成模型。以下Scala示例展示了如何读取数据集、将数据拆分为训练集和测试集、学习模型、打印模型和测试其精度。有关Java和Python中的示例,请参阅MLlib编程指南。请注意,GBT(梯度提升树)还没有Python API,但我们预计它将在Spark 1.3的发行版中出现(通过Github PR 3951)。

随机森林示例

import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.util.MLUtils
// 加载并解析数据文件。
val data =
  MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// 将数据拆分为训练/测试集
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))
// 训练随机森林模型。
val treeStrategy = Strategy.defaultStrategy("Classification")
val numTrees = 3 // 在实际中使用更多的numTrees
val featureSubsetStrategy = "auto" // 让算法进行选择。
val model = RandomForest.trainClassifier(trainingData,
  treeStrategy, numTrees, featureSubsetStrategy, seed = 12345)
// 在测试实例上评估模型并计算测试错误
val testErr = testData.map { point =>
  val prediction = model.predict(point.features)
  if (point.label == prediction) 1.0 else 0.0
}.mean()
println("Test Error = " + testErr)
println("Learned Random Forest:n" + model.toDebugString)

梯度提升树示例

import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.util.MLUtils
// 加载并解析数据文件。
val data =
  MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// 将数据拆分为训练/测试集
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))
// 训练梯度提升树模型。
val boostingStrategy =
  BoostingStrategy.defaultParams("Classification")
boostingStrategy.numIterations = 3 // 注意: 在实际中使用更多的numIterations
val model =
  GradientBoostedTrees.train(trainingData, boostingStrategy)
// 在测试实例上评估模型并计算测试错误
val testErr = testData.map { point =>
  val prediction = model.predict(point.features)
  if (point.label == prediction) 1.0 else 0.0
}.mean()
println("Test Error = " + testErr)
println("Learned GBT model:n" + model.toDebugString)

可扩展性

我们利用一些关于二元分类问题的实证结果展示了MLlib集成学习的可扩展性。下面的每张图比较了梯度增强树("GBT")和随机森林("RF"),这些图中的树被构建到不同的最大深度。

这些测试是在一个根据音频特征来预测歌曲发行日期的回归任务上进行的(特征来自UCI(加州大学尔湾分校)的ML(机器学习)库的YearPredictionMSD数据集)。我们使用EC2 r3.2xlarge机器。除另有说明外,算法参数保持为默认值。

扩展模型大小:训练时间和测试错误

下面的两幅图显示了增加集成模型中树的数量时的效果。对于两者而言,增加树的个数需要更长的时间来学习(第一张图),但在测试时的均方误差(MSE)上却获得了更好的结果(第二张图)。

这两种方法相比较,随机森林训练速度更快,但是他们通常比GBT(梯度提升树)需要训练更深的树来达到相同的误差。GBT(梯度提升树)可以进一步减少每次迭代的误差,但是经过多次迭代后,他们可能开始过拟合(即增加了测试的误差)。随机森林不容易过拟合,但他们的测试错误趋于平稳,无法进一步降低。

为了解MSE均方误差的基础,以下请注意,最左边的点显示了使用单个决策树时的错误率(深度分别为2、5或10)。

详情:463715个训练实例,16个工作节点

扩展训练数据集大小:训练时间和测试错误

接下来的两张图片显示了使用更大的训练数据集时的效果。在有更多的数据时,这两种方法都需要更长时间的训练,但取得了更好的测试结果。

详细信息:16个工作节点。

强大的扩展:利用更多的工作节点完成更快的训练

最后这张图显示了使用更大的计算集群来解决同一个问题时的效果。使用更多的工作节点时,这两种方法都会变快很多。例如,利用深度为2的树进行GBT(梯度提升树)集成训练时,在16个工作节点上训练的速度比在2个工作节点上快4.7倍,较大的数据集能够产生更大倍数的加速。

详情:有463715个训练实例。

下一步有什么?

GBT将很快包含有一个Python API。未来发展的另一个重点是可插拔性:集成方法几乎可以应用在任何分类或回归算法上,而不仅仅是决策树。由Spark 1.2中实验性spark.ml包引入的管道 API 将使我们能够将集成学习方法拓展为真正可插拔的算法。

要开始自己使用决策树,请下载Spark 1.2

进一步阅读

致谢

MLlib集成学习算法是由本文的作者李奇平(阿里巴巴)、宋钟(Alpine数据实验室)和Davies·刘(Databricks)合作开发的。我们也感谢Lee Yang,Andrew Feng和Hirakendu Das(雅虎)在设计和测试方面的帮助。我们也欢迎您的贡献!

本文的版权归 DeepValley 所有,如需转载请联系作者。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 如何重构你的时间序列预测问题

    你不必按照原样对你的时间序列预测问题进行建模。

    人工智能资讯小编
  • 基于TensorFlow的循环神经网络生成矢量格式的伪造汉字

    注意:对于中文汉字和日文汉字我根据具体情况交替使用它们。

    人工智能资讯小编
  • 神经张量网络:探索文本实体之间的关系

    在这篇文章中,我将介绍神经张量网络(NTN),如在用神经张量网络推理知识库的推理中所描述的那样 。我的NTN实现使用最新版本的Python 2.7,Keras...

    人工智能资讯小编
  • 【DB笔试面试730】在Oracle中,如果需要修改网卡、子网、网段等信息,那么应该如何操作?

    Oracle 11g RAC中的IP主要有:Public IP、VIP、SCAN VIP、Private IP这几种。一般这类改IP地址或者网卡名称的需求主要场...

    小麦苗DBA宝典
  • 深度 | 苹果揭秘“Hey Siri”的开发细节,原来不仅有两步检测,还能辨别说话人

    AI科技评论按:苹果的新一期机器学习开发日记来了~ 这次苹果介绍了通过讲话就能唤醒Siri的“Hey Siri”功能是如何从技术上实现的,同时也介绍了为了从用户...

    AI科技评论
  • K近邻算法小结

    什么是K近邻? K近邻一种非参数学习的算法,可以用在分类问题上,也可以用在回归问题上。 什么是非参数学习? 一般而言,机器学习算法都有相应的参数要学习,比如线...

    用户1631856
  • 解决:Failed to execute goal org.apache.maven.plugins:maven-deploy-plugin:2.8.2:deploy (default-deploy)

    1. 执行 mvn clean deploy ... 想把 jar 包更新到私服仓库,报错:

    微风-- 轻许--
  • Spark Streaming——Spark第一代实时计算引擎

    虽然SparkStreaming已经停止更新,Spark的重点也放到了 Structured Streaming ,但由于Spark版本过低或者其他技术选型问题...

    用户6070864
  • 解决 VS2012/2013/2015 下载帮助文档速度慢

    用过 VS2012 以上版本的人心里肯定清楚,想通过 Help Viewer 去下载帮助文档,那速度简直无法忍受,选择几个项目一晚上甚至几天都下载不完。很多人被...

    我与梦想有个约会
  • 整洁面向对象分层架构 (Clean Object-Oriented and Layered Architecture)

    在计算机领域,“分层” 概念无处不在。比如 web 开发时的 MVC ,网络编程时的 OSI 参考模型和 TCP/IP 协议族。

    一个会写诗的程序员

扫码关注云+社区

领取腾讯云代金券