Java 机器学习库Smile实战(一)SVM

本文不会介绍SVM的基本原理,如果想了解SVM基本原理,请参阅相关书籍。

要使用Java机器学习库Smile,需首先在项目的Maven配置文件pom.xml中添加如下的maven依赖项:

<dependency>
    <groupId>com.github.haifengl</groupId>
    <artifactId>smile-core</artifactId>
    <version>1.4.0</version>
</dependency>

Smile中的SVM是一个泛型类,他可以支持二分类和多分类两种使用方法,而且这两种使用方法差异较大,所以分开介绍。

1. 二分类

Smile 库的SVM类是一个泛型类型,默认情况下进行二分类,选择参数为核函数类型和惩罚项参数。

import smile.classification.SVM;
import smile.math.kernel.GaussianKernel;

public class Demo {
    public static void main(String[]args){

        double gamma = 1.0;
        double C = 1.0;

        //通过某种方式获取训练数据及其类标
        double[][] data = ...
        int[] label = ...

        SVM<double[]> svm = new SVM<double[]>(
           new GaussianKernel(gamma), C);
        
        svm.learn(data, label); //训练模型
        svm.finish();

        //获取测试数据
        double[][] testData = ...
        int[] result = new int[testData.length];
        for(int i=0; i < testData.length; i++){
            result[i] = svm.predict(testData[i]);
        }
    }
}

2. 多分类

接下来是我利用SVM对iris数据集进行分类的程序。首先我们将iris数据保存iris.txt文件,如下结构:

5.1 3.5 1.4 0.2 0

4.9 3 1.4 0.2 0

...

每一行代表一个测试数据项,前4列是属性向量,最后一列是类标(在Smile中类标不能为负数,并且只能是从0开始的正整数,所以上述类标为:0、1、2)。检测的完整的源代码如下:

import smile.classification.SVM;
import smile.math.kernel.GaussianKernel;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * Created by zhanghuayan on 2017/1/16.
 */
public class ClassificationTest {

    public static void main(String[] args) throws Exception {

        List<List<Double>> datas = 
            new ArrayList<List<Double>>();
        List<Double> data = new ArrayList<Double>();
        List<Integer> labels = new ArrayList<Integer>();

        String line;
        List<String> lines;
        File file = new File("iris.txt");
        BufferedReader reader = 
        new BufferedReader(new FileReader(file));
        
        while ((line = reader.readLine()) != null) {
            lines = Arrays.asList(line.trim().split("\t"));
            for (int i = 0; i < lines.size() - 1; i++) {
                data.add(Double.parseDouble(lines.get(i)));
            }
            labels.add(Integer.parseInt(
                  lines.get(lines.size() - 1)));

            datas.add(data);
            data = new ArrayList<Double>();

        }

        //转换label
        int[] label = new int[labels.size()];
        for (int i = 0; i < label.length; i++) {
            label[i] = labels.get(i);
        }

        //转换属性
        int rows = datas.size();
        int cols = datas.get(0).size();
        double[][] srcData = new double[rows][cols];
        for (int i = 0; i < rows; i++) {
            for (int j = 0; j < cols; j++) {
                srcData[i][j] = datas.get(i).get(j);
            }
        }

        SVM<double[]> svm = new SVM<double[]>(
           new GaussianKernel(1.0), 1.0, 3, 
           SVM.Multiclass.ONE_VS_ALL);

        svm.learn(srcData, label);
        svm.finish();

        double right = 0;
        for (int i = 0; i < srcData.length; i++) {
            int tag = svm.predict(srcData[i]);
            if (tag == label[i]) {
                right += 1;
            }
        }
        right = right / srcData.length;

        System.out.println(
        "Accrurate: " + right * 100 + "%");
    }
}

原文发布于微信公众号 - 机器学习算法全栈工程师(Jeemy110)

原文发表时间:2017-08-11

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏李智的专栏

Deep learning基于theano的keras学习笔记(1)-Sequential模型

《统计学习方法》中指出,机器学习的三个要素是模型,策略和优算法,这当然也适用于深度学习,而我个人觉得keras训练也是基于这三个要素的,先建立深度模型,然后选用...

781
来自专栏人工智能LeadAI

使用TensorFlow实现手写识别(Softmax)

准备工作 由于将TensorFlow安装到了Conda的tensorflow环境,虽然可以用Jupyter notebook打开,但是没有提示,写代码不方便,所...

3555
来自专栏人工智能LeadAI

机器学习实战 | 第一章:sklearn常用工具介绍

写在前面: 花了大力气学了很多的理论,也用Python实现了其中大部分的算法.接下来开始就进入实战阶段了. 实战阶段有三个重点: 1.选择合适的机器学习框...

26710
来自专栏ascii0x03的安全笔记

使用sklearn构建含有标量属性的决策树

网络上使用sklearn生成决策树的资料很多,这里主要说明遇见标量数据的处理。 经查验参考资料,sklearn并非使用了课上以及书上讲的ID3算法,而是选择了C...

3206
来自专栏潇涧技术专栏

Python Algorithms - C7 Greedy

Python算法设计篇(7) Chapter 7: Greed is good? Prove it!

542
来自专栏生信宝典

R语言学习 - 箱线图(小提琴图、抖动图、区域散点图)

箱线图 箱线图是能同时反映数据统计量和整体分布,又很漂亮的展示图。在2014年的Nature Method上有2篇Correspondence论述了使用箱线图的...

35410
来自专栏深度学习那些事儿

在pytorch中实现与TensorFlow类似的"same"方式padding

文章来自Oldpan博客:https://oldpan.me/archives/pytorch-same-padding-tflike

8017
来自专栏王小雷

基于Python3 神经网络的实现

基于Python3 神经网络的实现(下载源码) 本次学习是Denny Britz(作者)的Python2神经网络项目修改为基于Python3实现的神经网络(本篇...

21110
来自专栏漫漫深度学习路

MXNET学习笔记(一):Module类(1)

Module 是 mxnet 提供给用户的一个高级封装的类。有了它,我们可以很容易的来训练模型。 Module 包含以下单元的一个 wraper symbol ...

2035
来自专栏人工智能LeadAI

Tensorflow高级API的进阶--利用tf.contrib.learn建立输入函数

在实际的业务中,可能会遇到很大量的特征,这些特征良莠不齐,层次不一,可能有缺失,可能有噪声,可能规模不一致,可能类型不一样,等等问题都需要我们在建模之前,先预处...

33310

扫描关注云+社区