Java 机器学习库Smile实战(二)AdaBoost

1. AdaBoost算法简介

Boost 算法系列的起源来自于PAC Learnability(PAC 可学习性)。这套理论主要研究的是什么时候一个问题是可被学习的,当然也会探讨针对可学习的问题的具体的学习算法。这套理论是由Valiant提出来的,也因此(还有其他贡献哈)他获得了2010年的图灵奖。下面是他的照片:

PAC 定义了学习算法的强弱:

弱学习算法 — 识别错误率小于1/2(即准确率仅比随机猜测略高的学习算法);

强学习算法 — 识别准确率很高并能在多项式时间内完成的学习算法。

同时 ,Valiant和 Kearns首次提出了 PAC学习模型中弱学习算法和强学习算法的等价性问题,即任意给定仅比随机猜测略好的弱学习算法 ,是否可以将其提升为强学习算法 ? 如果二者等价 ,那么只需找到一个比随机猜测略好的弱学习算法就可以将其提升为强学习算法 ,而不必寻找很难获得的强学习算法。 也就是这种猜测,让无数牛人去设计算法来验证PAC理论的正确性。

不过很长一段时间都没有一个切实可行的办法来实现这个理想。细节决定成败,再好的理论也需要有效的算法来执行。终于功夫不负有心人, Schapire在1996年提出一个有效的算法真正实现了这个夙愿,它的名字叫AdaBoost。AdaBoost把多个不同的决策树用一种非随机的方式组合起来,表现出惊人的性能!第一,把决策树的准确率大大提高,可以与SVM媲美。第二,速度快,且基本不用调参数。第三,几乎不Overfitting。我估计当时Breiman和Friedman肯定高兴坏了,因为眼看着他们提出的CART正在被SVM比下去的时候,AdaBoost让决策树起死回生!Breiman情不自禁地在他的论文里赞扬AdaBoost是最好的现货方法(off-the-shelf,即“拿下了就可以用”的意思)。

下面是Adaboost 算法的实现思想:

接下来是Smile库中的AdaBoost算法的用法实战程序:

import smile.classification.AdaBoost;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * Created by zhanghuayan on 2017/1/17.
 */
public class AdaBoostTest {

  public static void main(String[] args) throws Exception {

        List<List<Double>> datas = 
                          new ArrayList<List<Double>>();
        List<Double> data = new ArrayList<Double>();
        List<Integer> labels = new ArrayList<Integer>();

        String line;
        List<String> lines;
        File file = new File("iris.txt");
        BufferedReader reader = 
                        new BufferedReader(new FileReader(file));
        while ((line = reader.readLine()) != null) {
            lines = Arrays.asList(line.trim().split("\t"));
            for (int i = 0; i < lines.size() - 1; i++) {
                data.add(Double.parseDouble(lines.get(i)));
            }
            labels.add(Integer.parseInt(lines.get(lines.size() - 1)));

            datas.add(data);
            data = new ArrayList<Double>();

        }

        //转换label
        int[] label = new int[labels.size()];
        for (int i = 0; i < label.length; i++) {
            label[i] = labels.get(i);
        }

        //转换属性
        int rows = datas.size();
        int cols = datas.get(0).size();
        double[][] srcData = new double[rows][cols];
        for (int i = 0; i < rows; i++) {
            for (int j = 0; j < cols; j++) {
                srcData[i][j] = datas.get(i).get(j);
            }
        }


        AdaBoost adaBoost = 
                     new AdaBoost(srcData, label, 4, 8);
        double right = 0;
        for (int i = 0; i < srcData.length; i++) {
            int tag = adaBoost.predict(srcData[i]);
            if (i % 10 == 0) System.out.println();
            System.out.print(tag + " ");
            if (tag == label[i]) {
                right += 1;
            }
        }
        right = right / srcData.length;
        System.out.println("Accrurate: " + right * 100 + "%");
    }
} 

数据集与Java机器学习库Smile实战(一)SVM中的iris数据集格式一致。下面是程序输出的结果:

0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0
1 1 1 1 1 1 1 1 1 1 
1 1 1 1 1 1 1 1 1 1 
1 1 1 1 1 1 1 1 1 1 
1 1 1 1 1 1 1 1 1 1 
1 1 1 1 1 1 1 1 1 1
2 2 2 2 2 2 2 2 2 2 
2 2 2 2 2 2 2 2 2 2 
2 2 2 2 2 2 2 2 2 2 
2 2 2 2 2 2 2 2 2 2 
2 2 2 2 2 2 2 2 2 2

Accrurate: 100.0%

当然这里的训练数据和测试数据使用的是同一数据,所以可能出现过拟合现象。因此现实训练的时候应采取例如交叉验证等技术手段调整参数。

2. 参数解释

第3个参数是子树的个数,第4个参数是每个子树最大叶子节点数:

AdaBoost adaBoost = new AdaBoost(srcData, label, 4, 8);

训练好了之后就可以如下调用predict()方法:

int tag = adaBoost.predict(srcData[i]);

原文发布于微信公众号 - 机器学习算法全栈工程师(Jeemy110)

原文发表时间:2017-08-17

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏编程之旅

算法时间复杂度

很多程序员,做了很长时间的编程工作却始终都弄不明白算法的时间复杂度的估算,这是很可悲的一件事情。因为弄不清楚,所以也就从不深究自己写的代码是否效率底下,是不是可...

631
来自专栏老九学堂

为什么每个程序员都需要学习算法?

懂算法的程序员 ? 不懂算法的程序员 ? 算法的力量 算法是计算机科学领域最重要的基石之一,但却受到了一些程序员的冷落。 许多小伙伴看到一些公司在招聘时要求的...

3498
来自专栏华章科技

与数据挖掘有关或有帮助的R包和函数的集合

rpart,party,randomForest,rpartOrdinal,tree,marginTree,

623
来自专栏新智元

【ICLR 2016最佳论文】DeepMind 开发 NPI,有望取代初级程序员(附下载)

【新智元导读】特征学习和深度学习重要会议 ICLR 2016 最佳论文,DeepMind 团队开发了一个“神经编程解释器”(NPI),能自己学习并且编辑简单的程...

2996
来自专栏漫漫全栈路

数据库E-R模型关系图

早在专科阶段学习SqlServers时就学习过数据库E-R图,但是并没有真正的去了解这个东西,只是知道了大致的概念而已,借这次Oracle课程设计的机会,重新...

4064
来自专栏斑斓

Spark 1.4为DataFrame新增的统计与数学函数

Spark一直都在快速地更新中,性能越来越快,功能越来越强大。我们既可以参与其中,也可以乐享其成。 目前,Spark 1.4版本在社区已经进入投票阶段,在Gi...

3447
来自专栏PPV课数据科学社区

【学习】R语言与机器学习(分类算法)logistic回归

由于我们在前面已经讨论过了神经网络的分类问题,如今再从最优化的角度来讨论logistic回归就显得有些不合适了。Logistic回归问题的最优化问题可以表述为:...

2634
来自专栏WOLFRAM

Mathematica 11 在偏微分方程中的应用

1793
来自专栏我的python

char-rnn 语言模型

项目GitHub地址:https://github.com/ClownW/Char-RNN-Pytorch

1696
来自专栏大数据挖掘DT机器学习

基于Xgboost + LR + Keras 建模评估用户信用状态

项目背景 拍拍贷“魔镜风控系统”基于400多个数据维度来对当前用户的信用状态进行评估,通过历史数据每个借款人的性别、年龄、籍贯、学历信息、通讯方式、网站登录...

3094

扫描关注云+社区