前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >回归树(一)

回归树(一)

作者头像
用户6021899
发布2019-08-14 17:29:16
9830
发布2019-08-14 17:29:16
举报
文章被收录于专栏:Python编程 pyqt matplotlib

线性回归模型需要拟合全部的样本点(局部加权线性回归除外)。当数据拥有众多特征并且特征之间的关系十分复杂时,构建全局模型的想法就不切实际。一种可行的方法是将数据集切分成很多份容易建模的数据,然后再用线性回归技术来建模。如果切分后任然难以用线性模型拟合就继续切分。在这种切分方式下,递归和树结构就相当有用。

本篇介绍一个叫做CART(Classfication And Regression Trees,分类回归树)的算法。先介绍一种简单的回归树,在每个叶子节点使用y的均值做预测。

首先加载一个200x3的数据集:

代码语言:python
代码运行次数:0
复制
def loadDataSet(fileName):      #general function to parse tab -delimited floats
    dataMat = []                #assume last column is target value
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float,curLine) )#map all elements to float() # py36
        dataMat.append(fltLine)
    return dataMat

数据集的大小为200x3,前两列为x0(恒为1)和x1的值,最后一列为y的值。x1和y的二维图如下:

回归树使用二元切分来处理连续型变量。具体的处理方法是:如果特征值大于给定的阈值就走左子树,否则就进入右子树。

代码语言:javascript
复制
def binSplitDataSet(dataSet, feature, value):
    #mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]#原文错误
    matLeft = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
    #mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
    matRight = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
    return matLeft, matRight

递归构建回归树:

代码语言:javascript
复制
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#选择最好的特征
    if feat == None: return val #if the splitting hit a stop condition return val
    retTree = {}
    retTree['spInd'] = feat #根据哪个特征划分
    retTree['spVal'] = val #根据和哪个值的比较结果进行划分
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

树的数据结构使用嵌套的字典实现,字典有4个键值,分别是

"spInd" : 特征的索引

"spVal" : 特征的阈值

"left" : 左子树,若是叶子节点则是该组样本y的均值

"right" : 右子树,若是叶子节点则是该组样本y的均值

使用叶子节点对应的y值的平均值作为预测值:

代码语言:javascript
复制
def regLeaf(dataSet):#returns the value used for each leaf    
    return mean(dataSet[:,-1])

这里用平方误差的总和作为误差函数:

代码语言:javascript
复制
def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]

下面给出如何找到最好的划分特征的伪代码:

代码语言:javascript
复制
对每个特征:
    对每个不重复的特征值:
        将数据集切分成两份
        计算误差(总方差)
        如果当前误差小于当前最小误差,就用当前最小误差替代当前误差
        如果误差下降值小于给定的最小值TolS, 则不再切分,直接返回
        如果去重的剩余特征值的数目小于TolN,则不再切分,直接返回
返回最佳切分的特征和阈值

代码实现:

代码语言:javascript
复制
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1.0, 4)):
    # tolS : 容许的误差下降值
    # tolN:切分的最少样本数
    tolS, tolN = ops
    #if all the target variables are the same value: quit and return value
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1
        return None, leafType(dataSet)
    m,n = shape(dataSet)
    #the choice of the best feature is driven by Reduction in RSS error from mean
    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set(array(dataSet[:,featIndex]).flatten().tolist()): # 利用集合去重,set()参数列表不能有嵌套,须先降维
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    #if the decrease (S-bestS) is less than a threshold don't do the split
    if (S - bestS) < tolS:
        return None, leafType(dataSet) #exit cond 2
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  #exit cond 3
        return None, leafType(dataSet)
    return bestIndex,bestValue#returns the best feature to split on
                              #and the value used for that split

调用上述函数,求出回归树:

代码语言:javascript
复制
myData = loadDataSet('ex0.txt')
myMat = mat(myData)
tree = createTree(myMat)
print(tree)

上面回归树的结果不太直观,我们可以用matplotlib 画出树的结构:

下面我也给出回归树绘图的代码:

代码语言:javascript
复制
from plotRegTree import createPlot
createPlot(tree,title="回归树\n 以分段常数预测y")

具体的实现在写plotRegTree模块中,会多次用到递归:

代码语言:javascript
复制
def getNumLeafs(regTree):
    '''返回叶子节点的数目(树的最大宽度)'''
    numLeafs = 0
    leftTree = regTree['left']
    rightTree = regTree['right']
    if type(leftTree).__name__ == "dict":#数据类型为字典(左树还有子树)
        numLeafs += getNumLeafs(leftTree)#递归调用
    else:
        numLeafs += 1
    if type(rightTree).__name__ == "dict":#数据类型为字典(右树还有子树)
        numLeafs += getNumLeafs(rightTree)#递归调用
    else:
        numLeafs += 1
       
    return numLeafs
    
def getTreeDepth(regTree):
    '''返回树的最大深度'''
    maxDepth = 0
    leftTree = regTree['left']
    rightTree = regTree['right']
    if type(leftTree).__name__ == "dict":#数据类型为字典(左树还有子树)
        thisDepth = 1 + getTreeDepth(leftTree)#递归调用
    else:
        thisDepth = 1
    if thisDepth >maxDepth :
        maxDepth = thisDepth
   
    if type(rightTree).__name__ == "dict":#数据类型为字典(右树还有子树)
        thisDepth = 1 + getTreeDepth(rightTree)#递归调用
    else:
        thisDepth = 1
    if thisDepth >maxDepth :
        maxDepth = thisDepth
       
    return maxDepth
    
yTop = 0.97 # 图形区域(含标题)X,Y坐标范围 均为0~1,0.97给title留空间
decisionNode = dict(boxstyle ="sawtooth", facecolor = "orange",edgecolor = "orange")
leafNode = dict(boxstyle = "round4", facecolor = "lime")
arrow_args = dict(arrowstyle = "<-", color ="r")

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    #return None
    createPlot.ax1.annotate(nodeTxt, xy =parentPt, xycoords = "axes fraction",
                            xytext = centerPt, textcoords ="axes fraction", va ="center",
                            ha = "center", bbox = nodeType, color ="black",weight ="bold",
                            arrowprops = arrow_args)
                            
def plotMidText(cntrPt, parentPt, textString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, textString)
    
def plotTree(regTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(regTree)
    depth = getTreeDepth(regTree)
    leftTree = regTree['left']
    rightTree = regTree['right']
   
    #firstStr = list(regTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + numLeafs) /2.0 / plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    print(regTree['spInd'])
    plotNode( "根据X%d划分" % regTree['spInd'], cntrPt, parentPt, decisionNode)
   
    plotTree.yOff  -= yTop /plotTree.totalD #到下一层
    specLimit = regTree['spVal']
    if type(leftTree).__name__ == "dict":#数据类型为字典(左树还有子树)
        plotTree(leftTree, cntrPt, ">%.6f" % specLimit )#递归调用
        #y的预测值的精度(小数点后显示6位)
    else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode("y预测值:%.3f" % leftTree, (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, ">%.6f" % specLimit)
    if type(rightTree).__name__ == "dict":#数据类型为字典(左树还有子树)
        plotTree(rightTree, cntrPt, "<=%.6f" % specLimit)#递归调用
    else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode("y预测值:%.3f" % rightTree, (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, "<=%.6f" % specLimit)
    plotTree.yOff  += yTop / plotTree.totalD #回到上一层

def createPlot(inTree,title ="回归树"):
    from matplotlib import pyplot as plt
    fig = plt.figure(1, facecolor = 'white')
    fig.clf()
    axprops = dict(xticks = [], yticks = []) #不显示x轴和y轴的刻度
    createPlot.ax1 = plt.subplot(111, frameon= False, ** axprops)
   
    plotTree.totalW = getNumLeafs(inTree)
    plotTree.totalD = getTreeDepth(inTree)
    plotTree.xOff = -0.5 / plotTree.totalW
   
    plotTree.yOff = yTop
    plotTree(inTree, (0.5, yTop), '')
  
    plt.title(title,fontsize =14, color ="B")
    plt.show()
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-08-10,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Python可视化编程机器学习OpenCV 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档