前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >如何参与MLSQL社区

如何参与MLSQL社区

作者头像
用户2936994
发布2022-07-21 14:01:12
4920
发布2022-07-21 14:01:12
举报
文章被收录于专栏:祝威廉祝威廉

为了方便Github访问比较慢的用户查阅该文章,我们同步了一份在知乎专栏。 Gitub地址为: https://github.com/allwefantasy/mlsql/issues/1326

初衷

MLSQL社区希望人人都能够参与进来。开源应该是普惠的,这种普惠应该是在价值的发挥上,以及社区的参与上。

我们认为积极的社区参与体现在如下点:

  • 积极帮助社区用户解答技术问题
  • 帮助项目增加测试​
  • 完善技术文档
  • 提供有价值的 Issue
  • 报告或修复未知 / 已知的 Bug​
  • 为项目撰写源码分析、实用案例相关的文章​

这些都是对社区非常有价值的贡献哦。

新手任务

我们会总计一些新手任务,他们会标记为为两个基础标签:

  • for new contributors
  • Difficulty:Low, Difficulty:Middle, Difficulty:High

第二标签主要标记难度。第一个标签方便大家找到新手任务。大家可以通过这两个标签过滤出相应的ISSUE来完成新手任务。

贡献内置UDF函数 (任务难度:低)

MLSQL有两种方式让用户自定义UDF,一种是内置的,一种是随写随用。随写随用的大概如下:

代码语言:javascript
复制
register ScriptUDF.`` as fetch_es where
code='''
def apply(query:String)={
  import org.apache.http.client.fluent.Request
  import org.apache.http.entity.ContentType
  import org.apache.http.entity.mime.{HttpMultipartMode, MultipartEntityBuilder}
  
  val res = Request.Post(uploadUrl).connectTimeout(60 * 1000)
          .socketTimeout(10 * 60 * 1000).addHeader("Content-Type", "application/json").body(query)
          .execute().returnContent().asString()
  res        
}
''';

set load_es='''
  set es_res=`select fetch_es("{}") as content` where type="sql";
  load jsonStr.`es_res` as es_res_table;
''';

!load_es esTable newTable;

而内置的则是一个特殊写法的Object:

代码语言:javascript
复制
object Functions {
 def uuid(uDFRegistration: UDFRegistration) = {
    uDFRegistration.register("uuid", () => {
      UUID.randomUUID().toString.replace("-", "")
    })
  }
}

比如上面的就可以这么用:

代码语言:javascript
复制
select uuid() as uuid as output;

MLSQL内置支持Spark SQL以及HIve大部分函数,同时也支持非常多的算法UDF,比如大量操作向量的函数,度量函数。

不过肯定还有更多好玩又有必要的UDF/UDAF,这一块相对来说比较容易,非常适合新手任务。贡献者可以在 tech.mlsql.udf.Functions 中添加觉得有价值的UDF函数。

贡献文档 (任务难度:中)

文档其实不容易,尤其是一个好的文档,考验个人的使用,理解,文字,以及耐力。总之是一件不容易做好的事情。 MLSQL的官网文档 http://docs.mlsql.tech/zh/ 全部使用markdown完成,具体的内容在主项目的 docs/gitbook/zh 目录里。

大家可以修改文档,或者创建新的文档。无论进行修改或者创建,用户都需要更新文档的头部:

代码语言:javascript
复制
【文档更新日志:2020-04-07】

> Note: 本文档适用于MLSQL Engine 1.6.0-SNAPSHOT/1.6.0 及以上版本。  
> 对应的Spark版本可支持2.3.2/2.4.3/3.0.0-preview2

也就是日期,同时也可以加上自己的名字,比如:

代码语言:javascript
复制
【文档更新日志:2020-04-07】
【威廉,Respect】

> Note: 本文档适用于MLSQL Engine 1.6.0-SNAPSHOT/1.6.0 及以上版本。  
> 对应的Spark版本可支持2.3.2/2.4.3/3.0.0-preview2

贡献算法模块内置Help (任务难度:低)

MLSQL的高阶API 提出了 Train/Register/Predict 三个核心概念来完成无编码实现算法模型的训练和预测功能。

下面是一个典型的文本分类场景:

代码语言:javascript
复制
load delta.`ai_data.20newsgroups` as 20newsgroups;


-- 提取数据中的分类,因为隐含在url里。这段代码其实依赖路径位置,所以不可移植
select *, split(file,"/")[4] as label from 20newsgroups as 20newsgroups;

-- 只需要两个字段
select label,value from 20newsgroups as 20newsgroups;

-- 把label转化为数字
train 20newsgroups as StringIndex.`/ai_model/label_mapping` where 
inputCol="label" and outputCol="label_num" as 20newsgroups;

-- 减少点数据集,我们只需要一半的数据集做训练
run 20newsgroups as RateSampler.`/tmp/model` 
where labelCol="label_num"
and sampleRate="0.5,0.5" 
as 20newsgroups;

select * from 20newsgroups where __split__=0 as 20newsgroups;


-- 使用TfIdf 模块完成对数据tf/idf处理。

train 20newsgroups as TfIdfInPlace.`/ai_model/tfidf`
where inputCol="value"
and ignoreNature="true"
-- and nGrams="2,3"
as tfTable;

select * from tfTable limit 100 as tfTable_test;

-- 接着,我们使用内置的随机森林算法做训练
train tfTable as RandomForest.`/ai_model/rf` where

-- 每次运行都保留模型版本
keepVersion="true" 

-- 测试集 方便验证效果
and evaluateTable="tfTable_test"

-- 指定第一组参数
and `fitParam.0.featuresCol`="value"
and `fitParam.0.labelCol`="label_num"
and `fitParam.0.maxDepth`="2"

-- 指定第二组参数
and `fitParam.1.featuresCol`="value"
and `fitParam.1.labelCol`="label_num"
and `fitParam.1.maxDepth`="3"
;

-- 把我们学习到的经验转化为函数
register StringIndex.`/ai_model/label_mapping` as label_convert;
register TfIdfInPlace.`/ai_model/tfidf` as tfidf_convert;
register RandomForest.`/ai_model/rf`  as rf_predict;

-- 构造几条测试数据
set testData = '''
God is love
OpenGL on the GPU is fast
''';

load csvStr.`testData` as testData;

select _c0 as doc from testData as testData;

-- 我们把训练时学到的东西都转换成了函数,所以可以实现对单条记录组合行数来预测。
-- 先用tfidf_convert将文本转化为向量
-- 使用rf_predict 进行预测 得到概率向量
-- 使用vec_argmax 获取最大的概率的位置
-- label_convert_r 将数字转化为标签名称
-- 完美!
select doc,label_convert_r(vec_argmax((rf_predict(tfidf_convert(doc) )))) as predicted_label  from testData as output;

其中,StringIndex,RateSampler,TfIdfInPlace,RandomForest 在MLSQL中我们都叫ET(Estimator/Transoformer)。 一个ET就是一个普通的class,核心有三个接口:

  1. train
  2. load
  3. predict

分别对应:

  1. 训练
  2. 加载模型
  3. 将模型转化为函数

下面的代码是随机森林算法的一个完整实现:

代码语言:javascript
复制
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import streaming.dsl.mmlib._

import scala.collection.mutable.ArrayBuffer
import streaming.dsl.mmlib.algs.classfication.BaseClassification
import streaming.dsl.mmlib.algs.param.BaseParams

/**
  * Created by allwefantasy on 13/1/2018.
  */
class SQLRandomForest(override val uid: String) extends SQLAlg with MllibFunctions with Functions with BaseClassification {

  def this() = this(BaseParams.randomUID())

  override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {


    val keepVersion = params.getOrElse("keepVersion", "true").toBoolean
    setKeepVersion(keepVersion)

    val evaluateTable = params.get("evaluateTable")
    setEvaluateTable(evaluateTable.getOrElse("None"))

    SQLPythonFunc.incrementVersion(path, keepVersion)
    val spark = df.sparkSession

    trainModelsWithMultiParamGroup[RandomForestClassificationModel](df, path, params, () => {
      new RandomForestClassifier()
    }, (_model, fitParam) => {
      evaluateTable match {
        case Some(etable) =>
          val model = _model.asInstanceOf[RandomForestClassificationModel]
          val evaluateTableDF = spark.table(etable)
          val predictions = model.transform(evaluateTableDF)
          multiclassClassificationEvaluate(predictions, (evaluator) => {
            evaluator.setLabelCol(fitParam.getOrElse("labelCol", "label"))
            evaluator.setPredictionCol("prediction")
          })

        case None => List()
      }
    }
    )

    formatOutput(getModelMetaData(spark, path))
  }


  override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
    val model = load(df.sparkSession, path, params).asInstanceOf[ArrayBuffer[RandomForestClassificationModel]].head
    model.transform(df)
  }

  override def explainParams(sparkSession: SparkSession): DataFrame = {
    _explainParams(sparkSession, () => {
      new RandomForestClassifier()
    })
  }

  override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = {

    val (bestModelPath, baseModelPath, metaPath) = mllibModelAndMetaPath(path, params, sparkSession)
    val model = RandomForestClassificationModel.load(bestModelPath(0))
    ArrayBuffer(model)
  }


  override def explainModel(sparkSession: SparkSession, path: String, params: Map[String, String]): DataFrame = {
    val models = load(sparkSession, path, params).asInstanceOf[ArrayBuffer[RandomForestClassificationModel]]
    val rows = models.flatMap { model =>
      val modelParams = model.params.filter(param => model.isSet(param)).map { param =>
        val tmp = model.get(param).get
        val str = if (tmp == null) {
          "null"
        } else tmp.toString
        Seq(("fitParam.[group]." + param.name), str)
      }
      Seq(
        Seq("uid", model.uid),
        Seq("numFeatures", model.numFeatures.toString),
        Seq("numClasses", model.numClasses.toString),
        Seq("numTrees", model.treeWeights.length.toString),
        Seq("treeWeights", model.treeWeights.mkString(","))
      ) ++ modelParams
    }.map(Row.fromSeq(_))
    sparkSession.createDataFrame(sparkSession.sparkContext.parallelize(rows, 1),
      StructType(Seq(StructField("name", StringType), StructField("value", StringType))))
  }

  override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = {
    predict_classification(sparkSession, _model, name)
  }

  override def modelType: ModelType = AlgType

  override def doc: Doc = Doc(HtmlDoc,
    """
      | <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> learning algorithm for
      | classification.
      | It supports both binary and multiclass labels, as well as both continuous and categorical
      | features.
      |
      | Use "load modelParams.`RandomForest` as output;"
      |
      | to check the available hyper parameters;
      |
      | Use "load modelExample.`RandomForest` as output;"
      | get example.
      |
      | If you wanna check the params of model you have trained, use this command:
      |
      | ```
      | load modelExplain.`/tmp/model` where alg="RandomForest" as outout;
      | ```
      |
    """.stripMargin)


  override def codeExample: Code = Code(SQLCode, CodeExampleText.jsonStr +
    """
      |load jsonStr.`jsonStr` as data;
      |select vec_dense(features) as features ,label as label from data
      |as data1;
      |
      |-- use RandomForest
      |train data1 as RandomForest.`/tmp/model` where
      |
      |-- once set true,every time you run this script, MLSQL will generate new directory for you model
      |keepVersion="true"
      |
      |-- specify the test dataset which will be used to feed evaluator to generate some metrics e.g. F1, Accurate
      |and evaluateTable="data1"
      |
      |-- specify group 0 parameters
      |and `fitParam.0.labelCol`="features"
      |and `fitParam.0.featuresCol`="label"
      |and `fitParam.0.maxDepth`="2"
      |
      |-- specify group 1 parameters
      |and `fitParam.1.featuresCol`="features"
      |and `fitParam.1.labelCol`="label"
      |and `fitParam.1.maxDepth`="10"
      |;
    """.stripMargin)

}

这里大家要关注两个函数,doc,codeExample。 有了他们,用户可以随时获得一个ET模块的使用方式:

代码语言:javascript
复制
load modelExample.`RandomForest` as output;

就能得到这个ET的详细说明:

目前MLSQL 有大量的ET 注释不够详细,用户可以阅读相关源码,并且使用后将自己的示例或者解释补充上去。

大部分算法ET都在streaming.dsl.mmlib.algs中.

贡献算法模块 (任务难度:中)

MLSQL的高阶API 提出了 Train/Register/Predict 三个核心概念来完成无编码实现算法模型的训练和预测功能。

下面是一个典型的文本分类场景:

代码语言:javascript
复制
load delta.`ai_data.20newsgroups` as 20newsgroups;


-- 提取数据中的分类,因为隐含在url里。这段代码其实依赖路径位置,所以不可移植
select *, split(file,"/")[4] as label from 20newsgroups as 20newsgroups;

-- 只需要两个字段
select label,value from 20newsgroups as 20newsgroups;

-- 把label转化为数字
train 20newsgroups as StringIndex.`/ai_model/label_mapping` where 
inputCol="label" and outputCol="label_num" as 20newsgroups;

-- 减少点数据集,我们只需要一半的数据集做训练
run 20newsgroups as RateSampler.`/tmp/model` 
where labelCol="label_num"
and sampleRate="0.5,0.5" 
as 20newsgroups;

select * from 20newsgroups where __split__=0 as 20newsgroups;


-- 使用TfIdf 模块完成对数据tf/idf处理。

train 20newsgroups as TfIdfInPlace.`/ai_model/tfidf`
where inputCol="value"
and ignoreNature="true"
-- and nGrams="2,3"
as tfTable;

select * from tfTable limit 100 as tfTable_test;

-- 接着,我们使用内置的随机森林算法做训练
train tfTable as RandomForest.`/ai_model/rf` where

-- 每次运行都保留模型版本
keepVersion="true" 

-- 测试集 方便验证效果
and evaluateTable="tfTable_test"

-- 指定第一组参数
and `fitParam.0.featuresCol`="value"
and `fitParam.0.labelCol`="label_num"
and `fitParam.0.maxDepth`="2"

-- 指定第二组参数
and `fitParam.1.featuresCol`="value"
and `fitParam.1.labelCol`="label_num"
and `fitParam.1.maxDepth`="3"
;

-- 把我们学习到的经验转化为函数
register StringIndex.`/ai_model/label_mapping` as label_convert;
register TfIdfInPlace.`/ai_model/tfidf` as tfidf_convert;
register RandomForest.`/ai_model/rf`  as rf_predict;

-- 构造几条测试数据
set testData = '''
God is love
OpenGL on the GPU is fast
''';

load csvStr.`testData` as testData;

select _c0 as doc from testData as testData;

-- 我们把训练时学到的东西都转换成了函数,所以可以实现对单条记录组合行数来预测。
-- 先用tfidf_convert将文本转化为向量
-- 使用rf_predict 进行预测 得到概率向量
-- 使用vec_argmax 获取最大的概率的位置
-- label_convert_r 将数字转化为标签名称
-- 完美!
select doc,label_convert_r(vec_argmax((rf_predict(tfidf_convert(doc) )))) as predicted_label  from testData as output;

其中,StringIndex,RateSampler,TfIdfInPlace,RandomForest 在MLSQL中我们都叫ET(Estimator/Transoformer)。 一个ET就是一个普通的class,核心有三个接口:

  1. train
  2. load
  3. predict

分别对应:

  1. 训练
  2. 加载模型
  3. 将模型转化为函数

下面的代码是随机森林算法的一个完整实现:

代码语言:javascript
复制
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import streaming.dsl.mmlib._

import scala.collection.mutable.ArrayBuffer
import streaming.dsl.mmlib.algs.classfication.BaseClassification
import streaming.dsl.mmlib.algs.param.BaseParams

/**
  * Created by allwefantasy on 13/1/2018.
  */
class SQLRandomForest(override val uid: String) extends SQLAlg with MllibFunctions with Functions with BaseClassification {

  def this() = this(BaseParams.randomUID())

  override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {


    val keepVersion = params.getOrElse("keepVersion", "true").toBoolean
    setKeepVersion(keepVersion)

    val evaluateTable = params.get("evaluateTable")
    setEvaluateTable(evaluateTable.getOrElse("None"))

    SQLPythonFunc.incrementVersion(path, keepVersion)
    val spark = df.sparkSession

    trainModelsWithMultiParamGroup[RandomForestClassificationModel](df, path, params, () => {
      new RandomForestClassifier()
    }, (_model, fitParam) => {
      evaluateTable match {
        case Some(etable) =>
          val model = _model.asInstanceOf[RandomForestClassificationModel]
          val evaluateTableDF = spark.table(etable)
          val predictions = model.transform(evaluateTableDF)
          multiclassClassificationEvaluate(predictions, (evaluator) => {
            evaluator.setLabelCol(fitParam.getOrElse("labelCol", "label"))
            evaluator.setPredictionCol("prediction")
          })

        case None => List()
      }
    }
    )

    formatOutput(getModelMetaData(spark, path))
  }


  override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
    val model = load(df.sparkSession, path, params).asInstanceOf[ArrayBuffer[RandomForestClassificationModel]].head
    model.transform(df)
  }

  override def explainParams(sparkSession: SparkSession): DataFrame = {
    _explainParams(sparkSession, () => {
      new RandomForestClassifier()
    })
  }

  override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = {

    val (bestModelPath, baseModelPath, metaPath) = mllibModelAndMetaPath(path, params, sparkSession)
    val model = RandomForestClassificationModel.load(bestModelPath(0))
    ArrayBuffer(model)
  }


  override def explainModel(sparkSession: SparkSession, path: String, params: Map[String, String]): DataFrame = {
    val models = load(sparkSession, path, params).asInstanceOf[ArrayBuffer[RandomForestClassificationModel]]
    val rows = models.flatMap { model =>
      val modelParams = model.params.filter(param => model.isSet(param)).map { param =>
        val tmp = model.get(param).get
        val str = if (tmp == null) {
          "null"
        } else tmp.toString
        Seq(("fitParam.[group]." + param.name), str)
      }
      Seq(
        Seq("uid", model.uid),
        Seq("numFeatures", model.numFeatures.toString),
        Seq("numClasses", model.numClasses.toString),
        Seq("numTrees", model.treeWeights.length.toString),
        Seq("treeWeights", model.treeWeights.mkString(","))
      ) ++ modelParams
    }.map(Row.fromSeq(_))
    sparkSession.createDataFrame(sparkSession.sparkContext.parallelize(rows, 1),
      StructType(Seq(StructField("name", StringType), StructField("value", StringType))))
  }

  override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = {
    predict_classification(sparkSession, _model, name)
  }

  override def modelType: ModelType = AlgType

  override def doc: Doc = Doc(HtmlDoc,
    """
      | <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> learning algorithm for
      | classification.
      | It supports both binary and multiclass labels, as well as both continuous and categorical
      | features.
      |
      | Use "load modelParams.`RandomForest` as output;"
      |
      | to check the available hyper parameters;
      |
      | Use "load modelExample.`RandomForest` as output;"
      | get example.
      |
      | If you wanna check the params of model you have trained, use this command:
      |
      | ```
      | load modelExplain.`/tmp/model` where alg="RandomForest" as outout;
      | ```
      |
    """.stripMargin)


  override def codeExample: Code = Code(SQLCode, CodeExampleText.jsonStr +
    """
      |load jsonStr.`jsonStr` as data;
      |select vec_dense(features) as features ,label as label from data
      |as data1;
      |
      |-- use RandomForest
      |train data1 as RandomForest.`/tmp/model` where
      |
      |-- once set true,every time you run this script, MLSQL will generate new directory for you model
      |keepVersion="true"
      |
      |-- specify the test dataset which will be used to feed evaluator to generate some metrics e.g. F1, Accurate
      |and evaluateTable="data1"
      |
      |-- specify group 0 parameters
      |and `fitParam.0.labelCol`="features"
      |and `fitParam.0.featuresCol`="label"
      |and `fitParam.0.maxDepth`="2"
      |
      |-- specify group 1 parameters
      |and `fitParam.1.featuresCol`="features"
      |and `fitParam.1.labelCol`="label"
      |and `fitParam.1.maxDepth`="10"
      |;
    """.stripMargin)

}

大部分算法ET都在streaming.dsl.mmlib.algs中.

贡献插件 (任务难度:高)

MLSQL有一个插件项目:https://github.com/allwefantasy/mlsql-plugins 里面已经有不少插件,比如典型的有支持excel的数据源插件, 支持bigdl的深度学习插件,还有给表做分区的插件等等。

我们也提供了一些插件文档 http://docs.mlsql.tech/zh/develop/et.html 的说明。通常,我们可以对存储,ET, API接口都可以进行插件扩展,从而使得MLSQL更加适合用户的需求。

这里简单描述下扩展点。

数据源扩展点

代码语言:javascript
复制
load excel.`/tmp/jack` where useHeader="true" as excelTable;
save overwrite as  excel.`/tmp/jack` where useHeader="true";

数据存储的加载和保存格式,都是可以扩展并且插件化的。比如 这里的excel的读取和写入就属于插件提供的功能。而且如果已经有兼容Spark的数据源,那么扩展起来就会非常简单,比如excel的代码如下:

代码语言:javascript
复制
package tech.mlsql.plugins.ds

import org.apache.spark.ml.param.Param
import streaming.core.datasource._
import streaming.dsl.mmlib.algs.param.{BaseParams, WowParams}
import streaming.dsl.{ConnectMeta, DBMappingKey}
import org.apache.spark.sql._
import tech.mlsql.version.VersionCompatibility

/**
 * Created by latincross on 12/29/2018.
 */
class MLSQLExcel(override val uid: String)
  extends MLSQLSource
    with MLSQLSink
    with MLSQLSourceInfo
    with MLSQLRegistry
    with VersionCompatibility
    with WowParams {
  def this() = this(BaseParams.randomUID())


  override def fullFormat: String = "tech.mlsql.plugins.ds"

  override def shortFormat: String = "excel"

  override def dbSplitter: String = ":"

  override def load(reader: DataFrameReader, config: DataSourceConfig): DataFrame = {
    val Array(_dbname, _dbtable) = if (config.path.contains(dbSplitter)) {
      config.path.split(dbSplitter, 2)
    } else {
      Array("", config.path)
    }

    var namespace = ""

    val format = config.config.getOrElse("implClass", fullFormat)
    if (_dbname != "") {
      ConnectMeta.presentThenCall(DBMappingKey(format, _dbname), options => {
        if (options.contains("namespace")) {
          namespace = options("namespace")
        }
        reader.options(options)
      })
    }

    if (config.config.contains("namespace")) {
      namespace = config.config("namespace")
    }

    val inputTableName = if (namespace == "") _dbtable else s"${namespace}:${_dbtable}"

    reader.option("inputTableName", inputTableName)

    //load configs should overwrite connect configs
    reader.options(config.config)
    reader.format(format).load()
  }

  override def save(writer: DataFrameWriter[Row], config: DataSinkConfig): Unit = {
    val Array(_dbname, _dbtable) = if (config.path.contains(dbSplitter)) {
      config.path.split(dbSplitter, 2)
    } else {
      Array("", config.path)
    }

    var namespace = ""

    val format = config.config.getOrElse("implClass", fullFormat)
    if (_dbname != "") {
      ConnectMeta.presentThenCall(DBMappingKey(format, _dbname), options => {
        if (options.contains("namespace")) {
          namespace = options.get("namespace").get
        }
        writer.options(options)
      })
    }

    if (config.config.contains("namespace")) {
      namespace = config.config.get("namespace").get
    }

    val outputTableName = if (namespace == "") _dbtable else s"${namespace}:${_dbtable}"

    writer.mode(config.mode)
    writer.option("outputTableName", outputTableName)
    //load configs should overwrite connect configs
    writer.options(config.config)
    config.config.get("partitionByCol").map { item =>
      writer.partitionBy(item.split(","): _*)
    }
    writer.format(config.config.getOrElse("implClass", fullFormat)).save()
  }

  override def register(): Unit = {
    DataSourceRegistry.register(MLSQLDataSourceKey(fullFormat, MLSQLSparkDataSourceType), this)
    DataSourceRegistry.register(MLSQLDataSourceKey(shortFormat, MLSQLSparkDataSourceType), this)
  }

  override def sourceInfo(config: DataAuthConfig): SourceInfo = {
    val format = config.config.getOrElse("implClass", fullFormat)
    val Array(connect, namespace, table) = if (config.path.contains(dbSplitter)) {
      config.path.split(dbSplitter) match {
        case Array(connect, namespace, table) => Array(connect, namespace, table)
        case Array(connectOrNameSpace, table) =>
          ConnectMeta.presentThenCall(DBMappingKey(format, connectOrNameSpace), (op) => {}) match {
            case Some(i) => Array(connectOrNameSpace, "", table)
            case None => Array("", connectOrNameSpace, table)
          }
        case Array(connect, namespace, table, _*) => Array(connect, namespace, table)
      }
    } else {
      Array("", "", config.path)
    }


    var finalNameSpace = config.config.getOrElse("namespace", namespace)

    ConnectMeta.presentThenCall(DBMappingKey(format, connect), (options) => {
      if (options.contains("namespace")) {
        finalNameSpace = options.get("namespace").get
      }

    })


    SourceInfo(shortFormat, finalNameSpace, table)
  }

  override def explainParams(spark: SparkSession) = {
    _explainParams(spark)
  }

  final val zk: Param[String] = new Param[String](this, "zk", "zk address")
  final val family: Param[String] = new Param[String](this, "family", "default cf")

  override def supportedVersions: Seq[String] = {
    Seq("1.5.0-SNAPSHOT", "1.5.0", "1.6.0-SNAPSHOT", "1.6.0")
  }
}

ET 扩展点

这个我们之前在说算法的时候就已经讨论过了。比如我还可以实现一个对临时表进行重新设置并行度的功能,代码是这样的:

代码语言:javascript
复制
package tech.mlsql.plugins.et

import org.apache.spark.ml.param.{IntParam, Param}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.mlsql.session.MLSQLException
import org.apache.spark.sql.{DataFrame, SparkSession, functions => F}
import streaming.dsl.auth.TableAuthResult
import streaming.dsl.mmlib._
import streaming.dsl.mmlib.algs.Functions
import streaming.dsl.mmlib.algs.param.{BaseParams, WowParams}
import tech.mlsql.dsl.auth.ETAuth
import tech.mlsql.dsl.auth.dsl.mmlib.ETMethod.ETMethod
import tech.mlsql.version.VersionCompatibility


class TableRepartition(override val uid: String) extends SQLAlg with VersionCompatibility with Functions with WowParams with ETAuth {
  def this() = this(BaseParams.randomUID())

  // 
  override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {

    params.get(partitionNum.name).map { item =>
      set(partitionNum, item.toInt)
      item
    }.getOrElse {
      throw new MLSQLException(s"${partitionNum.name} is required")
    }

    params.get(partitionType.name).map { item =>
      set(partitionType, item)
      item
    }.getOrElse {
      set(partitionType, "hash")
    }

    params.get(partitionCols.name).map { item =>
      set(partitionCols, item)
      item
    }.getOrElse {
      set(partitionCols, "")
    }

    $(partitionType) match {
      case "range" =>

        require(params.contains(partitionCols.name), "At least one partition-by expression must be specified.")
        df.repartitionByRange($(partitionNum), $(partitionCols).split(",").map(name => F.col(name)): _*)

      case _ =>
        df.repartition($(partitionNum))
    }


  }

  override def auth(etMethod: ETMethod, path: String, params: Map[String, String]): List[TableAuthResult] = {
    List()
  }

  override def supportedVersions: Seq[String] = {
    Seq("1.5.0-SNAPSHOT", "1.5.0", "1.6.0-SNAPSHOT", "1.6.0")
  }


  override def doc: Doc = Doc(MarkDownDoc,
    s"""
       |
    """.stripMargin)


  override def codeExample: Code = Code(SQLCode,
    """
      |
    """.stripMargin)

  override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = train(df, path, params)

  override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = ???

  override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = ???

  final val partitionNum: IntParam = new IntParam(this, "partitionNum",
    "")
  final val partitionType: Param[String] = new Param[String](this, "partitionType",
    "")

  final val partitionCols: Param[String] = new Param[String](this, "partitionCols",
    "")

  override def explainParams(sparkSession: SparkSession): DataFrame = _explainParams(sparkSession)

}

和算法ET的形态是一模一样的。

App扩展点

有的时候你可能需要提供新的API接口,以及一大批ET,DataSource功能,每个都作为插件显然不现实。我们这个时候就需要打个包了。通过App扩展点可以很好的满足这一点。下面是BigDL的:

代码语言:javascript
复制
package tech.mlsql.plugins.bigdl

import tech.mlsql.ets.register.ETRegister
import tech.mlsql.version.VersionCompatibility

/**
 * 5/4/2020 WilliamZhu(allwefantasy@gmail.com)
 */
class BigDLApp extends tech.mlsql.app.App with VersionCompatibility {
  override def run(args: Seq[String]): Unit = {
    ETRegister.register("ImageLoaderExt", classOf[SQLImageLoaderExt].getName)
    ETRegister.register("MnistLoaderExt", classOf[SQLMnistLoaderExt].getName)
    ETRegister.register("BigDLClassifyExt", classOf[SQLBigDLClassifyExt].getName)
    ETRegister.register("LeNet5Ext", classOf[SQLLeNet5Ext].getName)
  }

  override def supportedVersions: Seq[String] = Seq("1.5.0-SNAPSHOT", "1.5.0", "1.6.0-SNAPSHOT", "1.6.0")
}

我们看到,BigDL其实就是注册了一大堆的ET组件,从而完成图片处理,深度学习,已经一些现成的模型。

Script扩展点

[todo]

代码语言:javascript
复制
https://cloud.tencent.com/developer/support-plan?invite_code=45eorticfc31
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2020-04-23,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 贡献文档 (任务难度:中)
  • 贡献算法模块内置Help (任务难度:低)
  • 贡献算法模块 (任务难度:中)
  • 贡献插件 (任务难度:高)
相关产品与服务
图片处理
图片处理(Image Processing,IP)是由腾讯云数据万象提供的丰富的图片处理服务,广泛应用于腾讯内部各产品。支持对腾讯云对象存储 COS 或第三方源的图片进行处理,提供基础处理能力(图片裁剪、转格式、缩放、打水印等)、图片瘦身能力(Guetzli 压缩、AVIF 转码压缩)、盲水印版权保护能力,同时支持先进的图像 AI 功能(图像增强、图像标签、图像评分、图像修复、商品抠图等),满足多种业务场景下的图片处理需求。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档