# 回归树的原理及其 Python 实现

1. 原理篇

1.1 最简单的模型

1.2 加一点难度

1.3 最佳分割点

1.4 运用多个变量

1.5 答案揭晓

2. 实现篇

2.1 创建Node类

classNode(object):

def __init__(self,score=None):

self.score=score

self.left=None

self.right=None

self.feature=None

self.split=None

2.2 创建回归树类

classRegressionTree(object):

def __init__(self):

self.root=Node()

self.height=

2.3 计算分割点、MSE

2.4 计算最佳分割点

def _choose_split_point(self,X,y,idx,feature):

unique=set([X[i][feature]foriinidx])

iflen(unique)==1:

returnNone

unique.remove(min(unique))

mse,split,split_avg=min(

(self._get_split_mse(X,y,idx,feature,split)

forsplitinunique),key=lambdax:x[])

returnmse,feature,split,split_avg

2.5 选择最佳特征

def _choose_feature(self,X,y,idx):

m=len(X[])

split_rets=[xforxinmap(lambdax:self._choose_split_point(

X,y,idx,x),range(m))ifxisnotNone]

ifsplit_rets==[]:

returnNone

_,feature,split,split_avg=min(

split_rets,key=lambdax:x[])

idx_split=[[],[]]

whileidx:

i=idx.pop()

xi=X[i][feature]

ifxi

idx_split[].append(i)

else:

idx_split[1].append(i)

returnfeature,split,split_avg,idx_split

2.6 规则转文字

def _expr2literal(self,expr):

feature,op,split=expr

op=">="ifop==1else"

return"Feature%d %s %.4f"%(feature,op,split)

2.7 获取规则

def _get_rules(self):

que=[[self.root,[]]]

self.rules=[]

whileque:

nd,exprs=que.pop()

literals=list(map(self._expr2literal,exprs))

self.rules.append([literals,nd.score])

ifnd.left:

rule_left=copy(exprs)

rule_left.append([nd.feature,-1,nd.split])

que.append([nd.left,rule_left])

ifnd.right:

rule_right=copy(exprs)

rule_right.append([nd.feature,1,nd.split])

que.append([nd.right,rule_right])

2.8 训练模型

def fit(self,X,y,max_depth=5,min_samples_split=2):

self.root=Node()

que=[[,self.root,list(range(len(y)))]]

whileque:

depth,nd,idx=que.pop()

ifdepth==max_depth:

break

iflen(idx)

set(map(lambdai:y[i],idx))==1:

continue

feature_rets=self._choose_feature(X,y,idx)

iffeature_retsisNone:

continue

nd.feature,nd.split,split_avg,idx_split=feature_rets

nd.left=Node(split_avg[])

nd.right=Node(split_avg[1])

que.append([depth+1,nd.left,idx_split[]])

que.append([depth+1,nd.right,idx_split[1]])

self.height=depth

self._get_rules()

2.9 打印规则

def print_rules(self):

fori,ruleinenumerate(self.rules):

literals,score=rule

print("Rule %d: "%i,' | '.join(

literals)+' => split_hat %.4f'%score)

2.10 预测一个样本

def _predict(self,row):

nd=self.root

ifrow[nd.feature]

nd=nd.left

else:

nd=nd.right

returnnd.score

2.11 预测多个样本

def predict(self,X):

return[self._predict(Xi)forXiinX]

3 效果评估

3.1 main函数

@run_time

def main():

print("Tesing the accuracy of RegressionTree...")

# Split data randomly, train set rate 70%

X_train,X_test,y_train,y_test=train_test_split(

X,y,random_state=10)

# Train model

reg=RegressionTree()

reg.fit(X=X_train,y=y_train,max_depth=4)

# Show rules

reg.print_rules()

# Model accuracy

get_r2(reg,X_test,y_test)

3.2 效果展示

3.3 工具函数

https://zhuanlan.zhihu.com/p/41688007

【关于作者】

【关于投稿】

① 留言格式：

【投稿】+《 文章标题》+ 文章链接

② 示例：

【投稿】《不要自称是程序员，我十多年的 IT 职场总结》：

http://blog.jobbole.com/94148/

③ 最后请附上您的个人简介哈~

• 发表于:
• 原文链接https://kuaibao.qq.com/s/20180812B15CS700?refer=cp_1026
• 腾讯「云+社区」是腾讯内容开放平台帐号（企鹅号）传播渠道之一，根据《腾讯内容开放平台服务协议》转载发布内容。
• 如有侵权，请联系 yunjia_community@tencent.com 删除。

2021-06-14

2021-06-14

2021-06-14

2021-06-14

2018-07-04

2018-06-29

2018-07-04

2018-07-02

2018-06-19

2021-06-14

2021-06-14

2021-06-14