前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >回归树/模型树及python代码实现

回归树/模型树及python代码实现

作者头像
机器学习AI算法工程
发布2018-03-12 16:57:15
2.9K0
发布2018-03-12 16:57:15
举报

所谓回归就是数据进行曲线拟合,回归一般用来做预测,涵盖线性回归(经典最小二乘法)、局部加权线性回归、岭回归和逐步线性回归。先来看下线性回归,即经典最小二乘法,说到最小二乘法就不得说下线性代数,因为一般说线性回归只通过计算一个公式就可以得到答案,如(公式一)所示:

(公式一)

其中X是表示样本特征组成的矩阵,Y表示对应的值,比如房价,股票走势等,(公式一)是直接通过对(公式二)求导得到的,因为(公式二)是凸函数,导数等于零的点就是最小点。

(公式二)

不过并不是所有的码农能从(公式二)求导得到(公式一)的解,因此这里给出另外一个直观的解,直观理解建立起来后,后续几个回归就简单类推咯。从初中的投影点说起,如(图一)所示:

(图一)

在(图一)中直线a上离点b最近的点是点b在其上的投影,即垂直于a的交点p。p是b在a上的投影点。试想一下,如果我们把WX看成多维的a,即空间中的一个超面来代替二维空间中的直线,而y看成b,那现在要使得(公式二)最小是不是就是寻找(图一)中的e,即垂直于WX的垂线,因为只有垂直时e才最小。下面来看看如何通过寻找垂线并最终得到W。要寻找垂线,先从(图二)中的夹角theta 说起吧,因为当cos(theta)=0时,他们也就垂直了。下面来分析下直线或者向量之间的夹角,如(图二)所示:

(图二)

在(图二)中,

表示三角形

的斜边,那么:

角beta也可以得到同样的计算公式,接着利用三角形和差公式得到(公式三):

(公式三)

(公式三)表示的是两直线或者两向量之间的夹角公式,很多同学都学过。再仔细看下,发现分子其实是向量a,b之间的内积(点积),因此公式三变为简洁的(公式四)的样子:

(公式四)

接下来继续分析(图一)中的投影,为了方便观看,增加了一些提示如(图三)所示:

(图三)

在(图三)中,假设向量b在向量a中的投影为p(注意,这里都上升为向量空间,不再使用直线,因为(公式四)是通用的)。投影p和a 在同一方向上(也可以反方向),因此我们可以用一个系数乘上a来表示p,比如(图三)中的

,有了投影向量p,那么我们就可以表示向量e,因为根据向量法则,

,有因为a和e垂直,因此

,展开求得系数x,如(公式五)所示:

(公式五)

(公式五)是不是很像(公式一)?只不过公式一的分母写成了另外的形式,不过别急,现在的系数只是一个标量数字,因为a,b都是一个向量,我们要扩展一下,把a从向量扩展到子空间,因为(公式一)中的X是样本矩阵,矩阵有列空间和行空间,如(图四)所示:

(图四)

(图四)中的A表示样本矩阵X,假设它有两个列a1和a2,我们要找一些线性组合系数来找一个和(图三)一样的接受b 投影的向量,而这个向量通过矩阵列和系数的线性组合表示。求解的这个系数的思路和上面完全一样,就是寻找投影所在的向量和垂线e的垂直关系,得到系数,如(公式六)所示:

(公式六)

这下(公式六)和(公式一)完全一样了,基于最小二乘法的线性回归也就推导完成了,而局部加权回归其实只是相当于对不同样本之间的关系给出了一个权重,所以叫局部加权,如(公式七)所示:

(公式七)

而权重的计算可通过高斯核(高斯公式)来完成,核的作用就是做权重衰减,很多地方都要用到,表示样本的重要程度,一般离目标进的重要程度大些,高斯核可以很好的描述这种关系。如(公式八)所示,其中K是个超参数,根据情况灵活设置:

(公式八)

(图五)是当K分别为1.0, 0.01,0.003时的局部加权线性回归的样子,可以看出当K=1.0时,和线性回归没区别:

(图五)

而岭回归的样子如(公式九)所示:

(公式九)

岭回归主要是解决的问题就是当XX’无法求逆时,比如当特征很多,样本很少,矩阵X不是满秩矩阵,此时求逆会出错,但是通过加上一个对角为常量lambda的矩阵,就可以很巧妙的避免这个计算问题,因此会多一个参数lambda,lambda的最优选择由交叉验证(cross-validation)来决定,加上一个对角不为0的矩阵很形象的在对角上抬高了,因此称为岭。不同的lambda会使得系数缩减,如(图六)所示:

(图六)

说到系数缩减大家可能会觉得有奇怪,感觉有点类似于正则,但是这里只是相当于在(公式六)中增大分母,进而缩小系数,另外还有一些系数缩减的方法,比如直接增加一些约束,如(公式十)和(公式十一)所示:

(公式十)

(公式十一)

当线性回归增加了(公式十)的约束变得和桥回归差不多,系数缩减了,而如果增加了(公式十一)的约束时就是稀疏回归咯,(我自己造的名词,sorry),系数有一些0。

有了约束后,求解起来就不像上面那样直接计算个矩阵运算就行了,回顾第五节说中支持向量机原理,需要使用二次规划求解,不过仍然有一些像SMO算法一样的简化求解算法,比如前向逐步回归方法:

前向逐步回归的伪代码如(图七)所示,也不难,仔细阅读代码就可以理解:

(图七)

下面直接给出上面四种回归的代码:

[python] view plaincopy

  1. from numpy import *
  2. def loadDataSet(fileName): #general function to parse tab -delimited floats
  3. numFeat = len(open(fileName).readline().split('\t')) - 1 #get number of fields
  4. dataMat = []; labelMat = []
  5. fr = open(fileName)
  6. for line in fr.readlines():
  7. lineArr =[]
  8. curLine = line.strip().split('\t')
  9. for i in range(numFeat):
  10. lineArr.append(float(curLine[i]))
  11. dataMat.append(lineArr)
  12. labelMat.append(float(curLine[-1]))
  13. return dataMat,labelMat
  14. def standRegres(xArr,yArr):
  15. xMat = mat(xArr); yMat = mat(yArr).T
  16. xTx = xMat.T*xMat
  17. if linalg.det(xTx) == 0.0:
  18. print "This matrix is singular, cannot do inverse"
  19. return
  20. ws = xTx.I * (xMat.T*yMat)
  21. return ws
  22. def lwlr(testPoint,xArr,yArr,k=1.0):
  23. xMat = mat(xArr); yMat = mat(yArr).T
  24. m = shape(xMat)[0]
  25. weights = mat(eye((m)))
  26. for j in range(m): #next 2 lines create weights matrix
  27. diffMat = testPoint - xMat[j,:] #
  28. weights[j,j] = exp(diffMat*diffMat.T/(-2.0*k**2))
  29. xTx = xMat.T * (weights * xMat)
  30. if linalg.det(xTx) == 0.0:
  31. print "This matrix is singular, cannot do inverse"
  32. return
  33. ws = xTx.I * (xMat.T * (weights * yMat))
  34. return testPoint * ws
  35. def lwlrTest(testArr,xArr,yArr,k=1.0): #loops over all the data points and applies lwlr to each one
  36. m = shape(testArr)[0]
  37. yHat = zeros(m)
  38. for i in range(m):
  39. yHat[i] = lwlr(testArr[i],xArr,yArr,k)
  40. return yHat
  41. def lwlrTestPlot(xArr,yArr,k=1.0): #same thing as lwlrTest except it sorts X first
  42. yHat = zeros(shape(yArr)) #easier for plotting
  43. xCopy = mat(xArr)
  44. xCopy.sort(0)
  45. for i in range(shape(xArr)[0]):
  46. yHat[i] = lwlr(xCopy[i],xArr,yArr,k)
  47. return yHat,xCopy
  48. def rssError(yArr,yHatArr): #yArr and yHatArr both need to be arrays
  49. return ((yArr-yHatArr)**2).sum()
  50. def ridgeRegres(xMat,yMat,lam=0.2):
  51. xTx = xMat.T*xMat
  52. denom = xTx + eye(shape(xMat)[1])*lam
  53. if linalg.det(denom) == 0.0:
  54. print "This matrix is singular, cannot do inverse"
  55. return
  56. ws = denom.I * (xMat.T*yMat)
  57. return ws
  58. def ridgeTest(xArr,yArr):
  59. xMat = mat(xArr); yMat=mat(yArr).T
  60. yMean = mean(yMat,0)
  61. yMat = yMat - yMean #to eliminate X0 take mean off of Y
  62. #regularize X's
  63. xMeans = mean(xMat,0) #calc mean then subtract it off
  64. xVar = var(xMat,0) #calc variance of Xi then divide by it
  65. xMat = (xMat - xMeans)/xVar
  66. numTestPts = 30
  67. wMat = zeros((numTestPts,shape(xMat)[1]))
  68. for i in range(numTestPts):
  69. ws = ridgeRegres(xMat,yMat,exp(i-10))
  70. wMat[i,:]=ws.T
  71. return wMat
  72. def regularize(xMat):#regularize by columns
  73. inMat = xMat.copy()
  74. inMeans = mean(inMat,0) #calc mean then subtract it off
  75. inVar = var(inMat,0) #calc variance of Xi then divide by it
  76. inMat = (inMat - inMeans)/inVar
  77. return inMat
  78. def stageWise(xArr,yArr,eps=0.01,numIt=100):
  79. xMat = mat(xArr); yMat=mat(yArr).T
  80. yMean = mean(yMat,0)
  81. yMat = yMat - yMean #can also regularize ys but will get smaller coef
  82. xMat = regularize(xMat)
  83. m,n=shape(xMat)
  84. #returnMat = zeros((numIt,n)) #testing code remove
  85. ws = zeros((n,1)); wsTest = ws.copy(); wsMax = ws.copy()
  86. for i in range(numIt):
  87. print ws.T
  88. lowestError = inf;
  89. for j in range(n):
  90. for sign in [-1,1]:
  91. wsTest = ws.copy()
  92. wsTest[j] += eps*sign
  93. yTest = xMat*wsTest
  94. rssE = rssError(yMat.A,yTest.A)
  95. if rssE < lowestError:
  96. lowestError = rssE
  97. wsMax = wsTest
  98. ws = wsMax.copy()
  99. #returnMat[i,:]=ws.T
  100. #return returnMat
  101. #def scrapePage(inFile,outFile,yr,numPce,origPrc):
  102. # from BeautifulSoup import BeautifulSoup
  103. # fr = open(inFile); fw=open(outFile,'a') #a is append mode writing
  104. # soup = BeautifulSoup(fr.read())
  105. # i=1
  106. # currentRow = soup.findAll('table', r="%d" % i)
  107. # while(len(currentRow)!=0):
  108. # title = currentRow[0].findAll('a')[1].text
  109. # lwrTitle = title.lower()
  110. # if (lwrTitle.find('new') > -1) or (lwrTitle.find('nisb') > -1):
  111. # newFlag = 1.0
  112. # else:
  113. # newFlag = 0.0
  114. # soldUnicde = currentRow[0].findAll('td')[3].findAll('span')
  115. # if len(soldUnicde)==0:
  116. # print "item #%d did not sell" % i
  117. # else:
  118. # soldPrice = currentRow[0].findAll('td')[4]
  119. # priceStr = soldPrice.text
  120. # priceStr = priceStr.replace('$','') #strips out $
  121. # priceStr = priceStr.replace(',','') #strips out ,
  122. # if len(soldPrice)>1:
  123. # priceStr = priceStr.replace('Free shipping', '') #strips out Free Shipping
  124. # print "%s\t%d\t%s" % (priceStr,newFlag,title)
  125. # fw.write("%d\t%d\t%d\t%f\t%s\n" % (yr,numPce,newFlag,origPrc,priceStr))
  126. # i += 1
  127. # currentRow = soup.findAll('table', r="%d" % i)
  128. # fw.close()
  129. from time import sleep
  130. import json
  131. import urllib2
  132. def searchForSet(retX, retY, setNum, yr, numPce, origPrc):
  133. sleep(10)
  134. myAPIstr = 'AIzaSyD2cR2KFyx12hXu6PFU-wrWot3NXvko8vY'
  135. searchURL = 'https://www.googleapis.com/shopping/search/v1/public/products?key=%s&country=US&q=lego+%d&alt=json' % (myAPIstr, setNum)
  136. pg = urllib2.urlopen(searchURL)
  137. retDict = json.loads(pg.read())
  138. for i in range(len(retDict['items'])):
  139. try:
  140. currItem = retDict['items'][i]
  141. if currItem['product']['condition'] == 'new':
  142. newFlag = 1
  143. else: newFlag = 0
  144. listOfInv = currItem['product']['inventories']
  145. for item in listOfInv:
  146. sellingPrice = item['price']
  147. if sellingPrice > origPrc * 0.5:
  148. print "%d\t%d\t%d\t%f\t%f" % (yr,numPce,newFlag,origPrc, sellingPrice)
  149. retX.append([yr, numPce, newFlag, origPrc])
  150. retY.append(sellingPrice)
  151. except: print 'problem with item %d' % i
  152. def setDataCollect(retX, retY):
  153. searchForSet(retX, retY, 8288, 2006, 800, 49.99)
  154. searchForSet(retX, retY, 10030, 2002, 3096, 269.99)
  155. searchForSet(retX, retY, 10179, 2007, 5195, 499.99)
  156. searchForSet(retX, retY, 10181, 2007, 3428, 199.99)
  157. searchForSet(retX, retY, 10189, 2008, 5922, 299.99)
  158. searchForSet(retX, retY, 10196, 2009, 3263, 249.99)
  159. def crossValidation(xArr,yArr,numVal=10):
  160. m = len(yArr)
  161. indexList = range(m)
  162. errorMat = zeros((numVal,30))#create error mat 30columns numVal rows
  163. for i in range(numVal):
  164. trainX=[]; trainY=[]
  165. testX = []; testY = []
  166. random.shuffle(indexList)
  167. for j in range(m):#create training set based on first 90% of values in indexList
  168. if j < m*0.9:
  169. trainX.append(xArr[indexList[j]])
  170. trainY.append(yArr[indexList[j]])
  171. else:
  172. testX.append(xArr[indexList[j]])
  173. testY.append(yArr[indexList[j]])
  174. wMat = ridgeTest(trainX,trainY) #get 30 weight vectors from ridge
  175. for k in range(30):#loop over all of the ridge estimates
  176. matTestX = mat(testX); matTrainX=mat(trainX)
  177. meanTrain = mean(matTrainX,0)
  178. varTrain = var(matTrainX,0)
  179. matTestX = (matTestX-meanTrain)/varTrain #regularize test with training params
  180. yEst = matTestX * mat(wMat[k,:]).T + mean(trainY)#test ridge results and store
  181. errorMat[i,k]=rssError(yEst.T.A,array(testY))
  182. #print errorMat[i,k]
  183. meanErrors = mean(errorMat,0)#calc avg performance of the different ridge weight vectors
  184. minMean = float(min(meanErrors))
  185. bestWeights = wMat[nonzero(meanErrors==minMean)]
  186. #can unregularize to get model
  187. #when we regularized we wrote Xreg = (x-meanX)/var(x)
  188. #we can now write in terms of x not Xreg: x*w/var(x) - meanX/var(x) +meanY
  189. xMat = mat(xArr); yMat=mat(yArr).T
  190. meanX = mean(xMat,0); varX = var(xMat,0)
  191. unReg = bestWeights/varX
  192. print "the best model from Ridge Regression is:\n",unReg
  193. print "with constant term: ",-1*sum(multiply(meanX,unReg)) + mean(yMat)

以上各种回归方法没有考虑实际数据的噪声,如果噪声很多,直接用上述的回归不是太好,因此需要加上正则,然后迭代更新权重

参考文献:

[1] machine learning in action.Peter Harrington

[2]Linear Algebra and Its Applications_4ed.Gilbert_Strang

回归树和模型树

前一节的回归是一种全局回归模型,它设定了一个模型,不管是线性还是非线性的模型,然后拟合数据得到参数,现实中会有些数据很复杂,肉眼几乎看不出符合那种模型,因此构建全局的模型就有点不合适。这节介绍的树回归就是为了解决这类问题,它通过构建决策节点把数据数据切分成区域,然后局部区域进行回归拟合。先来看看分类回归树吧(CART:Classification And Regression Trees),这个模型优点就是上面所说,可以对复杂和非线性的数据进行建模,缺点是得到的结果不容易理解。顾名思义它可以做分类也可以做回归,至于分类前面在说决策树时已经说过了,这里略过。直接通过分析回归树的代码来理解吧:

[python] view plaincopy

  1. from numpy import *
  2. def loadDataSet(fileName): #general function to parse tab -delimited floats
  3. dataMat = [] #assume last column is target value
  4. fr = open(fileName)
  5. for line in fr.readlines():
  6. curLine = line.strip().split('\t')
  7. fltLine = map(float,curLine) #map all elements to float()
  8. dataMat.append(fltLine)
  9. return dataMat
  10. def binSplitDataSet(dataSet, feature, value):
  11. mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]
  12. mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
  13. return mat0,mat1

上面两个函数,第一个函数加载样本数据,第二个函数用来指定在某个特征和维度上切分数据,示例如(图一)所示:

(图一)

注意一下,CART是一种通过二元切分来构建树的,前面的决策树的构建是通过香农熵最小作为度量,树的节点是个离散的阈值;这里不再使用香农熵,因为我们要做回归,因此这里使用计算分割数据的方差作为度量,而树的节点也对应使用使得方差最小的某个连续数值(其实是特征值)。试想一下,如果方差越小,说明误差那个节点最能表述那块数据。下面来看看树的构建代码:

[python] view plaincopy

  1. def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering
  2. feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best split
  3. if feat == None: return val #if the splitting hit a stop condition return val(叶子节点值)
  4. retTree = {}
  5. retTree['spInd'] = feat
  6. retTree['spVal'] = val
  7. lSet, rSet = binSplitDataSet(dataSet, feat, val)
  8. retTree['left'] = createTree(lSet, leafType, errType, ops)
  9. retTree['right'] = createTree(rSet, leafType, errType, ops)
  10. return retTree

这段代码中主要工作任务就是选择最佳分割特征,然后分割,是叶子节点就返回,不是叶子节点就递归的生成树结构。其中调用了最佳分割特征的函数:chooseBestSplit,前面决策树的构建中,这个函数里用熵来度量,这里采用误差(方差)来度量,同样先看代码:

[python] view plaincopy

  1. def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
  2. tolS = ops[0]; tolN = ops[1]
  3. #if all the target variables are the same value: quit and return value
  4. if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1
  5. return None, leafType(dataSet)
  6. m,n = shape(dataSet)
  7. #the choice of the best feature is driven by Reduction in RSS error from mean
  8. S = errType(dataSet)
  9. bestS = inf; bestIndex = 0; bestValue = 0
  10. for featIndex in range(n-1):
  11. for splitVal in set(dataSet[:,featIndex]):
  12. mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
  13. if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
  14. newS = errType(mat0) + errType(mat1)
  15. if newS < bestS:
  16. bestIndex = featIndex
  17. bestValue = splitVal
  18. bestS = newS
  19. #if the decrease (S-bestS) is less than a threshold don't do the split
  20. if (S - bestS) < tolS:
  21. return None, leafType(dataSet) #exit cond 2
  22. mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
  23. if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): #exit cond 3
  24. return None, leafType(dataSet)
  25. return bestIndex,bestValue#returns the best feature to split on
  26. #and the value used for that split

这段代码的主干是:

遍历每个特征:

遍历每个特征值:

把数据集切分成两份

计算此时的切分误差

如果切分误差小于当前最小误差,更新最小误差值,当前切分为最佳切分

返回最佳切分的特征值和阈值

尤其注意最后的返回值,因为它是构建树每个节点成分的东西。另外代码中errType=regErr 调用了regErr函数来计算方差,下面给出:

[python] view plaincopy

  1. def regErr(dataSet):
  2. return var(dataSet[:,-1]) * shape(dataSet)[0]

如果误差变化不大时(代码中(S - bestS)),则生成叶子节点,叶子节点函数是:

[python] view plaincopy

  1. def regLeaf(dataSet):#returns the value used for each leaf
  2. return mean(dataSet[:,-1])

这样回归树构建的代码就初步分析完毕了,运行结果如(图二)所示:

(图二)

数据ex00.txt在文章最后给出,它的分布如(图三)所示:

(图三)

根据(图三),我们可以大概看出(图二)的代码的运行结果具有一定的合理性,选用X(用0表示)特征作为分割特征,然后左右节点各选了一个中心值来描述树回归。节点比较少,但很能说明问题,下面给出一个比较复杂数据跑出的结果,如(图四)所示:

(图四)

对应的数据如(图五)所示:

(图五)

对于树的叶子节点和节点值的合理性,大家逐个对照(图五)来验证吧。下面简单的说下树的修剪,如果特征维度比较高,很容易发生节点过多,造成过拟合,过拟合(overfit)会出现high variance, 而欠拟合(under fit)会出现high bias,这点是题外话,因为机器学习理论一般要讲这些,当出现过拟合时,一般使用正则方法,由于回归树没有建立目标函数,因此这里解决过拟合的方法就是修剪树,简单的说就是使用少量的、关键的特征来判别,下面来看看如何修剪树:很简单,就是递归的遍历一个子树,从叶子节点开始,计算同一父节点的两个子节点合并后的误差,再计算不合并的误差,如果合并会降低误差,就把叶子节点合并。说到误差,其实前面的chooseBestSplit函数里有一句代码:

[python] view plaincopy

  1. #if the decrease (S-bestS) is less than a threshold don't do the split
  2. if (S - bestS) < tolS:

tolS 是个阈值,当误差变化不太大时,就不再分裂下去,其实也是修剪树的方法,只不过它是事前修剪,而计算合并误差的则是事后修剪。下面是其代码:

[python] view plaincopy

  1. def getMean(tree):
  2. if isTree(tree['right']): tree['right'] = getMean(tree['right'])
  3. if isTree(tree['left']): tree['left'] = getMean(tree['left'])
  4. return (tree['left']+tree['right'])/2.0
  5. def prune(tree, testData):
  6. if shape(testData)[0] == 0: return getMean(tree) #if we have no test data collapse the tree
  7. if (isTree(tree['right']) or isTree(tree['left'])):#if the branches are not trees try to prune them
  8. lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
  9. if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
  10. if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
  11. #if they are now both leafs, see if we can merge them
  12. if not isTree(tree['left']) and not isTree(tree['right']):
  13. lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
  14. errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\
  15. sum(power(rSet[:,-1] - tree['right'],2))
  16. treeMean = (tree['left']+tree['right'])/2.0
  17. errorMerge = sum(power(testData[:,-1] - treeMean,2))
  18. if errorMerge < errorNoMerge:
  19. print "merging"
  20. return treeMean
  21. else: return tree
  22. else: return tree

说完了树回归,再简单的提下模型树,因为树回归每个节点是一些特征和特征值,选取的原则是根据特征方差最小。如果把叶子节点换成分段线性函数,那么就变成了模型树,如(图六)所示:

(图六)

(图六)中明显是两个直线组成,以X坐标(0.0-0.3)和(0.3-1.0)分成的两个线段。如果我们用两个叶子节点保存两个线性回归模型,就完成了这部分数据的拟合。实现也比较简单,代码如下:

[python] view plaincopy

  1. def linearSolve(dataSet): #helper function used in two places
  2. m,n = shape(dataSet)
  3. X = mat(ones((m,n))); Y = mat(ones((m,1)))#create a copy of data with 1 in 0th postion
  4. X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]#and strip out Y
  5. xTx = X.T*X
  6. if linalg.det(xTx) == 0.0:
  7. raise NameError('This matrix is singular, cannot do inverse,\n\
  8. try increasing the second value of ops')
  9. ws = xTx.I * (X.T * Y)
  10. return ws,X,Y
  11. def modelLeaf(dataSet):#create linear model and return coeficients
  12. ws,X,Y = linearSolve(dataSet)
  13. return ws
  14. def modelErr(dataSet):
  15. ws,X,Y = linearSolve(dataSet)
  16. yHat = X * ws
  17. return sum(power(Y - yHat,2))

代码和树回归相似,只不过modelLeaf在返回叶子节点时,要完成一个线性回归,由linearSolve来完成。最后一个函数modelErr则和回归树的regErr函数起着同样的作用。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2015-08-31,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 大数据挖掘DT数据分析 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档