Apache Spark 2.0预览:机器学习模型持久性

使用在Databricks中的笔记

介绍

机器学习(ML)的应用场景:

  • 数据科学家生成一个ML模型,并让工程团队将其部署在生产环境中。
  • 每个数据引擎集成一个Python模型训练集和一个Java模型服务集。
  • 数据科学家创任务去训练各种ML模型,然后将它们保存并进行评估。

以上所有应用场景在模型持久性、保存和加载模型的能力方面都更为容易。随着Apache Spark 2.0即将发布,Spark的机器学习库MLlib将在DataFrame-based的API中对ML提供长期的近乎完整的支持。本博客给出了关于它的早期概述、代码示例以及MLlib的持久性API的一些细节。

ML持久性的关键特征包括:

  • 支持所有Spark API中使用的语言:Scala,Java,Python&R
  • 支持几乎所有的DataFrame-based的API中的ML算法
  • 支持单个模型和完整的Pipelines,包括非适应(a recipe)和适应(a result)
  • 使用可交换格式的分布式存储

感谢所有帮助MLlib实现飞跃的社区贡献者!参阅JIRA获取Scala / JavaPythonR贡献者的完整名单。

学习API

在Apache Spark 2.0中,MLlib的DataFrame-based的API在Spark上占据了ML的重要地位(请参阅曾经的博客文章获取针对此API的介绍以及它所介绍的“Pipelines”概念)。此MLlib的DataFrame-based的API提供了用于保存和加载模拟相似的Spark Data Source API模型的功能。

我们将用多种编程语言演示保存和加载模型,使用流行的MNIST数据集进行手写数字识别(LeCun et al., 1998; 可从LibSVM数据集页面获得)。该数据集包含手写数字0-9,以及地面实况标签。几个例子:

我们的目标是通过拍摄手写的数字然后识别图像中的数字。点击笔记获取完整的加载数据、填充模型、保存和加载它们的完整示例代码。

保存和加载单个模型

我们首先给出如何保存和加载单个模型以在语言之间共享。我们使用Python语言填充Random Forest Classifier并保存,然后使用Scala语言加载这个模型。

training = sqlContext.read…  # data: features, label
rf = RandomForestClassifier(numTrees=20)
model = rf.fit(training)

我们可以调用save方法来轻松地保存这个模型,调用load方法来加载模型:

model.save("myModelPath")
sameModel = RandomForestClassificationModel.load("myModelPath")

我们还可以加载模型(之前使用Python语言保存的)到一个Scala应用或者一个Java应用中:

// Load the model in Scala
val sameModel = RandomForestClassificationModel.load("myModelPath")

这种用法适用于小型的局部模型,例如K-Means模型(用于聚类),也适用于大型分布式模型,如ALS模型(推荐使用的场景)。因为加载到的模型具有相同的参数和数据,所以即使模型部署在完全不同的Spark上也会返回相同的预测结果。

保存和加载完整的Pipelines

我们目前只讨论了保存和加载单个ML模型。在实际应用中,ML工作流程包括许多阶段,从特征提取及转换到模型的拟合和调整。MLlib提供Pipelines来帮助用户构建这些工作流程。(点击笔记获取使用ML Pipelines分析共享自行车数据集的教程。)

MLlib允许用户保存和加载整个Pipelines。我们来看一个在Pipeline上完成这些步骤的例子:

  • 特征提取:二进制转换器将图像转换为黑白图像
  • 模型拟合:Random Forest Classifier拍摄图像并预测数字0-9
  • 调整:交叉验证以调整森林中树木的深度

这是我们的笔记中生成这个管道的一个部分代码:

// Construct the Pipeline: Binarizer + Random Forest
val pipeline = new Pipeline().setStages(Array(binarizer, rf))
// Wrap the Pipeline in CrossValidator to do model tuning.
val cv = new CrossValidator().setEstimator(pipeline) … 

在我们填充这个Pipeline之前,我们将展示我们可以保存整个工作流程(在填充之前)。这个工作流程稍后可以加载到另一个在Spark集群上运行的数据集。

cv.save("myCVPath")
val sameCV = CrossValidator.load("myCVPath")

最后,我们填充Pipeline并保存,然后把它加载回来。这节省了特征提取步骤、交叉验证调整后的Random Forest模型的步骤,模型调整过程中的统计步骤。

val cvModel = cv.fit(training)
cvModel.save("myCVModelPath")
val sameCVModel = CrossValidatorModel.load("myCVModelPath") 

了解详细信息

Python调整

Spark 2.0中缺少Python的调整部分。Python目前还不支持保存和加载用于调整模型超参数的CrossValidator和TrainValidationSplit, 这个问题将在Spark 2.1(SPARK-13​​786)中进行考虑。尽管如此,我们仍可以保存Python中的CrossValidator和TrainValidationSplit的结果。例如我们使用交叉验证来调整Random Forest,然后调整过程中找到的最佳模型并保存。

Define the workflow
rf = RandomForestClassifier()
cv = CrossValidator(estimator=rf, …)
Fit the model, running Cross-Validation
cvModel = cv.fit(trainingData)
Extract the results, i.e., the best Random Forest model
bestModel = cvModel.bestModel
Save the RandomForest model
bestModel.save("rfModelPath")

点击笔记查看完整代码。

可交换的存储格式

在内部,我们将模型元数据和参数保存为JSON和Parquet格式。这些存储格式是可交换的并且可以使用其他库进行读取。我们能够使用Parquet 存储小模型(如朴素贝叶斯分类)和大型分布式模型(如推荐的ALS)。存储路径可以是任何URI支持的可以进行保存和加载的Dataset / DataFrame,还包括S3、本地存储等路径。

语言交叉兼容性

模型可以在Scala、Java和Python中轻松地进行保存和加载。R语言有两个限制,首先,R并非支持全部的MLlib模型,所以并不是所有使用其他语言训练过的模型都可以使用R语言加载。第二,R语言模型的格式还存储了额外数据,所以用其他语言加载使用R语言训练和保存后的模型有些困难(供参考的笔记本)。在不久的将来R语言将会有更好的跨语言支持。

总结

随着即将到来的2.0版本的发布,DataFrame-based的MLlib API将为持久化模型和Pipelines提供近乎全面的覆盖。持久性对于在团队之间共享模型、创建多语言ML工作流以及将模型转移到生产环境至关重要。准备将DataFrame-based的MLlib API变成Apache Spark中的机器学习的主要API是这项功能的最后一部分。

接下来?

高优先级的项目包括完整的持久性覆盖,包括Python模型调整算法以及R和其他语言API之间的兼容性改进。

使用Scala和Python的教程笔记开始。您也可以只更新您当前的MLlib工作流程以使用保存和加载功能。

实验性功能:使用在Apache Spark2.0的分支(Databricks Community Edition中的测试代码)预览版中的API。加入beta版的等待名单。

阅读更多

本文的版权归 用户1208073 所有,如需转载请联系作者。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏机器之心

教程 | 从预处理到部署:如何使用Lore快速构建机器学习模型

选自Medium 作者:Montana Low 机器之心编译 参与:李诗萌、思源 机器学习的构建和部署通常需要非常多的工作与努力,这对于软件开发者和入门者造成了...

41150
来自专栏新智元

深度学习开源框架PaddlePaddle发布新版API,简化深度学习编程

【新智元导读】 本文来自百度PaddlePaddle团队成员骆涛,他在文章中介绍了百度深度学习开源框架Paddlepaddle新推出的API,它们能更好地支持分...

31270
来自专栏企鹅号快讯

重合散点图绘制:neat

hello诸君,暖阳高照,午间一杯清茶,又到了爬虫俱乐部向大家种草新命令新方法的时候啦! 许多同学学到的第一个Stata绘图命令想必就是scatter命令,该命...

29990
来自专栏新智元

【解放程序员】MIT“创世纪”机器学习新系统,自动生成补丁修复Bug

【新智元导读】当您辛辛苦苦写了大半年程序,终于要享受一下国庆长假的时候,别让 bug 把您的假期毁了。MIT 研究团队开发了一个称为“创世纪”的系统,能够对以前...

37650
来自专栏媒矿工厂

视频编码的GPU加速

前言 随着视频编解码技术的不断发展,视频逐步向着高清晰、高动态、高数据量的方向演进。这对视频编解码终端的计算能力提出了越来越高的要求。同时,在GPU领域,随着C...

71540
来自专栏腾讯开源的专栏

【开源公告】腾讯第三代高性能计算平台Angel 正式全面开源

Angel 项目简介 Angel是一个基于参数服务器(Parameter Server)理念开发的高性能分布式机器学习框架,在其之上,用户能轻松开发适用于高维度...

44970
来自专栏大数据挖掘DT机器学习

Python+Hadoop 从DBLP数据库中挖掘经常一起写作的合作者

任务描述: 本文的写作目的是从DBLP数据库中找到经常一起写作的合作者。熟悉数据挖掘中频繁项挖掘的经典算法(FP-Growth)并作出改进和优化。实验代码...

43250
来自专栏编程

安卓手机如何玩转动作手势检测?有TensorFlow就够了,附实用教程

? 原文来源:Lemberg Solutions Ltd 作者:Zahra Mahoor、Jack Felag、 Josh Bongard 编译:嗯~阿童木呀...

44970
来自专栏程序员宝库

使用机器学习预测天气(第一部分)

本章是使用机器学习预测天气系列教程的第一部分,使用Python和机器学习来构建模型,根据从Weather Underground收集的数据来预测天气温度。

53550
来自专栏吉浦迅科技

为啥在Matlab上用NVIDIA Titan V训练的速度没有GTX1080快?

在Matlab官方论坛上看到这个帖子,希望给大家带来参考 有一天,有人在Matlab的论坛上发出了求救帖: ? 楼主说: 我想要加快我的神经网络训练,所以把G...

58380

扫码关注云+社区

领取腾讯云代金券