前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >决策树(一)

决策树(一)

作者头像
用户6021899
发布2019-08-14 16:57:55
7000
发布2019-08-14 16:57:55
举报
文章被收录于专栏:Python编程 pyqt matplotlib

你是否玩过20个问题的游戏? 游戏的规则很简单:参与游戏的一方在脑海里想着某个事物,其它参与者想他提问题,最多只允许提20个问题,问题的答案也只能用“对”或者“错”回答。问问题的人通过推断分解,逐步缩小猜测事物的范围。

决策树的工作原理与之相似,用户输入一系列数据,机器给出分类答案。下面的流程图就是一个简单的决策树。矩形代表判断节点。椭圆代表叶子节点,表示已得出结论,可以终止运行。从判断节点引出的左右箭头称作分支,它指向另一个判断节点或者叶子节点。

决策树适用于标称型数据,因此数值型数据必须先离散化。决策树的主要优势在于数据形式非常容易理解。它的一个重要任务是提取数据中所蕴含的知识信息。因此决策树可以使用不熟悉的数据集,并从中提取一系列规则,这就是机器学习的过程。

在构造决策树时,我们需要解决的第一个问题是,当前数据集上那个特征在划分数据分类时起决定作用,即先用那个特征进行分类效率最高。为了找到决定性的特征,划分出最好的结果,我们需评估每一个特征。之后,原始数据集就被划分为几个数据子集。这些数据子集会分布在第一个决策点的所有分支上。如果某个分支下的数据全部属于同一类型,在该分支已完成了分类,无需做进一步分割,否则就要重复 划分数据子集的过程(递归)。直到所有具有相同类型的数据均在一个数据子集内。

我们以下面这个简单的水中生物分类的数据集为例,介绍决策树算法的基本流程。

首先创建数据集:

代码语言:javascript
复制
def createDataset():
    '''创建一个简单的数据集'''
    dataset = [ ["yes", "yes", "fish"],
                ["yes", "yes","fish"],
                ["yes", "no", "nonfish"],
                ["no", "yes", "nonfish"],
                ["no", "yes", "nonfish"]]
    featnames =["no surfacing", "flippers"]#特征 名
    return dataset, featnames

之后我们需要划分数据集。但如何寻找划当前分数据集的最好的特征呢?标准是什么?划分数据集的最大原则是:将无序的数据变得更加有序。组织杂乱无章的数据的一种方法是 使用信息论度量信息。

集合信息的度量方式成为香农熵,或者简称为(Entropy), 这个名字来源于信息论支付 克劳德·香农。熵定义为信息的期望值,在明晰熵的定义之前,我们需直到信息的定义。如果待分类的事物可能划分在多个分类之中,则对应第i个分类的信息定义为:

其中,

为选择该分类的概率。

则香农熵为所有类别包含的信息的期望值:

例如,若只有一个分类,则概率为1,熵为0,此时熵最小。若有100个事物,类别各不相同,则分到每个类别的概率均为0.01,熵为 -100*0.01*log2(0.01), 约等于6.644。

计算数据集的熵的代码如下:

代码语言:javascript
复制
from math import log

def calcEntropy(dataset):
    '''计算给定数据集的 香农熵'''
    numSamples = len(dataset) #样本(特征向量)个数
    labelCounts = dict()
    for sample in dataset :
        currentLabel = sample[-1]
        #有则 +1, 无则 0+1
        labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) +1 
    entropy = 0.0
    for key in labelCounts:
        prob =  float(labelCounts[key]) / numSamples #按不同分类标签的数量计算概率
        entropy -= prob * log(prob, 2) #计算香农 熵
    return entropy

经计算,上述水中生物分类的数据集的熵值为 0.97095。

划当前分数据集的最好的特征就是使信息增益(熵的减少量)最大的那个特征。下面的代码使用循环找出使信息增益最大的那个特征的索引:

代码语言:javascript
复制
def splitDataset(dataset, axis, value):
    '''划分数据集
    3个输入参数分别为:待划分的数据集、待划分特征的索引,用于划分的特征的值'''
    retDataset = []
    for featVec in dataset: #for 数据集中每个样本(特征向量)
        if featVec [axis] == value:
            reducedFeatVec = featVec[: axis]
            reducedFeatVec.extend(featVec[axis+1 : ])
            retDataset.append(reducedFeatVec)
    return retDataset
def chooseBestFeatureToSplit(dataset):
    '''选择最好的(最大化信息增益)数据集划分方式'''
    numFeatures =len(dataset[0]) -1 # 特征个数(列数 减掉分类标签所占一列)
    baseEntropy =  calcEntropy(dataset)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [sample[i] for sample in dataset]#第i个特征的所有特征的值
        uniqueValues = set(featList) # 通过列表转集合去重,得到第i个特征的值的集合
        newEntropy = 0.0
        for value in uniqueValues :
            subDataset =  splitDataset(dataset, i, value)
            prob = len(subDataset)/ float(len(dataset))
            newEntropy += prob * calcEntropy(subDataset)
            #print(value, prob, newEntropy)
            
        infoGain = baseEntropy  - newEntropy
        if  infoGain > bestInfoGain :
            bestInfoGain = infoGain
            bestFeatureAxis = i
    return bestFeatureAxis

现在,我们依据最好的特征就可以依靠递归调用得出决策树的全部结构。本例中,决策树的数据结构用 嵌套的字典来表示。

代码语言:javascript
复制
def majorityCnt(classList):
    classCount = dict()
    for vote in classList:
        classCount[vote] = classCount.get(vote, 0) +1 #有则 +1, 无则 0+1
    import operator
    #对键值对组成的列表按值从大到小排序
    sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse =True)
    #返回投票最多的分类标签名 。(其实不必全部排序)
    return sortedClassCount[0][0]
    
def createTree(dataset, featnames):
    classList = [sample[-1] for sample in dataset] # 类别 列表
    if classList.count(classList[0]) == len(classList) :#类别完全相同则停止划分
        return classList[0]
    if len(dataset[0]) == 1: # 遍历完所有特征,则返回出现次数最多的类别
        return majorityCnt(classList)
    bestFeatureAxis = chooseBestFeatureToSplit(dataset)
    bestFeatureName = featnames[bestFeatureAxis]
    myTree  = {bestFeatureName: {}}
    del featnames[bestFeatureAxis] #删除最佳特征名
    featValues = [sample[bestFeatureAxis] for sample in dataset]
    uniqueValues = set(featValues) #集合去重
    for value in uniqueValues:
        subFeatnames = featnames[:] #深拷贝
        myTree[bestFeatureName][value] = createTree(splitDataset(dataset, bestFeatureAxis, value),
                                                    subFeatnames)
    return myTree

调用createTree() 函数,即可得到本例数据集对应的决策树字典为:

不够直观对不对?下面的代码是用matplotlib画出决策树(入口函数是 createPlot()):

代码语言:javascript
复制
def getNumLeafs(myTree):
    '''返回叶子节点的数目(树的最大宽度)'''
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == "dict":#数据类型为字典(还有子树)
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs                                                 
   
def getTreeDepth(myTree):
    '''返回树的最大深度'''
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == "dict":#数据类型为字典(还有子树)
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth
    
import matplotlib.pyplot as plt
# maptlot annotate 的 bbox的 属性字典
decisionNode = dict(boxstyle ="sawtooth", facecolor = "orange",edgecolor = "orange")
leafNode = dict(boxstyle = "round4", facecolor = "lime")
arrow_args = dict(arrowstyle = "<-", color ="r")

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy =parentPt, xycoords = "axes fraction",
                            xytext = centerPt, textcoords ="axes fraction", va ="center",
                            ha = "center", bbox = nodeType, color ="black",weight ="bold",
                            arrowprops = arrow_args)
                            
def plotMidText(cntrPt, parentPt, textString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, textString)
    
def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + numLeafs) /2.0 / plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr +" ?", cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 /plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == "dict":
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD

def createPlot(inTree):
    fig = plt.figure(1, facecolor = 'white')
    fig.clf()
    axprops = dict(xticks = [], yticks = []) #不显示x轴和y轴的刻度
    createPlot.ax1 = plt.subplot(111, frameon= False, ** axprops)
    
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5,1.0), '')
    plt.show() 
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-05-19,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Python可视化编程机器学习OpenCV 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档