机器学习实战:Adaboost元算法(上)

大家好!我是MPIG的李超杰,今天我给大家带来的是Adaboost元算法的理论讲解以及相关代码的讲解。当做重要决定时,大家可能都会考虑多个专家而不只是一个人的意见。机器学习处理问题也是如此。这就是元算法背后的思路。元算法是对其他算法进行组合的一种方式。这一讲我们将集中关注一个称作AdaBoost的最流行的元算法,该方法是机器学习工具中最有力的工具之一。

首先我们介绍了基学习器以及强可学习和弱可学习的概念:

接着我们介绍了两种主要的集成方法:Bagging和Boosting,两者主要的区别在于:Bagging对训练数据采用自举采样(boostrap sampling),即有放回地采样数据。Boosting的思路则是采用重赋权(re-weighting)法迭代地训练基分类器。其他区别如下:

接着我们讲解Adaboost算法的运行过程:

Adaboost算法示意图如下:

左边是数据集,其中直方图的不同宽度表示每个样例上的不同权重。在经过一个分类器之后,加权的预测结果会通过三角形中的alpha值进行加权。每个三角形中输出的加权结果在圆形中求和,从而得到最终的输出结果。

单层决策树是一种简单的决策树,也称为决策树桩。单层决策树可以看做是由一个根节点直接连接两个叶结点的简单决策树,比如x>v或x

接下来,我们就要通过上述数据集来寻找最佳的单层决策树,最佳单层决策树是具有最低分类错误率的单层决策树。单层决策树的生成函数代码:

上面的代码包含两个函数,第一个函数是分类器的阈值过滤函数,即设定某一阈值,凡是超过该阈值的结果被归为一类,小于阈值的结果都被分为另外一类,这里的两类依然同SVM一样,采用+1和-1作为类别。第二个函数,就是建立单层决策树的具体代码,基于样本值的各个特征及特征值的大小,设定合适的步长,获得不同的阈值,然后以此阈值作为根结点,对数据集样本进行分类,并计算错误率,需要指出的是,这里的错误率计算是基于样本权重的,所有分错的样本乘以其对应的权重,然后进行累加得到分类器的错误率。错误率得到之后,根据错误率的大小,跟当前存储的最小错误率的分类器进行比较,选择出错误率最小的特征训练出来的分类器,作为最佳单层决策树输出,并通过字典类型保存其相关重要的信息。

上面已经构建好了基于加权输入值进行决策的单层分类器,那么就有了实现一个完整AdaBoost算法所需要的所有信息了。完整AdaBoost实现代码如下:

对于上面的代码,需要说明以下几点:

(1)上面的输入除了数据集和标签之外,还有用户自己指定的迭代次数,用户可以根据自己的成本需要和实际情况,设定合适的迭代次数,构建出需要的弱分类器数量。

(2)权重向量D包含了当前单层决策树分类器下,各个数据集样本的权重,一开始它们的值都相等。但是,经过分类器分类之后,会根据分类的权重加权错误率对这些权重进行修改,修改的方向为:提高分类错误样本的权重,减少分类正确的样本的权重。

(3)分类器系数alpha,是另外一个非常重要的参数,它在最终的分类器组合决策分类结果的过程中,起到了非常重要的作用,如果某个弱分类器的分类错误率更低,那么根据错误率计算出来的分类器系数将更高,这样,这些分类错误率更低的分类器在最终的分类决策中,会起到更加重要的作用。

(4)上述代码的训练过程是以达到迭代的用户指定的迭代次数或者训练错误率达到要求而跳出循环。而最终的分类器决策结果,会通过sign函数,将结果指定为+1或者-1。

有了训练好的分类器,是不是要测试一下呢,毕竟训练错误率针对的是已知的数据,我们需要在分类器未知的数据上进行测试,看看分类效果。上面的训练代码会帮我们保存每个弱分类器的重要信息,比如分类器系数,分类器的最优特征,特征阈值等。有了这些重要的信息,我们拿到之后,就可以对测试数据进行预测分类了。测试代码如下:

代码很简单,在之前代码的基础上,添加adaClassify()函数,该函数遍历所有训练得到的弱分类器,利用单层决策树,输出的类别估计值乘以该单层决策树的分类器权重alpha,然后累加到aggClassEst上,最后通过sign函数最终的结果。可以看到,分类没有问题,(5,5)属于正类,(0,0)属于负类。

本讲主要讲解了Adaboost元算法的基本概念以及相关函数代码的构建,下一讲我们会用本讲的内容重新测试前几讲中Logistic回归中预测马疝病是否死亡的案例,大家敬请期待下一讲精彩内容。

想要更加详细了解本讲更多细节的内容吗?那就一起来观看下面的Presentation的具体讲解吧:

想获取本presentation的对应文稿和代码,可以点击如下链接下载:

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20180710G0GVBA00?refer=cp_1026
  • 腾讯「云+社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 yunjia_community@tencent.com 删除。

扫码关注云+社区

领取腾讯云代金券