
想象一个生活中的场景,我们去水果店买一个西瓜,该怎么判断一个西瓜是不是又甜又好的呢?我们可能会问自己一系列问题:
这个过程,就是你大脑中的一棵“决策树”。决策树算法,就是让计算机从数据中自动学习出这一系列问题和判断规则的方法。
它的核心思想非常简单:通过提出一系列问题,对数据进行层层筛选,最终得到一个结论(分类或预测)。每一个问题都是关于某个特征的判断(例如:“纹路是否清晰?”),而每个答案都会引导我们走向下一个问题,直到得到最终答案。
一棵成熟的决策树包含以下部分:


现在我们来解决最关键的问题:计算机如何从一堆数据中自动找出最好的提问顺序?
1. 关键问题:根据哪个特征进行分裂?
假设我们有一个西瓜数据集,包含很多西瓜的特征(纹路、根蒂、声音、触感...)和标签(好瓜/坏瓜)。 在根节点,我们有所有数据。算法需要决定:第一个问题应该问什么? 是问“纹路清晰吗?”还是“声音清脆吗?”?
选择的标准是:哪个特征能最好地把数据分开,使得分裂后的子集尽可能纯净。所谓纯净,就是同一个子集里的西瓜尽可能都是好瓜,或者都是坏瓜。
2. 衡量标准:“不纯度”的度量
我们如何量化“纯度”呢?科学家们设计了几种指标来衡量“不纯度”:
3. 核心概念:信息增益
决策树算法通过计算信息增益来决定用什么特征分裂。
信息增益 = 分裂前的不纯度 - 分裂后的不纯度
信息增益越大,说明这个特征分裂后,数据的纯度提升得越多,这个特征就越应该被用来做分裂。
简单比喻:
4. 核心算法
5. 停止条件
不能无限地分下去,否则每个叶节点可能只有一个样本(过拟合)。停止条件包括:
理论深奥让人难以琢磨,我们来点实际的。用经典的scikit-learn库,建一棵决策树,细细的分析一下里面的每个步骤;
import pandas as pd
import numpy as np
from sklearn import tree
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
from io import StringIO
import pydotplus
from IPython.display import Image
# 1. 设置中文字体支持
# 尝试使用系统中已有的中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans', 'Arial Unicode MS', 'Microsoft YaHei']
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
# 2. 创建示例数据集(使用鸢尾花数据集,但用中文重命名)
iris = load_iris()
X = iris.data
y = iris.target
# 创建中文特征名称和类别名称
chinese_feature_names = ['花萼长度', '花萼宽度', '花瓣长度', '花瓣宽度']
chinese_class_names = ['山鸢尾', '变色鸢尾', '维吉尼亚鸢尾']
# 3. 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# 4. 创建并训练决策树模型
clf = tree.DecisionTreeClassifier(
criterion='gini', # 使用基尼不纯度
max_depth=3, # 限制树深度,防止过拟合
min_samples_split=2, # 节点最小分裂样本数
min_samples_leaf=1, # 叶节点最小样本数
random_state=42 # 随机种子,确保结果可重现
)
clf.fit(X_train, y_train)
# 5. 评估模型
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy:.2%}")
# 6. 可视化决策树 - 方法1:使用Matplotlib(简单但不支持中文特征名)
plt.figure(figsize=(20, 12))
tree.plot_tree(
clf,
feature_names=chinese_feature_names, # 使用中文特征名
class_names=chinese_class_names, # 使用中文类别名
filled=True, # 填充颜色表示类别
rounded=True, # 圆角节点
proportion=True, # 显示比例而非样本数
precision=2 # 数值精度
)
plt.title("决策树可视化 - 鸢尾花分类", fontsize=16)
plt.savefig('decision_tree_chinese.png', dpi=300, bbox_inches='tight')
plt.show()
这是一个鸢尾花分类的决策过程,首先简单描述一下鸢尾花分类的基础知识,鸢尾花分类是一个经典的机器学习入门问题,也是一个多类别分类任务。其目标是根据一朵鸢尾花的花瓣和花萼的测量数据,自动判断它属于三个品种中的哪一种。
接下来这些很重要,很重要,很重要!

整个数据集就是一个大表格,有150行(代表150朵不同的花)和5列。
品种 (标签) | 花萼长度 | 花萼宽度 | 花瓣长度 | 花瓣宽度 |
|---|---|---|---|---|
Iris-setosa | 5.1 | 3.5 | 1.4 | 0.2 |
Iris-versicolor | 7.0 | 3.2 | 4.7 | 1.4 |
Iris-virginica | 6.3 | 3.3 | 6.0 | 2.5 |
... | ... | ... | ... |
数据集包含150个样本,每个样本有4个特征和1个标签:
它之所以成为经典,是因为它具备了一个完美教学数据集的所有特点:
看到这个图,首先要明白这是在决策鸢尾花具体是属于哪一种类型,每一层都有几个值:花瓣长度、gini、samples、value、class,其中:
先了解概念,在了解具体的公式和推导值;
4.1 第一步:判断“花瓣长度 <= 2.45”
此处非常有意思,2.45是怎么来的,是固定的还是随机抽取的,首先这不是随意的,而是通过严密的数学计算和优化选择得出的结果:
从最直观理解,数据分布的角度,花瓣长度的数据分布一般在:
由上得知可以看到:
这个点能够完美地将Setosa从其他两种花中分离出来
其次,最有依据的是基于准确的数学方法,最佳分裂点(先了解,后面会细讲):
4.2 samples(样本总量)100%
这个很好理解,此时抽取的是所有样本,所有数量为100%。
4.3 value = [0.33, 0.34, 0.32]
value表示的是每个样本在三类鸢尾花中的分布,这里也比较有趣味性了,按常理来说应该都是均分,都应该是0.3333,为什么会有差异呢,与训练集和测试集的分布有关系,这个比例会随着你划分训练集和测试集的方式不同而发生微小的变化。
简单看看这个变化的过程:
4.3.1 数据的初始状态:理论上应该是固定的
鸢尾花数据集本身有150个样本,每个品种(Setosa, Versicolor, Virginica)各50个。因此,在整个数据集中,每个类别的比例是精确的:
value = [50/150, 50/150, 50/150] = [0.333..., 0.333..., 0.333...]
所以,如果你在根节点看到 value = [0.33, 0.34, 0.32] 而不是完美的 [0.333, 0.333, 0.333],这已经暗示了我们没有使用全部150个样本。
4.3.2 为什么我们看到的不是固定值?—— 训练集与测试集的划分
在机器学习中,我们不会用全部数据来训练模型。为了评估模型的真实性能,我们通常会将数据划分为训练集和测试集。
最常用的划分比例是 80% 的数据用于训练,20% 用于测试。
关键点就在这里:150 * 0.8 = 120。现在,训练集只剩下120个样本。原来每个类别有50个,但在随机抽取80%后,每个类别在训练集中的数量几乎是 50 * 0.8 = 40,但不会那么精确。
4.3.3 random_state 参数的作用
您可能注意到了上面代码中的 random_state=42。这个参数控制了随机抽样的“种子”。
所以,value 的值是“固定”还是“变化”,完全取决于你的代码配置。
4.3.4 抽取的流程
下面这张图展示了数据如何从原始全集被随机划分到训练集,从而导致节点中类别比例发生微小变化的过程:

所以,看到的 [0.33, 0.34, 0.32] 是一个在随机划分训练集后,各类别比例的正常、微小的波动表现,并不意味着数据或代码有问题。
4.4 gini = 0.67
gini值计算的公式:

其中:
根节点参数:
4.5 class = 变色鸢尾
对应的占比,中间的Iris Versicolor变色鸢尾比例为0.34居多,所以当前的预测类别偏重于变色鸢尾。
4.6 花瓣长度<=2.45的结果
4.6.1 结果成立
如果结果成立则走第二次的左侧节点,直接判定为山鸢尾,流程结束。
强化值计算:
4.6.2 如果结果不成立
如果结果成立则走第二次的右侧节点,继续下一步的决策,调整判断参数,判断“花瓣长度<=4.75”,观察对应的参数值:
按照这样的思路,逐步分析决策,最终匹配到最适合的类型;如果还是有疑问,可以从根节点开始,跟着它的条件一步步走,看看模型是如何根据花的尺寸来分类的。这就像看到了模型的“思考过程”,非常直观!
决策树是机器学习中最基础、最直观的算法之一:
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。