# 回归树的原理及Python实现

github：https://github.com/tushushu/Imylu/blob/master/regression_tree.py

## 2.1 创建Node类

class Node(object): def __init__(self, score=None): self.score = score self.left = None self.right = None self.feature = None self.split = None

2.2 创建回归树类

class RegressionTree(object): def __init__(self): self.root = Node() self.height = 0

2.3 计算分割点、MSE

## 2.4 计算最佳分割点

def _choose_split_point(self, X, y, idx, feature): unique = set([X[i][feature] for i in idx]) if len(unique) == 1: return None unique.remove(min(unique)) mse, split, split_avg = min( (self._get_split_mse(X, y, idx, feature, split) for split in unique), key=lambda x: x[0]) return mse, feature, split, split_avg

## 2.5 选择最佳特征

def _choose_feature(self, X, y, idx): m = len(X[0]) split_rets = [x for x in map(lambda x: self._choose_split_point( X, y, idx, x), range(m)) if x is not None] if split_rets == []: return None _, feature, split, split_avg = min( split_rets, key=lambda x: x[0]) idx_split = [[], []] while idx: i = idx.pop() xi = X[i][feature] if xi < split: idx_split[0].append(i) else: idx_split[1].append(i) return feature, split, split_avg, idx_split

## 2.6 规则转文字

def _expr2literal(self, expr): feature, op, split = expr op = ">=" if op == 1 else "<" return "Feature%d %s %.4f" % (feature, op, split)

2.7 获取规则

def _get_rules(self): que = [[self.root, []]] self.rules = [] while que: nd, exprs = que.pop(0) if not(nd.left or nd.right): literals = list(map(self._expr2literal, exprs)) self.rules.append([literals, nd.score]) if nd.left: rule_left = copy(exprs) rule_left.append([nd.feature, -1, nd.split]) que.append([nd.left, rule_left]) if nd.right: rule_right = copy(exprs) rule_right.append([nd.feature, 1, nd.split]) que.append([nd.right, rule_right])

## 2.8 训练模型

1. 控制树的最大深度max_depth；
2. 控制分裂时最少的样本量min_samples_split；
3. 叶子结点至少有两个不重复的y值；
4. 至少有一个特征是没有重复值的。

def fit(self, X, y, max_depth=5, min_samples_split=2): self.root = Node() que = [[0, self.root, list(range(len(y)))]] while que: depth, nd, idx = que.pop(0) if depth == max_depth: break if len(idx) < min_samples_split or \ set(map(lambda i: y[i], idx)) == 1: continue feature_rets = self._choose_feature(X, y, idx) if feature_rets is None: continue nd.feature, nd.split, split_avg, idx_split = feature_rets nd.left = Node(split_avg[0]) nd.right = Node(split_avg[1]) que.append([depth+1, nd.left, idx_split[0]]) que.append([depth+1, nd.right, idx_split[1]]) self.height = depth self._get_rules()

## 2.9 打印规则

def print_rules(self): for i, rule in enumerate(self.rules): literals, score = rule print("Rule %d: " % i, ' | '.join( literals) + ' => split_hat %.4f' % score)

## 2.10 预测一个样本

def _predict(self, row): nd = self.root while nd.left and nd.right: if row[nd.feature] < nd.split: nd = nd.left else: nd = nd.right return nd.score

## 2.11 预测多个样本

def predict(self, X): return [self._predict(Xi) for Xi in X]

3 效果评估

## 3.1 main函数

@run_time def main(): print("Tesing the accuracy of RegressionTree...") # Load data X, y = load_boston_house_prices() # Split data randomly, train set rate 70% X_train, X_test, y_train, y_test = train_test_split( X, y, random_state=10) # Train model reg = RegressionTree() reg.fit(X=X_train, y=y_train, max_depth=4) # Show rules reg.print_rules() # Model accuracy get_r2(reg, X_test, y_test)

## 总结

【关于作者】

Python中文社区作为一个去中心化的全球技术社区，以成为全球20万Python中文开发者的精神部落为愿景，目前覆盖各大主流媒体和协作平台，与阿里、腾讯、百度、微软、亚马逊、开源中国、CSDN等业界知名公司和技术社区建立了广泛的联系，拥有来自十多个国家和地区数万名登记会员，会员来自以公安部、工信部、清华大学、北京大学、北京邮电大学、中国人民银行、中科院、中金、华为、BAT、谷歌、微软等为代表的政府机关、科研单位、金融机构以及海内外知名公司，全平台近20万开发者关注。

338 篇文章134 人订阅

0 条评论

## 相关文章

37050

21390

### 最新综述文章推荐：自然语言生成、深度学习算法、多媒体大数据分析

【导读】专知内容组整理了最近人工智能领域相关期刊的5篇最新综述文章，为大家进行介绍，欢迎查看! 1 ▌自然语言生成综述：任务，应用，评价 ---- ---- ...

66770

8520

44280

10830

33590

207110

444110

38690