前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >交叉验证的Java weka实现,并保存和重载模型

交叉验证的Java weka实现,并保存和重载模型

作者头像
百川AI
发布2021-10-19 15:58:05
8620
发布2021-10-19 15:58:05
举报
文章被收录于专栏:我还不懂对话我还不懂对话

我觉得首先有必要简单说说交叉验证,即用只有一个训练集的时候,用一部分数据训练,一部分做测试,当然怎么分配及时不同的方法了。

1)k-folder cross-validation:

k个子集,每个子集均做一次测试集,其余的作为训练集。交叉验证重复k次,每次选择一个子集作为测试集,并将k次的平均交叉验证识别正确率作为结果。 优点:所有的样本都被作为了训练集和测试集,每个样本都被验证一次。10-folder通常被使用。

2)K * 2 folder cross-validation

是k-folder cross-validation的一个变体,对每一个folder,都平均分成两个集合s0,s1,我们先在集合s0训练用s1测试,然后用s1训练s0测试。 优点是:测试和训练集都足够大,每一个个样本都被作为训练集和测试集。一般使用k=10

3)least-one-out cross-validation(loocv)

假设dataset中有n个样本,那LOOCV也就是n-CV,意思是每个样本单独作为一次测试集,剩余n-1个样本则做为训练集。 优点: 1)每一回合中几乎所有的样本皆用于训练model,因此最接近母体样本的分布,估测所得的generalization error比较可靠。 2)实验过程中没有随机因素会影响实验数据,确保实验过程是可以被复制的。 但LOOCV的缺点则是计算成本高,为需要建立的models数量与总样本数量相同,当总样本数量相当多时,LOOCV在实作上便有困难,除非每次训练model的速度很快,或是可以用平行化计算减少计算所需的时间。

废话不多说,直接上代码:

关键代码:

代码语言:javascript
复制
//直接调用Evaluation即可完成
Evaluation eval = null;
for (int i = 0; i < 10; i++) {
   eval = new Evaluation(Train);
   eval.crossValidateModel(m_classifier, Train, 10, new Random(i),
args);// 实现交叉验证模型
}
System.out.println(eval.toSummaryString());// 输出总结信息
System.out.println(eval.toClassDetailsString());// 输出分类详细信息
System.out.println(eval.toMatrixString());// 输出分类的混淆矩阵

Java调用weka实现算法,并保存模型,以及读取。

这个在网上找了很久,没找到,却偶然一次发现了,其实很简单,只要因为好一点的话,看国外论坛就好多了。 保存模型方法:

代码语言:javascript
复制
SerializationHelper.write("LibSVM.model", classifier4);//参数一为模型保存文件,classifier4为要保存的模型

加载模型:

代码语言:javascript
复制
Classifier classifier8 = (Classifier) weka.core.SerializationHelper.read("LibSVM.model");

全部代码:

代码语言:javascript
复制
package weka_test;
import java.io.File;
import java.io.IOException;
import weka.classifiers.Classifier;
import weka.classifiers.trees.J48;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ArffLoader;
import weka.experiment.InstanceQuery;
import weka.classifiers.Evaluation;
import java.util.Random;

public class test {
   /**
    * oracleInput
    * @return data
    * @throws Exception
*/
   public static Instances oracleInput() throws Exception{
      InstanceQuery query = new InstanceQuery();
      String sql = "SELECT to_char(z.cydate,'yyyy/mm') AS d,sum(z.bcmoney) as c FROM zybc z"
+ " WHERE to_char(z.cydate,'yyyy/mm') IS NOT NULL"
+ " GROUP BY to_char(z.cydate,'yyyy/mm') ORDER BY to_date(to_char(z.cydate,'yyyy/mm'),'yyyy/mm') ASC";
      //System.out.println(sql);
      query.setCustomPropsFile(new File("weka/weka_oracle.props"));
      query.setDatabaseURL("jdbc:oracle:thin:@192.168.2.133:1521/XE");
      query.setUsername("***");
      query.setPassword("***");
      query.setQuery(sql);
      Instances data = query.retrieveInstances();          
return data;
   }
   /**
    * mysqlInput
    * @return data
    * @throws Exception
*/
   public static Instances mysqlInput() throws Exception{
      InstanceQuery query = new InstanceQuery();
      String sql = "SELECT * FROM iris";
      //System.out.println(sql);
      query.setCustomPropsFile(new File("weka/weka_mysql.props"));
      query.setDatabaseURL("jdbc:mysql://localhost:3306/test");
      query.setUsername("***");
      query.setPassword("***");
      query.setQuery(sql);
      Instances data = query.retrieveInstances();    
return data;
   }

   /**
    * @param args
    * @throws Exception 
*/
   public static void main(String[] args) throws Exception {
      // TODO Auto-generated method stub
      Classifier m_classifier = new J48();

        /*File inputFile = new File("D://Program Files//Weka-3-7//data//iris.arff");//训练语料文件
        ArffLoader atf = new ArffLoader(); 
        atf.setFile(inputFile);
        Instances instancesTrain = atf.getDataSet(); // 读入训练文件    */ 

      Instances Train = mysqlInput();
      Instances Test = mysqlInput();
        Test.setClassIndex(4); //设置分类属性所在行号(第一行为0号),instancesTest.numAttributes()可以取得属性总数
        double sum = Test.numInstances(),//测试语料实例数
        right = 0.0f;
        Train.setClassIndex(4);
        m_classifier.buildClassifier(Train); //训练          
        //System.out.println(m_classifier.toString());

      //2、利用模型进行预测 
int a=0,b=0,c=0,d=0;//记录每个类别的个数,方便计算评价指标
for (int i = 0; i < Test.numInstances(); i++) {
         double classification = m_classifier.classifyInstance(Train.instance(i));
         double classValue = Train.instance(i).classValue();
if (classification == 0.0 && classValue == 0.0) {
            a++;
         } else if (classification == 0.0 && classValue == 1.0) {
            b++;
         } else if (classification == 1.0 && classValue == 0.0) {
            c++;
         } else if (classification == 1.0 && classValue == 1.0) {
            d++;
         }
      }
      // 3、得出预测效果评测指标
      double precision = (double) a / (a + b);
      double recall = (double) a / (a + c);
      double fMeasure = 2 * precision * recall / (precision + recall);
      System.out.println("precision\trecall\tF-Measure");
      System.out.println((precision) + "\t\t"
+ (recall) + "\t"
+ (fMeasure));

for(int  i = 0;i

CSDN博客原文

http://blog.csdn.net/shine19930820/article/details/50921109

授人以鱼不如授人以渔:

python sklearn数据预处理:

http://blog.csdn.net/shine19930820/article/details/50915361

广义线性模型--Generalized Linear Models

http://blog.csdn.net/shine19930820/article/details/50997645

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2016-03-18 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1)k-folder cross-validation:
  • 2)K * 2 folder cross-validation
  • 3)least-one-out cross-validation(loocv)
  • Java调用weka实现算法,并保存模型,以及读取。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档