前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >scala-sparkML学习笔记:serializable custom transformer with spark-scala

scala-sparkML学习笔记:serializable custom transformer with spark-scala

作者头像
MachineLP
发布2019-08-31 19:25:02
6100
发布2019-08-31 19:25:02
举报
文章被收录于专栏:小鹏的专栏小鹏的专栏

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

本文链接:https://blog.csdn.net/u014365862/article/details/100146543

有时候在构建pipeline时,sparkML中有些功能不存在需要自己定义,可以参考这个样例:

(src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/OwnMLlibPipeline.scala

代码语言:javascript
复制
/*
-------------------------------------------------
   Description :  Serializable Custom Transformer with Spark 2.0 (Scala)
   Author :       liupeng
   Date :         2019/08/29
-------------------------------------------------
 */

package ml.dmlc.xgboost4j.scala.example.spark

import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.PipelineStage
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.StructType



import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{ Param, ParamMap }
import org.apache.spark.ml.util.{ DefaultParamsReadable, DefaultParamsWritable, Identifiable }
import org.apache.spark.sql.{ DataFrame, Dataset }
import org.apache.spark.sql.types.StructType


class ColRenameTransformer(override val uid: String) extends Transformer with DefaultParamsWritable {

  def this() = this(Identifiable.randomUID("ColRenameTransformer"))
  def setInputCol(value: String): this.type = set(inputCol, value)
  def setOutputCol(value: String): this.type = set(outputCol, value)
  def getOutputCol: String = getOrDefault(outputCol)

  val inputCol = new Param[String](this, "inputCol", "input column")
  val outputCol = new Param[String](this, "outputCol", "output column")

  override def transform(dataset: Dataset[_]): DataFrame = {
    val outCol = extractParamMap.getOrElse(outputCol, "output")
    val inCol = extractParamMap.getOrElse(inputCol, "input")

    dataset.drop(outCol).withColumnRenamed(inCol, outCol)
  }

  override def copy(extra: ParamMap): ColRenameTransformer = defaultCopy(extra)
  override def transformSchema(schema: StructType): StructType = schema
}

object ColRenameTransformer extends DefaultParamsReadable[ColRenameTransformer] {
  override def load(path: String): ColRenameTransformer = super.load(path)
}


object OwnMLlibPipeline {

  def main(args: Array[String]): Unit = {

    val pipelineModelPath = args(0)

    val spark = SparkSession.builder().getOrCreate()
    val data = spark.createDataFrame(Seq(
                     ("hi,there", 1),
                     ("a,b,c", 2),
                     ("no", 3) )).toDF("myInputCol", "id")
    data.show(false)
    val myTransformer = new ColRenameTransformer().setInputCol( "id" ).setOutputCol( "lpid" )
    println(s"Original data has ${data.count()} rows.")
    // val output = myTransformer.transform(data)
    // println(s"Output data has ${output.count()} rows.")
    // output.show(false)
  
    val pipeline = new Pipeline().setStages(Array( myTransformer ))
    val model = pipeline.fit(data)
    // ML pipeline persistence
    model.write.overwrite().save(pipelineModelPath)
    // Load a saved model and serving
    val model2 = PipelineModel.load(pipelineModelPath)
    model2.transform(data).show(false)
  }
}

运行结果:

代码语言:javascript
复制
input:
+----------+---+
|myInputCol|id |
+----------+---+
|hi,there  |1  |
|a,b,c     |2  |
|no        |3  |
+----------+---+
res:
+----------+----+
|myInputCol|lpid|
+----------+----+
|hi,there  |1   |
|a,b,c     |2   |
|no        |3   |
+----------+----+
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019年08月29日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档