前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >XGBoost 源码阅读笔记(2):树构造之 Exact Greedy Algorithm

XGBoost 源码阅读笔记(2):树构造之 Exact Greedy Algorithm

原创
作者头像
程飞翔
修改2017-08-14 09:33:26
2.6K0
修改2017-08-14 09:33:26
举报
文章被收录于专栏:程飞翔的专栏程飞翔的专栏

在上一篇《XGBoost 源码阅读笔记(1)--代码逻辑结构》中向大家介绍了 XGBoost 源码的逻辑结构,同时也简单介绍了 XGBoost 的基本情况。本篇将继续向大家介绍 XGBoost 源码是如何构造一颗回归树,不过在分析源码之前,还是有必要先和大家一起推导下 XGBoost 的目标函数。本次推导过程公式截图主要摘抄于陈天奇的论文《XGBoost:A Scalable Tree Boosting System》。在后续的源码分析中,会省略一些与本篇无关的代码,如并行化,多线程。

一、目标函数优化

XGBoost 和以往的 GBT(Gradient Boost Tree) 不同之一在于其将目标函数进行了二阶泰勒展开,在模型训练过程中使用了二阶导数加快其模型收敛速度。与此同时,为了防止模型过拟合其给目标函数加上了控制模型结构的惩罚项。

[1502415419671_5528_1502415419296.png]
[1502415419671_5528_1502415419296.png]

图 1-1  目标函数

目标函数主要有两部分组成。第一部分是表示模型的预测误差;第二部分是表示模型结构。当模型预测误差越大,树的叶子个数越多,树的权重越大,目标函数就越大。我们的优化目标是使目标函数尽可能的小,这样在降低预测误差的同时也会减少树叶子的个数以及降低叶子权重。这也正符合机器学习中的"奥卡姆剃刀"原则,即选择与经验观察一致的最简单假设。

图 1-1 的目标函数由于存在以函数为参数的模型惩罚项导致其不能使用传统的方式进行优化,所以将其改写成如下形式

[1502415442814_9161_1502415442193.png]
[1502415442814_9161_1502415442193.png]

图 1-2 改变形式的目标函数

图 1-2 与图 1-1 的区别在于图 1-1 是通过整个模型去优化函数,而图 1-2 的优化目标是每次迭代过程中构造一个使目标函数达到最小值的弱分类器,从这个过程中就可以看出图 1-2 使用的是贪婪算法。将图 1-2 中的预测误差项在$y_i^{t-1}$处进行二阶泰勒展开:

[1502415460074_684_1502415459446.png]
[1502415460074_684_1502415459446.png]

图 1-3 二阶泰勒展开

并省去常数项

[1502415481343_8894_1502415480709.png]
[1502415481343_8894_1502415480709.png]

图 1-4 省去常数项

图 1-4 就是每次迭代过程中简化的目标函数。我们的目标是在第 t 次迭代过程中获得一个使目标函数达到最小值的最优弱分类器,即$f_t(x)$。在这里累加项 n 是样本实例的个数,为了使编码更加方便,定义一个新的变量表示表示叶子 j 的所有样本实例$x_i$

[1502415529337_5700_1502415528707.png]
[1502415529337_5700_1502415528707.png]

图 1-5 新的变量

同时展开目标函数的模型惩罚项,并以叶子为纬度可以改写成

[1502415546340_6657_1502415545736.png]
[1502415546340_6657_1502415545736.png]

图 1-6 以叶子为纬度的目标函数

这里函数 f 是将对应实例归类到对应的叶子下,并返回该实例在当前叶子下的权重 w。图 1-6 对叶子权重 w 求导,便得出最优的叶子权重 w

[1502415560657_521_1502415560011.png]
[1502415560657_521_1502415560011.png]

图 1-7 最优的叶子权重

与此同时将权重代入目标函数,并且省去常量,便得到了目标函数的解析式

[1502415579258_4503_1502415578644.png]
[1502415579258_4503_1502415578644.png]

图 1-8 目标函数的解析式

我们的目标便是极小化该目标函数解析式。目标函数的解析式可以通过图 1-9 清晰形象的描绘出来

[1502415601015_8411_1502415600562.png]
[1502415601015_8411_1502415600562.png]

图 1-9 目标函数的解析式计算过程

从图 1-9 可以清晰看出目标函数解析式的计算过程。目标函数的结果可以用来评价模型的好坏。这样在模型训练过程中,当前的叶子结点是否需要继续分裂主要就看分裂后的增益损失 loss_change。

[1502415618968_9983_1502415618353.png]
[1502415618968_9983_1502415618353.png]

图 1-10 分裂增益

增益损失 loss_change 的计算公式如图 1-10 所示,它是由该结点分裂后的左孩子增益加上右孩子增益减去该父结点的增益。这样在选择分裂点时候就是选择增益损失最大的分裂点。而寻找最佳分裂点是一个非常耗时的过程,上一篇《XGBoost 源码阅读笔记(1)--代码逻辑结构》介绍了几种 XGBoost 使用的分裂算法,这里选择其中最简单的 Exact Greedy Algorithm 进行讲解:

[1502415637930_7331_1502415637391.png]
[1502415637930_7331_1502415637391.png]

图 1-11  Exact Greedy Algorithm

图 1-11 算法的大意是遍历每个特征,在每个特征中选择该特征下的每个值作为其分裂点,计算增益损失。当遍历完所有特征之后,增益损失最大的特征值将作为其分裂点。由此可以看出这其实就是一种穷举算法,而整个树构造过程最耗时的过程就是寻找最优分裂点的过程。但是由于该算法简单易于理解,所以就以该算法来向大家介绍 XGBoost 源码树构造的实现过程。

如果对推导过程读起来比较吃力的话也没关系,这里主要需要记住的是每个结点增益和权值的计算公式。增益是用来决定当前结点是否需要继续分裂下去,而结点权值的线性组合即是模型最终的输出值。所以只要记住这两个公式就不会影响源码的阅读。

二、源码分析

1 代码逻辑结构回顾

在上一篇结尾的时候说过源码最终调用过程如下:

代码语言:txt
复制
  
代码语言:txt
复制
//gbtree.cc
|--GBTree::DoBoost()
	|--GBTree::BoostNewTrees()
		|--GBTree::InitUpdater()
		|--TreeUpdater::Update()

这里简化后的源码如下:

代码语言:txt
复制
   
代码语言:txt
复制
//gbtree.cc line:452
BoostNewTrees(const std::vector<bst_gpair> &gpair,
							DMatrix *p_fmat,
							int bst_group,
							std::vector<std::unique_ptr<RegTree> >* ret) {
	this->InitUpdater();
	std::vector<RegTree*> new_tress;
	for(auto& up: updaters){
		up->Update(gpair,p_fmat, new_trees);
	}
}

gpair 是一个 vector 向量,保存了对应样本实例的一阶导数和二阶导数。p_fmat 是一个指针,指向对应样本实例的特征,new_trees 用于存储构造好的回归树。

InitUpdater() 是为了初始化 updaters, 在上一篇文章也说过 updaters 是抽象类 Class TreeUpdater 的指针对象,定义了基本的 Init 和 Update 接口,该抽象的派生类定义了一系列树构造和剪枝方法。这里主要介绍其派生类 Class ColMaker,该类使用的即使我们前面介绍的 Exact Greedy Algorithm。

2 Class ColMaker 数据结构介绍

在 Class ColMaker 定义了一些数据结构用于辅助树的构造。

代码语言:txt
复制
    
代码语言:txt
复制
//updater_colmaker.cc line:755
const TrainParam& param; //训练参数,即我们设置的一些超参数
std::vector<int> position;  //当前样本实例在回归树结中对应点的索引
std::vector<NodeEntry> snode; //回归树中的结点
std::vector<int> qexpand_;  //保存将有可能分类的结点的索引

XGBoost 的树构造类似于 BFS(Breadth First Search),它是一层一层的构造树结点。所以需要一个队列 qexpand_用来保存当前层的结点,这些结点会根据增益损失 loss_change 决定是否需要分裂形成下一层的结点。

3 Class ColMaker 树构造源码

代码语言:txt
复制
  
代码语言:txt
复制
//updater_colmaker.cc line:29
void ColMaker::Update(...)
{
	for(size_t i = 0; i < trees.size();   ){
		Builder builder(param);
		builder.Update(gpair, dmat, trees[i]);
	}
}

在 Class ColMaker 中定义了一个 Class Builder 类,所有的构造过程都由这个类完成。

代码语言:txt
复制
   
代码语言:txt
复制
//updater_colmaker.cc line:89
void ColMaker::Builder::Update(...)
{
	this -> InitData(...);    //初始化 Builder 参数
	// 初始化树根结点的权值和增益
	this -> InitNewNode(gpair, *p_fmat,*p_tree);
	for( int depth = 0; depth < param.max_depth;   depth)
	{
		//给队列中的当层结点寻找分裂特征,构造出树的下一层
		this->FindSplit(depth, qexpand_, gpair, p_fmat, p_tree);
		//将当层各个非叶子结点中的样本实例分类到下一层的各个结点中
		this->ResetPosition();
		//更新队列,存储下一个层结点
		this->UpdateQueueExpand();
		//计算队列中下一层结点的权值和增益
		this->InitNewNode();
		//如果当前队列中没有候选分裂点,就退出循环
		If(qexpand_.size() == 0) break;
	}
	//由于树的深度限制,将队列中剩下结点都设置为树的叶子
	for(size_t i = 0; i < qexpand_.szie();   i)
	{
		 ...
	}
	//记录构造好的回归树的一些辅助统计信息
	...
}

在以上代码中核心部分就是第一个循环里面的四个函数。我们首先来看下 Builder::InitNewNode 是如何初始化结点的增益和权值。

(1) Builder::InitNewNode()

代码语言:txt
复制
    
代码语言:txt
复制
//updater_colmaker.cc
|--Builder::InitNewNode()
	|--for(size_t  j = 0;  j < qexpand_.size();   j)
	|--{
	|--  snode[qexpand[j]].root_gain = CalGain(...)
	|--  snode[qexpand[j]].weight = CalWeight(...)
	|--}

这里点的 root_gain 就是前面说的结点增益,将用于判断该点是否需要分裂。weigtht 就是当前点的权值,最终模型输出就是叶子结点 weight 的线性组合。CalGain() 和 CalWeight() 是两个模版函数,其简化的源码如下:

代码语言:txt
复制
    
代码语言:txt
复制
//param.h  line:242
Template<typename TrainingParams, typename T>
T CalGain(const TrainingParams &p, T sum_grad, T sum_hess)
{
	return (sum_grad * sum_grad)/( sum_hess   p.reg_lambda);
}  
//param.h line:275
Template<typename TrainingParams, typename T>
T CalWeight(const TrainingParams &p, T sum_grad, T sum_hess)
{
	return -sum_grad /( sum_hess   p.reg_lambda);
}

以上两个函数就是实现了我们一开始推导的两个公式,即计算结点的增益和权重。在初始化队列中的结点后,就需要对队列中的每个结点遍历寻找最优的分裂属性。

(2)XGBoost::Builder::FindSplit()

代码语言:txt
复制
    
代码语言:txt
复制
//updater_colmaker.cc
|--XGBoost::Builder::FindSplit()
	|--//寻找特征的最佳分裂值
	|--for(size_t i = 0; i< feature_num; i  )
	|--{
	|--  XGBoost::Builder::UpdateSolution()
			| --XGBoost::Builder::EnumerateSplit()
	|--}

分裂过程最终调用了 EnumerateSplit() 函数,为了便于理解对代码变量名做了修改,其简化的代码如下

代码语言:txt
复制
//updater_colmaker.cc line:508
void EnumerateSplit(...){
	//建立个临时变量 temp 用来保存结点信息
	//空间大小为队列 qexpand_中结点的最大索引
	vector<TStats> temp( std::max(qexpand_)   1);
	TStats left_child(param) //结点分裂后左孩子的统计信息
	//遍历当前特征的所有值
	for(const ColBatch::Entry * it = begin; it != end; it  = d_step){
		//得到当前特征值所对应的样本实例索引和特征值
		const int rIndex = it -> index;
		const int fValue = it->value;
		//根据当前样本索引得到其分类到的结点索引
		const int node_id = position[rIndex ];
		//结点分裂后右孩子的统计信息
		TStats & right_child = temp[node_id]
		//以当前特征值为分裂阈值,将当前样本归类到左孩子
		left_child = snode[node_id].stats - right_child;
		//计算增益损失
		int loss_change= CalcSplitGain(param, left_child, right_child)   
										 - snode[node_id].root_gain;
		//记录下最好的特征值分裂阈值,该阈值是左右孩子相邻特征值的中间值
		right_child.best.Update(loss_change, feature_id  
														, 0.5 * (fValue   right_child.left_value) );
		//将当前样本实例归类到右孩子结点
		right_child.add(gpair, info , rIndex)
	}
}

从上述代码可以很清晰看出整个代码的流程思路就是之前介绍的 Exact Greedy Alogrithm. 这里需要说明寻找分裂点有两个方向,一个是从左到右寻找,一个是从右到左寻找。上述代码只展示了一个方向的寻找过程。在寻找特征分裂阈值的时候分裂增益的计算函数是 CalcSplitGain(),其具体代码如下:

代码语言:txt
复制
   
代码语言:txt
复制
//param.h line:365
double CalcSplitGain(const TrainParam& param
, GradStats left, GradStats right) const {
		return left.CalcGain(param)   right.CalcGain(param);
}

上述代码就是简单将左孩子和右孩子的增益相加,而增益损失 loss_change 就是将左右孩子相加的增益减去其父节点的增益。

(3)XGboost::Builder::ResetPosition()

在寻找到当前层各个结点的分裂阈值之后,便可以在对应结点上构造其左右孩子来增加当前树的深度。当树的深度增加了,就需要将分类到当前层非叶子结点的样本实例分类到下一层的结点中。这个过程就是通过 ResetPosition() 函数完成的。

(4)XGboost::Builder::UpdateQueueExpand()

XGboost::Builder::UpdateQueueExpand() 函数更新 qexpand队列中的结点为下一层结点,然后在调用 XGboost::Builder::InitNewNode() 更新 qexpand中结点的权值和增益以便下一次循环

三、总结

本篇主要详细叙述了 XGBoost 使用 Exact Greedy Algorithm 构造树的方法,并分析了对应的源码。在分析源码过程中为了便于理解对代码做了一些简化,如省去了其中多线程,并行化的操作,并修改了一些变量名。在上述的树构造完成之后,还需要对树进行剪枝操作以防止模型过拟合。由于篇幅所限,这里就不再介绍剪枝操作。本篇文章只是起一个抛砖引玉的引导作用,想要对 XGBoost 实现细节有更加深刻理解,还需要去阅读 XGBoost 源码,毕竟有些东西用文字描述远不如用代码描述清晰。最后欢迎大家一起来讨论。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、目标函数优化
  • 二、源码分析
    • 1 代码逻辑结构回顾
      • 2 Class ColMaker 数据结构介绍
        • 3 Class ColMaker 树构造源码
        • 三、总结
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档