Java 机器学习库Smile实战(二)AdaBoost

1. AdaBoost算法简介

Boost 算法系列的起源来自于PAC Learnability(PAC 可学习性)。这套理论主要研究的是什么时候一个问题是可被学习的,当然也会探讨针对可学习的问题的具体的学习算法。这套理论是由Valiant提出来的,也因此(还有其他贡献哈)他获得了2010年的图灵奖。下面是他的照片:

PAC 定义了学习算法的强弱:

弱学习算法 — 识别错误率小于1/2(即准确率仅比随机猜测略高的学习算法);

强学习算法 — 识别准确率很高并能在多项式时间内完成的学习算法。

同时 ,Valiant和 Kearns首次提出了 PAC学习模型中弱学习算法和强学习算法的等价性问题,即任意给定仅比随机猜测略好的弱学习算法 ,是否可以将其提升为强学习算法 ? 如果二者等价 ,那么只需找到一个比随机猜测略好的弱学习算法就可以将其提升为强学习算法 ,而不必寻找很难获得的强学习算法。 也就是这种猜测,让无数牛人去设计算法来验证PAC理论的正确性。

不过很长一段时间都没有一个切实可行的办法来实现这个理想。细节决定成败,再好的理论也需要有效的算法来执行。终于功夫不负有心人, Schapire在1996年提出一个有效的算法真正实现了这个夙愿,它的名字叫AdaBoost。AdaBoost把多个不同的决策树用一种非随机的方式组合起来,表现出惊人的性能!第一,把决策树的准确率大大提高,可以与SVM媲美。第二,速度快,且基本不用调参数。第三,几乎不Overfitting。我估计当时Breiman和Friedman肯定高兴坏了,因为眼看着他们提出的CART正在被SVM比下去的时候,AdaBoost让决策树起死回生!Breiman情不自禁地在他的论文里赞扬AdaBoost是最好的现货方法(off-the-shelf,即“拿下了就可以用”的意思)。

下面是Adaboost 算法的实现思想:

接下来是Smile库中的AdaBoost算法的用法实战程序:

import smile.classification.AdaBoost;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * Created by zhanghuayan on 2017/1/17.
 */
public class AdaBoostTest {

  public static void main(String[] args) throws Exception {

        List<List<Double>> datas = 
                          new ArrayList<List<Double>>();
        List<Double> data = new ArrayList<Double>();
        List<Integer> labels = new ArrayList<Integer>();

        String line;
        List<String> lines;
        File file = new File("iris.txt");
        BufferedReader reader = 
                        new BufferedReader(new FileReader(file));
        while ((line = reader.readLine()) != null) {
            lines = Arrays.asList(line.trim().split("\t"));
            for (int i = 0; i < lines.size() - 1; i++) {
                data.add(Double.parseDouble(lines.get(i)));
            }
            labels.add(Integer.parseInt(lines.get(lines.size() - 1)));

            datas.add(data);
            data = new ArrayList<Double>();

        }

        //转换label
        int[] label = new int[labels.size()];
        for (int i = 0; i < label.length; i++) {
            label[i] = labels.get(i);
        }

        //转换属性
        int rows = datas.size();
        int cols = datas.get(0).size();
        double[][] srcData = new double[rows][cols];
        for (int i = 0; i < rows; i++) {
            for (int j = 0; j < cols; j++) {
                srcData[i][j] = datas.get(i).get(j);
            }
        }


        AdaBoost adaBoost = 
                     new AdaBoost(srcData, label, 4, 8);
        double right = 0;
        for (int i = 0; i < srcData.length; i++) {
            int tag = adaBoost.predict(srcData[i]);
            if (i % 10 == 0) System.out.println();
            System.out.print(tag + " ");
            if (tag == label[i]) {
                right += 1;
            }
        }
        right = right / srcData.length;
        System.out.println("Accrurate: " + right * 100 + "%");
    }
} 

数据集与Java机器学习库Smile实战(一)SVM中的iris数据集格式一致。下面是程序输出的结果:

0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0
1 1 1 1 1 1 1 1 1 1 
1 1 1 1 1 1 1 1 1 1 
1 1 1 1 1 1 1 1 1 1 
1 1 1 1 1 1 1 1 1 1 
1 1 1 1 1 1 1 1 1 1
2 2 2 2 2 2 2 2 2 2 
2 2 2 2 2 2 2 2 2 2 
2 2 2 2 2 2 2 2 2 2 
2 2 2 2 2 2 2 2 2 2 
2 2 2 2 2 2 2 2 2 2

Accrurate: 100.0%

当然这里的训练数据和测试数据使用的是同一数据,所以可能出现过拟合现象。因此现实训练的时候应采取例如交叉验证等技术手段调整参数。

2. 参数解释

第3个参数是子树的个数,第4个参数是每个子树最大叶子节点数:

AdaBoost adaBoost = new AdaBoost(srcData, label, 4, 8);

训练好了之后就可以如下调用predict()方法:

int tag = adaBoost.predict(srcData[i]);

原文发布于微信公众号 - 机器学习算法全栈工程师(Jeemy110)

原文发表时间:2017-08-17

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏机器学习算法与Python学习

n-gram文法中数据稀疏问题解决方案之一:Good-Turing平滑

关键字全网搜索最新排名 【机器学习算法】:排名第一 【机器学习】:排名第二 【Python】:排名第三 【算法】:排名第四 统计语言模型中,N元语法模型不可避免...

2974
来自专栏Python小屋

计算Fibonacci数列第n项的第8种方法(数学推导与Python实现)

感谢山东工商学院学院厉玉蓉老师提供的完美数学推导,我在重写和整理时略加修改,比如变量替换时她喜欢用字母z,而我喜欢用x,哈哈。当然,还有另外几个小地方^_^ 本...

2765
来自专栏深度学习自然语言处理

【机器学习】决策树的理论与实践

zenRRan二十出头了,到了婚配的年龄啦。又因为家是名门望族,所以一堆人抢着想来应聘配偶的职位。但是zenRRan比较挑剔,必须达到他的要求才能有机会成为他的...

701
来自专栏奔跑的蛙牛技术博客

递归简论递归的重要法则

612
来自专栏华章科技

与数据挖掘有关或有帮助的R包和函数的集合

rpart,party,randomForest,rpartOrdinal,tree,marginTree,

683
来自专栏小小挖掘机

算法面试太难?反手就是一波面经

八月参加了一些提前批的面试,包括阿里、百度、头条、贝壳、一点资讯等。整理了一些面试题,分享给大家。

953
来自专栏数说工作室

【分类战车SVM】第六话:SMO算法(像smoke一样简单!)

分类战车SVM (第六话:SMO算法) 查看本《分类战车SVM》系列的内容: 第一话:开题话 第二话:线性分类 第三话:最大间隔分类器 第四话:拉格朗日对偶问题...

46012
来自专栏云时之间

聚类分析的简单理解(1)

各位小伙伴们大家好,这几天我在学习聚类分析这个统计方法,所以希望通过这个文章来概括下自己所学的知识,并且希望大家可以指出不足 1:什么是聚类分析? 聚类分析(...

3436
来自专栏C/C++基础

算法设计的方法

作用:要使计算机能完成人们预定的工作,首先必须为如何完成预定的工作设计一个算法,然后再根据算法编写程序。

733
来自专栏企鹅号快讯

《教育统计与SPSS应用》学习笔记(8)

第8讲 回归分析 主要内容 回归分析简介 一元线性回归分析 多元线性回归分析 第一部分 回归分析简介 一、回归分析的意义 表示变量之间的不确定性关系以...

2228

扫码关注云+社区