专栏首页机器学习算法与Python学习机器学习(9)之ID3算法详解及python实现

机器学习(9)之ID3算法详解及python实现

关键字全网搜索最新排名

【机器学习算法】:排名第一

【机器学习】:排名第二

【Python】:排名第三

【算法】:排名第四

前言

决策树算法在机器学习中算是很经典的算法系列。它既可以作为分类算法,也可以作为回归算法,同时也特别适合集成学习比如随机森林。本文就对决策树算法ID3思想做个总结。

ID3算法的信息论基础

1970年代,一个叫昆兰的大牛找到了用信息论中的熵来度量决策树的决策选择过程,它的简洁和高效就引起了轰动,昆兰把这个算法叫做ID3。下面我们就看看ID3算法是怎么选择特征的。

首先,我们需要熟悉信息论中熵的概念。熵度量了事物的不确定性,越不确定的事物,它的熵就越大。具体的,随机变量X的熵的表达式如下:

其中n代表X的n种不同的离散取值。而pi代表了X取值为i的概率,log为以2或者e为底的对数。

熟悉了一个变量X的熵,很容易推广到多个个变量的联合熵,这里给出两个变量X和Y的联合熵表达式:

有了联合熵,又可以得到条件熵的表达式H(X|Y),条件熵类似于条件概率,它度量了我们的X在知道Y以后剩下的不确定性。表达式如下:

上面刚才提到H(X)度量了X的不确定性,条件熵H(X|Y)度量了我们在知道Y以后X剩下的不确定性,那么H(X)-H(X|Y)呢?它度量了X在知道Y以后不确定性减少程度,这个度量我们在信息论中称为互信息,记为I(X,Y)。在决策树ID3算法中叫做信息增益。ID3算法就是用信息增益来判断当前节点应该用什么特征来构建决策树。信息增益大,则越适合用来分类。

下面这个图可以比较清晰的反映他们之间的关系。左边的椭圆代表H(X),右边的椭圆代表H(Y),中间重合的部分就是我们的互信息或者信息增益I(X,Y), 左边的椭圆去掉重合部分就是H(X|Y),右边的椭圆去掉重合部分就是H(Y|X)。两个椭圆的并就是H(X,Y)。

ID3算法的思路

上面提到ID3算法就是用信息增益大小来判断当前节点应该用什么特征来构建决策树,用计算出的信息增益最大的特征来建立决策树的当前节点。这里我们举一个信息增益计算的具体的例子。比如我们有15个样本D,输出为0或者1。其中有9个输出为0, 6个输出为1。 样本中有个特征A,取值为A1,A2和A3。在取值为A1的样本的输出中,有3个输出为1, 2个输出为0,取值为A2的样本输出中,2个输出为1,3个输出为0, 在取值为A3的样本中,4个输出为1,1个输出为0.

输入的是m个样本,样本输出集合为D,每个样本有n个离散特征,特征集合即为A,输出为决策树T。

算法的过程为:

  1)初始化信息增益的阈值ϵ

  2)判断样本是否为同一类输出Di,如果是则返回单节点树T。标记类别为Di

  3) 判断特征是否为空,如果是则返回单节点树T,标记类别为样本中输出类别D实例数最多的类别。

  4)计算A中的各个特征(一共n个)对输出D的信息增益,选择信息增益最大的特征Ag

  5) 如果Ag的信息增益小于阈值ϵ,则返回单节点树T,标记类别为样本中输出类别D实例数最多的类别。

  6)否则,按特征Ag的不同取值Agi将对应的样本输出D分成不同的类别Di。每个类别产生一个子节点。对应特征值为Agi。返回增加了节点的数T。

  7)对于所有的子节点,令D=Di,A=A−{Ag}递归调用2-6步,得到子树Ti并返回。

ID3算法的不足

ID3算法虽然提出了新思路,但是还是有很多值得改进的地方。  

  a) ID3没有考虑连续特征,比如长度,密度都是连续值,无法在ID3运用。这大大限制了ID3的用途。

  b) ID3采用信息增益大的特征优先建立决策树的节点。很快就被人发现,在相同条件下,取值比较多的特征比取值少的特征信息增益大。比如一个变量有2个值,各为1/2,另一个变量为3个值,各为1/3,其实他们都是完全不确定的变量,但是取3个值的比取2个值的信息增益大。如果校正这个问题呢?

  c) ID3算法对于缺失值的情况没有做考虑

  d) 没有考虑过拟合的问题

ID3 算法的作者昆兰基于上述不足,对ID3算法做了改进,这就是C4.5算法。

python实现

实验数据来自【美】Peter Harrington 写的《Machine Learning in Action》

下载链接:http://pan.baidu.com/s/1jIR4wdg

或者加入微信交流群下载(二维码在文末)

trees.py

from math import log

import operator

def createDataSet():

dataSet = [[1, 1, 'yes'],

[1, 1, 'yes'],

[1, 0, 'no'],

[0, 1, 'no'],

[0, 1, 'no']]

labels = ['no surfacing','flippers']

#change to discrete values

return dataSet, labels

def calcShannonEnt(dataSet):

numEntries = len(dataSet)

labelCounts = {}

for featVec in dataSet: #the the number of unique elements and their occurance

currentLabel = featVec[-1]

if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0

labelCounts[currentLabel] += 1

shannonEnt = 0.0

for key in labelCounts:

prob = float(labelCounts[key])/numEntries

shannonEnt -= prob * log(prob,2) #log base 2

return shannonEnt

def splitDataSet(dataSet, axis, value):

retDataSet = []

for featVec in dataSet:

if featVec[axis] == value:

reducedFeatVec = featVec[:axis] #chop out axis used for splitting

reducedFeatVec.extend(featVec[axis+1:])

retDataSet.append(reducedFeatVec)

return retDataSet

def chooseBestFeatureToSplit(dataSet):

numFeatures = len(dataSet[0]) - 1 #the last column is used for the labels

baseEntropy = calcShannonEnt(dataSet)

bestInfoGain = 0.0; bestFeature = -1

for i in range(numFeatures): #iterate over all the features

featList = [example[i] for example in dataSet]#create a list of all the examples of this feature

uniqueVals = set(featList) #get a set of unique values

newEntropy = 0.0

for value in uniqueVals:

subDataSet = splitDataSet(dataSet, i, value)

prob = len(subDataSet)/float(len(dataSet))

newEntropy += prob * calcShannonEnt(subDataSet)

infoGain = baseEntropy - newEntropy #calculate the info gain; ie reduction in entropy

if (infoGain > bestInfoGain): #compare this to the best gain so far

bestInfoGain = infoGain #if better than current best, set to best

bestFeature = i

return bestFeature #returns an integer

def majorityCnt(classList):

classCount={}

for vote in classList:

if vote not in classCount.keys(): classCount[vote] = 0

classCount[vote] += 1

sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)

return sortedClassCount[0][0]

def createTree(dataSet,labels):

classList = [example[-1] for example in dataSet]

if classList.count(classList[0]) == len(classList):

return classList[0]#stop splitting when all of the classes are equal

if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet

return majorityCnt(classList)

bestFeat = chooseBestFeatureToSplit(dataSet)

bestFeatLabel = labels[bestFeat]

myTree = {bestFeatLabel:{}}

del(labels[bestFeat])

featValues = [example[bestFeat] for example in dataSet]

uniqueVals = set(featValues)

for value in uniqueVals:

subLabels = labels[:] #copy all of labels, so trees don't mess up existing labels

myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)

return myTree

def classify(inputTree,featLabels,testVec):

firstStr = inputTree.keys()[0]

secondDict = inputTree[firstStr]

featIndex = featLabels.index(firstStr)

key = testVec[featIndex]

valueOfFeat = secondDict[key]

if isinstance(valueOfFeat, dict):

classLabel = classify(valueOfFeat, featLabels, testVec)

else: classLabel = valueOfFeat

return classLabel

def storeTree(inputTree,filename):

import pickle

fw = open(filename,'w')

pickle.dump(inputTree,fw)

fw.close()

def grabTree(filename):

import pickle

fr = open(filename)

return pickle.load(fr)

treePlotter.py

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")

leafNode = dict(boxstyle="round4", fc="0.8")

arrow_args = dict(arrowstyle="<-")

def getNumLeafs(myTree):

numLeafs = 0

firstStr = myTree.keys()[0]

secondDict = myTree[firstStr]

for key in secondDict.keys():

if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes

numLeafs += getNumLeafs(secondDict[key])

else: numLeafs +=1

return numLeafs

def getTreeDepth(myTree):

maxDepth = 0

firstStr = myTree.keys()[0]

secondDict = myTree[firstStr]

for key in secondDict.keys():

if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes

thisDepth = 1 + getTreeDepth(secondDict[key])

else: thisDepth = 1

if thisDepth > maxDepth: maxDepth = thisDepth

return maxDepth

def plotNode(nodeTxt, centerPt, parentPt, nodeType):

createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',

xytext=centerPt, textcoords='axes fraction',

va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

def plotMidText(cntrPt, parentPt, txtString):

xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]

yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]

createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on

numLeafs = getNumLeafs(myTree) #this determines the x width of this tree

depth = getTreeDepth(myTree)

firstStr = myTree.keys()[0] #the text label for this node should be this

cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)

plotMidText(cntrPt, parentPt, nodeTxt)

plotNode(firstStr, cntrPt, parentPt, decisionNode)

secondDict = myTree[firstStr]

plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD

for key in secondDict.keys():

if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes

plotTree(secondDict[key],cntrPt,str(key)) #recursion

else: #it's a leaf node print the leaf node

plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW

plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)

plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))

plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

#if you do get a dictonary you know it's a tree, and the first element will be another dict

def createPlot(inTree):

fig = plt.figure(1, facecolor='white')

fig.clf()

axprops = dict(xticks=[], yticks=[])

createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks

#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses

plotTree.totalW = float(getNumLeafs(inTree))

plotTree.totalD = float(getTreeDepth(inTree))

plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;

plotTree(inTree, (0.5,1.0), '')

plt.show()

#def createPlot():

# fig = plt.figure(1, facecolor='white')

# fig.clf()

# createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses

# plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)

# plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)

# plt.show()

def retrieveTree(i):

listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},

{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}

]

return listOfTrees[i]

#createPlot(thisTree)

测试代码:

import trees

import treePlotter

fr = open('lenses.txt')

lenses = [inst.strip().split('\t') for inst in fr.readlines()]

lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']

lensesTree = trees.createTree(lenses, lensesLabels)

lensesTree

treePlotter.createPlot(lensesTree)

参考:

1. 周志华《机器学习》

2. http://www.cnblogs.com/pinard/p/6018889.html

本文分享自微信公众号 - 机器学习算法与Python学习(guodongwei1991)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2017-08-08

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 7 种简洁 Python 语法,教你码出一手好代码

    Python 是一门用途广泛、易读、而且容易入门的编程语言。但同时 python 语法也允许我们做一些很奇怪的事情。

    昱良
  • 1 行Python代码能干哪些事,这 13个你知道吗?

    来源:http://www.techug.com/post/what-can-a-line-of-python-code-do.html

    昱良
  • 机器学习(12)之决策树总结与python实践(~附源码链接~)

    关键字全网搜索最新排名 【机器学习算法】:排名第一 【机器学习】:排名第二 【Python】:排名第三 【算法】:排名第四 前言 在(机器学习(9)之ID3算法...

    昱良
  • 90% 的人说 Python 程序慢,5 大神招让你的代码像赛车一样跑起来

    很多人抱怨说自己写的 Python 代码跑的慢,尤其是当处理的数据集比较大的时候,其实稍微改动几行代码就可以让你的代码性能提高好几倍,不信一起来看下面这个 5 ...

    崔庆才
  • EOS 区块链数据实时异构到 MongoDB

    执行 eosio_build.sh 脚本编译 nodeos 会默认安装 mongodb,但是从 Dawn 4.0 开始, mongo_db_plugin 插件不...

    robinwen
  • LeCun:就通用智能而言,人工智能甚至还不如老鼠

    新智元编译 来源:PJ Media 编译整理: 张黔 弗格森 【新智元导读】AI到底有多智能?斯坦福大学、麻省理工学院、SRI International和其...

    企鹅号小编
  • tomcat nio源码分析

    1 NioEndpoint.Acceptor等待客户端连接,客户端连接之后将SocketChannel转发给Poller

    东营浪人
  • 快速学习-电影推荐系统设计(离线推荐模块)

    cwl_java
  • Android实现图文垂直跑马灯效果

    最近在维护老项目,老项目有一个地方需要修改,就是垂直跑马灯的问题,之前的垂直跑马灯是只有文字跑马灯,新版需要加上。

    砸漏
  • 移动互联网进入小巨头时代!张一鸣、王兴和程维谁是下一个大佬

    TMD是今日头条、美团点评和滴滴的代称,也是移动互联网时代崛起的三个“小巨头”。其中,美团点评和滴滴都由腾讯主要投资,而且借助腾讯的流量资源快速发展起来。今日头...

    光荣与梦想1987

扫码关注云+社区

领取腾讯云代金券