首页
学习
活动
专区
工具
TVP
发布

AI机器学习-决策树-python实现cart算法

CART算法主要通过基尼系数来做条件选择,基尼系数的计算公式如下:

其中pk表示第k个样本所占的比例

属性A划分的子集合的基尼系数为:

其中,V表示属性A的可能取值。

基于之前的代码,添加基尼系数的计算方法:

#计算基尼系数

defcalcGini(dataSet):

# 获取数据集的长度

numEntries = len(dataSet)

# 定义结果统计对象

resultCounts = {}

# 遍历数据集

fordataItemindataSet:

# 获取当前结果值

currentResult = dataItem[-1]

# 统计结果值个数,如果没有在集合中,统计值为0,否则+1

ifcurrentResultnot inresultCounts:

resultCounts[currentResult] = 0

resultCounts[currentResult] += 1

# 定义基尼系数

giniBase = 0.0

# 遍历结果统计集,计算基尼系数

forkeyinresultCounts:

# 计算某个结果所占的比率

prob = float(resultCounts[key]) / numEntries

# 计算基尼系数

giniBase += pow(prob,2)

print("giniBase:"+str(giniBase)+",prob:"+str(prob))

gini = 1 - giniBase

returngini

再添加计算属性基尼系数的方法:

#计算以某属性列为选择标准的基尼系数

#输入数据集dataSet,列索引columnIndex

defcalcAttrGini(dataSet,columnIndex):

# 获取第i列的数组信息

columnValues = [example[columnIndex]forexampleindataSet]

# 获取此列的无重复的属性特征值

distinctColumnValues = set(columnValues)

attrGini = 0.0

fordistinctColumnValueindistinctColumnValues:

subDataSet = splitDataSet(dataSet, columnIndex, distinctColumnValue)

prob = len(subDataSet) / float(len(dataSet))

entItem = prob * calcGini(subDataSet)

print("giniItem:"+str(entItem)+",prob:"+str(prob)+",subGini:"+str(calcGini(subDataSet)))

attrGini += entItem

returnattrGini

然后添加通过基尼系数做特征选择,选择基尼系数小的优先

#用CART算法,通过信息增益率做特征选择,返回属性列index

defchooseBestFeatureToSplit3(dataSet):

numFeatures = len(dataSet[0]) - 1 # 获取数据集列数,因为要从0开始所以减1

bestGini = 0.0

bestFeature = -1

#从0开始遍历数据集列

forcolumnIndexinrange(numFeatures):

attrGini = calcAttrGini(dataSet,columnIndex)

print("columIndex:"+str(columnIndex)+"attrGini:"+str(attrGini))

ifattrGini

bestGini = attrGini

bestFeature = columnIndex

returnbestFeature

最后重构决策树创建方法,添加CART算法支持

ifmethod =="CART":

bestFeatColumnIndex = chooseBestFeatureToSplit3(dataSet) # CART算法获取最佳分支属性列索引

最后在主方法中进行调用测试

myTree = createTree(data, label,"CART")

运行后生成的树结构,可以发现这个树结构与ID3算法一致,可见CART算法也有它的局限性

关注公众号“挨踢学霸”,获取更多技术图文,视频教程

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20180510A20AV100?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券