决策树

一、 决策树简介

决策树是一种特殊的树形结构,一般由节点和有向边组成。其中,节点表示特征、属性或者一个类。而有向边包含有判断条件。如图所示,决策树从根节点开始延伸,经过不同的判断条件后,到达不同的子节点。而上层子节点又可以作为父节点被进一步划分为下层子节点。一般情况下,我们从根节点输入数据,经过多次判断后,这些数据就会被分为不同的类别。这就构成了一颗简单的分类决策树。

1.jpg

2.jpg

二、 相关知识

请参考周志华《机器学习》第4章:决策树

注意,第75页有一行内容:信息熵的值越小,则样本集合的纯度越高。 这怎么理解呢?可举几个例子来帮助理解 例1: p1 = 1, p2 = 0, 则Ent(D) = - (1 * log1 + 0 * log0) = 0 这表示样本集合里只有一种样本。纯度最高。

例2: p1 = p2 = 0.5,则Ent(D) = - (0.5 * log0.5 + 0.5 * log0.5) = - (-0.5 - 0.5) = 1 样本集合中有两种样本,每种样本的纯度都是50%

例3: p1 = 0.5, p2 = p3 = 0.25,则Ent(D) = -(0.5 * log0.5 + 0.25 * log0.25 + 0.25 * log0.25) = -(-0.5 - 0.5 - 0.5) = 1.5

例4: p1 = p2 = p3 = p4 = 0.25,则Ent(D) = 2

例5: p1 = p2 = p3 = p4 = 1/1024,则Ent(D) = 10

通过上面几个可以看出来,熵值越小,纯度越高,熵值越大,纯度越低。纯度范围为[0, log|y|]

三、 数据集

为了更好地理解程序,咱们只采用15行数据来实现决策树二分类。这些数据中,前6行是特征属性,最后一行是分类结果。

2,2,2,4,0,2,1
2,2,2,2,0,1,0
2,2,2,2,0,2,0
2,2,2,2,1,0,0
2,2,2,2,1,1,0
2,2,2,2,1,2,0
2,2,2,2,2,0,0
2,2,2,2,2,1,0
2,2,2,2,2,2,0
2,2,2,4,0,0,0
2,2,2,4,0,1,0
2,2,2,4,0,2,1
2,2,2,4,1,0,0
2,2,2,4,1,1,0
2,2,2,4,1,2,1

对于上面的十五行数据,咱们可以把前1/3(即前5行)数据做为测试集,把后2/3(即后10行)数据作为训练集。 注意:这里是为了好理解,划分成前1/3和后2/3,更好的做法是随机挑选一部分作为测试集,剩余的部分作为训练集。

四、 算法思路

可直接查看周志华《机器学习》中的思路

3.jpg

五、 实现过程

(一)利用信息熵和信息增益

信息熵越小越好,信息增益越大越好。 信息熵的公式为:

4.png

对于各属性,信息增益的公式为:

5.png

因为D都是一样的,程序中可以省略掉D。用信息增益最大的属性作为节点划分,即要也上式最大,也即要求下式最小:

6.png

实现代码:

def __get_info_entropy(label, attr):
    result = 0.0
    for this_attr in np.unique(attr):
        sub_label, entropy = label[np.where(attr == this_attr)[0]], 0.0
        for this_label in np.unique(sub_label):
            p = len(np.where(sub_label == this_label)[0]) / len(sub_label)
            entropy -= p * np.log2(p)
        result += len(sub_label) / len(label) * entropy
    print("result:", result)
    return result

(二)划分属性

训练集的数据(后10行)为

2,2,2,2,1,2,0
2,2,2,2,2,0,0
2,2,2,2,2,1,0
2,2,2,2,2,2,0
2,2,2,4,0,0,0
2,2,2,4,0,1,0
2,2,2,4,0,2,1
2,2,2,4,1,0,0
2,2,2,4,1,1,0
2,2,2,4,1,2,1

根据上面的代码,计算出的各属性的信息增益分别为:

第0个属性的信息增益为:0.7219280948873623
第1个属性的信息增益为:0.7219280948873623
第2个属性的信息增益为:0.7219280948873623
第3个属性的信息增益为:0.5509775004326937
第4个属性的信息增益为:0.6
第5个属性的信息增益为:0.4

上述中最后一个值0.4最小,所以取该值所对应的第5列(从0开始计数)属性作为根结点。 用第5列属性进行划分,属性0对应着分类0,属性1对应着分类0,属性2对应着分类0和1,此时树的结构是这样的:

7.png

第二个属性如何计算呢? 找出第5列属性的值为2所对应的类别。再找出这些类别对应的各列属性值(不包含第5列属性),得

2,2,2,2,1,0
2,2,2,2,2,0
2,2,2,4,0,1
2,2,2,4,1,1

这里前五列表示属性0,1,2,3,4;最后一列表示分类。 计算出的各属性的信息增益分别为:

第0个属性的信息增益为:1.0
第1个属性的信息增益为:1.0
第2个属性的信息增益为:1.0
第3个属性的信息增益为:0.0
第4个属性的信息增益为:0.5

则取第3个属性来划分。第3个属性取值为2时,分类为0;取值为4时,分类为1。 此时的决策树为

8.png

(三)预测结果

需要预测的五条数据为 2,2,2,4,0,2,1 2,2,2,2,0,1,0 2,2,2,2,0,2,0 2,2,2,2,1,0,0 2,2,2,2,1,1,0 前六列是特征属性,最后一列是实际结果,用来和预测结果做比较。 第一条数据,第5个属性值是2,需要再判断第3个属性,第3个属性的值为4,根据决策树得出的预测分类为1,与实际结果吻合 第二条数据,第5个属性值是1,根据决策树得出的预测分类为0,与实际结果吻合 第三条数据,第5个属性值是2,需要再判断第3个属性,第3个属性的值为2,根据决策树得出的预测分类为0,与实际结果吻合 第四条数据,第5个属性值是0,根据决策树得出的预测分类为0,与实际结果吻合 第五条数据,第5个属性值是1,根据决策树得出的预测分类为0,与实际结果吻合

六、 完整代码

(1)DecisionTree.py

# 具有两种剪枝功能的简单决策树
# 使用信息熵进行划分,剪枝时采用激进策略(即使剪枝后正确率相同,也会剪枝)
import numpy as np


class Tree:
    def __init__(self, label, attr, pruning=None):
        self.__root = None
        boundary = len(label) // 3

        if pruning is None:
            self.__root = self.__run_build(label[boundary:], attr[boundary:],
                                           np.array(range(len(attr.transpose()))), False)
            return
        if pruning == 'Pre':
            self.__root = self.__run_build(label[boundary:], attr[boundary:],
                                           np.array(range(len(attr.transpose()))),
                                           True, attr[0:boundary], label[0:boundary])
        elif pruning == 'Post':
            self.__root = self.__run_build(label[boundary:], attr[boundary:],
                                           np.array(range(len(attr.transpose()))), False)
            self.print_tree()
            self.__post_pruning(self.__root, attr[0:boundary], label[0:boundary])
        else:
            raise RuntimeError('未能识别的参数:%s' % pruning)

    @staticmethod
    # 返回使用特定属性划分下的信息熵之和
    # label: 类别标签
    # attr: 用于进行数据划分的属性
    def __get_info_entropy(label, attr):
        result = 0.0
        for this_attr in np.unique(attr):
            sub_label, entropy = label[np.where(attr == this_attr)[0]], 0.0
            for this_label in np.unique(sub_label):
                p = len(np.where(sub_label == this_label)[0]) / len(sub_label)
                entropy -= p * np.log2(p)
            result += len(sub_label) / len(label) * entropy
        return result

    # 递归构建一颗决策树
    # label: 维度为1 * N的数组。第i个元素表示第i行数据所对应的标签
    # attr: 维度为 N * M 的数组,每行表示一条数据的属性,列数随着决策树的构建而变化
    # attr_idx: 表示每个属性在原始属性集合中的索引,用于决策树的构建
    # pre_pruning: 表示是否进行预剪枝
    # check_attr: 在预剪枝时,用作测试数据的属性集合
    # check_label: 在预剪枝时,用作测试数据的验证标签
    def __run_build(self, label, attr, attr_idx, pre_pruning, check_attr=None, check_label=None):
        node, right_count = {}, None
        max_type = np.argmax(np.bincount(label))
        if len(np.unique(label)) == 1:
            # 如果所有样本属于同一类C,则将结点标记为C
            node['type'] = label[0]
            return node
        if attr is None or len(np.unique(attr, axis=0)) == 1:
            # 如果 attr 为空或者 attr 上所有元素取值一致,则将结点标记为样本数最多的类
            node['type'] = max_type
            return node
        attr_trans = np.transpose(attr) #每一行就是原先的属性列,转置是为了计算方便
        min_entropy, best_attr = np.inf, None
        # 获取各种划分模式下的信息熵之和(作用和信息增益类似)
        # 并以此为信息,找出最佳的划分属性
        if pre_pruning:
            right_count = len(np.where(check_label == max_type)[0])
        for this_attr in attr_trans:
            entropy = self.__get_info_entropy(label, this_attr)
            if entropy < min_entropy:
                min_entropy = entropy
                best_attr = this_attr
        # branch_attr_idx 表示用于划分的属性是属性集合中的第几个
        branch_attr_idx = np.where((attr_trans == best_attr).all(1))[0][0]
        if pre_pruning:
            sub_right_count = 0
            check_attr_trans = check_attr.transpose()
            # branch_attr_idx 表示本次划分依据的属性属于属性集中的哪一个
            for val in np.unique(best_attr):
                # 按照预划分的特征进行划分,并统计划分后的正确率
                # branch_data_idx 表示数据集中,被划分为 idx 的数据的索引
                branch_data_idx = np.where(best_attr == val)[0]
                # predict_label 表示一次划分以后,该分支数据的预测类别
                print(label[branch_data_idx])
                print(np.bincount(label[branch_data_idx]))
                predict_label = np.argmax(np.bincount(label[branch_data_idx]))
                # check_data_idx 表示验证集中,属性编号为 branch_attr_idx 的属性值等于 val 的项的索引
                check_data_idx = np.where(check_attr_trans[branch_attr_idx] == val)[0]
                # check_branch_label 表示按照当前特征划分以后,被分为某一类的数据的标签
                check_branch_label = check_label[check_data_idx]
                # 随后判断这些标签是否等于前面计算得到的类别,如果相等,则分类正确
                sub_right_count += len(np.where(check_branch_label == predict_label)[0])
            if sub_right_count <= right_count:
                # 如果划分后的正确率小于等于不划分的正确率,则剪枝
                node['type'] = max_type
                return node
        values = []
        for val in np.unique(best_attr):
            values.append(val)
            branch_data_idx = np.where(best_attr == val)[0]
            if len(branch_data_idx) == 0:
                new_node = {'type': np.argmax(np.bincount(label))}
            else:
                # 按照划分构造新数据,并开始递归
                branch_label = label[branch_data_idx]

                # 哪几行branch_attr对应着上面的branch_label数组
                branch_attr = np.delete(attr_trans, branch_attr_idx, axis=0).transpose()[branch_data_idx]
                new_node = self.__run_build(branch_label, branch_attr,
                                            np.delete(attr_idx, branch_attr_idx, axis=0),
                                            pre_pruning, check_attr, check_label)
            node[str(val)] = new_node
        node['attr'] = attr_idx[branch_attr_idx]
        node['type'] = max_type
        node['values'] = values
        return node

    # 后剪枝
    # node: 当前进行判断和剪枝操作的结点
    # check_attr: 用于验证的数据属性集
    # check_label: 用于验证的数据标签集
    def __post_pruning(self, node, check_attr, check_label):
        check_attr_trans = check_attr.transpose()
        if node.get('attr') is None:
            # attr 为 None 代表叶节点
            return len(np.where(check_label == node['type'])[0])
        sub_right_count = 0
        for val in node['values']:
            sub_node = node[str(val)]
            # 找到当前分支点中,数据属于 idx 这一分支的数据的索引
            idx = np.where(check_attr_trans[node['attr']] == val)[0]
            # 使用上述数据,从子节点开始新的递归
            sub_right_count += self.__post_pruning(sub_node, check_attr[idx], check_label[idx])
        if sub_right_count <= len(np.where(check_label == node['type'])[0]):
            for val in node['values']:
                del node[str(val)]
            del node['values']
            del node['attr']
            return len(np.where(check_label == node['type'])[0])
        return sub_right_count

    # 根据构建的决策树预测结果
    # data: 用于预测的数据,维度为M
    # return: 预测结果
    def predict(self, data):
        node = self.__root
        while node.get('attr') is not None:
            attr = node['attr']
            node = node.get(str(data[attr]))
            if node is None:
                return None
        return node.get('type')

    # 以文本形式(类JSON)打印出决策树
    def print_tree(self):
        print(self.__root)

(2)Main.py

import DecisionTree
import numpy as np


if __name__ == '__main__':
    print('正在准备数据并种树……')
    file = open('Data/car.data')
    lines = file.readlines()
    raw_data = np.zeros([len(lines), 7], np.int32)
    for idx in range(len(lines)):
        raw_data[idx] = np.array(lines[idx].split(','), np.int32)
    file.close()
    #np.random.shuffle(raw_data)
    data =  raw_data.transpose()[0:6].transpose()
    label = raw_data.transpose()[6]
    # tree_no_pruning = DecisionTree.Tree(label, data, None)
    # tree_no_pruning.print_tree()
    # tree_pre_pruning = DecisionTree.Tree(label, data, 'Pre')
    # tree_pre_pruning.print_tree()
    tree_post_pruning = DecisionTree.Tree(label, data, 'Post')
    tree_post_pruning.print_tree()
    test_count = len(label) // 3
    test_data, test_label = data[0:test_count], label[0:test_count]
    times_no_pruning, times_pre_pruning, times_post_pruning = 0, 0, 0
    print('正在检验结果(共 %d 条验证数据)' % test_count)
    for idx in range(test_count):
        # if tree_no_pruning.predict(test_data[idx]) == test_label[idx]:
        #     times_no_pruning += 1
        # if tree_pre_pruning.predict(test_data[idx]) == test_label[idx]:
        #     times_pre_pruning += 1
        if tree_post_pruning.predict(test_data[idx]) == test_label[idx]:
            times_post_pruning += 1
    #print('【未剪枝】:命中 %d 次,命中率 %.2f%%' % (times_no_pruning, times_no_pruning * 100 / test_count))
    #print('【预剪枝】:命中 %d 次,命中率 %.2f%%' % (times_pre_pruning, times_pre_pruning * 100 / test_count))
print('【后剪枝】:命中 %d 次,命中率 %.2f%%' % (times_post_pruning, times_post_pruning * 100 / test_count))

运行结果:

正在准备数据并种树……
决策树的结构为:
 {'0': {'type': 0}, '1': {'type': 0}, '2': {'2': {'type': 0}, '4': {'type': 1}, 'attr': 3, 'type': 0, 'values': [2, 4]}, 'attr': 5, 'type': 0, 'values': [0, 1, 2]}
正在检验结果(共 5 条验证数据)
【未剪枝】:命中 5 次,命中率 100.00%

七、参考资料

https://blog.csdn.net/dapanbest/article/details/78281201

原文发布于微信公众号 - 海天一树(gh_de7b45c40e8b)

原文发表时间:2018-08-04

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏数值分析与有限元编程

可视化 | MATLAB划分均匀矩形网格

之前发过一个划分均匀三角形网格的例子。下面结合一个悬臂梁说说如何在规则区域划分均匀矩形网格。 ? 将一个矩形平面区域划分成相同大小的矩形。X方向等分nex,Y方...

5669
来自专栏李智的专栏

Python针对图像的基础操作

5. 返回目录中所有JPG 图像的文件名列表,直方图均衡化,平均图像,主成分分析等

1642
来自专栏iOSDevLog

估计器接口小结摘自:《Python 机器学习基础教程》 第3章 无监督学习与预处理(三)

scikit-learn 中的所有算法——无论是预处理、监督学习还是无监督学习算法——都被实现为类。这些类在 scikit-learn 中叫作估计器(estim...

1532
来自专栏along的开发之旅

glLoadIdentity()与glTranslatef()和glRotatef()--坐标变换

初学OpenGL,对它的矩阵变换不甚了解,尤其是glTranslatef和glRotatef联合使用,立即迷得不知道东西南北。在代码中改变数据多次,终于得到了相...

1304
来自专栏iOSDevLog

决策树

1194
来自专栏深度学习那些事儿

风格迁移(Style Transfer)中直方图匹配(Histogram Match)的作用

风格迁移是神经网络深度学习中比较重要且有趣的一个项目。如果不知道什么是风格迁移的请参考这篇文章:https://oldpan.me/archives/pytor...

60014
来自专栏杨熹的专栏

神经网络 之 线性单元

本文结构: 什么是线性单元 有什么用 代码实现 ---- 1. 什么是线性单元 线性单元和感知器的区别就是在激活函数: ? 感知器的 f 是阶越函数: ? 线性...

3374
来自专栏机器学习原理

图像处理和数据增强图片处理数据增强颜色空间转换噪音数据的加入样本不均衡

7844
来自专栏null的专栏

简单易学的机器学习算法——K-Means++算法

一、K-Means算法存在的问题 由于K-Means算法的简单且易于实现,因此K-Means算法得到了很多的应用,但是从K-Means算法的过程中发现,K-Me...

3735
来自专栏mantou大数据

[机器学习实战]K-近邻算法

1. K-近邻算法概述(k-Nearest Neighbor,KNN) K-近邻算法采用测量不同的特征值之间的距离方法进行分类。该方法的思路是:如果一个样本在特...

4565

扫码关注云+社区

领取腾讯云代金券