专栏首页杨熹的专栏决策树的python实现

决策树的python实现

本文结构:
  1. 是什么?
  2. 有什么算法?
  3. 数学原理?
  4. 编码实现算法?

1. 是什么?

简单地理解,就是根据一些 feature 进行分类,每个节点提一个问题,通过判断,将数据分为几类,再继续提问。这些问题是根据已有数据学习出来的,再投入新数据的时候,就可以根据这棵树上的问题,将数据划分到合适的叶子上。


2. 有什么算法?

常用的几种决策树算法有ID3、C4.5、CART:

ID3:选择信息熵增益最大的feature作为node,实现对数据的归纳分类。 C4.5:是ID3的一个改进,比ID3准确率高且快,可以处理连续值和有缺失值的feature。 CART:使用基尼指数的划分准则,通过在每个步骤最大限度降低不纯洁度,CART能够处理孤立点以及能够对空缺值进行处理。


3. 数学原理?

ID3: Iterative Dichotomiser 3

参考

下面这个数据集,可以同时被上面两颗树表示,结果是一样的,而我们更倾向于选择简单的树。 那么怎样做才能使得学习到的树是最简单的呢?

下面是 ID3( Iterative Dichotomiser 3 )的算法:

例如下面数据集,哪个是最好的 Attribute?

用熵Entropy来衡量: E(S) 是数据集S的熵 i 指每个结果,即 No,Yes的概率

E越大意味着信息越混乱,我们的目标是要让E最小。 E在0-1之间,如果P+的概率在0.5, 此时E最大,这时候说明信息对我们没有明确的意义,对分类没有帮助。

但是我们不仅仅想要变量的E最小,还想要这棵树是 well organized。 所以用到 Gain:信息增益

意思是如果我后面要用这个变量的话,它的E会减少多少。

例如下面的数据集:

  1. 先计算四个feature的熵E,及其分支的熵,然后用Gain的公式计算信息增益。
  1. 再选择Gain最大的特征是 outlook。
  2. 第一层选择出来后,各个分支再继续选择下一层,计算Gain最大的,例如分支 sunny 的下一层节点是 humidity。

详细的计算步骤可以参考这篇博文。


C4.5

参考

ID3有个局限是对于有大量数据的feature过于敏感,C4.5是它的一个改进,通过选择最大的信息增益率 gain ratio 来选择节点。而且它可以处理连续的和有缺失值的数据。

P’ (j/p) is the proportion of elements present at the position p, taking the value of j-th test.

例如 outlook 作为第一层节点后,它有 3 个分支,分别有 5,4,5 条数据,则 SplitInfo(5,4,5) = -5/14log(5,14)-4/14log(4,14)-5/14(5,14) ,其中 log(5,14) 即为 log2(5/14)。

下面是一个有连续值和缺失值的例子:

连续值

第一步计算 Gain,除了连续值的 humudity,其他步骤和前文一样。

要计算 humudity 的 Gain 的话,先把所有值升序排列: {65, 70, 70, 70, 75, 78, 80, 80, 80, 85, 90, 90, 95, 96} 然后把重复的去掉: {65, 70, 75, 78, 80, 85, 90, 95, 96} 如下图所示,按区间计算 Gain,然后选择最大的 Gain (S, Humidity) = 0.102

因为 Gain(S, Outlook) = 0 .246,所以root还是outlook:

缺失值

处理有缺失值的数据时候,用下图的公式:

例如 D12 是不知道的。

  1. 计算全集和 outlook 的 info,
  1. 其中几个分支的熵如下,再计算出 outlook 的 Gain:

比较一下 ID3 和 C4.5 的准确率和时间:

accuracy :

execution time:


4. 编码实现算法?

代码可以看《机器学习实战》这本书和这篇博客。

完整代码可以在 github 上查看。

接下来以 C4.5 的代码为例:

** 1. 定义数据:**

def createDataSet():
    dataSet = [[0, 0, 0, 0, 'N'], 
               [0, 0, 0, 1, 'N'], 
               [1, 0, 0, 0, 'Y'], 
               [2, 1, 0, 0, 'Y'], 
               [2, 2, 1, 0, 'Y'], 
               [2, 2, 1, 1, 'N'], 
               [1, 2, 1, 1, 'Y']]
    labels = ['outlook', 'temperature', 'humidity', 'windy']
    return dataSet, labels

** 2. 计算熵:**

def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1      # 数每一类各多少个, {'Y': 4, 'N': 3}
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt

** 3. 选择最大的gain ratio对应的feature:**

def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1                 #feature个数
    baseEntropy = calcShannonEnt(dataSet)             #整个dataset的熵
    bestInfoGainRatio = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]  #每个feature的list
        uniqueVals = set(featList)                      #每个list的唯一值集合                 
        newEntropy = 0.0
        splitInfo = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)  #每个唯一值对应的剩余feature的组成子集
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
            splitInfo += -prob * log(prob, 2)
        infoGain = baseEntropy - newEntropy              #这个feature的infoGain
        if (splitInfo == 0): # fix the overflow bug
            continue
        infoGainRatio = infoGain / splitInfo             #这个feature的infoGainRatio      
        if (infoGainRatio > bestInfoGainRatio):          #选择最大的gain ratio
            bestInfoGainRatio = infoGainRatio
            bestFeature = i                              #选择最大的gain ratio对应的feature
    return bestFeature

** 4. 划分数据,为下一层计算准备: **

def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:                      #只看当第i列的值=value时的item
            reduceFeatVec = featVec[:axis]              #featVec的第i列给除去
            reduceFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reduceFeatVec)            
    return retDataSet

** 5. 多重字典构建树:**

def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]         # ['N', 'N', 'Y', 'Y', 'Y', 'N', 'Y']
    if classList.count(classList[0]) == len(classList):
        # classList所有元素都相等,即类别完全相同,停止划分
        return classList[0]                                  #splitDataSet(dataSet, 0, 0)此时全是N,返回N
    if len(dataSet[0]) == 1:                                 #[0, 0, 0, 0, 'N'] 
        # 遍历完所有特征时返回出现次数最多的
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)             #0-> 2   
        # 选择最大的gain ratio对应的feature
    bestFeatLabel = labels[bestFeat]                         #outlook -> windy     
    myTree = {bestFeatLabel:{}}                   
        #多重字典构建树{'outlook': {0: 'N'
    del(labels[bestFeat])                                    #['temperature', 'humidity', 'windy'] -> ['temperature', 'humidity']        
    featValues = [example[bestFeat] for example in dataSet]  #[0, 0, 1, 2, 2, 2, 1]     
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]                                #['temperature', 'humidity', 'windy']
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
            # 划分数据,为下一层计算准备
    return myTree

** 6. 可视化决策树的结果: **

dataSet, labels = createDataSet()
labels_tmp = labels[:]
desicionTree = createTree(dataSet, labels_tmp)
treePlotter.createPlot(desicionTree)

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 如何自动生成文本摘要

    学习资料: https://www.youtube.com/watch?v=ogrJaOIuBx4&list=PL2-dafEMk2A7YdKv4XfKpfb...

    杨熹
  • Q-learning 的 python 实现

    通过前面的几篇文章可以知道,当我们要用 Q-learning 解决一个问题时,首先需要知道这个问题有多少个 state,每个 state 有多少 action,...

    杨熹
  • 《提问的艺术》

    《提问的艺术》 for 沟通 练习场景 解决问题,倾听 [What] 从问句开始,而不是阐述或命令 问一些最基本的问题 [How] 封闭式:问具体行动是...

    杨熹
  • 【机器学习笔记之二】决策树的python实现

    本文结构: 是什么? 有什么算法? 数学原理? 编码实现算法? ---- 1. 是什么? 简单地理解,就是根据一些 feature 进行分类,每个节点提一个问题...

    Angel_Kitty
  • 个性化推荐系统(二)---构建推荐引擎

      当下推荐系统包含的层级特别的多,整个线上推荐系统包含:最上层线上推荐服务、中层各个推荐数据召回集(数据主题、分类池子)、底层各种推荐模型。        ...

    杉枫
  • 百度介绍测试人工智能模型稳健性的对抗工具箱

    不管人工智能和机器学习系统在生产中宣称的稳健性如何,没有一个系统能够完全抵御对手的攻击,也没有一个技术能够通过恶意输入来愚弄算法。结果表明,即使在图像上产生很小...

    AiTechYun
  • Java 正则表达式的捕获组

    从正则表达式左侧开始,每出现一个左括号"("记做一个分组,分组编号从 1 开始。0 代表整个表达式。

    编程范 源代码公司
  • 企业级 IM 工具要更加“外向”| 研报×To B

    T客汇官网:www.tikehui.com 撰文|丁兆增 ? 这里是移动信息化研究中心在 T 客汇上的研报专栏。我们每周针对企业服务领域,进行深度解读。 -...

    人称T客
  • 企业级 IM 工具要更加“外向”| 研报×To B

    T客汇官网:www.tikehui.com 撰文|杨丽 ? 移动信息化研究中心 分析师 丁兆增 这里是移动信息化研究中心在 T 客汇上的研报专栏。我们每周针对企...

    人称T客
  • 聊聊spring cloud的RequestHeaderToRequestUriGatewayFilter

    本文主要研究一下spring cloud的RequestHeaderToRequestUriGatewayFilter

    codecraft

扫码关注云+社区

领取腾讯云代金券