前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Apache Spark 2.0预览:机器学习模型持久性

Apache Spark 2.0预览:机器学习模型持久性

作者头像
花落花飞去
发布2018-02-01 16:45:09
1.9K0
发布2018-02-01 16:45:09
举报
文章被收录于专栏:人工智能人工智能

使用在Databricks中的笔记

介绍

机器学习(ML)的应用场景:

  • 数据科学家生成一个ML模型,并让工程团队将其部署在生产环境中。
  • 每个数据引擎集成一个Python模型训练集和一个Java模型服务集。
  • 数据科学家创任务去训练各种ML模型,然后将它们保存并进行评估。

以上所有应用场景在模型持久性、保存和加载模型的能力方面都更为容易。随着Apache Spark 2.0即将发布,Spark的机器学习库MLlib将在DataFrame-based的API中对ML提供长期的近乎完整的支持。本博客给出了关于它的早期概述、代码示例以及MLlib的持久性API的一些细节。

ML持久性的关键特征包括:

  • 支持所有Spark API中使用的语言:Scala,Java,Python&R
  • 支持几乎所有的DataFrame-based的API中的ML算法
  • 支持单个模型和完整的Pipelines,包括非适应(a recipe)和适应(a result)
  • 使用可交换格式的分布式存储

感谢所有帮助MLlib实现飞跃的社区贡献者!参阅JIRA获取Scala / JavaPythonR贡献者的完整名单。

学习API

在Apache Spark 2.0中,MLlib的DataFrame-based的API在Spark上占据了ML的重要地位(请参阅曾经的博客文章获取针对此API的介绍以及它所介绍的“Pipelines”概念)。此MLlib的DataFrame-based的API提供了用于保存和加载模拟相似的Spark Data Source API模型的功能。

我们将用多种编程语言演示保存和加载模型,使用流行的MNIST数据集进行手写数字识别(LeCun et al., 1998; 可从LibSVM数据集页面获得)。该数据集包含手写数字0-9,以及地面实况标签。几个例子:

手写数字的屏幕截图。
手写数字的屏幕截图。

我们的目标是通过拍摄手写的数字然后识别图像中的数字。点击笔记获取完整的加载数据、填充模型、保存和加载它们的完整示例代码。

保存和加载单个模型

我们首先给出如何保存和加载单个模型以在语言之间共享。我们使用Python语言填充Random Forest Classifier并保存,然后使用Scala语言加载这个模型。

代码语言:javascript
复制
training = sqlContext.read…  # data: features, label
rf = RandomForestClassifier(numTrees=20)
model = rf.fit(training)

我们可以调用save方法来轻松地保存这个模型,调用load方法来加载模型:

代码语言:javascript
复制
model.save("myModelPath")
sameModel = RandomForestClassificationModel.load("myModelPath")

我们还可以加载模型(之前使用Python语言保存的)到一个Scala应用或者一个Java应用中:

代码语言:javascript
复制
// Load the model in Scala
val sameModel = RandomForestClassificationModel.load("myModelPath")

这种用法适用于小型的局部模型,例如K-Means模型(用于聚类),也适用于大型分布式模型,如ALS模型(推荐使用的场景)。因为加载到的模型具有相同的参数和数据,所以即使模型部署在完全不同的Spark上也会返回相同的预测结果。

保存和加载完整的Pipelines

我们目前只讨论了保存和加载单个ML模型。在实际应用中,ML工作流程包括许多阶段,从特征提取及转换到模型的拟合和调整。MLlib提供Pipelines来帮助用户构建这些工作流程。(点击笔记获取使用ML Pipelines分析共享自行车数据集的教程。)

MLlib允许用户保存和加载整个Pipelines。我们来看一个在Pipeline上完成这些步骤的例子:

  • 特征提取:二进制转换器将图像转换为黑白图像
  • 模型拟合:Random Forest Classifier拍摄图像并预测数字0-9
  • 调整:交叉验证以调整森林中树木的深度

这是我们的笔记中生成这个管道的一个部分代码:

代码语言:javascript
复制
// Construct the Pipeline: Binarizer + Random Forest
val pipeline = new Pipeline().setStages(Array(binarizer, rf))
// Wrap the Pipeline in CrossValidator to do model tuning.
val cv = new CrossValidator().setEstimator(pipeline) … 

在我们填充这个Pipeline之前,我们将展示我们可以保存整个工作流程(在填充之前)。这个工作流程稍后可以加载到另一个在Spark集群上运行的数据集。

代码语言:javascript
复制
cv.save("myCVPath")
val sameCV = CrossValidator.load("myCVPath")

最后,我们填充Pipeline并保存,然后把它加载回来。这节省了特征提取步骤、交叉验证调整后的Random Forest模型的步骤,模型调整过程中的统计步骤。

代码语言:javascript
复制
val cvModel = cv.fit(training)
cvModel.save("myCVModelPath")
val sameCVModel = CrossValidatorModel.load("myCVModelPath") 

了解详细信息

Python调整

Spark 2.0中缺少Python的调整部分。Python目前还不支持保存和加载用于调整模型超参数的CrossValidator和TrainValidationSplit, 这个问题将在Spark 2.1(SPARK-13​​786)中进行考虑。尽管如此,我们仍可以保存Python中的CrossValidator和TrainValidationSplit的结果。例如我们使用交叉验证来调整Random Forest,然后调整过程中找到的最佳模型并保存。

代码语言:javascript
复制
Define the workflow
rf = RandomForestClassifier()
cv = CrossValidator(estimator=rf, …)
Fit the model, running Cross-Validation
cvModel = cv.fit(trainingData)
Extract the results, i.e., the best Random Forest model
bestModel = cvModel.bestModel
Save the RandomForest model
bestModel.save("rfModelPath")

点击笔记查看完整代码。

可交换的存储格式

在内部,我们将模型元数据和参数保存为JSON和Parquet格式。这些存储格式是可交换的并且可以使用其他库进行读取。我们能够使用Parquet 存储小模型(如朴素贝叶斯分类)和大型分布式模型(如推荐的ALS)。存储路径可以是任何URI支持的可以进行保存和加载的Dataset / DataFrame,还包括S3、本地存储等路径。

语言交叉兼容性

模型可以在Scala、Java和Python中轻松地进行保存和加载。R语言有两个限制,首先,R并非支持全部的MLlib模型,所以并不是所有使用其他语言训练过的模型都可以使用R语言加载。第二,R语言模型的格式还存储了额外数据,所以用其他语言加载使用R语言训练和保存后的模型有些困难(供参考的笔记本)。在不久的将来R语言将会有更好的跨语言支持。

总结

随着即将到来的2.0版本的发布,DataFrame-based的MLlib API将为持久化模型和Pipelines提供近乎全面的覆盖。持久性对于在团队之间共享模型、创建多语言ML工作流以及将模型转移到生产环境至关重要。准备将DataFrame-based的MLlib API变成Apache Spark中的机器学习的主要API是这项功能的最后一部分。

接下来?

高优先级的项目包括完整的持久性覆盖,包括Python模型调整算法以及R和其他语言API之间的兼容性改进。

使用Scala和Python的教程笔记开始。您也可以只更新您当前的MLlib工作流程以使用保存和加载功能。

实验性功能:使用在Apache Spark2.0的分支(Databricks Community Edition中的测试代码)预览版中的API。加入beta版的等待名单。

阅读更多

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 介绍
  • 学习API
    • 保存和加载单个模型
      • 保存和加载完整的Pipelines
      • 了解详细信息
        • Python调整
          • 可交换的存储格式
            • 语言交叉兼容性
            • 总结
            • 接下来?
            • 阅读更多
            相关产品与服务
            对象存储
            对象存储(Cloud Object Storage,COS)是由腾讯云推出的无目录层次结构、无数据格式限制,可容纳海量数据且支持 HTTP/HTTPS 协议访问的分布式存储服务。腾讯云 COS 的存储桶空间无容量上限,无需分区管理,适用于 CDN 数据分发、数据万象处理或大数据计算与分析的数据湖等多种场景。
            领券
            问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档