机器学习--决策树(ID3)算法及案例

1

基本原理

决策树是一个预测模型。它代表的是对象属性与对象值之间的一种映射关系。树中每个节点表示某个对象,每个分支路径代表某个可能的属性值,每个叶结点则对应从根节点到该叶节点所经历的路径所表示的对象的值。一般情况下,决策树由决策结点、分支路径和叶结点组成。在选择哪个属性作为结点的时候,采用信息论原理,计算信息增益,获得最大信息增益的属性就是最好的选择。信息增益是指原有数据集的熵减去按某个属性分类后数据集的熵所得的差值。然后采用递归的原则处理数据集,并得到了我们需要的决策树。

2

算法流程

检测数据集中的每个子项是否属于同一分类:

If 是,则返回类别标签;

Else

计算信息增益,寻找划分数据集的最好特征

划分数据数据集

创建分支节点(叶结点或决策结点)

for 每个划分的子集

递归调用,并增加返回结果到分支节点中

return 分支结点

算法的基本思想可以概括为:

1)树以代表训练样本的根结点开始。

2)如果样本都在同一个类.则该结点成为树叶,并记录该类。

3)否则,算法选择最有分类能力的属性作为决策树的当前结点.

4 )根据当前决策结点属性取值的不同,将训练样本根据该属性的值分为若干子集,每个取值形成一个分枝,有几个取值形成几个分枝。匀针对上一步得到的一个子集,重复进行先前步骤,递归形成每个划分样本上的决策树。一旦一个属性只出现在一个结点上,就不必在该结点的任何后代考虑它,直接标记类别。

5)递归划分步骤仅当下列条件之一成立时停止:

①给定结点的所有样本属于同一类。

②没有剩余属性可以用来进一步划分样本.在这种情况下.使用多数表决,将给定的结点转换成树叶,并以样本中元组个数最多的类别作为类别标记,同时也可以存放该结点样本的类别分布[这个主要可以用来剪枝]。

③如果某一分枝tc,没有满足该分支中已有分类的样本,则以样本的多数类生成叶子节点。

算法中2)步所指的最优分类能力的属性。这个属性的选择是本算法种的关键点,分裂属性的选择直接关系到此算法的优劣。

一般来说可以用比较信息增益和信息增益率的方式来进行。

其中信息增益的概念又会牵扯出熵的概念。熵的概念是香农在研究信息量方面的提出的。它的计算公式是:

Info(D)=-p1log(p1)/log(2.0)-p2log(p2)/log(2.0)-p3log(p3)/log(2.0)+...-pNlog(pN)/log(2.0) (其中N表示所有的不同类别)

而信息增益为:

Gain(A)=Info(D)-Info(Da) 其中Info(Da)数据集在属性A的情况下的信息量(熵)。

3

算法的特点

优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。

缺点:可能会产生过度匹配问题

适用数据范围:数值型和标称型。

4

python代码实现

1、创建初始数据集,用于测试

###################################### #功能:创建数据集 #输入变量:空 #输出变量:data_set, labels 数据集,标签 ###################################### def create_data_set():

data_set = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] # 数据集最后一列表示叶结点,也称类别标签 labels = ['no surfacing', 'flippers'] # 表示决策结点,也称特征标签

return data_set, labels

2、计算给定数据集的信息熵

############################# #功能:计算信息熵 #输入变量:data_set 数据集 #输出变量:shannon_ent 信息熵 ############################# def calc_shannon_ent(data_set):

num_entries = len(data_set) label_counts = {}

for feat_vec in data_set: current_label = feat_vec[-1] # get相当于一条if...else...语句 # 如果参数current_label不在字典中则返回参数0, # 如果current_label在字典中则返回current_label对应的value值 label_counts[current_label] = label_counts.get(current_label, 0) + 1

shannon_ent = 0.0

for key in label_counts: prob = float(label_counts[key])/num_entries shannon_ent -= prob*log(prob, 2)

return shannon_ent

3、按照给定特征划分数据集。分类算法除了需要测量信息熵,还需要划分数据集,这就需要对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集是最好的划分方式。

################################# #功能:划分数据集 #输入变量:data_set, axis, value # 数据集,数据集的特征,特征的值 #输出变量:ret_data_set, 划分后的数据集 #################################

def split_data_set(data_set, axis, value):

ret_data_set = []

for feat_vec in data_set: if feat_vec[axis] == value:

# 把axis特征位置之前和之后的特征值切出来 # 没有使用del函数的原因是,del会改变原始数据 reduced_feat_vec = feat_vec[:axis] reduced_feat_vec.extend(feat_vec[axis+1:]) ret_data_set.append(reduced_feat_vec)

return ret_data_set

4、遍历整个数据集,循环划分数据并计算信息熵,通过计算最大信息增益来找到最好的特征划分方式。

具体做法是,遍历当前特征中的所有唯一属性值,对每个特征划分一次数据集,然后计算数据集的新熵值,并对所有唯一特征值得到的熵求和。最后用所求的和值与原始信息熵相减,计算寻找最大信息增益。

###################################### #功能:选择最好的数据集划分方式 #输入变量:data_set 待划分的数据集 #输出变量:best_feature 计算得出最好的划分数据集的特征 ###################################### def choose_best_feature_to_split(data_set):

num_features = len(data_set[0]) - 1 # 最后一个是类别标签,所以特征属性长度为总长度减1 base_entropy = calc_shannon_ent(data_set) # 计算数据集原始信息熵

best_info_gain = 0.0 best_feature = -1

for i in xrange(num_features):

# feat_vec[i]代表第i列的特征值,在for循环获取这一列的所有值 feat_list = [feat_vec[i] for feat_vec in data_set] unique_vals = set(feat_list) # set函数得到的是一个无序不重复数据集

new_entropy = 0.0

# 计算每种划分方式的信息熵 for value in unique_vals:

sub_data_set = split_data_set(data_set, i, value) prob = len(sub_data_set)/float(len(data_set)) new_entropy += prob*calc_shannon_ent(sub_data_set)

info_gain = base_entropy - new_entropy

if info_gain > best_info_gain:

best_info_gain = info_gain best_feature = i

return best_feature

5

递归构建决策树

工作原理:得到原始数据集,然后基于最好的属性值划分数据集,由于特征值可能多于两个,因此可能存在大于两个分支的数据集划分。第一次划分之后,数据将被向下传递到树分支的下一个节点,在这个节点上,我们可以再次划分数据。因此我们可以采用递归的原则处理数据集。

递归结束条件是:第一、所有的类别标签(叶结点)完全相同。第二、使用完了所有的特征,仍然不能将数据集划分成仅包含唯一类别的分组,则挑选出次数最多的类别作为返回值。

###################################### #功能:多数表决分类 #输入变量:class_list 所有数据的标签数组 #输出变量:sorted_class_count[0][0] 出现次数最多的分类名称 ###################################### def majority_vote_sort(class_list):

class_count = {}

for vote in class_list: class_count[vote] = class_count.get(vote, 0) + 1

# items以列表方式返回字典中的键值对,iteritems以迭代器对象返回键值对,而键值对以元组方式存储,即这种方式[(), ()] # operator.itemgetter(0)获取对象的第0个域的值,即返回的是key值 # operator.itemgetter(1)获取对象的第1个域的值,即返回的是value值 # operator.itemgetter定义了一个函数,通过该函数作用到对象上才能获取值 # reverse=True是按降序排序 sorted_class_count = sorted(class_count.iteritems(), key=operator.itemgetter(1), reverse=True)

return sorted_class_count[0][0]

###################################### #功能:创建数 #输入变量:data_set, labels 待分类数据集,标签 #输出变量:my_tree 决策树 ###################################### def create_tree(data_set, labels):

class_list = [example[-1] for example in data_set]

# 判断类别标签是否完全相同 # count()是列表内置的方法,可以统计某个元素在列表中出现的次数 if class_list.count(class_list[0]) == len(class_list): return class_list[0]

# 遍历完所有特征时返回出现次数最多的 if len(data_set[0]) == 1: return majority_vote_sort(class_list)

best_feat = choose_best_feature_to_split(data_set) best_feat_label = labels[best_feat] my_tree = {best_feat_label: {}} del(labels[best_feat])

# 得到列表包含的所有属性值 feat_values = [example[best_feat] for example in data_set] unique_vals = set(feat_values)

for value in unique_vals:

sub_labels = labels[:] # :复制特征标签,为了保证循环调用函数create_tree()不改变原始的内容 ret_data_set = split_data_set(data_set, best_feat, value) my_tree[best_feat_label][value] = create_tree(ret_data_set, sub_labels)

return my_tree

6

测试代码

def main():

my_data, my_labels = create_data_set() #my_data[0][-1] = 'maybe' print 'my_data=', my_data print 'my_labels=', my_labels

shannon_ent = calc_shannon_ent(my_data) print 'shannon_ent=', shannon_ent

ret_data_set = split_data_set(my_data, 1, 1) # 由第1个特征且特征值为1的数据集划分出来 print 'ret_data_set=', ret_data_set

best_feature = choose_best_feature_to_split(my_data) print 'best_feature=', best_feature

my_tree = create_tree(my_data, my_labels) print 'my_tree=', my_tree

if __name__ == '__main__': main()

在进行案例分析前,先对决策树算法的分类函数进行测试。考虑到构造决策树非常耗时,为了节省计算时间,最好能够在每次执行分类时调用已经构造好的决策树。这就需要利用python模块pickle序列化对象将决策树分类算法保存在磁盘中,并在需要的时候读取出来。

1、测试决策树分类算法性能

###################################### #功能:决策树的分类函数 #输入变量:input_tree, feat_labels, test_vec # 决策树,分类标签,测试数据 #输出变量:class_label 类标签 ###################################### def classify(input_tree, feat_labels, test_vec): first_str = input_tree.keys()[0] second_dict = input_tree[first_str] class_label = -1

# index方法用于查找当前列表中第一个匹配first_str变量的索引 feat_index = feat_labels.index(first_str)

for key in second_dict.keys(): if test_vec[feat_index] == key: if type(second_dict[key]).__name__ == 'dict': class_label = classify(second_dict[key], feat_labels, test_vec) else: class_label = second_dict[key] return class_label

2、对决策树算法进行存储 ###################################### #功能:将决策树存储到磁盘中 #输入变量:input_tree, filename 决策树,存储的文件名 ###################################### def store_tree(input_tree, filename):

import pickle fw = open(filename, 'w') pickle.dump(input_tree, fw) # 序列化,将数据写入到文件中 fw.close()

3、对决策树算法进行读取 ###################################### #功能:从磁盘中读取决策树信息 #输入变量:filename 存储的文件名 ###################################### def grab_tree(filename):

import pickle fr = open(filename, 'r') return pickle.load(fr) # 反序列化

4、代码测试 def main():

my_data, my_labels = create_data_set() print 'my_data=', my_data print 'my_labels=', my_labels

class_label = classify(my_tree, my_labels, [1, 1]) print 'class_label=', class_label

store_tree(my_tree, 'classifierStorage.txt') tree = grab_tree('classifierStorage.txt') print 'tree=', tree

if __name__ == '__main__': main()

案例分析:使用决策树预测隐形眼镜类型

隐形眼镜类型包括硬材质、软材质以及不适合佩戴隐形眼镜。而眼科医生需要从age、prescript、astigmatic和tearRate这四个方面对患者进行询问,以此来判断患者佩戴的镜片类型。利用决策树算法,我们甚至也可以帮助人们判断需要佩戴的镜片类型。

在构造决策树前,我们需要获取隐形眼镜数据集,从lenses.txt文件读取。还需要获取特征属性(或者说决策树的决策结点),从代码输入。将数据集和特征属性代入决策树分类算法,就能构造出隐形眼镜决策树,沿着不同分支,我们可以得到不同患者需要的眼镜类型。

代码如下:

fr = open('lenses.txt', 'r')

lenses = [line.strip().split('\t') for line in fr.readlines()] lenses_labels = ['age', 'prescript', 'astigmatic', 'tearRate']

lenses_tree = create_tree(lenses, lenses_labels)

print 'lenses_tree=', lenses_tree

原文发布于微信公众号 - 大数据挖掘DT数据分析(datadw)

原文发表时间:2015-07-01

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏数据结构与算法

1113. 括号匹配

题目描述 给定一个只包含左右括号的合法括号序列,按右括号从左到右的顺序输出每一对配对的括号出现的位置(括号序列以0开始编号)。 输入 仅一行,表示一个合法的括号...

2505
来自专栏钱塘大数据

R语言的常用函数速查

一、基本 1.数据管理 vector:向量 numeric:数值型向量 logical:逻辑型向量character;字符型向量 list:列表 data....

2859
来自专栏Petrichor的专栏

tensorflow编程: Constants, Sequences, and Random Values

  注意: start 和 stop 参数都必须是 浮点型;     取值范围也包括了 stop; tf.lin_space 等同于 tf.lins...

382
来自专栏崔庆才的专栏

TensorFlow layers模块用法

TensorFlow 中的 layers 模块提供用于深度学习的更高层次封装的 API,利用它我们可以轻松地构建模型,这一节我们就来看下这个模块的 API 的具...

7278
来自专栏窗户

平方根的C语言实现(二) —— 手算平方根的原理

  一个函数从数学上来说可以有无数个函数列收敛于这个函数,那么程序逼近实现来说可以有无数种算法,平方根自然也不例外。   不知道有多少人还记得手算平方根,那是满...

1999
来自专栏专知

【干货】计算机视觉实战系列03——用Python做图像处理

【导读】专知成员Hui上一次为大家介绍Matplotlib的使用,包括绘图,绘制点和线,以及图像的轮廓和直方图,这一次为大家详细讲解Numpy工具包中的各种工具...

39210
来自专栏木东居士的专栏

Bloom Filter 的数学背景

1223
来自专栏抠抠空间

python常见模块之random模块

python常见模块之random模块 import random print(random.random()) #随机产生一个0-1之间的小数 p...

26710
来自专栏云瓣

探寻 JavaScript 精度问题

阅读完本文可以了解到 0.1 + 0.2 为什么等于 0.30000000000000004 以及 JavaScript 中最大安全数是如何来的。

632
来自专栏用户2442861的专栏

openCV-图像算数与逻辑运算

621

扫码关注云+社区