前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >从零开始在Python中实现决策树算法

从零开始在Python中实现决策树算法

作者头像
Steve Wang
发布2018-02-02 16:13:38
3.2K1
发布2018-02-02 16:13:38
举报
文章被收录于专栏:从流域到海域从流域到海域

How To Implement The Decision Tree Algorithm From Scratch In Python

原文作者:Jason Brownlee

原文地址:https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/

译者微博:@从流域到海域

译者博客:blog.csdn.net/solo95

(译者注:本文涉及到的所有split point,绝大部分翻译成了分割点,因为根据该点的值会做出逻辑上的分割,但其实在树的概念中就是一个分支点。撇开专业知识不谈,仅就英语的层面来说翻译成分裂点也是可以的,因为将从该点分裂出左孩子或右孩子结点)

从零开始在Python中实现决策树算法

决策树是一个强大的预测方法,非常受欢迎。

决策树很受欢迎,因为最终的模型很容易被从业者和领域专家所理解。最终的决策树可以准确地解释为什么进行了具体的预测,使其在业务使用上非常有吸引力。

决策树还为更先进的集合方法(如装袋(baging),随机森林(random forests)和梯度提升(gradient boosting))提供了基础。

在本教程中,您将了解如何使用Python从头开始实现分类回归树算法(Classification And Regression Tree algorithm)

读完本教程后,您将知道:

  • 如何计算和评估数据中的候选分割(split points)点。
  • 如何将分支安排到决策树结构中。
  • 如何将分类回归树算法应用于实际问题。

让我们开始吧。

  • 2017年1月更新:将cross_validation_split()中fold_size的计算更改为始终为整数。修复了Python 3的一些问题。
  • 2017年2月更新:修复了build_tree中的一个bug。
  • 2017年8月更新:修正了Gini计算中的一个bug,增加了缺失的根据群组大小给出的群组权重Gini得分(感谢Michael)!

从零开始在Python中实现来自Scratch的决策树算法

照片由马丁Cathrae提供,保留某些权利。

说明

本节简要介绍分类回归树算法以及本教程中使用的Banknote数据集。

分类回归树

分类回归树或简称CART是Leo Breiman提出的可用于分类或回归预测建模问题的决策树算法。

本教程将重点介绍如何使用CART进行分类。

CART模型的表示是二叉树。这里说的二叉树是与算法和数据结构相同的二叉树,没有什么特别的(每个结点可以有零个,一个或两个子结点)。

一个结点表示一个单一的输入变量(X)和该变量上的一个分割点,假定变量是数字的。树的叶子结点(也称为终端结点)包含用于进行预测的输出变量(y)。

一旦创建完成,就可以在每个分支之后使用新的一行数据对一棵树进行导航直到最终的预测。

创建一个二叉决策树实际上是一个划分输入空间的过程。有一个贪婪方法被用来划分空间,它被称为递归二进制。这是一个数值过程,其中所有值排列在一起,尝试不同的分割点并使用成本函数进行测试。

成本最低的分割点(之所以成本低是因为我们降低了它的成本)最终会被选择出来。所有输入变量和所有可能的分割点都是基于成本函数以贪婪的方式评估和选择的。

  • 回归(Regression):成本函数是用来最小化选择分割点的(成本),即落在矩形内的所有训练样本的(平方和)误差。
  • 分类(Classification):使用基尼(Gini)成本函数时,需要其提供结点的纯度的指示,其中结点纯度指的是分配给每个结点的训练数据是如何混合到何种程度。(作者想表达的意思是,在需要提供结点纯度指示的条件下,才会使用Gini函数,译者注。)

继续进行,直到结点包含最少数量的训练示例或达到最大树深度。

钞票数据集(Banknote)

钞票数据集涉及的是通过从照片中提取的一系列方法来预测给定的钞票是否是真实的。

数据集包含1,372个5个数字变量。这是一个两个类分类问题(二元分类)。

下面提供了数据集中五个变量的列表。

  1. 小波变换图像的方差(连续)。
  2. 小波变换图像的偏度(连续)。
  3. 小波变换图像的峭度(连续)。
  4. 图像熵(连续)。
  5. 类(整数)。

以下是数据集前5行的示例

代码语言:txt
复制
3.6216,8.6661,-2.8073,-0.44699,0
4.5459,8.1674,-2.4586,-1.4621,0
3.866,-2.6383,1.9242,0.10645,0
3.4566,9.5228,-4.0112,-3.5944,0
0.32924,-4.4552,4.5718,-0.9888,0
4.3684,9.6718,-3.9606,-3.1625,0

使用零规则算法来预测最常见的类的值,问题的基线准确度约为50%。

您可以了解更多信息并从UCI Machine Learning Repository下载数据集。

下载数据集并将其放在当前工作目录中,文件命名为data_banknote_authentication.csv

教程

本教程分为5个部分:

  1. 基尼(Gini)指数。
  2. 分割点创建。
  3. 建立一棵树。
  4. 做一次预测。
  5. 钞票案例的研究。

这些步骤将为您提供从零开始实施CART算法的基础,并将其应用于您自己的预测建模问题。

1.基尼(Gini)指数

基尼指数是成本函数的名称,用于评估数据集中的分割(的成本)。

数据集中分割的涉及该属性的一个输入属性和一个值。它可以用来将训练模式分成两组。

一个基尼评分给出了一个(定义)分割有多好的想法,这个通过(观察)由这个分割创造的两组数据的混合程度如何来实现。一个完美的分割结果会导致基尼评分为0,而最差情况下的分割导致每个组的50/50类基尼评分为0.5(对于有2个类的问题)(50/50即五五开)。

计算基尼(Gini)最好用一个例子来演示。

我们有两组数据,每组有两行。第一组中的行全部属于类0,第二组中的行属于类1,所以它的分割是完美的。

我们首先需要计算每组中类的比例。

代码语言:txt
复制
proportion = count(class_value) / count(rows)

这个例子的比例是:

代码语言:txt
复制
group_1_class_0 = 2 / 2 = 1
group_1_class_1 = 0 / 2 = 0
group_2_class_0 = 0 / 2 = 0
group_2_class_1 = 2 / 2 = 1

然后计算每个子结点的Gini如下:

代码语言:txt
复制
gini_index = sum(proportion * (1.0 - proportion))
gini_index = 1.0 - sum(proportion * proportion)

然后,每个组的基尼指数必须按组的大小加权,相对于双亲(结点)中的所有样本,例如当前正在分组的所有样本。我们可以把这个权重加到一个组的基尼计算上,如下所示:

代码语言:txt
复制
gini_index = (1.0 - sum(proportion * proportion)) * (group_size/total_samples)

在这个例子中,每组的基尼分数计算如下:

代码语言:txt
复制
Gini(group_1) = (1 - (1*1 + 0*0)) * 2/4
Gini(group_1) = 0.0 * 0.5
Gini(group_1) = 0.0
Gini(group_2) = (1 - (0*0 + 1*1)) * 2/4
Gini(group_2) = 0.0 * 0.5
Gini(group_2) = 0.0

然后将分数添加到分割点处的每个子结点上,以给出可以与其他候选(分割)点相比较的最终基尼分数。

这个分割点的基尼将被计算为0.0 + 0.0或完美的基尼分数为0.0。

下面是一个名为gini_index()的函数,用于计算一系列组和一系列已知类的值的Gini指数。

你可以看到在那里有一些安全检查,以避免空组被零除(零除即除数为零)。

代码语言:txt
复制
# Calculate the Gini index for a split dataset
def gini_index(groups, classes):
	# count all samples at split point
	n_instances = float(sum([len(group) for group in groups]))
	# sum weighted Gini index for each group
	gini = 0.0
	for group in groups:
		size = float(len(group))
		# avoid divide by zero
		if size == 0:
			continue
		score = 0.0
		# score the group based on the score for each class
		for class_val in classes:
			p = [row[-1] for row in group].count(class_val) / size
			score += p * p
		# weight the group score by its relative size
		gini += (1.0 - score) * (size / n_instances)
	return gini

我们可以用我们上面的工作示例来测试这个函数。我们也可以测试每组50/50划分的最差情况。完整的例子如下所示。

代码语言:txt
复制
# Calculate the Gini index for a split dataset
def gini_index(groups, classes):
	# count all samples at split point
	n_instances = float(sum([len(group) for group in groups]))
	# sum weighted Gini index for each group
	gini = 0.0
	for group in groups:
		size = float(len(group))
		# avoid divide by zero
		if size == 0:
			continue
		score = 0.0
		# score the group based on the score for each class
		for class_val in classes:
			p = [row[-1] for row in group].count(class_val) / size
			score += p * p
		# weight the group score by its relative size
		gini += (1.0 - score) * (size / n_instances)
	return gini

# test Gini values
print(gini_index([[[1, 1], [1, 0]], [[1, 1], [1, 0]]], [0, 1]))
print(gini_index([[[1, 0], [1, 0]], [[1, 1], [1, 1]]], [0, 1]))

运行示例将打印出两个基尼评分,首先是最差情况评分为0.5,最后一个评分为0.0。

代码语言:txt
复制
0.5
0.0

现在我们知道如何评估一种分割的结果,我们来看看如何创建分割。

2.创建分割

一个分割由数据集中的一个属性和一个值组成。

我们可以将其归纳为要拆分的属性的索引(index)和该属性上拆分行的值。这只是索引数据行的一个有用的速记。

创建一个分割涉及三个部分,第一个我们已经看过哪个是计算基尼分数。其余两部分是:

  1. 划分一个数据集。
  2. 评估所有划分(方法)。

我们来看看每个。

2.1.一个数据集

拆分数据集意味着将数据集分成两个行的数据列,给定属性的索引和该属性的拆分值。

一旦我们有了这两个组,我们就可以用我们的基尼分数来评估拆分的成本。

拆分数据集涉及遍历每一行,检查属性值是否低于或高于拆分值,并分别将其分配给左侧组或右侧组。

下面是一个名为test_split()的函数,它实现了这个过程。

代码语言:txt
复制
# Split a dataset based on an attribute and an attribute value
def test_split(index, value, dataset):
	left, right = list(), list()
	for row in dataset:
		if row[index] < value:
			left.append(row)
		else:
			right.append(row)
	return left, right

没有太复杂的东西。

请注意,右边的组包含索引所指向的值大于或等于分割值的所有行。

2.2. 评估所有分割

通过上面的Gini函数和测试函数,我们现在拥有了评估分割所需的一切。

给定一个数据集,我们必须检查每个属性的每个值作为候选,评估分割的成本并找到可能实现的最佳分割。

一旦找到最佳分割,我们可以将它用作决策树中的一个结点。

这是一个详尽而贪婪的算法。

我们将使用字典来表示决策树中的一个结点,因为我们可以按名称存储数据。当选择最佳分割并将其用作树的新结点时,我们将通过从哪一个(结点)开始分割和所选择的分割点分割的两组数据来存储所选属性的索引和属性的值。

每组数据都是它自己的小数据集,这些数据集只是由分割过程分配给左侧或右侧组的那些行。你可以想象,当我们建立我们的决策树时可以再次分割每一个组,这是一个递归的过程。

下面是一个名为get_split()的函数,它实现了这个过程。你可以看到它遍历每个属性(除了类的值),然后每个属性的值,正如它的走向那样拆分和评估分割。

最好的分割将会被记录下来,然后在所有检查完成后返回。

代码语言:txt
复制
# Select the best split point for a dataset
def get_split(dataset):
	class_values = list(set(row[-1] for row in dataset))
	b_index, b_value, b_score, b_groups = 999, 999, 999, None
	for index in range(len(dataset[0])-1):
		for row in dataset:
			groups = test_split(index, row[index], dataset)
			gini = gini_index(groups, class_values)
			if gini < b_score:
				b_index, b_value, b_score, b_groups = index, row[index], gini, groups
	return {'index':b_index, 'value':b_value, 'groups':b_groups}

我们可以设计一个小数据集来测试这个函数和我们整个数据集的分割过程。

代码语言:txt
复制
X1			X2			Y
2.771244718		1.784783929		0
1.728571309		1.169761413		0
3.678319846		2.81281357		0
3.961043357		2.61995032		0
2.999208922		2.209014212		0
7.497545867		3.162953546		1
9.00220326		3.339047188		1
7.444542326		0.476683375		1
10.12493903		3.234550982		1
6.642287351		3.319983761		1

我们可以为每个类使用不同的颜色绘制出这个数据集。您可以看到,手动选择X1的值(图上的x轴)来拆分该数据集并不困难。

CART Contrived数据集

下面的例子将所有这些放在一起。

代码语言:txt
复制
# Split a dataset based on an attribute and an attribute value
def test_split(index, value, dataset):
	left, right = list(), list()
	for row in dataset:
		if row[index] < value:
			left.append(row)
		else:
			right.append(row)
	return left, right

# Calculate the Gini index for a split dataset
def gini_index(groups, classes):
	# count all samples at split point
	n_instances = float(sum([len(group) for group in groups]))
	# sum weighted Gini index for each group
	gini = 0.0
	for group in groups:
		size = float(len(group))
		# avoid divide by zero
		if size == 0:
			continue
		score = 0.0
		# score the group based on the score for each class
		for class_val in classes:
			p = [row[-1] for row in group].count(class_val) / size
			score += p * p
		# weight the group score by its relative size
		gini += (1.0 - score) * (size / n_instances)
	return gini

# Select the best split point for a dataset
def get_split(dataset):
	class_values = list(set(row[-1] for row in dataset))
	b_index, b_value, b_score, b_groups = 999, 999, 999, None
	for index in range(len(dataset[0])-1):
		for row in dataset:
			groups = test_split(index, row[index], dataset)
			gini = gini_index(groups, class_values)
			print('X%d < %.3f Gini=%.3f' % ((index+1), row[index], gini))
			if gini < b_score:
				b_index, b_value, b_score, b_groups = index, row[index], gini, groups
	return {'index':b_index, 'value':b_value, 'groups':b_groups}

dataset = [[2.771244718,1.784783929,0],
	[1.728571309,1.169761413,0],
	[3.678319846,2.81281357,0],
	[3.961043357,2.61995032,0],
	[2.999208922,2.209014212,0],
	[7.497545867,3.162953546,1],
	[9.00220326,3.339047188,1],
	[7.444542326,0.476683375,1],
	[10.12493903,3.234550982,1],
	[6.642287351,3.319983761,1]]
split = get_split(dataset)
print('Split: [X%d < %.3f]' % ((split['index']+1), split['value']))

get_split()函数被修改,以打印出每一点和它的基尼系数,因为它已经被评估过了。

运行示例打印所有基尼得分,然后打印X1 <6.642的数据集中的最佳分割分数,基尼系数为0.0或称作完美分割。

代码语言:txt
复制
X1 < 2.771 Gini=0.444
X1 < 1.729 Gini=0.500
X1 < 3.678 Gini=0.286
X1 < 3.961 Gini=0.167
X1 < 2.999 Gini=0.375
X1 < 7.498 Gini=0.286
X1 < 9.002 Gini=0.375
X1 < 7.445 Gini=0.167
X1 < 10.125 Gini=0.444
X1 < 6.642 Gini=0.000
X2 < 1.785 Gini=0.500
X2 < 1.170 Gini=0.444
X2 < 2.813 Gini=0.320
X2 < 2.620 Gini=0.417
X2 < 2.209 Gini=0.476
X2 < 3.163 Gini=0.167
X2 < 3.339 Gini=0.444
X2 < 0.477 Gini=0.500
X2 < 3.235 Gini=0.286
X2 < 3.320 Gini=0.375
Split: [X1 < 6.642]

现在我们知道如何在数据集或行的列表中找到最佳分割点,让我们看看如何使用它来构建决策树。

3.构建一棵树

创建树的根结点很容易。

我们对整个数据集调用上面的get_split()函数。

添加更多的结点到我们的树会更有趣。

建树可分为三个主要部分:

  1. 终端结点。
  2. 递归划分(分割)。
  3. 建立一棵树。
3.1. 终端结点

我们需要决定何时停止一棵树的增长。

我们可以使用深度和结点在训练数据集中对应的函数的来做到这一点。

  • 树深度的最大值。这是树从根结点开始的最大结点数。一旦树的最大深度得到满足,我们必须停止分割添加新的结点。更深的树更复杂,更有可能过度训练数据。
  • 最小结点记录数。这是给定结点负责的训练模式的最小数。一旦达到或低于这个最低限度,我们必须停止分割和增加新的结点。考虑到训练模式太少的结点预期会过于具体,并且可能过度训练数据。

这两种方法将是我们树的构建过程中由用户指定的参数。

还有一个条件。可以选择所有行都属于一个组的分割(方式)。在这种情况下,我们将无法继续拆分和添加子结点,因为我们将没有在一侧或另一侧的记录来进行(进一步)拆分。

现在我们有了什么时候停止树增长的方法。当我们在一个给定的点停止生长时,这个结点被称为终端结点,并被用来作出最终的预测。

这是通过处理分配给该结点的行的组并选择该组中最常见的类的值来完成的。这将被用来做出预测。

下面是一个名为to_terminal()的函数,它将为一组行选择一个类的值。它将返回行列表中最常见的输出值。

代码语言:txt
复制
# Create a terminal node value
def to_terminal(group):
    outcomes = [row[-1] for row in group]
    return max(set(outcomes), key=outcomes.count)
3.2. 递归分割

现在我们知道如何以及何时创建终端结点,可以建立我们的树了。

构建决策树涉及到重复调用上面开发的get_split()函数为每个结点创建组。

添加到现有结点的新结点称为子结点。一个结点可以有零个孩子(这是一个终端结点),一个孩子(一方直接进行预测)或者两个孩子结点。我们将在给定结点的字典表示中将子结点称为左(结点)和右(结点)。

一旦创建了一个结点,我们就可以通过再次调用相同的函数在分割点出来对每组数据递归地创建子结点。

下面是一个实现这个递归过程的函数。它以一个结点,以及结点的最大深度,最小模式数和当前结点深度作为参数。

你可以想象这被称为在根结点进行传递的深度为1的第一次调用是如何进行的,.这个功能最好用以下步骤来解释:

  1. 首先,由结点拆分的两组数据被提取使用并从结点中删除。当我们在这些组上工作时,结点不再需要访问这些数据。
  2. 接下来,我们检查左边或右边的行是否是空的,如果是这样的话,我们使用我们所拥有的记录创建一个终端结点。
  3. 然后检查我们是否达到了最大深度,如果是,我们创建一个终端结点。
  4. 然后,我们处理左侧子结点,如果行组太小,则创建终端结点,否则以深度优先方式创建并添加左侧结点,直到树的底部到达此分支。
  5. 然后以相同的方式对右侧进行处理,因为我们要将构建的树回溯到根。
代码语言:txt
复制
# Create child splits for a node or make terminal
def split(node, max_depth, min_size, depth):
	left, right = node['groups']
	del(node['groups'])
	# check for a no split
	if not left or not right:
		node['left'] = node['right'] = to_terminal(left + right)
		return
	# check for max depth
	if depth >= max_depth:
		node['left'], node['right'] = to_terminal(left), to_terminal(right)
		return
	# process left child
	if len(left) <= min_size:
		node['left'] = to_terminal(left)
	else:
		node['left'] = get_split(left)
		split(node['left'], max_depth, min_size, depth+1)
	# process right child
	if len(right) <= min_size:
		node['right'] = to_terminal(right)
	else:
		node['right'] = get_split(right)
		split(node['right'], max_depth, min_size, depth+1)
3.3. 构建一棵树

现在我们可以把所有的模块放在一起(组成一个整体)。

构建树包括创建根结点并调用split()函数,然后递归调用自身来构建整个树。

下面是实现这个过程的build_tree()函数。

代码语言:txt
复制
# Build a decision tree
def build_tree(train, max_depth, min_size):
	root = get_split(train)
	split(root, max_depth, min_size, 1)
	return root

我们可以使用我们上面设计的小型数据集来测试整个过程。

以下是完整的例子。

还包括一个小的print_tree()函数,这个函数每个结点一行递归地打印决策树的结点。虽然不像真正的决策树的图那样引人注目,但它给出了整个树形结构和做出决策的大致过程。

代码语言:txt
复制
# Split a dataset based on an attribute and an attribute value
def test_split(index, value, dataset):
	left, right = list(), list()
	for row in dataset:
		if row[index] < value:
			left.append(row)
		else:
			right.append(row)
	return left, right

# Calculate the Gini index for a split dataset
def gini_index(groups, classes):
	# count all samples at split point
	n_instances = float(sum([len(group) for group in groups]))
	# sum weighted Gini index for each group
	gini = 0.0
	for group in groups:
		size = float(len(group))
		# avoid divide by zero
		if size == 0:
			continue
		score = 0.0
		# score the group based on the score for each class
		for class_val in classes:
			p = [row[-1] for row in group].count(class_val) / size
			score += p * p
		# weight the group score by its relative size
		gini += (1.0 - score) * (size / n_instances)
	return gini

# Select the best split point for a dataset
def get_split(dataset):
	class_values = list(set(row[-1] for row in dataset))
	b_index, b_value, b_score, b_groups = 999, 999, 999, None
	for index in range(len(dataset[0])-1):
		for row in dataset:
			groups = test_split(index, row[index], dataset)
			gini = gini_index(groups, class_values)
			if gini < b_score:
				b_index, b_value, b_score, b_groups = index, row[index], gini, groups
	return {'index':b_index, 'value':b_value, 'groups':b_groups}

# Create a terminal node value
def to_terminal(group):
	outcomes = [row[-1] for row in group]
	return max(set(outcomes), key=outcomes.count)

# Create child splits for a node or make terminal
def split(node, max_depth, min_size, depth):
	left, right = node['groups']
	del(node['groups'])
	# check for a no split
	if not left or not right:
		node['left'] = node['right'] = to_terminal(left + right)
		return
	# check for max depth
	if depth >= max_depth:
		node['left'], node['right'] = to_terminal(left), to_terminal(right)
		return
	# process left child
	if len(left) <= min_size:
		node['left'] = to_terminal(left)
	else:
		node['left'] = get_split(left)
		split(node['left'], max_depth, min_size, depth+1)
	# process right child
	if len(right) <= min_size:
		node['right'] = to_terminal(right)
	else:
		node['right'] = get_split(right)
		split(node['right'], max_depth, min_size, depth+1)

# Build a decision tree
def build_tree(train, max_depth, min_size):
	root = get_split(train)
	split(root, max_depth, min_size, 1)
	return root

# Print a decision tree
def print_tree(node, depth=0):
	if isinstance(node, dict):
		print('%s[X%d < %.3f]' % ((depth*' ', (node['index']+1), node['value'])))
		print_tree(node['left'], depth+1)
		print_tree(node['right'], depth+1)
	else:
		print('%s[%s]' % ((depth*' ', node)))

dataset = [[2.771244718,1.784783929,0],
	[1.728571309,1.169761413,0],
	[3.678319846,2.81281357,0],
	[3.961043357,2.61995032,0],
	[2.999208922,2.209014212,0],
	[7.497545867,3.162953546,1],
	[9.00220326,3.339047188,1],
	[7.444542326,0.476683375,1],
	[10.12493903,3.234550982,1],
	[6.642287351,3.319983761,1]]
tree = build_tree(dataset, 1, 1)
print_tree(tree)

我们可以在运行此示例时更改最大深度参数,并在打印的树上查看(改动带来)效果。

最大深度为1(调用build_tree()函数中的第二个参数),我们可以看到该树使用了我们在上一节中发现的完美分割。这是一棵只有一个结点的树,也称为决策桩(decision stump)。

代码语言:txt
复制
[X1 < 6.642]
[0]
[1]

将最大深度增加到2之后,即使不需要,也会迫使树进行分割。该X1属性在之后被根结点的左,右两个孩子再次使用(进一步)分割已经完美组合的类(意思是,当前已经达到了最佳组合,再次分割反而导致结果变差,译者注)。

代码语言:txt
复制
[X1 < 6.642]
[X1 < 2.771]
  [0]
  [0]
[X1 < 7.498]
  [1]
  [1]

最后,执迷不悟的话,我们也可以强制进行一个最大深度为3的分割。

代码语言:txt
复制
[X1 < 6.642]
[X1 < 2.771]
  [0]
  [X1 < 2.771]
   [0]
   [0]
[X1 < 7.498]
  [X1 < 7.445]
   [1]
   [1]
  [X1 < 7.498]
   [1]
   [1]

这些测试表明,有很大的机会来完善(算法的)实现,以避免不必要的分割。这被留作一个扩展。

现在我们可以创建一个决策树,让我们看看我们如何使用它来新的数据进行预测。

做一次预测

使用决策树进行预测涉及使用专门提供的数据行来对树进行导航(navigating)(意思是将数据行顺着树的分支走向终端结点,得到一个由起始结点到终端结点的路径)。

再次说明,我们可以使用递归函数来实现这一点,实现时左侧或右侧子结点会调用相同的预测例程,这(左侧还是右侧)取决于分割如何影响提供的数据。

我们必须检查一个子结点是否是作为预测返回的终端值,或者是包含另一层树的字典结点。

下面是执行这个过程的predict*()函数。您可以看到结点中的索引和值是如何给出的。

您可以看到如何使用给定结点中的索引和值来评估提供的数据行是落在左侧还是右侧。

代码语言:txt
复制
# Make a prediction with a decision tree
def predict(node, row):
	if row[node['index']] < node['value']:
		if isinstance(node['left'], dict):
			return predict(node['left'], row)
		else:
			return node['left']
	else:
		if isinstance(node['right'], dict):
			return predict(node['right'], row)
		else:
			return node['right']

我们可以使用我们设计好的数据集来测试这个功能。下面是一个使用单个结点对数据进行最佳分割的硬编码决策树(即决策桩)的示例。

该示例对数据集中的每一行都进行了预测。

代码语言:txt
复制
# Make a prediction with a decision tree
def predict(node, row):
	if row[node['index']] < node['value']:
		if isinstance(node['left'], dict):
			return predict(node['left'], row)
		else:
			return node['left']
	else:
		if isinstance(node['right'], dict):
			return predict(node['right'], row)
		else:
			return node['right']

dataset = [[2.771244718,1.784783929,0],
	[1.728571309,1.169761413,0],
	[3.678319846,2.81281357,0],
	[3.961043357,2.61995032,0],
	[2.999208922,2.209014212,0],
	[7.497545867,3.162953546,1],
	[9.00220326,3.339047188,1],
	[7.444542326,0.476683375,1],
	[10.12493903,3.234550982,1],
	[6.642287351,3.319983761,1]]

#  predict with a stump
stump = {'index': 0, 'right': 1, 'value': 6.642287351, 'left': 0}
for row in dataset:
	prediction = predict(stump, row)
	print('Expected=%d, Got=%d' % (row[-1], prediction))

按运行该示例将打印每行的正确预测,就像我们所期望的那样。

代码语言:txt
复制
Expected=0, Got=0
Expected=0, Got=0
Expected=0, Got=0
Expected=0, Got=0
Expected=0, Got=0
Expected=1, Got=1
Expected=1, Got=1
Expected=1, Got=1
Expected=1, Got=1
Expected=1, Got=1

我们现在知道如何创建一个决策树并用它来做出预测。现在,让我们将其应用到一个真实的数据集。

5.钞票案例研究

本节将CART算法应用于钞票数据集*(ank Note dataset)。

第一步是加载数据集并将加载的数据转换为我们可用来计算分割点的数字值。为此,我们将使用helper函数load_csv()加载文件,使用str_column_to_float()将函数字符串数字转换为浮点数。

我们将使用5层(flod)的k-fold交叉验证来评估算法。这意味着1372/5 = 274.4或者超过270个记录将被用于每个层(fold)。我们将使用helper函数evaluate_algorithm()来评估交叉验证算法并使用accuracy_metric()函数来计算预测的准确性。

一个名为decision_tree()的新函数被开发了出来,用于管理CART算法的应用,(它)首先从训练数据集中创建树,然后使用树对测试数据集进行预测。

完整的例子如下所示。

代码语言:txt
复制
# CART on the Bank Note dataset# CART on the Bank Note dataset
from random import seed
from random import randrange
from csv import reader

# Load a CSV file
def load_csv(filename):
	file = open(filename, "rb")
	lines = reader(file)
	dataset = list(lines)
	return dataset

# Convert string column to float
def str_column_to_float(dataset, column):
	for row in dataset:
		row[column] = float(row[column].strip())

# Split a dataset into k folds
def cross_validation_split(dataset, n_folds):
	dataset_split = list()
	dataset_copy = list(dataset)
	fold_size = int(len(dataset) / n_folds)
	for i in range(n_folds):
		fold = list()
		while len(fold) < fold_size:
			index = randrange(len(dataset_copy))
			fold.append(dataset_copy.pop(index))
		dataset_split.append(fold)
	return dataset_split

# Calculate accuracy percentage
def accuracy_metric(actual, predicted):
	correct = 0
	for i in range(len(actual)):
		if actual[i] == predicted[i]:
			correct += 1
	return correct / float(len(actual)) * 100.0

# Evaluate an algorithm using a cross validation split
def evaluate_algorithm(dataset, algorithm, n_folds, *args):
	folds = cross_validation_split(dataset, n_folds)
	scores = list()
	for fold in folds:
		train_set = list(folds)
		train_set.remove(fold)
		train_set = sum(train_set, [])
		test_set = list()
		for row in fold:
			row_copy = list(row)
			test_set.append(row_copy)
			row_copy[-1] = None
		predicted = algorithm(train_set, test_set, *args)
		actual = [row[-1] for row in fold]
		accuracy = accuracy_metric(actual, predicted)
		scores.append(accuracy)
	return scores

# Split a dataset based on an attribute and an attribute value
def test_split(index, value, dataset):
	left, right = list(), list()
	for row in dataset:
		if row[index] < value:
			left.append(row)
		else:
			right.append(row)
	return left, right

# Calculate the Gini index for a split dataset
def gini_index(groups, classes):
	# count all samples at split point
	n_instances = float(sum([len(group) for group in groups]))
	# sum weighted Gini index for each group
	gini = 0.0
	for group in groups:
		size = float(len(group))
		# avoid divide by zero
		if size == 0:
			continue
		score = 0.0
		# score the group based on the score for each class
		for class_val in classes:
			p = [row[-1] for row in group].count(class_val) / size
			score += p * p
		# weight the group score by its relative size
		gini += (1.0 - score) * (size / n_instances)
	return gini

# Select the best split point for a dataset
def get_split(dataset):
	class_values = list(set(row[-1] for row in dataset))
	b_index, b_value, b_score, b_groups = 999, 999, 999, None
	for index in range(len(dataset[0])-1):
		for row in dataset:
			groups = test_split(index, row[index], dataset)
			gini = gini_index(groups, class_values)
			if gini < b_score:
				b_index, b_value, b_score, b_groups = index, row[index], gini, groups
	return {'index':b_index, 'value':b_value, 'groups':b_groups}

# Create a terminal node value
def to_terminal(group):
	outcomes = [row[-1] for row in group]
	return max(set(outcomes), key=outcomes.count)

# Create child splits for a node or make terminal
def split(node, max_depth, min_size, depth):
	left, right = node['groups']
	del(node['groups'])
	# check for a no split
	if not left or not right:
		node['left'] = node['right'] = to_terminal(left + right)
		return
	# check for max depth
	if depth >= max_depth:
		node['left'], node['right'] = to_terminal(left), to_terminal(right)
		return
	# process left child
	if len(left) <= min_size:
		node['left'] = to_terminal(left)
	else:
		node['left'] = get_split(left)
		split(node['left'], max_depth, min_size, depth+1)
	# process right child
	if len(right) <= min_size:
		node['right'] = to_terminal(right)
	else:
		node['right'] = get_split(right)
		split(node['right'], max_depth, min_size, depth+1)

# Build a decision tree
def build_tree(train, max_depth, min_size):
	root = get_split(train)
	split(root, max_depth, min_size, 1)
	return root

# Make a prediction with a decision tree
def predict(node, row):
	if row[node['index']] < node['value']:
		if isinstance(node['left'], dict):
			return predict(node['left'], row)
		else:
			return node['left']
	else:
		if isinstance(node['right'], dict):
			return predict(node['right'], row)
		else:
			return node['right']

# Classification and Regression Tree Algorithm
def decision_tree(train, test, max_depth, min_size):
	tree = build_tree(train, max_depth, min_size)
	predictions = list()
	for row in test:
		prediction = predict(tree, row)
		predictions.append(prediction)
	return(predictions)

# Test CART on Bank Note dataset
seed(1)
# load and prepare data
filename = 'data_banknote_authentication.csv'
dataset = load_csv(filename)
# convert string attributes to integers
for i in range(len(dataset[0])):
	str_column_to_float(dataset, i)
# evaluate algorithm
n_folds = 5
max_depth = 5
min_size = 10
scores = evaluate_algorithm(dataset, decision_tree, n_folds, max_depth, min_size)
print('Scores: %s' % scores)
print('Mean Accuracy: %.3f%%' % (sum(scores)/float(len(scores))))

该示例使用5层的最大树深度以及每个结点的最小行数设定为10个。这些用于CART的参数是通过一些实验来选择的,但绝不是最佳的。

运行该示例将打印每个层(folder)的平均分类准确度以及所有层(folder)的平均性能。

您可以看到,CART和所选配置(参数)的平均分类准确率达到了97%左右,远远超过了达到50%准确度的零规则算法。

代码语言:txt
复制
Scores: [96.35036496350365, 97.08029197080292, 97.44525547445255, 98.17518248175182, 97.44525547445255]
Mean Accuracy: 97.299%

扩展

本小节列出了您可能希望探索的本教程的扩展。

  • 算法调优(Algorithm Tuning)。CART在Bank Note数据集中的应用没有被调整过。(你可以)尝试使用不同的参数值,看看能否取得更好的表现。
  • 交叉熵(Cross Entropy)。另一个评估分割的成本函数是交叉熵(logloss)。你可以实现这个替代成本函数来进行实验。
  • 树枝修剪(Tree Pruning)。减少训练数据集过度拟合的一个重要技术是树枝修剪。(你可以)查一查并实现树枝修剪的方法。
  • 分类数据集(Categorical Dataset)。这个例子是为具有数字或有序的输入属性的输入数据而设计的,可以使用分类输入数据对它进行实验并且使用对等(equality)而不是排名(Ranking)方式对数据进行分割。
  • 回归(Regression)。使用不同的成本函数和创建终端结点的方法来修改树进行回归操作。
  • 更多数据集(More Datasets)。将算法应用于UCI机器学习库(UCI Machine Learning Repository)中的更多数据集。

你有没有探索这些扩展?

在下面的评论中分享你的经验。

评论

在本教程中,您了解了如何从零开始使用Python实现决策树算法。

具体来说,你学到了:

  • 如何选择和评估训练数据集中的分割点。
  • 如何从多次分割中递归地构建决策树。
  • 如何将CART算法应用于真实世界分类预测建模问题。
评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • How To Implement The Decision Tree Algorithm From Scratch In Python
  • 从零开始在Python中实现决策树算法
  • 说明
    • 分类回归树
      • 钞票数据集(Banknote)
      • 教程
        • 1.基尼(Gini)指数
          • 2.创建分割
            • 2.1.一个数据集
            • 2.2. 评估所有分割
          • 3.构建一棵树
            • 3.1. 终端结点
            • 3.2. 递归分割
            • 3.3. 构建一棵树
          • 做一次预测
            • 5.钞票案例研究
            • 扩展
            • 评论
            领券
            问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档