【算法】xgboost算法

小编邀请您,先思考:

1 XGBoost和GDBT算法有什么差异?

XGBoost的全称是 eXtremeGradient Boosting,2014年2月诞生的专注于梯度提升算法的机器学习函数库,作者为华盛顿大学研究机器学习的大牛——陈天奇。他在研究中深深的体会到现有库的计算速度和精度问题,为此而着手搭建完成 xgboost 项目。xgboost问世后,因其优良的学习效果以及高效的训练速度而获得广泛的关注,并在各种算法大赛上大放光彩。

1.CART

CART(回归树, regressiontree)是xgboost最基本的组成部分。其根据训练特征及训练数据构建分类树,判定每条数据的预测结果。其中构建树使用gini指数计算增益,即进行构建树的特征选取,gini指数公式如式(1), gini指数计算增益公式如式(2):

P_{k}表示数据集中类别的概率,表示类别个数。

注:此处图的表示分类类别。

D表示整个数据集,D_{1}D_{2} 分别表示数据集中特征为的数据集和特征非的数据集,Gini(D_{1}) 表示特征为的数据集的gini指数。

以是否打网球为例(只是举个栗子):

其中,

最小,所以构造树首先使用温度适中。然后分别在左右子树中查找构造树的下一个条件。

本例中,使用温度适中拆分后,是子树刚好类别全为是,即温度适中时去打网球的概率为1。

2.Boostingtree

一个CART往往过于简单,并不能有效地做出预测,为此,采用更进一步的模型boosting tree,利用多棵树来进行组合预测。具体算法如下:

输入:训练集

输出:提升树f_{M}(x)

步骤:

(1)初始化f_{0}(x)=0

(2) 对m=1,2,3……M

a)计算残差

b)拟合残差r_{mi} 学习一个回归树,得到T(x:\theta _{m})

c)更新

(3)得到回归提升树:

例子详见后面代码部分。

3.xgboost

首先,定义一个目标函数:

constant为一个常数,正则项\Omega(f_{t}) 如下,

其中,T表示叶子节点数,W_{j} 表示第j个叶子节点的权重。

例如下图,叶子节点数为3,每个叶子节点的权重分别为2,0.1,-1,正则项计算见图:

利用泰勒展开式

,对式(3)进行展开:

其中,g_{i} 表示L(y_{i},\widehat{y}^{t-1})\widehat{y}^{t-1} 的一阶导数,h_{i} 表示L(y_{i},\widehat{y}^{t-1})\widehat{y}^{t-1} 的二阶导数。L(y_{i},\widehat{y}^{t-1}) 为真实值与前一个函数计算所得残差是已知的(我们都是在已知前一个树的情况下计算下一颗树的),同时,在同一个叶子节点上的数的函数值是相同的,可以做合并,于是:

通过对求导等于0,可以得到

W_{j} 带入得目标函数的简化公式如下:

目标函数简化后,可以看到xgboost的目标函数是可以自定义的,计算时只是用到了它的一阶导和二阶导。得到简化公式后,下一步针对选择的特征计算其所带来的增益,从而选取合适的分裂特征。

提升树例子代码:

# !/usr/bin/env python # -*- coding: utf-8 -*- # 目标函数为真实值与预测值的差的平方和 import math # 数据集,只包含两列 test_list = [[1,5.56], [2,5.7], [3,5.81], [4,6.4], [5,6.8],\ [6,7.05], [7,7.9], [8,8.7], [9,9],[10,9.05]] step = 1 #eta # 起始拆分点 init = 1.5 # 最大拆分次数 max_times = 10 # 允许的最大误差 threshold = 1.0e-3 def train_loss(t_list): sum = 0 for fea in t_list: sum += fea[1] avg = sum * 1.0 /len(t_list) sum_pow = 0 for fea in t_list: sum_pow =math.pow((fea[1]-avg), 2) return sum_pow, avg def boosting(data_list): ret_dict = {} split_num = init while split_num <data_list[-1][0]: pos = 0 for idx, data inenumerate(data_list): if data[0]> split_num: pos = idx break if pos > 0: l_train_loss,l_avg = train_loss(data_list[:pos]) r_train_loss,r_avg = train_loss(data_list[pos:]) ret_dict[split_num] = [pos,l_train_loss+r_train_loss, l_avg, r_avg] split_num += step return ret_dict def main(): ret_list = [] data_list =sorted(test_list, key=lambda x:x[0]) time_num = 0 while True: time_num += 1 print 'beforesplit:',data_list ret_dict =boosting(data_list) t_list =sorted(ret_dict.items(), key=lambda x:x[1][1]) print 'splitnode:',t_list[0] ret_list.append([t_list[0][0], t_list[0][1][1]]) if ret_list[-1][1]< threshold or time_num > max_times: break for idx, data inenumerate(data_list): if idx <t_list[0][1][0]: data[1] -=t_list[0][1][2] else: data[1] -=t_list[0][1][3] print 'after split:',data_list print 'split node andloss:' print'\n'.join(["%s\t%s" %(str(data[0]), str(data[1])) for data inret_list]) if __name__ == '__main__': main()

原文发布于微信公众号 - 数据科学与人工智能(DS_AI_shujuren)

原文发表时间:2018-03-10

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏人工智能

机器学习算法之旅

在这篇文章中, 我们将介绍最流行的机器学习算法.

4485
来自专栏大数据文摘

数据科学家需要了解的45个回归问题测试题(附答案)

2272
来自专栏CreateAMind

Geoffrey Hinton的“胶囊理论” 多语言实现代码、效果、论文解读

https://github.com/XifengGuo/CapsNet-Keras

1465
来自专栏AI科技评论

学界 | 超越何恺明等组归一化 Group Normalization,港中文团队提出自适配归一化取得突破

AI 科技评论:港中文最新论文研究表明目前的深度神经网络即使在人工标注的标准数据库中训练(例如 ImageNet),性能也会出现剧烈波动。这种情况在使用少批量数...

1221
来自专栏机器之心

CVPR 2018 | UNC&amp;Adobe提出模块化注意力模型MAttNet,解决指示表达的理解问题

3199
来自专栏计算机视觉战队

ECCV-2018最佼佼者的目标检测算法

转眼间,离上次9月3日已有9天的时间,好久没有将最新最好的“干货”分享给大家,让大家一起在学习群里讨论最新技术,那今天我给大家带来ECCV-2018年最优pap...

1.5K3
来自专栏机器之心

深度 | 一文介绍3篇无需Proposal的实例分割论文

选自Medium 作者:Bar Vinograd 机器之心编译 参与:Nurhachu Null、黄小天 本文解析了实例分割领域中的三篇论文,它们不同于主流的基...

3825
来自专栏大数据挖掘DT机器学习

判别模型、生成模型与朴素贝叶斯方法

1、判别模型与生成模型 回归模型其实是判别模型,也就是根据特征值来求结果的概率。形式化表示为 ? ,在参数 ? 确定的情况下,求解条件概率 ? 。通俗的解...

3656
来自专栏大数据挖掘DT机器学习

情感分析的新方法,使用word2vec对微博文本进行情感分析和分类

情感分析是一种常见的自然语言处理(NLP)方法的应用,特别是在以提取文本的情感内容为目标的分类方法中。通过这种方式,情感分析可以被视为利用一些情感得分指标来...

1K10
来自专栏机器学习和数学

[高大上的DL] Deep Learning中常用loss function损失函数的小结

在前面我们分享的如何来训练CNN中,提到了BP算法,还记得BP算法是怎么更新参数w,b的吗?当我们给网络一个输入,乘以w的初值,然后经过激活函数得到一个输出。然...

4.2K8

扫码关注云+社区