前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >李航《统计学习方法》决策树ID3算法实现

李航《统计学习方法》决策树ID3算法实现

作者头像
Coggle数据科学
发布2019-09-12 12:04:34
5620
发布2019-09-12 12:04:34
举报
文章被收录于专栏:Coggle数据科学Coggle数据科学

在开篇我们使用pandas、numpy和sklearn先对数据进行一些处理。

数据集选用《统计学习方法》中提供的,保存为csv文件。

代码语言:javascript
复制
age,work,hourse,loan,class
青年,否,否,一般,否
青年,否,否,好,否
青年,是,否,好,是
青年,是,是,一般,是
青年,否,否,一般,否
中年,否,否,一般,否
中年,否,否,好,否
中年,是,是,好,是
中年,否,是,非常好,是
中年,否,是,非常好,是
老年,否,是,非常好,是
老年,否,是,好,是
老年,是,否,好,是
老年,是,否,非常好,是
老年,否,否,一般,否

1、查看数据基本信息

代码语言:javascript
复制
import pandas as pd
import numpy as np
from sklearn import preprocessing

dataset = pd.read_csv('dataset.csv')
dataset.info()
代码语言:javascript
复制
#获取数据集的形状
n_data = dataset.shape[0]
# 得到变量列表,得到格式为list
cols = dataset.columns.tolist()

2、描述型变量转数值型变量

代码语言:javascript
复制
#创建obj_vals列表,并将描述型变量存入
for col in cols:
    if dataset[col].dtype == "object":
        obj_vars.append(col)
print(obj_vars)
代码语言:javascript
复制
# 将描述变量转化为数值型变量
# 并将转化为的数据附加到原始数据上
le = preprocessing.LabelEncoder()
for col in obj_vars:
    tran = le.fit_transform(dataset[col].tolist())
    tran_dataset = pd.DataFrame(tran, columns=['num_'+col])
    dataset = pd.concat([dataset, tran_dataset], axis=1)

当然对于决策树来说描述型转换为数值型不是必须的。


机器学习算法其实很古老,作为一个码农经常会不停的敲if, else if, else,其实就已经在用到决策树的思想了。只是你有没有想过,有这么多条件,用哪个条件特征先做if,哪个条件特征后做if比较优呢?怎么准确的定量选择这个标准就是决策树机器学习算法的关键了。1970年代,一个叫昆兰的大牛找到了用信息论中的熵来度量决策树的决策选择过程,方法一出,它的简洁和高效就引起了轰动,昆兰把这个算法叫做ID3。下面给出ID3算法的初始形式。

Decision Tree ID3算法初始形式

ID3算法:

代码语言:javascript
复制
import pandas as pd
import numpy as np
from math import log

def loadData(filename):
    '''
    输入:文件
    输出:csv数据集
    '''
    dataset = pd.read_csv("dataset.csv")
    return dataset

def calcShannonEnt(dataset):
    '''
    输入:数据集
    输出:数据集的香农熵
    描述:计算给定数据集的香农熵
    '''
    numEntries = dataset.shape[0]  
    labelCounts = {} 
    cols = dataset.columns.tolist() 
    classlabel = dataset[cols[-1]].tolist() 
    for currentlabel in classlabel:
        if currentlabel not in labelCounts.keys():
            labelCounts[currentlabel] = 1
        else:
            labelCounts[currentlabel] += 1

    ShannonEnt = 0.0

    for key in labelCounts:
        prob = labelCounts[key]/numEntries
        ShannonEnt -= prob*log(prob, 2)

    return ShannonEnt

def splitDataSet(dataset, axis, value):
    '''
    输入:数据集,所占列,选择值
    输出:划分数据集
    描述:按照给定特征划分数据集;选择所占列中等于选择值的项
    '''
    cols = dataset.columns.tolist()
    axisFeat = dataset[axis].tolist()
    #更新数据集
    retDataSet = pd.concat([dataset[featVec] for featVec in cols if featVec != axis], axis=1)
    i = 0
    dropIndex = [] #删除项的索引集
    for featVec in axisFeat:
        if featVec != value:
            dropIndex.append(i)
            i += 1
        else:
            i += 1
    newDataSet = retDataSet.drop(dropIndex)
    return newDataSet.reset_index(drop=True)


def chooseBestFeatureToSplit(dataset):
    '''
    输入:数据集
    输出:最好的划分特征
    描述:选择最好的数据集划分维度
    '''
    numFeatures = dataset.shape[1] - 1
    ShannonEnt = calcShannonEnt(dataset)
    bestInfoGain = 0.0
    bestFeature = -1
    cols = dataset.columns.tolist()
    for i in range(numFeatures):
        equalVals = set(dataset[cols[i]].tolist())
        newEntropy = 0.0
        for value in equalVals:
            subDataSet = splitDataSet(dataset, cols[i], value)
            prob = subDataSet.shape[0] / dataset.shape[0]
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = ShannonEnt - newEntropy
        print(cols[i],infoGain)
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = cols[i]
    return bestFeature, bestInfoGain

def majorityCnt(classList):
    '''
    输入:分类类别列表
    输出:子节点的分类
    描述:数据集已经处理了所有属性,但是类标签依然不是唯一的,
          采用多数判决的方法决定该子节点的分类
    '''
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reversed=True)
    return sortedClassCount[0][0]

def createTree(dataset, dropCol):
    '''
    输入:数据集,删除特征
    输出:决策树
    描述:递归构建决策树,利用上述的函数
    '''
    cols = dataset.columns.tolist()[:-1]
    classList = dataset[dataset.columns.tolist()[-1]].tolist()

    #若数据集中所有实例属于同一类Ck,则为单节点树,并将Ck作为该节点的类标记
    if classList.count(classList[0]) == len(classList):
        return classList[0]

    #若特征集为空集,则为单节点树,并将数据集中实例数最大的类Ck作为该节点的类标记
    if len(dataset[0:1]) == 0:
        return majorityCnt(classList)
    
    # dataset.drop(dropCol, axis=1, inplace=True)
    print('特征集和类别:',dataset.columns.tolist())
    bestFeature, bestInfoGain=chooseBestFeatureToSplit(dataset)
    print('bestFeture:',bestFeature)

    myTree = {bestFeature:{}}

    #del(labels[bestFeat])
    # 得到列表包括节点所有的属性值
    print(bestFeature)
    featValues = dataset[bestFeature]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        myTree[bestFeature][value] = createTree(splitDataSet(dataset, bestFeature, value), bestFeature)
    return myTree

def main():
    filename = "dataset.csv"
    dataset = loadData(filename)
    dropCol = []
    myTree = createTree(dataset, dropCol)
    print(myTree)

if __name__ == '__main__':
    main()

输出样例:

代码语言:javascript
复制
特征集和类别: ['age', 'work', 'hourse', 'loan', 'class']
age 0.08300749985576883
work 0.32365019815155627
hourse 0.4199730940219749
loan 0.36298956253708536
bestFeture: hourse
hourse
特征集和类别: ['age', 'work', 'loan', 'class']
age 0.2516291673878229
work 0.9182958340544896
loan 0.47385138961004514
bestFeture: work
work
{'hourse': {'是': '是', '否': {'work': {'是': '是', '否': '否'}}}}

ID3算法的不足:

ID3算法虽然提出了新思路,但是还是有很多值得改进的地方。  

  1. ID3没有考虑连续特征,比如长度,密度都是连续值,无法在ID3运用。这大大限制了ID3的用途。
  2. ID3采用信息增益大的特征优先建立决策树的节点。很快就被人发现,在相同条件下,取值比较多的特征比取值少的特征信息增益大。比如一个变量有2个值,各为1/2,另一个变量为3个值,各为1/3,其实他们都是完全不确定的变量,但是取3个值的比取2个值的信息增益大。
  3. ID3算法对于缺失值的情况没有做考虑
  4. 没有考虑过拟合的问题

写在最后:

由于ID3的不足,其作者昆兰对ID3算法进行了改进,并称其为C4.5算法。在后续文章将会对其进行实现。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • ID3算法:
  • 输出样例:
  • ID3算法的不足:
  • 写在最后:
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档