专栏首页xingoo, 一个梦想做发明家的程序员在Java Web中使用Spark MLlib训练的模型

在Java Web中使用Spark MLlib训练的模型

PMML是一种通用的配置文件,只要遵循标准的配置文件,就可以在Spark中训练机器学习模型,然后再web接口端去使用。目前应用最广的就是基于Jpmml来加载模型在javaweb中应用,这样就可以实现跨平台的机器学习应用了。

训练模型

首先在spark MLlib中使用mllib包下的逻辑回归训练模型:

import org.apache.spark.mllib.classification.{LogisticRegressionModel, LogisticRegressionWithLBFGS}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
    
val training = spark.sparkContext
  .parallelize(Seq("0,1 2 3 1", "1,2 4 1 5", "0,7 8 3 6", "1,2 5 6 9").map( line => LabeledPoint.parse(line)))

// Run training algorithm to build the model
val model = new LogisticRegressionWithLBFGS()
  .setNumClasses(2)
  .run(training)

val test = spark.sparkContext
  .parallelize(Seq("0,1 2 3 1").map( line => LabeledPoint.parse(line)))


// Compute raw scores on the test set.
val predictionAndLabels = test.map { case LabeledPoint(label, features) =>
  val prediction = model.predict(features)
  (prediction, label)
}

// Get evaluation metrics.
val metrics = new MulticlassMetrics(predictionAndLabels)
val accuracy = metrics.accuracy
println(s"Accuracy = $accuracy")

// Save and load model
//    model.save(spark.sparkContext, "target/tmp/scalaLogisticRegressionWithLBFGSModel")
//    val sameModel = LogisticRegressionModel.load(spark.sparkContext,"target/tmp/scalaLogisticRegressionWithLBFGSModel")

model.toPMML(spark.sparkContext, "/tmp/xhl/data/test2")

训练得到的模型保存到hdfs。

PMML模型文件

模型下载到本地,重新命名为xml。 可以看到默认四个特征分别叫做feild_0field_1...目标为target

<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML version="4.2" xmlns="http://www.dmg.org/PMML-4_2">
    <Header description="logistic regression">
        <Application name="Apache Spark MLlib" version="2.2.0"/>
        <Timestamp>2018-11-15T10:22:25</Timestamp>
    </Header>
    <DataDictionary numberOfFields="5">
        <DataField name="field_0" optype="continuous" dataType="double"/>
        <DataField name="field_1" optype="continuous" dataType="double"/>
        <DataField name="field_2" optype="continuous" dataType="double"/>
        <DataField name="field_3" optype="continuous" dataType="double"/>
        <DataField name="target" optype="categorical" dataType="string"/>
    </DataDictionary>
    <RegressionModel modelName="logistic regression" functionName="classification" normalizationMethod="logit">
        <MiningSchema>
            <MiningField name="field_0" usageType="active"/>
            <MiningField name="field_1" usageType="active"/>
            <MiningField name="field_2" usageType="active"/>
            <MiningField name="field_3" usageType="active"/>
            <MiningField name="target" usageType="target"/>
        </MiningSchema>
        <RegressionTable intercept="0.0" targetCategory="1">
            <NumericPredictor name="field_0" coefficient="-5.552297758753701"/>
            <NumericPredictor name="field_1" coefficient="-1.4863480719075117"/>
            <NumericPredictor name="field_2" coefficient="-5.7232298850417855"/>
            <NumericPredictor name="field_3" coefficient="8.134075057437393"/>
        </RegressionTable>
        <RegressionTable intercept="-0.0" targetCategory="0"/>
    </RegressionModel>
</PMML>

接口使用

在接口的web工程中引入maven jar:

<!-- https://mvnrepository.com/artifact/org.jpmml/pmml-evaluator -->
<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>pmml-evaluator</artifactId>
    <version>1.4.3</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.jpmml/pmml-evaluator-extension -->
<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>pmml-evaluator-extension</artifactId>
    <version>1.4.3</version>
</dependency>

接口代码中直接读取pmml,使用模型进行预测:

package soundsystem;

import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;

import java.io.FileInputStream;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

public class PMMLDemo2 {
    private Evaluator loadPmml(){
        PMML pmml = new PMML();
        try(InputStream inputStream = new FileInputStream("/Users/xingoo/Desktop/test2.xml")){
            pmml = org.jpmml.model.PMMLUtil.unmarshal(inputStream);
        } catch (Exception e) {
            e.printStackTrace();
        }
        ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
        return modelEvaluatorFactory.newModelEvaluator(pmml);
    }
    private Object predict(Evaluator evaluator,int a, int b, int c, int d) {
        Map<String, Integer> data = new HashMap<String, Integer>();
        data.put("field_0", a);
        data.put("field_1", b);
        data.put("field_2", c);
        data.put("field_3", d);
        List<InputField> inputFields = evaluator.getInputFields();
        //过模型的原始特征,从画像中获取数据,作为模型输入
        Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
        for (InputField inputField : inputFields) {
            FieldName inputFieldName = inputField.getName();
            Object rawValue = data.get(inputFieldName.getValue());
            FieldValue inputFieldValue = inputField.prepare(rawValue);
            arguments.put(inputFieldName, inputFieldValue);
        }

        Map<FieldName, ?> results = evaluator.evaluate(arguments);

        List<TargetField> targetFields = evaluator.getTargetFields();
        TargetField targetField = targetFields.get(0);
        FieldName targetFieldName = targetField.getName();
        ProbabilityDistribution target = (ProbabilityDistribution) results.get(targetFieldName);
        System.out.println(a + " " + b + " " + c + " " + d + ":" + target);
        return target;
    }
    public static void main(String args[]){
        PMMLDemo2 demo = new PMMLDemo2();
        Evaluator model = demo.loadPmml();
        demo.predict(model,2,5,6,8);
        demo.predict(model,7,9,3,6);
        demo.predict(model,1,2,3,1);
        demo.predict(model,2,4,1,5);
    }
}

得到输出内容:

2 5 6 8:ProbabilityDistribution{result=1, probability_entries=[1=0.9999949538769296, 0=5.046123070395758E-6]}
7 9 3 6:ProbabilityDistribution{result=0, probability_entries=[1=1.1216598160542013E-9, 0=0.9999999988783402]}
1 2 3 1:ProbabilityDistribution{result=0, probability_entries=[1=2.363331367481431E-8, 0=0.9999999763666864]}
2 4 1 5:ProbabilityDistribution{result=1, probability_entries=[1=0.9999999831203591, 0=1.6879640907241367E-8]}

其中result为LR最终的结果,概率为二分类的概率。

参考资料

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 【插件开发】—— 9 编辑器代码分块着色-高亮显示!

    前文回顾: 1 插件学习篇 2 简单的建立插件工程以及模型文件分析 3 利用扩展点,开发透视图 4 SWT编程须知 5 SWT简单控件的使用与布局搭...

    用户1154259
  • JQuery ztree 异步加载实践

    本来要做一个文件目录浏览界面,需要遍历所有的文件和目录,很显然一次性读取时很费时费力的一件事情。 因此就需要做异步加载.... 不过网上的几篇帖子还挺坑的...

    用户1154259
  • 【Hibernate那点事儿】—— Hibernate知识总结

    前言: 上一篇简单的讲解了下Hibernate的基础知识。这里对Hibernate比较重要的一些知识点,进行总结和归纳。 总结的知识点: 1 关于...

    用户1154259
  • 第73节:Java中的HTTPServletReauest和HTTPServletResponse

    ServletContext可以获取全局配置参数,可以获取web工程中的资源,存储数据,servlet简共享数据。

    达达前端
  • 一键生成前后端代码,Mybatis-Plus代码生成器让我舒服了

    在日常的软件开发中,程序员往往需要花费大量的时间写CRUD,不仅枯燥效率低,而且每个人的代码风格不统一。MyBatis-Plus 代码生成器,通过 AutoGe...

    不会飞的小鸟
  • Transformers2.0让你三行代码调用语言模型,兼容TF2.0和PyTorch

    最近,专注于自然语言处理(NLP)的初创公司 HuggingFace 对其非常受欢迎的 Transformers 库进行了重大更新,从而为 PyTorch 和 ...

    机器之心
  • 蓝桥杯 基础练习 数的读法

      Tom教授正在给研究生讲授一门关于基因的课程,有一件事情让他颇为头疼:一条染色体上有成千上万个碱基对,它们从0开始编号,到几百万,几千万,甚至上亿。   比...

    Debug客栈
  • java高级进阶|对数据库事务传播行为再次理解

    自己在18,19年的时候分别写过一个示例程序关于数据库事务传播行为的演练操作,但是示例程序主要还是针对mongodb数据库是否支持数据库事务的操作和Mysql这...

    后端Coder
  • Java并发编程锁系列之ReentrantLock对象总结

    在Java并发编程中,根据不同维度来区分锁的话,锁可以分为十五种。ReentranckLock就是其中的多个分类。

    凯哥Java
  • caffe的demo测试

    当运行 demo.py 有如上输出时, 说明我们已经可以通过之前别人训练好的模型进行测试, 下面我们将自己动手训练一个模型。该模型数据采用 voc2007 数据...

    foochane

扫码关注云+社区

领取腾讯云代金券