首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

AI机器学习-决策树-python实现CART算法

CART算法主要通过基尼系数来做条件选择,基尼系数的计算公式如下:

其中pk表示第k个样本所占的比例

属性A划分的子集合的基尼系数为:

其中,V表示属性A的可能取值。

基于之前的代码,添加基尼系数的计算方法:

#计算基尼系数

defcalcGini(dataSet):

#获取数据集的长度

numEntries =len(dataSet)

#定义结果统计对象

resultCounts = {}

#遍历数据集

fordataItemindataSet:

#获取当前结果值

currentResult = dataItem[-1]

#统计结果值个数,如果没有在集合中,统计值为0,否则+1

ifcurrentResultnot inresultCounts:

resultCounts[currentResult] =

resultCounts[currentResult] +=1

#定义基尼系数

giniBase =0.0

#遍历结果统计集,计算基尼系数

forkeyinresultCounts:

#计算某个结果所占的比率

prob =float(resultCounts[key])/ numEntries

#计算基尼系数

giniBase +=pow(prob,2)

print("giniBase:"+str(giniBase)+",prob:"+str(prob))

gini =1-giniBase

returngini

再添加计算属性基尼系数的方法:

#计算以某属性列为选择标准的基尼系数

#输入数据集dataSet,列索引columnIndex

defcalcAttrGini(dataSet,columnIndex):

#获取第i列的数组信息

columnValues = [example[columnIndex]forexampleindataSet]

#获取此列的无重复的属性特征值

distinctColumnValues =set(columnValues)

attrGini =0.0

fordistinctColumnValueindistinctColumnValues:

subDataSet = splitDataSet(dataSet, columnIndex, distinctColumnValue)

prob =len(subDataSet) /float(len(dataSet))

entItem = prob * calcGini(subDataSet)

print("giniItem:"+str(entItem)+",prob:"+str(prob)+",subGini:"+str(calcGini(subDataSet)))

attrGini += entItem

returnattrGini

然后添加通过基尼系数做特征选择,选择基尼系数小的优先

#用CART算法,通过信息增益率做特征选择,返回属性列index

defchooseBestFeatureToSplit3(dataSet):

numFeatures =len(dataSet[]) -1#获取数据集列数,因为要从0开始所以减1

bestGini =0.0

bestFeature = -1

#从0开始遍历数据集列

forcolumnIndexinrange(numFeatures):

attrGini = calcAttrGini(dataSet,columnIndex)

print("columIndex:"+str(columnIndex)+"attrGini:"+str(attrGini))

ifattrGini

bestGini = attrGini

bestFeature = columnIndex

returnbestFeature

最后重构决策树创建方法,添加CART算法支持

ifmethod =="CART":

bestFeatColumnIndex = chooseBestFeatureToSplit3(dataSet)# CART算法获取最佳分支属性列索引

最后在主方法中进行调用测试

myTree = createTree(data, label,"CART")

运行后生成的树结构,可以发现这个树结构与ID3算法一致,可见CART算法也有它的局限性

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20180511G0CEVW00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券