首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >MLlib中的随机森林和提升方法

MLlib中的随机森林和提升方法

作者头像
花落花飞去
发布2018-02-01 16:48:41
1.3K0
发布2018-02-01 16:48:41
举报
文章被收录于专栏:人工智能人工智能人工智能

本帖是与来自于Origami Logic 的Manish Amd共同撰写的。

Apache Spark 1.2将随机森林梯度提升树(GBT)引入到MLlib中。这两个算法适用于分类和回归,是最成功的且被广泛部署的机器学习方法之一。随机森林和GBT是两类集成学习算法,它们结合了多个决策树,以生成更强大的模型。在这篇文章中,我们将描述这些模型和它们在MLlib中的分布式实现。我们还展示了一些简单的例子,并提供了一些我们该如何开始学习的建议。

集成方法

简而言之,集成学习算法通过组合不同的模型,是建立在其他机器学习方法之上的算法。这种组合可以比任意的单个模型更加强大且准确。

在MLlib 1.2中,我们使用决策树作为基础模型。我们提供了两种集成方法:随机森林梯度提升树(GBT)。这两种算法的主要区别在于集成模型中每个树部件的训练顺序。

随机森林使用数据的随机样本独立地训练每棵树。这种随机性有助于使模型比单个决策树更健壮,而且不太可能会在训练数据上过拟合。

GBT(梯度提升树)每次只训练一棵树,每棵新树帮助纠正先前训练过的树所产生的错误。随着每一棵新树的加入,模型变得更加具有表现力。

最后,这两种方法都会产生一个决策树的加权集合。集成模型通过结合所有单个树的结果进行预测。下图显示了一个采用三棵树进行集成的简单例子。

合奏示例
合奏示例

在上面的集成回归的例子中,每棵树都预测了一个实值。然后将这三个预测结合起来获得集成模型的最终预测。在这里,我们使用均值来将结合不同的预测值(但具体的算法设计时,需要根据预测任务的特点来使用不同的技术)。

分布式集成学习

在MLlib中,随机森林和GBT(梯度提升树)通过实例(行)来对数据进行划分。该实现建立在最初的决策树代码之上,该代码实现了单个决策树的学习(在较早的博客文章中进行了描述)。我们的许多优化都基于Google的PLANET项目,这是发表过的、在分布式环境下进行决策树集成学习的主要作品之一。

随机森林:由于随机森林中的每棵树都是独立训练的,所以可以并行地训练多棵树(作为并行化训练单颗树的补充)。MLlib正是这样做的:并行地训练可变数目的子树,这里的子树的数目根据内存约束在每次迭代中都进行优化。

GBT:由于GBT(梯度提升树)必须一次训练一棵树,所以训练只在单颗树的水平上进行并行化。

我们想强调在MLlib中使用的两个关键优化:

  • 内存:随机森林使用不同的数据子样本来训练每棵树。我们不使用显式复制数据,而是使用TreePoint结构来保存内存信息,该结构存储每个子样本中每个实例的副本数量。
  • 通信:在决策树中的每个决策节点,决策树通常是通过从所有特征中选择部分特征来进行训练的,随机森林经常在每个节点将特征的选择限制在某个随机子集上。MLlib的实现利用了这种二次采样的优点来减少通信开销:例如,如果在每个节点只使用1/3的特征,那么我们可以将通信减少到原来的1/3。

更多的详细信息,请参见“MLlib编程指南”中的“集成”部分

使用MLlib集成

我们演示如何使用MLlib来学习集成模型。以下Scala示例展示了如何读取数据集、将数据拆分为训练集和测试集、学习模型、打印模型和测试其精度。有关Java和Python中的示例,请参阅MLlib编程指南。请注意,GBT(梯度提升树)还没有Python API,但我们预计它将在Spark 1.3的发行版中出现(通过Github PR 3951)。

随机森林示例
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.util.MLUtils
// 加载并解析数据文件。
val data =
  MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// 将数据拆分为训练/测试集
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))
// 训练随机森林模型。
val treeStrategy = Strategy.defaultStrategy("Classification")
val numTrees = 3 // 在实际中使用更多的numTrees
val featureSubsetStrategy = "auto" // 让算法进行选择。
val model = RandomForest.trainClassifier(trainingData,
  treeStrategy, numTrees, featureSubsetStrategy, seed = 12345)
// 在测试实例上评估模型并计算测试错误
val testErr = testData.map { point =>
  val prediction = model.predict(point.features)
  if (point.label == prediction) 1.0 else 0.0
}.mean()
println("Test Error = " + testErr)
println("Learned Random Forest:n" + model.toDebugString)
梯度提升树示例
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.util.MLUtils
// 加载并解析数据文件。
val data =
  MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// 将数据拆分为训练/测试集
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))
// 训练梯度提升树模型。
val boostingStrategy =
  BoostingStrategy.defaultParams("Classification")
boostingStrategy.numIterations = 3 // 注意: 在实际中使用更多的numIterations
val model =
  GradientBoostedTrees.train(trainingData, boostingStrategy)
// 在测试实例上评估模型并计算测试错误
val testErr = testData.map { point =>
  val prediction = model.predict(point.features)
  if (point.label == prediction) 1.0 else 0.0
}.mean()
println("Test Error = " + testErr)
println("Learned GBT model:n" + model.toDebugString)

可扩展性

我们利用一些关于二元分类问题的实证结果展示了MLlib集成学习的可扩展性。下面的每张图比较了梯度增强树("GBT")和随机森林("RF"),这些图中的树被构建到不同的最大深度。

这些测试是在一个根据音频特征来预测歌曲发行日期的回归任务上进行的(特征来自UCI(加州大学尔湾分校)的ML(机器学习)库的YearPredictionMSD数据集)。我们使用EC2 r3.2xlarge机器。除另有说明外,算法参数保持为默认值。

扩展模型大小:训练时间和测试错误

下面的两幅图显示了增加集成模型中树的数量时的效果。对于两者而言,增加树的个数需要更长的时间来学习(第一张图),但在测试时的均方误差(MSE)上却获得了更好的结果(第二张图)。

这两种方法相比较,随机森林训练速度更快,但是他们通常比GBT(梯度提升树)需要训练更深的树来达到相同的误差。GBT(梯度提升树)可以进一步减少每次迭代的误差,但是经过多次迭代后,他们可能开始过拟合(即增加了测试的误差)。随机森林不容易过拟合,但他们的测试错误趋于平稳,无法进一步降低。

合奏 - 树x时间
合奏 - 树x时间

为了解MSE均方误差的基础,以下请注意,最左边的点显示了使用单个决策树时的错误率(深度分别为2、5或10)。

合奏 - 树x mse
合奏 - 树x mse

详情:463715个训练实例,16个工作节点

扩展训练数据集大小:训练时间和测试错误

接下来的两张图片显示了使用更大的训练数据集时的效果。在有更多的数据时,这两种方法都需要更长时间的训练,但取得了更好的测试结果。

合奏 -  ntrain x时间
合奏 - ntrain x时间
合奏 -  ntrain x mse
合奏 - ntrain x mse

详细信息:16个工作节点。

强大的扩展:利用更多的工作节点完成更快的训练

最后这张图显示了使用更大的计算集群来解决同一个问题时的效果。使用更多的工作节点时,这两种方法都会变快很多。例如,利用深度为2的树进行GBT(梯度提升树)集成训练时,在16个工作节点上训练的速度比在2个工作节点上快4.7倍,较大的数据集能够产生更大倍数的加速。

合奏 - 工人x时间
合奏 - 工人x时间

详情:有463715个训练实例。

下一步有什么?

GBT将很快包含有一个Python API。未来发展的另一个重点是可插拔性:集成方法几乎可以应用在任何分类或回归算法上,而不仅仅是决策树。由Spark 1.2中实验性spark.ml包引入的管道 API 将使我们能够将集成学习方法拓展为真正可插拔的算法。

要开始自己使用决策树,请下载Spark 1.2

进一步阅读

致谢

MLlib集成学习算法是由本文的作者李奇平(阿里巴巴)、宋钟(Alpine数据实验室)和Davies·刘(Databricks)合作开发的。我们也感谢Lee Yang,Andrew Feng和Hirakendu Das(雅虎)在设计和测试方面的帮助。我们也欢迎您的贡献!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 集成方法
  • 分布式集成学习
  • 使用MLlib集成
    • 随机森林示例
      • 梯度提升树示例
      • 可扩展性
        • 扩展模型大小:训练时间和测试错误
          • 扩展训练数据集大小:训练时间和测试错误
            • 强大的扩展:利用更多的工作节点完成更快的训练
            • 下一步有什么?
            • 进一步阅读
            • 致谢
            领券
            问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档