前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【ML】回归树算法原理及实现

【ML】回归树算法原理及实现

作者头像
yuquanle
发布2020-05-25 17:09:15
6400
发布2020-05-25 17:09:15
举报

由于现实中的很多问题是非线性的,当处理这类复杂的数据的回归问题时,特征之间的关系并不是简单的线性关系,此时,不可能利用全局的线性回归模型拟合这类数据。在上一篇文章"分类树算法原理及实现"中,分类树算法可以解决现实中非线性的分类问题,那么本文要讲的就是可以解决现实中非线性回归问题的回归树算法。

本文以决策树中的CART树为例介绍回归树的原理及实现。

叶节点分裂指标

通常在CART回归树中,样本的标签是一系列的连续值的集合,不能再使用基尼指数作为划分树的指标。在回归问题中我们可以发现,对于连续数据, 当数据分布比较分散时,各个数据与平均数的差的平方和较大,方差就较大;当数据分布比较集中时,各个数据与平均数的差的平方和较小。方差越大,数据的波动越大;方差越小,数据的波动就越小。因此,对于连续的数据,可以使用样本与平均值的差的平方和作为划分回归树的指标。

方差是度量数据分布离散程度最常用的一种指标,对于包含m个训练样本的数据集D{(X(1),y(1)),(X(2),y(2)),…,(X(m),y(m))},则指标为数据集D中所有样本标签与均值的差的平方和:

在回归树中即用该指标来进行叶节点分裂。现在让我们用代码将其实现。

import numpy as np
def err_cnt(dataSet):
    '''input: dataSet训练数据
    output: m*s^2总方差'''
    data = np.mat(dataSet)
    return np.var(data[:, -1]) * np.shape(data)[0]

回归树

先定义样本被划分到左右子树的过程函数,原理为根据特征fea位置处的特征,按照值value将样本划分到左右子树中,当样本在特征fea处的值大于或者等于value时,将其划分到右子树中;否则,将其划分到左子树中。用代码实现如下:

def split_tree(data, fea, value):
    '''input: data训练样本
            fea需要划分的特征编号
            value指定的划分的值
    output: (set_1, set_2)左右子树的聚合'''
    set_1 = []  # 右子树的集合
    set_2 = []  # 左子树的集合
    for x in data:
        if x[fea] >= value:
            set_1.append(x)
        else:
            set_2.append(x)
    return (set_1, set_2) 

另外需要定义计算当前叶子节点的值,计算的方法是使用划分到该叶子节点的所有样本的标签均值,代码如下:

def leaf(dataSet):
    '''input: dataSet训练样本
    output: 均值'''
    data = np.mat(dataSet)
    return np.mean(data[:, -1])

在按照特征对上述的数据进行划分的过程中,需要设置划分的终止条件和分类树比较类似。其构建过程可以分为以下几个步骤:

  • 对于当前训练数据集,遍历所有特征及其对应的所有可能切分点,寻找最佳切分特征及其最佳切分点,使得切分之后的各子集方差和最小,利用该最佳切分特征及其最佳切分点将训练数据集切分成两个子集,分别对应判别结果为左子树和判别结果为右子树。
  • 重复以下的步骤直至满足停止条件:为每一个叶子节点寻找最佳切分特征及其最佳切分点,将其划分为左右子树。
  • 生成回归树。

现在先为树中的节点定义一个结构类,代码如下:

class node:
    def __init__(self, fea=-1, value=None, results=None, right=None, left=None):
        self.fea = fea  # 用于切分数据集的特征的列索引值
        self.value = value  # 设置划分的值
        self.results = results  # 存储叶节点的值
        self.right = right  # 右子树
        self.left = left  # 左子树

然后我们可以利用递归的方法开始构建树了,在构建树的过程中,如果节点中的样本个数小于或者等于指定的最小样本数min_sample,则该节点不再划分。当节点需要划分时,首先计算当前节点的error值,划分后产生左子树和右子树,此时,计算左右子树的error值,若此时的error值小于最优的error值,则更新最优划分,当该节点划分完成后,继续对其左右子树进行划分。

def build_tree(data, min_sample, min_err):
    '''input: data训练样本
            min_sample叶子节点中最少样本数
            min_err最小的error
    output: node:树的根结点'''
    # 构建回归树,函数返回该树的根节点
    if len(data) <= min_sample:
        return node(results=leaf(data))
    
    # 1、初始化
    best_err = err_cnt(data)
    bestCriteria = None  # 存储最佳切分特征以及最佳切分点
    bestSets = None  # 存储切分后的两个数据集
    
    # 2、开始构建回归树
    feature_num = len(data[0]) - 1
    for fea in range(0, feature_num):
        feature_values = {}
        for sample in data:
            feature_values[sample[fea]] = 1
        
        for value in feature_values.keys():
            # 2.1、尝试划分
            (set_1, set_2) = split_tree(data, fea, value)
            if len(set_1) < 2 or len(set_2) < 2:
                continue
            # 2.2、计算划分后的error值
            now_err = err_cnt(set_1) + err_cnt(set_2)
            # 2.3、更新最优划分
            if now_err < best_err and len(set_1) > 0 and len(set_2) > 0:
                best_err = now_err
                bestCriteria = (fea, value)
                bestSets = (set_1, set_2)

    # 3、判断划分是否结束
    if best_err > min_err:
        right = build_tree(bestSets[0], min_sample, min_err)
        left = build_tree(bestSets[1], min_sample, min_err)
        return node(fea=bestCriteria[0], value=bestCriteria[1], \
                    right=right, left=left)
    else:
        return node(results=leaf(data))     

剪枝

树回归中,当树中的节点对样本一直划分下去时,会出现的最极端的情况是:每一个叶子节点中仅包含一个样本,此时,叶子节点的值即为该样本的标签的值。这种情况易对训练样本"过拟合",通过这样方式训练出来的样本可以对训练样本拟合得很好,但是对新样本的预测效果将会较差,而这种问题一般大多发生在回归问题中。为了防止构建好的树模型过拟合,通常需要对回归树进行剪枝,剪枝的目的是防止回归树生成过多的叶子节点。在剪枝中主要分为:前剪枝和后剪枝。

前剪枝是指在生成回归树的过程中对树的深度进行控制,防止生成过多的叶子节点。

后剪枝是指将训练样本分成两个部分,一部分用来训练树模型,这部分数据被称为训练数据,另一部分用来对生成的树模型进行剪枝,这部分数据被称为验证数据。如果出现过拟合的现象,则合并一些叶子节点来达到对树模型的剪枝。

到这里整个流程基本就结束了~

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-05-17,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI小白入门 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 叶节点分裂指标
  • 回归树
  • 剪枝
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档