前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >scala-sparkML学习笔记:xgboost进行分布式训练

scala-sparkML学习笔记:xgboost进行分布式训练

作者头像
MachineLP
发布2019-08-31 19:25:47
4.5K0
发布2019-08-31 19:25:47
举报
文章被收录于专栏:小鹏的专栏

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

本文链接:https://blog.csdn.net/u014365862/article/details/100146395

java/scala生成jar一般采用有两种sbt和maven,本人介绍通过maven生成jar的方式,同时可以查看git:https://github.com/MachineLP/Spark-/tree/master/scala-xgboost

xgboost SparkMLlibPipeline.scala代码如下:(注意运行时要按照特征目录格式组织:src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala )

代码语言:javascript
复制
package ml.dmlc.xgboost4j.scala.example.spark

import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature._
import org.apache.spark.ml.tuning._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types._

import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassifier, XGBoostClassificationModel}

// this example works with Iris dataset (https://archive.ics.uci.edu/ml/datasets/iris)

object SparkMLlibPipeline {

  def main(args: Array[String]): Unit = {

    if (args.length != 3) {
      println("Usage: SparkMLlibPipeline input_path native_model_path pipeline_model_path")
      sys.exit(1)
    }

    val inputPath = args(0)
    val nativeModelPath = args(1)
    val pipelineModelPath = args(2)

    val spark = SparkSession
      .builder()
      .appName("XGBoost4J-Spark Pipeline Example")
      .getOrCreate()

    // Load dataset
    val schema = new StructType(Array(
      StructField("sepal length", DoubleType, true),
      StructField("sepal width", DoubleType, true),
      StructField("petal length", DoubleType, true),
      StructField("petal width", DoubleType, true),
      StructField("class", StringType, true)))

    val rawInput = spark.read.schema(schema).csv(inputPath)

    // Split training and test dataset
    val Array(training, test) = rawInput.randomSplit(Array(0.8, 0.2), 123)

    // Build ML pipeline, it includes 4 stages:
    // 1, Assemble all features into a single vector column.
    // 2, From string label to indexed double label.
    // 3, Use XGBoostClassifier to train classification model.
    // 4, Convert indexed double label back to original string label.
    val assembler = new VectorAssembler()
      .setInputCols(Array("sepal length", "sepal width", "petal length", "petal width"))
      .setOutputCol("features")
    val labelIndexer = new StringIndexer()
      .setInputCol("class")
      .setOutputCol("classIndex")
      .fit(training)
    val booster = new XGBoostClassifier(
      Map("eta" -> 0.1f,
        "max_depth" -> 2,
        "objective" -> "multi:softprob",
        "num_class" -> 3,
        "num_round" -> 100,
        "num_workers" -> 2
      )
    )
    booster.setFeaturesCol("features")
    booster.setLabelCol("classIndex")
    val labelConverter = new IndexToString()
      .setInputCol("prediction")
      .setOutputCol("realLabel")
      .setLabels(labelIndexer.labels)

    val pipeline = new Pipeline()
      .setStages(Array(assembler, labelIndexer, booster, labelConverter))
    val model = pipeline.fit(training)

    // Batch prediction
    val prediction = model.transform(test)
    prediction.show(false)

    // Model evaluation
    val evaluator = new MulticlassClassificationEvaluator()
    evaluator.setLabelCol("classIndex")
    evaluator.setPredictionCol("prediction")
    val accuracy = evaluator.evaluate(prediction)
    println("The model accuracy is : " + accuracy)

    // Tune model using cross validation
    val paramGrid = new ParamGridBuilder()
      .addGrid(booster.maxDepth, Array(3, 8))
      .addGrid(booster.eta, Array(0.2, 0.6))
      .build()
    val cv = new CrossValidator()
      .setEstimator(pipeline)
      .setEvaluator(evaluator)
      .setEstimatorParamMaps(paramGrid)
      .setNumFolds(3)

    val cvModel = cv.fit(training)

    val bestModel = cvModel.bestModel.asInstanceOf[PipelineModel].stages(2)
      .asInstanceOf[XGBoostClassificationModel]
    println("The params of best XGBoostClassification model : " +
      bestModel.extractParamMap())
    println("The training summary of best XGBoostClassificationModel : " +
      bestModel.summary)

    // Export the XGBoostClassificationModel as local XGBoost model,
    // then you can load it back in local Python environment.
    bestModel.nativeBooster.saveModel(nativeModelPath)

    // ML pipeline persistence
    model.write.overwrite().save(pipelineModelPath)

    // Load a saved model and serving
    val model2 = PipelineModel.load(pipelineModelPath)
    model2.transform(test).show(false)
  }
}

pom.xml文件如下:(注意添加正确的依赖)

代码语言:javascript
复制
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
  xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
  <modelVersion>4.0.0</modelVersion>
  <groupId>ml.dmlc</groupId>
  <artifactId>xgboost4j-example_2.11</artifactId>
  <version>1.0.0</version>
  <packaging>jar</packaging>
  <name>${project.artifactId}</name>
  <description>This is a boilerplate maven project to start using Spark in Scala</description>
  <inceptionYear>2010</inceptionYear>

  <properties>
    <maven.compiler.source>1.8</maven.compiler.source>
    <maven.compiler.target>1.8</maven.compiler.target>
    <encoding>UTF-8</encoding>
    <scala.tools.version>2.11</scala.tools.version>
    <!-- Put the Scala version of the cluster --> 
    <scala.version>2.11.12</scala.version> 
    <scala.binary.version>2.11</scala.binary.version> 
    <spark.version>2.4.3</spark.version> 
  </properties>
  
  <!-- repository to add org.apache.spark -->
  <repositories>
    <repository>
      <id>cloudera-repo-releases</id>
      <url>https://repository.cloudera.com/artifactory/repo/</url>
    </repository>
    <repository>
      <id>GitHub Repo</id>
      <name>GitHub Repo</name>
      <url>https://raw.githubusercontent.com/CodingCat/xgboost/maven-repo/</url>
    </repository>
  </repositories>

  <build>
    <sourceDirectory>src/main/scala</sourceDirectory>
    <testSourceDirectory>src/test/scala</testSourceDirectory>
    <plugins>
      <plugin>
        <!-- see http://davidb.github.com/scala-maven-plugin -->
        <!-- https://mvnrepository.com/artifact/net.alchim31.maven/scala-maven-plugin -->
        <groupId>net.alchim31.maven</groupId>
        <artifactId>scala-maven-plugin</artifactId>
        <!-- <version>3.1.3</version> -->
        <version>4.0.2</version>
        <!-- <version>3.4.6</version> -->
        <executions>
          <execution>
            <goals>
              <goal>compile</goal>
              <goal>testCompile</goal>
            </goals>
            <configuration>
              <args>
                <arg>-dependencyfile</arg>
                <arg>${project.build.directory}/.scala_dependencies</arg>
              </args>
            </configuration>
          </execution>
        </executions>
      </plugin>
      <plugin>
        <groupId>org.apache.maven.plugins</groupId>
        <artifactId>maven-surefire-plugin</artifactId>
        <version>2.13</version>
        <configuration>
          <useFile>false</useFile>
          <disableXmlReport>true</disableXmlReport>
          <!-- If you have classpath issue like NoDefClassError,... -->
          <!-- useManifestOnlyJar>false</useManifestOnlyJar -->
          <includes>
            <include>**/*Test.*</include>
            <include>**/*Suite.*</include>
          </includes>
        </configuration>
      </plugin>

      <!-- "package" command plugin -->
      <plugin>
        <artifactId>maven-assembly-plugin</artifactId>
        <version>2.4.1</version>
        <configuration>
          <descriptorRefs>
            <descriptorRef>jar-with-dependencies</descriptorRef>
          </descriptorRefs>
        </configuration>
        <executions>
          <execution>
            <id>make-assembly</id>
            <phase>package</phase>
            <goals>
              <goal>single</goal>
            </goals>
          </execution>
        </executions>
      </plugin>
    </plugins>
  </build>

  <dependencies>
    <dependency>
        <groupId>ml.dmlc</groupId>
        <artifactId>xgboost4j-spark</artifactId> 
        <version>0.90</version> 
    </dependency>
    <dependency>
        <groupId>ml.dmlc</groupId>
        <artifactId>xgboost4j</artifactId> 
        <version>0.90</version> 
    </dependency>
    <!-- Scala and Spark dependencies -->
    <!-- https://mvnrepository.com/artifact/org.scala-lang/scala-library -->
    <dependency>
      <groupId>org.scala-lang</groupId>
      <artifactId>scala-library</artifactId>
      <version>${scala.version}</version>
    </dependency>
    <!-- https://mvnrepository.com/artifact/org.apache.spark/spark-core -->
    <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-core_${scala.binary.version}</artifactId>
      <version>${spark.version}</version>
    </dependency>
    <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-mllib_${scala.binary.version}</artifactId>
      <version>${spark.version}</version>
    </dependency>
    <dependency>
      <groupId>org.apache.commons</groupId>
      <artifactId>commons-lang3</artifactId>
      <version>3.4</version>
    </dependency>
    <dependency>
      <groupId>org.apache.velocity</groupId>
      <artifactId>velocity</artifactId>
      <version>1.7</version>
    </dependency>
    <dependency>
      <groupId>commons-logging</groupId>
      <artifactId>commons-logging</artifactId>
      <version>1.2</version>
    </dependency>
    <!-- https://mvnrepository.com/artifact/com.github.scopt/scopt_2.11 -->
    <dependency>
      <groupId>com.github.scopt</groupId>
      <artifactId>scopt_2.11</artifactId>
      <version>3.5.0</version>
    </dependency>
  </dependencies>
</project>

之后运行生成jar包:

代码语言:javascript
复制
mvn clean package

最后,在集群上提交任务即可:

代码语言:javascript
复制
spark-2.4.3-bin-hadoop2.7/bin/spark-submit  --class ml.dmlc.xgboost4j.scala.example.spark.SparkMLlibPipeline --jars /***/scala_workSpace/test/xgboost4j-example_2.11-1.0.0-jar-with-dependencies.jar /***/scala_workSpace/test/xgboost4j-example_2.11-1.0.0.jar /tmp/rd/lp/iris.data /***/scala_workSpace/test/nativeModel /tmp/rd/lp/pipelineModel
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2019年08月29日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档