Python写算法：二元决策树

import urllib2 import numpy from sklearn import tree from sklearn.tree import DecisionTreeRegressor from sklearn.externals.six import StringIO from math import sqrt import matplotlib.pyplot as plot #read data into iterable target_url = ("http://archive.ics.uci.edu/ml/machine-learning-" "databases/wine-quality/winequality-red.csv") data = urllib2.urlopen(target_url) xList = [] labels = [] names = [] firstLine = True for line in data: if firstLine: names = line.strip().split(";") firstLine = False else: #split on semi-colon row = line.strip().split(";") #put labels in separate array labels.append(float(row[-1])) #remove label from row row.pop() #convert row to floats floatRow = [float(num) for num in row] xList.append(floatRow) nrows = len(xList) ncols = len(xList[0]) wineTree = DecisionTreeRegressor(max_depth=3) wineTree.fit(xList, labels) with open("wineTree.dot", 'w') as f: f = tree.export_graphviz(wineTree, out_file=f) #Note: The code above exports the trained tree info to a #Graphviz "dot" file. #Drawing the graph requires installing GraphViz and the running the #following on the command line #dot -Tpng wineTree.dot -o wineTree.png # In Windows, you can also open the .dot file in the GraphViz #gui (GVedit.exe)]

1.2　如何训练一个二元决策树

import numpy import matplotlib.pyplot as plot from sklearn import tree from sklearn.tree import DecisionTreeRegressor from sklearn.externals.six import StringIO #Build a simple data set with y = x + random nPoints = 100 #x values for plotting xPlot = [(float(i)/float(nPoints) - 0.5) for i in range(nPoints + 1)] #x needs to be list of lists. x = [[s] for s in xPlot] #y (labels) has random noise added to x-value #set seed numpy.random.seed(1) y = [s + numpy.random.normal(scale=0.1) for s in xPlot] plot.plot(xPlot,y) plot.axis('tight') plot.xlabel('x') plot.ylabel('y') plot.show() simpleTree = DecisionTreeRegressor(max_depth=1) simpleTree.fit(x, y) #draw the tree with open("simpleTree.dot", 'w') as f: f = tree.export_graphviz(simpleTree, out_file=f) #compare prediction from tree with true values yHat = simpleTree.predict(x) plot.figure() plot.plot(xPlot, y, label='True y') plot.plot(xPlot, yHat, label='Tree Prediction ', linestyle='--') plot.legend(bbox_to_anchor=(1,0.2)) plot.axis('tight') plot.xlabel('x') plot.ylabel('y') plot.show() simpleTree2 = DecisionTreeRegressor(max_depth=2) simpleTree2.fit(x, y) #draw the tree with open("simpleTree2.dot", 'w') as f: f = tree.export_graphviz(simpleTree2, out_file=f) #compare prediction from tree with true values yHat = simpleTree2.predict(x) plot.figure() plot.plot(xPlot, y, label='True y') plot.plot(xPlot, yHat, label='Tree Prediction ', linestyle='--') plot.legend(bbox_to_anchor=(1,0.2)) plot.axis('tight') plot.xlabel('x') plot.ylabel('y') plot.show() #split point calculations - try every possible split point to #find the best one sse = [] xMin = [] for i in range(1, len(xPlot)): #divide list into points on left and right of split point lhList = list(xPlot[0:i]) rhList = list(xPlot[i:len(xPlot)]) #calculate averages on each side lhAvg = sum(lhList) / len(lhList) rhAvg = sum(rhList) / len(rhList) #calculate sum square error on left, right and total lhSse = sum([(s - lhAvg) * (s - lhAvg) for s in lhList]) rhSse = sum([(s - rhAvg) * (s - rhAvg) for s in rhList]) #add sum of left and right to list of errors sse.append(lhSse + rhSse) xMin.append(max(lhList)) plot.plot(range(1, len(xPlot)), sse) plot.xlabel('Split Point Index') plot.ylabel('Sum Squared Error') plot.show() minSse = min(sse) idxMin = sse.index(minSse) print(xMin[idxMin]) #what happens if the depth is really high? simpleTree6 = DecisionTreeRegressor(max_depth=6) simpleTree6.fit(x, y) #too many nodes to draw the tree #with open("simpleTree2.dot", 'w') as f: # f = tree.export_graphviz(simpleTree6, out_file=f) #compare prediction from tree with true values yHat = simpleTree6.predict(x) plot.figure() plot.plot(xPlot, y, label='True y') plot.plot(xPlot, yHat, label='Tree Prediction ', linestyle='–') plot.legend(bbox_to_anchor=(1,0.2)) plot.axis('tight') plot.xlabel('x') plot.ylabel('y') plot.show()

1.3　决策树的训练等同于分割点的选择

通过递归分割获得更深的决策树

<p class="图题">图6-6　深度为2的决策树的预测曲线</p>

1.4　二元决策树的过拟合

二元决策树过拟合的度量

<p class="图题">图6-8　深度为6的决策树的预测曲线</p>

权衡二元决策树复杂度以获得最佳性能

import numpy import matplotlib.pyplot as plot from sklearn import tree from sklearn.tree import DecisionTreeRegressor from sklearn.externals.six import StringIO #Build a simple data set with y = x + random nPoints = 100 #x values for plotting xPlot = [(float(i)/float(nPoints) - 0.5) for i in range(nPoints + 1)] #x needs to be list of lists. x = [[s] for s in xPlot] #y (labels) has random noise added to x-value #set seed numpy.random.seed(1) y = [s + numpy.random.normal(scale=0.1) for s in xPlot] nrow = len(x) #fit trees with several different values for depth and use #x-validation to see which works best. depthList = [1, 2, 3, 4, 5, 6, 7] xvalMSE = [] nxval = 10 for iDepth in depthList: #build cross-validation loop to fit tree and evaluate on #out of sample data for ixval in range(nxval): #Define test and training index sets idxTest = [a for a in range(nrow) if a%nxval == ixval%nxval] idxTrain = [a for a in range(nrow) if a%nxval != ixval%nxval] #Define test and training attribute and label sets xTrain = [x[r] for r in idxTrain] xTest = [x[r] for r in idxTest] yTrain = [y[r] for r in idxTrain] yTest = [y[r] for r in idxTest] #train tree of appropriate depth and accumulate #out of sample (oos) errors treeModel = DecisionTreeRegressor(max_depth=iDepth) treeModel.fit(xTrain, yTrain) treePrediction = treeModel.predict(xTest) error = [yTest[r] - treePrediction[r] \ for r in range(len(yTest))] #accumulate squared errors if ixval == 0: oosErrors = sum([e * e for e in error]) else: #accumulate predictions oosErrors += sum([e * e for e in error]) #average the squared errors and accumulate by tree depth mse = oosErrors/nrow xvalMSE.append(mse) plot.plot(depthList, xvalMSE) plot.axis('tight') plot.xlabel('Tree Depth') plot.ylabel('Mean Squared Error') plot.show()

822 篇文章227 人订阅

0 条评论

相关文章

文本与序列的深度模型 | 深度学习笔记

Rare Event 与其他机器学习不同，在文本分析里，陌生的东西（rare event）往往是最重要的，而最常见的东西往往是最不重要的。 语法多义性 一个东西...

479100

39080

14330

250100

机器学习概念总结笔记（二）

logistic回归又称logistic回归分析，是一种广义的线性回归分析模型，常用于数据挖掘，疾病自动诊断，经济预测等领域。例如，探讨引发疾病的危险因素，并根...

88900

51940

12020

16540

11920