MLSQL, 专为大数据和AI设计的编程语言. General: MLSQL专为大数据和机器学习而设计。它还包括一个分布式运行时,该运行时可以在EC2,Hadoop YARN,Mesos或Kubernetes上运行,并可以访问HDFS,Alluxio,Cassandra,HBase,Hive和数百个其他数据源中的数据。 Ease Of Use: MLSQL结合了SQL和Python,使其更易于在BigData和AI中使用。如果您知道命令行,就知道MLSQL。如果您知道SQL,就会知道MLSQL。如果您了解Python,就会了解MLSQL。 Open Source: MLSQL是供所有人免费使用的,所有源代码都可以在GitHub上公开查看。
MLSQL社区希望人人都能够参与进来。开源应该是普惠的,这种普惠应该是在价值的发挥上,以及社区的参与上。我们认为积极的社区参与体现在如下点:
这些都是对社区非常有价值的贡献哦。
新手任务 我们会总计一些新手任务,他们会标记为为两个基础标签:
第二标签主要标记难度。第一个标签方便大家找到新手任务。大家可以通过这两个标签过滤出相应的ISSUE来完成新手任务。
MLSQL有两种方式让用户自定义UDF,一种是内置的,一种是随写随用。随写随用的大概如下:
register ScriptUDF.`` as fetch_es where
code='''
def apply(query:String)={
import org.apache.http.client.fluent.Request
import org.apache.http.entity.ContentType
import org.apache.http.entity.mime.{HttpMultipartMode, MultipartEntityBuilder}
val res = Request.Post(uploadUrl).connectTimeout(60 * 1000)
.socketTimeout(10 * 60 * 1000).addHeader("Content-Type", "application/json").body(query)
.execute().returnContent().asString()
res
}
''';
set load_es='''
set es_res=`select fetch_es("{}") as content` where type="sql";
load jsonStr.`es_res` as es_res_table;
''';
!load_es esTable newTable;
而内置的则是一个特殊写法的Object:
object Functions {
def uuid(uDFRegistration: UDFRegistration) = {
uDFRegistration.register("uuid", () => {
UUID.randomUUID().toString.replace("-", "")
})
}
}
比如上面的就可以这么用:
select uuid() as uuid as output;
MLSQL内置支持Spark SQL以及HIve大部分函数,同时也支持非常多的算法UDF,比如大量操作向量的函数,度量函数。
不过肯定还有更多好玩又有必要的UDF/UDAF,这一块相对来说比较容易,非常适合新手任务。贡献者可以在 tech.mlsql.udf.Functions
中添加觉得有价值的UDF函数。
MLSQL的高阶API 提出了 Train/Register/Predict 三个核心概念来完成无编码实现算法模型的训练和预测功能。
下面是一个典型的文本分类场景:
load delta.`ai_data.20newsgroups` as 20newsgroups;
-- 提取数据中的分类,因为隐含在url里。这段代码其实依赖路径位置,所以不可移植
select *, split(file,"/")[4] as label from 20newsgroups as 20newsgroups;
-- 只需要两个字段
select label,value from 20newsgroups as 20newsgroups;
-- 把label转化为数字
train 20newsgroups as StringIndex.`/ai_model/label_mapping` where
inputCol="label" and outputCol="label_num" as 20newsgroups;
-- 减少点数据集,我们只需要一半的数据集做训练
run 20newsgroups as RateSampler.`/tmp/model`
where labelCol="label_num"
and sampleRate="0.5,0.5"
as 20newsgroups;
select * from 20newsgroups where __split__=0 as 20newsgroups;
-- 使用TfIdf 模块完成对数据tf/idf处理。
train 20newsgroups as TfIdfInPlace.`/ai_model/tfidf`
where inputCol="value"
and ignoreNature="true"
-- and nGrams="2,3"
as tfTable;
select * from tfTable limit 100 as tfTable_test;
-- 接着,我们使用内置的随机森林算法做训练
train tfTable as RandomForest.`/ai_model/rf` where
-- 每次运行都保留模型版本
keepVersion="true"
-- 测试集 方便验证效果
and evaluateTable="tfTable_test"
-- 指定第一组参数
and `fitParam.0.featuresCol`="value"
and `fitParam.0.labelCol`="label_num"
and `fitParam.0.maxDepth`="2"
-- 指定第二组参数
and `fitParam.1.featuresCol`="value"
and `fitParam.1.labelCol`="label_num"
and `fitParam.1.maxDepth`="3"
;
-- 把我们学习到的经验转化为函数
register StringIndex.`/ai_model/label_mapping` as label_convert;
register TfIdfInPlace.`/ai_model/tfidf` as tfidf_convert;
register RandomForest.`/ai_model/rf` as rf_predict;
-- 构造几条测试数据
set testData = '''
God is love
OpenGL on the GPU is fast
''';
load csvStr.`testData` as testData;
select _c0 as doc from testData as testData;
-- 我们把训练时学到的东西都转换成了函数,所以可以实现对单条记录组合行数来预测。
-- 先用tfidf_convert将文本转化为向量
-- 使用rf_predict 进行预测 得到概率向量
-- 使用vec_argmax 获取最大的概率的位置
-- label_convert_r 将数字转化为标签名称
-- 完美!
select doc,label_convert_r(vec_argmax((rf_predict(tfidf_convert(doc) )))) as predicted_label from testData as output;
其中,StringIndex,RateSampler,TfIdfInPlace,RandomForest 在MLSQL中我们都叫ET(Estimator/Transoformer)。一个ET就是一个普通的class,核心有三个接口:
分别对应:
下面的代码是随机森林算法的一个完整实现:
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import streaming.dsl.mmlib._
import scala.collection.mutable.ArrayBuffer
import streaming.dsl.mmlib.algs.classfication.BaseClassification
import streaming.dsl.mmlib.algs.param.BaseParams
/**
* Created by allwefantasy on 13/1/2018.
*/
class SQLRandomForest(override val uid: String) extends SQLAlg with MllibFunctions with Functions with BaseClassification {
def this() = this(BaseParams.randomUID())
override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
val keepVersion = params.getOrElse("keepVersion", "true").toBoolean
setKeepVersion(keepVersion)
val evaluateTable = params.get("evaluateTable")
setEvaluateTable(evaluateTable.getOrElse("None"))
SQLPythonFunc.incrementVersion(path, keepVersion)
val spark = df.sparkSession
trainModelsWithMultiParamGroup[RandomForestClassificationModel](df, path, params, () => {
new RandomForestClassifier()
}, (_model, fitParam) => {
evaluateTable match {
case Some(etable) =>
val model = _model.asInstanceOf[RandomForestClassificationModel]
val evaluateTableDF = spark.table(etable)
val predictions = model.transform(evaluateTableDF)
multiclassClassificationEvaluate(predictions, (evaluator) => {
evaluator.setLabelCol(fitParam.getOrElse("labelCol", "label"))
evaluator.setPredictionCol("prediction")
})
case None => List()
}
}
)
formatOutput(getModelMetaData(spark, path))
}
override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
val model = load(df.sparkSession, path, params).asInstanceOf[ArrayBuffer[RandomForestClassificationModel]].head
model.transform(df)
}
override def explainParams(sparkSession: SparkSession): DataFrame = {
_explainParams(sparkSession, () => {
new RandomForestClassifier()
})
}
override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = {
val (bestModelPath, baseModelPath, metaPath) = mllibModelAndMetaPath(path, params, sparkSession)
val model = RandomForestClassificationModel.load(bestModelPath(0))
ArrayBuffer(model)
}
override def explainModel(sparkSession: SparkSession, path: String, params: Map[String, String]): DataFrame = {
val models = load(sparkSession, path, params).asInstanceOf[ArrayBuffer[RandomForestClassificationModel]]
val rows = models.flatMap { model =>
val modelParams = model.params.filter(param => model.isSet(param)).map { param =>
val tmp = model.get(param).get
val str = if (tmp == null) {
"null"
} else tmp.toString
Seq(("fitParam.[group]." + param.name), str)
}
Seq(
Seq("uid", model.uid),
Seq("numFeatures", model.numFeatures.toString),
Seq("numClasses", model.numClasses.toString),
Seq("numTrees", model.treeWeights.length.toString),
Seq("treeWeights", model.treeWeights.mkString(","))
) ++ modelParams
}.map(Row.fromSeq(_))
sparkSession.createDataFrame(sparkSession.sparkContext.parallelize(rows, 1),
StructType(Seq(StructField("name", StringType), StructField("value", StringType))))
}
override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = {
predict_classification(sparkSession, _model, name)
}
override def modelType: ModelType = AlgType
override def doc: Doc = Doc(HtmlDoc,
"""
| <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> learning algorithm for
| classification.
| It supports both binary and multiclass labels, as well as both continuous and categorical
| features.
|
| Use "load modelParams.`RandomForest` as output;"
|
| to check the available hyper parameters;
|
| Use "load modelExample.`RandomForest` as output;"
| get example.
|
| If you wanna check the params of model you have trained, use this command:
|
| ```
| load modelExplain.`/tmp/model` where alg="RandomForest" as outout;
| ```
|
""".stripMargin)
override def codeExample: Code = Code(SQLCode, CodeExampleText.jsonStr +
"""
|load jsonStr.`jsonStr` as data;
|select vec_dense(features) as features ,label as label from data
|as data1;
|
|-- use RandomForest
|train data1 as RandomForest.`/tmp/model` where
|
|-- once set true,every time you run this script, MLSQL will generate new directory for you model
|keepVersion="true"
|
|-- specify the test dataset which will be used to feed evaluator to generate some metrics e.g. F1, Accurate
|and evaluateTable="data1"
|
|-- specify group 0 parameters
|and `fitParam.0.labelCol`="features"
|and `fitParam.0.featuresCol`="label"
|and `fitParam.0.maxDepth`="2"
|
|-- specify group 1 parameters
|and `fitParam.1.featuresCol`="features"
|and `fitParam.1.labelCol`="label"
|and `fitParam.1.maxDepth`="10"
|;
""".stripMargin)
}
这里大家要关注两个函数,doc,codeExample。有了他们,用户可以随时获得一个ET模块的使用方式:
load modelExample.`RandomForest` as output;
就能得到这个ET的详细说明:
目前MLSQL 有大量的ET 注释不够详细,用户可以阅读相关源码,并且使用后将自己的示例或者解释补充上去。
大部分算法ET都在streaming.dsl.mmlib.algs
中.
文档其实不容易,尤其是一个好的文档,考验个人理解功底,文字能力,以及耐力。总之是一件不容易做好的事情。MLSQL的官网文档 http://docs.mlsql.tech/zh/ 的全部使用markdown完成,具体的内容在主项目的 docs/gitbook/zh 目录里。
大家可以修改文档,或者创建新的文档。无论进行修改或者创建,用户都需要更新文档的头部:
【文档更新日志:2020-04-07】
> Note: 本文档适用于MLSQL Engine 1.6.0-SNAPSHOT/1.6.0 及以上版本。
> 对应的Spark版本可支持2.3.2/2.4.3/3.0.0-preview2
也就是日期,同时也可以加上自己的名字,比如:
【文档更新日志:2020-04-07】
【威廉,Respect】
> Note: 本文档适用于MLSQL Engine 1.6.0-SNAPSHOT/1.6.0 及以上版本。
> 对应的Spark版本可支持2.3.2/2.4.3/3.0.0-preview2
MLSQL的高阶API 提出了 Train/Register/Predict 三个核心概念来完成无编码实现算法模型的训练和预测功能。
下面是一个典型的文本分类场景:
load delta.`ai_data.20newsgroups` as 20newsgroups;
-- 提取数据中的分类,因为隐含在url里。这段代码其实依赖路径位置,所以不可移植
select *, split(file,"/")[4] as label from 20newsgroups as 20newsgroups;
-- 只需要两个字段
select label,value from 20newsgroups as 20newsgroups;
-- 把label转化为数字
train 20newsgroups as StringIndex.`/ai_model/label_mapping` where
inputCol="label" and outputCol="label_num" as 20newsgroups;
-- 减少点数据集,我们只需要一半的数据集做训练
run 20newsgroups as RateSampler.`/tmp/model`
where labelCol="label_num"
and sampleRate="0.5,0.5"
as 20newsgroups;
select * from 20newsgroups where __split__=0 as 20newsgroups;
-- 使用TfIdf 模块完成对数据tf/idf处理。
train 20newsgroups as TfIdfInPlace.`/ai_model/tfidf`
where inputCol="value"
and ignoreNature="true"
-- and nGrams="2,3"
as tfTable;
select * from tfTable limit 100 as tfTable_test;
-- 接着,我们使用内置的随机森林算法做训练
train tfTable as RandomForest.`/ai_model/rf` where
-- 每次运行都保留模型版本
keepVersion="true"
-- 测试集 方便验证效果
and evaluateTable="tfTable_test"
-- 指定第一组参数
and `fitParam.0.featuresCol`="value"
and `fitParam.0.labelCol`="label_num"
and `fitParam.0.maxDepth`="2"
-- 指定第二组参数
and `fitParam.1.featuresCol`="value"
and `fitParam.1.labelCol`="label_num"
and `fitParam.1.maxDepth`="3"
;
-- 把我们学习到的经验转化为函数
register StringIndex.`/ai_model/label_mapping` as label_convert;
register TfIdfInPlace.`/ai_model/tfidf` as tfidf_convert;
register RandomForest.`/ai_model/rf` as rf_predict;
-- 构造几条测试数据
set testData = '''
God is love
OpenGL on the GPU is fast
''';
load csvStr.`testData` as testData;
select _c0 as doc from testData as testData;
-- 我们把训练时学到的东西都转换成了函数,所以可以实现对单条记录组合行数来预测。
-- 先用tfidf_convert将文本转化为向量
-- 使用rf_predict 进行预测 得到概率向量
-- 使用vec_argmax 获取最大的概率的位置
-- label_convert_r 将数字转化为标签名称
-- 完美!
select doc,label_convert_r(vec_argmax((rf_predict(tfidf_convert(doc) )))) as predicted_label from testData as output;
其中,StringIndex,RateSampler,TfIdfInPlace,RandomForest 在MLSQL中我们都叫ET(Estimator/Transoformer)。一个ET就是一个普通的class,核心有三个接口:
分别对应:
下面的代码是随机森林算法的一个完整实现:
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import streaming.dsl.mmlib._
import scala.collection.mutable.ArrayBuffer
import streaming.dsl.mmlib.algs.classfication.BaseClassification
import streaming.dsl.mmlib.algs.param.BaseParams
/**
* Created by allwefantasy on 13/1/2018.
*/
class SQLRandomForest(override val uid: String) extends SQLAlg with MllibFunctions with Functions with BaseClassification {
def this() = this(BaseParams.randomUID())
override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
val keepVersion = params.getOrElse("keepVersion", "true").toBoolean
setKeepVersion(keepVersion)
val evaluateTable = params.get("evaluateTable")
setEvaluateTable(evaluateTable.getOrElse("None"))
SQLPythonFunc.incrementVersion(path, keepVersion)
val spark = df.sparkSession
trainModelsWithMultiParamGroup[RandomForestClassificationModel](df, path, params, () => {
new RandomForestClassifier()
}, (_model, fitParam) => {
evaluateTable match {
case Some(etable) =>
val model = _model.asInstanceOf[RandomForestClassificationModel]
val evaluateTableDF = spark.table(etable)
val predictions = model.transform(evaluateTableDF)
multiclassClassificationEvaluate(predictions, (evaluator) => {
evaluator.setLabelCol(fitParam.getOrElse("labelCol", "label"))
evaluator.setPredictionCol("prediction")
})
case None => List()
}
}
)
formatOutput(getModelMetaData(spark, path))
}
override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
val model = load(df.sparkSession, path, params).asInstanceOf[ArrayBuffer[RandomForestClassificationModel]].head
model.transform(df)
}
override def explainParams(sparkSession: SparkSession): DataFrame = {
_explainParams(sparkSession, () => {
new RandomForestClassifier()
})
}
override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = {
val (bestModelPath, baseModelPath, metaPath) = mllibModelAndMetaPath(path, params, sparkSession)
val model = RandomForestClassificationModel.load(bestModelPath(0))
ArrayBuffer(model)
}
override def explainModel(sparkSession: SparkSession, path: String, params: Map[String, String]): DataFrame = {
val models = load(sparkSession, path, params).asInstanceOf[ArrayBuffer[RandomForestClassificationModel]]
val rows = models.flatMap { model =>
val modelParams = model.params.filter(param => model.isSet(param)).map { param =>
val tmp = model.get(param).get
val str = if (tmp == null) {
"null"
} else tmp.toString
Seq(("fitParam.[group]." + param.name), str)
}
Seq(
Seq("uid", model.uid),
Seq("numFeatures", model.numFeatures.toString),
Seq("numClasses", model.numClasses.toString),
Seq("numTrees", model.treeWeights.length.toString),
Seq("treeWeights", model.treeWeights.mkString(","))
) ++ modelParams
}.map(Row.fromSeq(_))
sparkSession.createDataFrame(sparkSession.sparkContext.parallelize(rows, 1),
StructType(Seq(StructField("name", StringType), StructField("value", StringType))))
}
override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = {
predict_classification(sparkSession, _model, name)
}
override def modelType: ModelType = AlgType
override def doc: Doc = Doc(HtmlDoc,
"""
| <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> learning algorithm for
| classification.
| It supports both binary and multiclass labels, as well as both continuous and categorical
| features.
|
| Use "load modelParams.`RandomForest` as output;"
|
| to check the available hyper parameters;
|
| Use "load modelExample.`RandomForest` as output;"
| get example.
|
| If you wanna check the params of model you have trained, use this command:
|
| ```
| load modelExplain.`/tmp/model` where alg="RandomForest" as outout;
| ```
|
""".stripMargin)
override def codeExample: Code = Code(SQLCode, CodeExampleText.jsonStr +
"""
|load jsonStr.`jsonStr` as data;
|select vec_dense(features) as features ,label as label from data
|as data1;
|
|-- use RandomForest
|train data1 as RandomForest.`/tmp/model` where
|
|-- once set true,every time you run this script, MLSQL will generate new directory for you model
|keepVersion="true"
|
|-- specify the test dataset which will be used to feed evaluator to generate some metrics e.g. F1, Accurate
|and evaluateTable="data1"
|
|-- specify group 0 parameters
|and `fitParam.0.labelCol`="features"
|and `fitParam.0.featuresCol`="label"
|and `fitParam.0.maxDepth`="2"
|
|-- specify group 1 parameters
|and `fitParam.1.featuresCol`="features"
|and `fitParam.1.labelCol`="label"
|and `fitParam.1.maxDepth`="10"
|;
""".stripMargin)
}
大部分算法ET都在streaming.dsl.mmlib.algs
中.
MLSQL有一个插件项目:https://github.com/allwefantasy/mlsql-plugins 里面已经有不少插件,比如典型的有支持excel的数据源插件, 支持bigdl的深度学习插件,还有给表做分区的插件等等。
我们也提供了一些插件文档 http://docs.mlsql.tech/zh/develop/et.html 的说明。通常,我们可以对存储,ET, API接口都可以进行插件扩展,从而使得MLSQL更加适合用户的需求。
这里简单描述下扩展点。
4.1 数据扩展点
load excel.`/tmp/jack` where useHeader="true" as excelTable;
save overwrite as excel.`/tmp/jack` where useHeader="true";
数据存储的加载和保存格式,都是可以扩展并且插件化的。比如 这里的excel的读取和写入就属于插件提供的功能。而且如果已经有兼容Spark的数据源,那么扩展起来就会非常简单,比如excel的代码如下:
package tech.mlsql.plugins.ds
import org.apache.spark.ml.param.Param
import streaming.core.datasource._
import streaming.dsl.mmlib.algs.param.{BaseParams, WowParams}
import streaming.dsl.{ConnectMeta, DBMappingKey}
import org.apache.spark.sql._
import tech.mlsql.version.VersionCompatibility
/**
* Created by latincross on 12/29/2018.
*/
class MLSQLExcel(override val uid: String)
extends MLSQLSource
with MLSQLSink
with MLSQLSourceInfo
with MLSQLRegistry
with VersionCompatibility
with WowParams {
def this() = this(BaseParams.randomUID())
override def fullFormat: String = "tech.mlsql.plugins.ds"
override def shortFormat: String = "excel"
override def dbSplitter: String = ":"
override def load(reader: DataFrameReader, config: DataSourceConfig): DataFrame = {
val Array(_dbname, _dbtable) = if (config.path.contains(dbSplitter)) {
config.path.split(dbSplitter, 2)
} else {
Array("", config.path)
}
var namespace = ""
val format = config.config.getOrElse("implClass", fullFormat)
if (_dbname != "") {
ConnectMeta.presentThenCall(DBMappingKey(format, _dbname), options => {
if (options.contains("namespace")) {
namespace = options("namespace")
}
reader.options(options)
})
}
if (config.config.contains("namespace")) {
namespace = config.config("namespace")
}
val inputTableName = if (namespace == "") _dbtable else s"${namespace}:${_dbtable}"
reader.option("inputTableName", inputTableName)
//load configs should overwrite connect configs
reader.options(config.config)
reader.format(format).load()
}
override def save(writer: DataFrameWriter[Row], config: DataSinkConfig): Unit = {
val Array(_dbname, _dbtable) = if (config.path.contains(dbSplitter)) {
config.path.split(dbSplitter, 2)
} else {
Array("", config.path)
}
var namespace = ""
val format = config.config.getOrElse("implClass", fullFormat)
if (_dbname != "") {
ConnectMeta.presentThenCall(DBMappingKey(format, _dbname), options => {
if (options.contains("namespace")) {
namespace = options.get("namespace").get
}
writer.options(options)
})
}
if (config.config.contains("namespace")) {
namespace = config.config.get("namespace").get
}
val outputTableName = if (namespace == "") _dbtable else s"${namespace}:${_dbtable}"
writer.mode(config.mode)
writer.option("outputTableName", outputTableName)
//load configs should overwrite connect configs
writer.options(config.config)
config.config.get("partitionByCol").map { item =>
writer.partitionBy(item.split(","): _*)
}
writer.format(config.config.getOrElse("implClass", fullFormat)).save()
}
override def register(): Unit = {
DataSourceRegistry.register(MLSQLDataSourceKey(fullFormat, MLSQLSparkDataSourceType), this)
DataSourceRegistry.register(MLSQLDataSourceKey(shortFormat, MLSQLSparkDataSourceType), this)
}
override def sourceInfo(config: DataAuthConfig): SourceInfo = {
val format = config.config.getOrElse("implClass", fullFormat)
val Array(connect, namespace, table) = if (config.path.contains(dbSplitter)) {
config.path.split(dbSplitter) match {
case Array(connect, namespace, table) => Array(connect, namespace, table)
case Array(connectOrNameSpace, table) =>
ConnectMeta.presentThenCall(DBMappingKey(format, connectOrNameSpace), (op) => {}) match {
case Some(i) => Array(connectOrNameSpace, "", table)
case None => Array("", connectOrNameSpace, table)
}
case Array(connect, namespace, table, _*) => Array(connect, namespace, table)
}
} else {
Array("", "", config.path)
}
var finalNameSpace = config.config.getOrElse("namespace", namespace)
ConnectMeta.presentThenCall(DBMappingKey(format, connect), (options) => {
if (options.contains("namespace")) {
finalNameSpace = options.get("namespace").get
}
})
SourceInfo(shortFormat, finalNameSpace, table)
}
override def explainParams(spark: SparkSession) = {
_explainParams(spark)
}
final val zk: Param[String] = new Param[String](this, "zk", "zk address")
final val family: Param[String] = new Param[String](this, "family", "default cf")
override def supportedVersions: Seq[String] = {
Seq("1.5.0-SNAPSHOT", "1.5.0", "1.6.0-SNAPSHOT", "1.6.0")
}
}
4.2 ET扩展点
这个我们之前在说算法的时候就已经讨论过了。比如我还可以实现一个对临时表进行重新设置并行度的功能,代码是这样的:
package tech.mlsql.plugins.et
import org.apache.spark.ml.param.{IntParam, Param}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.mlsql.session.MLSQLException
import org.apache.spark.sql.{DataFrame, SparkSession, functions => F}
import streaming.dsl.auth.TableAuthResult
import streaming.dsl.mmlib._
import streaming.dsl.mmlib.algs.Functions
import streaming.dsl.mmlib.algs.param.{BaseParams, WowParams}
import tech.mlsql.dsl.auth.ETAuth
import tech.mlsql.dsl.auth.dsl.mmlib.ETMethod.ETMethod
import tech.mlsql.version.VersionCompatibility
class TableRepartition(override val uid: String) extends SQLAlg with VersionCompatibility with Functions with WowParams with ETAuth {
def this() = this(BaseParams.randomUID())
//
override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
params.get(partitionNum.name).map { item =>
set(partitionNum, item.toInt)
item
}.getOrElse {
throw new MLSQLException(s"${partitionNum.name} is required")
}
params.get(partitionType.name).map { item =>
set(partitionType, item)
item
}.getOrElse {
set(partitionType, "hash")
}
params.get(partitionCols.name).map { item =>
set(partitionCols, item)
item
}.getOrElse {
set(partitionCols, "")
}
$(partitionType) match {
case "range" =>
require(params.contains(partitionCols.name), "At least one partition-by expression must be specified.")
df.repartitionByRange($(partitionNum), $(partitionCols).split(",").map(name => F.col(name)): _*)
case _ =>
df.repartition($(partitionNum))
}
}
override def auth(etMethod: ETMethod, path: String, params: Map[String, String]): List[TableAuthResult] = {
List()
}
override def supportedVersions: Seq[String] = {
Seq("1.5.0-SNAPSHOT", "1.5.0", "1.6.0-SNAPSHOT", "1.6.0")
}
override def doc: Doc = Doc(MarkDownDoc,
s"""
|
""".stripMargin)
override def codeExample: Code = Code(SQLCode,
"""
|
""".stripMargin)
override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = train(df, path, params)
override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = ???
override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = ???
final val partitionNum: IntParam = new IntParam(this, "partitionNum",
"")
final val partitionType: Param[String] = new Param[String](this, "partitionType",
"")
final val partitionCols: Param[String] = new Param[String](this, "partitionCols",
"")
override def explainParams(sparkSession: SparkSession): DataFrame = _explainParams(sparkSession)
}
和算法ET的形态是一模一样的。
4.3 APP扩展点
有的时候你可能需要提供新的API接口,以及一大批ET,DataSource功能,每个都作为插件显然不现实。我们这个时候就需要打个包了。通过App扩展点可以很好的满足这一点。下面是BigDL的:
package tech.mlsql.plugins.bigdl
import tech.mlsql.ets.register.ETRegister
import tech.mlsql.version.VersionCompatibility
/**
* 5/4/2020 WilliamZhu(allwefantasy@gmail.com)
*/
class BigDLApp extends tech.mlsql.app.App with VersionCompatibility {
override def run(args: Seq[String]): Unit = {
ETRegister.register("ImageLoaderExt", classOf[SQLImageLoaderExt].getName)
ETRegister.register("MnistLoaderExt", classOf[SQLMnistLoaderExt].getName)
ETRegister.register("BigDLClassifyExt", classOf[SQLBigDLClassifyExt].getName)
ETRegister.register("LeNet5Ext", classOf[SQLLeNet5Ext].getName)
}
override def supportedVersions: Seq[String] = Seq("1.5.0-SNAPSHOT", "1.5.0", "1.6.0-SNAPSHOT", "1.6.0")
}
我们看到,BigDL其实就是注册了一大堆的ET组件,从而完成图片处理,深度学习,已经一些现成的模型。