技术干货 XGBoost原理解析

作者简介

刘英涛:达观数据推荐算法工程师,负责达观数据个性化推荐系统的研发与优化。

XGBoost的全称是 eXtremeGradient Boosting,2014年2月诞生的专注于梯度提升算法的机器学习函数库,作者为华盛顿大学研究机器学习的大牛——陈天奇。他在研究中深深的体会到现有库的计算速度和精度问题,为此而着手搭建完成 xgboost 项目。xgboost问世后,因其优良的学习效果以及高效的训练速度而获得广泛的关注,并在各种算法大赛上大放光彩。

1.CART

CART(回归树, regressiontree)是xgboost最基本的组成部分。其根据训练特征及训练数据构建分类树,判定每条数据的预测结果。其中构建树使用gini指数计算增益,即进行构建树的特征选取,gini指数公式如式(1), gini指数计算增益公式如式(2):

表示数据集中类别的概率,表示类别个数。

注:此处图的表示分类类别。

D表示整个数据集,和分别表示数据集中特征为的数据集和特征非的数据集,表示特征为的数据集的gini指数。

以是否打网球为例(只是举个栗子):

其中,最小,所以构造树首先使用温度适中。然后分别在左右子树中查找构造树的下一个条件。

本例中,使用温度适中拆分后,是子树刚好类别全为是,即温度适中时去打网球的概率为1。

2.Boostingtree

一个CART往往过于简单,并不能有效地做出预测,为此,采用更进一步的模型boosting tree,利用多棵树来进行组合预测。具体算法如下:

输入:训练集

输出:提升树

步骤:

(1)初始化

(2) 对m=1,2,3……M

a)计算残差

b)拟合残差学习一个回归树,得到

c)更新

(3)得到回归提升树:

例子详见后面代码部分。

3.xgboost

首先,定义一个目标函数:

constant为一个常数,正则项如下,

其中,T表示叶子节点数,表示第j个叶子节点的权重。

例如下图,叶子节点数为3,每个叶子节点的权重分别为2,0.1,-1,正则项计算见图:

利用泰勒展开式,对式(3)进行展开:

其中,表示对的一阶导数,表示对的二阶导数。

为真实值与前一个函数计算所得残差是已知的(我们都是在已知前一个树的情况下计算下一颗树的),同时,在同一个叶子节点上的数的函数值是相同的,可以做合并,于是:

通过对求导等于0,可以得到

将带入得目标函数的简化公式如下:

目标函数简化后,可以看到xgboost的目标函数是可以自定义的,计算时只是用到了它的一阶导和二阶导。得到简化公式后,下一步针对选择的特征计算其所带来的增益,从而选取合适的分裂特征。

提升树例子代码:

# !/usr/bin/env python

# -*- coding: utf-8 -*-

# 目标函数为真实值与预测值的差的平方和

import math

# 数据集,只包含两列

test_list = [[1,5.56], [2,5.7], [3,5.81], [4,6.4], [5,6.8],\

[6,7.05], [7,7.9], [8,8.7], [9,9],[10,9.05]]

step = 1 #eta

# 起始拆分点

init = 1.5

# 最大拆分次数

max_times = 10

# 允许的最大误差

threshold = 1.0e-3

def train_loss(t_list):

sum = 0

for fea in t_list:

sum += fea[1]

avg = sum * 1.0 /len(t_list)

sum_pow = 0

for fea in t_list:

sum_pow =math.pow((fea[1]-avg), 2)

return sum_pow, avg

def boosting(data_list):

ret_dict = {}

split_num = init

while split_num

pos = 0

for idx, data inenumerate(data_list):

if data[0]> split_num:

pos = idx

break

if pos > 0:

l_train_loss,l_avg = train_loss(data_list[:pos])

r_train_loss,r_avg = train_loss(data_list[pos:])

ret_dict[split_num] = [pos,l_train_loss+r_train_loss, l_avg, r_avg]

split_num += step

return ret_dict

def main():

ret_list = []

data_list =sorted(test_list, key=lambda x:x[0])

time_num = 0

while True:

time_num += 1

print 'beforesplit:',data_list

ret_dict =boosting(data_list)

t_list =sorted(ret_dict.items(), key=lambda x:x[1][1])

print 'splitnode:',t_list[0]

ret_list.append([t_list[0][0], t_list[0][1][1]])

if ret_list[-1][1]< threshold or time_num > max_times:

break

for idx, data inenumerate(data_list):

if idx

data[1] -=t_list[0][1][2]

else:

data[1] -=t_list[0][1][3]

print 'after split:',data_list

print 'split node andloss:'

print'\n'.join(["%s\t%s" %(str(data[0]), str(data[1])) for data inret_list])

if __name__ == '__main__':

main()

本文来自企鹅号 - 达观数据媒体

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏数说工作室

【分类战车SVM】第六话:SMO算法(像smoke一样简单!)

分类战车SVM (第六话:SMO算法) 查看本《分类战车SVM》系列的内容: 第一话:开题话 第二话:线性分类 第三话:最大间隔分类器 第四话:拉格朗日对偶问题...

46412
来自专栏AI研习社

不到 200 行代码 教你如何用 Keras 搭建生成对抗网络(GAN)

生成对抗网络(Generative Adversarial Networks,GAN)最早由 Ian Goodfellow 在 2014 年提出,是目前深度学习...

30910
来自专栏大数据挖掘DT机器学习

文本情感分析:特征提取(TFIDF指标)&随机森林模型实现

作者:Matt 自然语言处理实习生 http://blog.csdn.net/sinat__26917383/article/details/513024...

6084
来自专栏AI研习社

看完立刻理解 GAN!初学者也没关系

前言 GAN 从 2014 年诞生以来发展的是相当火热,比较著名的 GAN 的应用有 Pix2Pix、CycleGAN 等。本篇文章主要是让初学者通过代码了...

3135
来自专栏AI科技评论

开发 | 看完立刻理解GAN!初学者也没关系

AI 科技评论按:本文原作者天雨粟,原文载于作者的知乎专栏——机器不学习,经授权发布。 前言 GAN 从 2014 年诞生以来发展的是相当火热,比较著名的 GA...

35413
来自专栏一名叫大蕉的程序员

机器学习虾扯淡之Logistic回归No.44

0x00 前言 大家好我是小蕉。上一次我们说完了线性回归。不知道小伙伴有没有什么意见建议,是不是发现每个字都看得懂,但是全篇都不知道在说啥?哈哈哈哈哈哈,那就...

1665
来自专栏语言、知识与人工智能

基于深度学习的FAQ问答系统

| 导语 问答系统是信息检索的一种高级形式,能够更加准确地理解用户用自然语言提出的问题,并通过检索语料库、知识图谱或问答知识库返回简洁、准确的匹配答案。相较于...

9K10
来自专栏达观数据

技术干货 | XGBoost原理解析

作者简介 刘英涛:达观数据推荐算法工程师,负责达观数据个性化推荐系统的研发与优化。 XGBoost的全称是 eXtremeGradient Boosting,...

31410
来自专栏数据科学与人工智能

【智能】自然语言处理概述

1 什么是文本挖掘? 文本挖掘是信息挖掘的一个研究分支,用于基于文本信息的知识发现。文本挖掘的准备工作由文本收集、文本分析和特征修剪三个步骤组成。目前研究和应用...

3595
来自专栏刘明的小酒馆

文本相似度算法小结

首先是最简单粗暴的算法。为了对比两个东西的相似度,我们很容易就想到可以看他们之间有多少相似的内容,又有多少不同的内容,再进一步可以想到集合的交并集概念。

64510

扫码关注云+社区