前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Python人工智能:基于sklearn的决策树分类算法实现总结

Python人工智能:基于sklearn的决策树分类算法实现总结

作者头像
用户1143655
发布2023-03-21 20:16:29
1.3K0
发布2023-03-21 20:16:29
举报
文章被收录于专栏:全栈之殇

一、sklearn实现决策树简介

!! ✨ sklearn决策树模块包括回归与分类决策树两大类,另外还包括三个决策树结构绘制函数。对于初学者我们的重点可以先放在结合决策树的基本原理的基础上,学会对这些接口的灵活应用,本文以分类决策树为例进行介绍。

  • sklearn官方给出的sklearn.tree 相关API接口如下表所示:

Sklearn决策树API接口

功能

tree.DecisionTreeClassifier

决策树分类器

tree.DecisionTreeRegressor

决策树回归器

tree.ExtraTreeClassifier

An extremely randomized tree classifier.

tree.ExtraTreeRegressor

An extremely randomized tree regressor.

tree.export_graphviz

导出一个决策树为DOT格式

tree.export_text

生成一份展示决策树规则的文本报告

tree.plot_tree

绘制决策树

  • sklearn基本流程如下图所示:

二、通过sklearn实现一个分类决策树实例

本文通过sklearn实现一个分类决策树包括如下四个步骤:

  • (1) 数据集信息查看
  • (2) 数据集的获取与预处理
  • (3) 分类决策树模型构建
  • (4) 模型结构图可视化
  • (5) 特征重要性结果查看

2.1 数据集信息查看

!! ✨ 只有对数据集有了充分了解才能很好地使用它.

通过下面的代码获取本文使用的数据集,并查看数据集的具体信息:

代码语言:javascript
复制
from sklearn.datasets import load_wine

# 获取数据集
wine = load_wine()

# 查看wine数据集的整体信息
print("wine数据的类型:\n", type(wine))
print("\nwine数据包含的内容:\n", wine.keys())

# 查看data数据的信息
print("\ndata数据的类型:\n", type(wine.data))
print("\ndata数据的形状:\n", wine.data.shape)

# 查看target数据的信息
print("\ntarget数据的类型:\n", type(wine.target))
print("\ntarget数据的形状:\n", wine.target.shape)

代码执行结果如下图所示:

由此,可以看出wine数据集是一种字典格式的数据,其主要内容如下所示:

  • (1) data:特征数据,用作X的值。其数据类型为ndarray格式,其形状为(178, 13)即具有178个数据,每个数据包含13个特征。
  • (2) target:目标数据(酒的类别),用作y的值,其形状为(178,);
  • (3) frame:None;
  • (4) target_names:目标值标签,其为array(['class_0', 'class_1', 'class_2'])
  • (5) DESCR:数据集描述信息;
  • (6) feature_names:13个特征对应实际含义。

2.2 数据集的获取与预处理

数据集的获取与预处理如下所示:

代码语言:javascript
复制
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split

# 获取数据集并对数据进行训练/测试集的切分
wine = load_wine()
# 将数据集按照7:3划分为训练数据集与测试数据集
X_train, X_test, y_train, y_test = train_test_split(
    wine.data, wine.target, test_size=0.3
)

2.3 分类决策树模型构建

仅需要三行代码就可以简单的实现分类决策树模型的构建:

代码语言:javascript
复制
clf = DecisionTreeClassifier()  # 分类决策树模型实例化
clf.fit(X_train, y_train)       # 使用训练集进行模型训练
score = clf.score(X_test, y_test)   # 测试模型在数据集上的性能表现

# 查看分类决策树预测结果
print("分类决策树在测试集上的分类精度:", score)

代码执行结果如下图所示:

由此可见,使用sklearn默认参数的分类决策树分类精度高达90.7%。

2.4 模型结构图可视化

本文使用sklearn的sklearn.tree.export_graphviz类函数实现分类决策树的可视化。需要注意的的是我们需要首先配置graphviz软件,具体配置方法可以自行百度,我前面写了一篇文章可供参考Python人工智能:Ubuntu系统中网络结构绘图工具库Graphviz的使用方法简介

本文的分类决策树可视化代码如下所示:

代码语言:javascript
复制
# 模型结构图可视化
from sklearn.tree import export_graphviz
import graphviz

# 将英文特征名字映射为中文
feature_names = [
    '酒精',         # alcohol
    '苹果酸',       # malic_acid
    '灰',          # ash
    '灰的碱性',     # alcalinity_of_ash
    '镁',          # magnesium
    '总酚',        # total_phenols
    '类黄酮',       # flavanoids
    '非黄烷类酚类',  # nonflavanoid_phenols
    '花青素',       # proanthocyanins
    '颜色强度',     # color_intensity
    '色调',        # hue
    'od280/od315稀释葡萄酒', # od280/od315_of_diluted_wines
    '脯氨酸'       # proline
]

dot_data = export_graphviz(
    clf,                        # 需要绘制的模型
    feature_names=feature_names,# 特征名
    class_names=["琴酒", "雪梨", "贝尔摩德"],   # 分类名称
    filled=True,                # 颜色填充
    rounded=True                # 边框圆角
)

# 绘图
graph = graphviz.Source(dot_data)
graph

代码执行结果如下所示:

2.5 特征重要性结果查看

查看模型每个特征对于决策树分类重要性的代码如下所示:

代码语言:javascript
复制
clf.feature_importances_

代码的执行结果如下图所示:

我们还可以通过下面的命令,更加直观的展示各个特征对于模型的重要性:

代码语言:javascript
复制
[*zip(feature_names, clf.feature_importances_)]

代码执行结果如下图所示:

三、分类决策树实例化的主要参数

DecisionTreeClassifier分类决策树实例化类方法为例,其经常用的7个参数如下所示:

代码语言:javascript
复制
from sklearn.tree import DecisionTreeClassifier

clf_tree = DecisionTreeClassifier(
    criterion='gini',           # 不纯度衡量指标计算方法的选择参数
    splitter='best',            # 决策树特征组合随机选择参数
    max_depth=5,                # 限制决策树最大深度
    min_samples_leaf=10,        # 设定一个节点在分枝后的每个子节点包含的最小的训练样本个数
    min_samples_split=10,       # 设定一个节点必须包含的最小训练样本个数
    max_features=5,             # 限制分枝时考虑的特征个数
    min_impurity_decrease=1e-2  # 限制信息增益的大小
)

下面详细介绍上述7个主要参数的具体含义:

3.1 criterion:不纯度衡量指标计算方法的选择参数

对于决策树,我们通常使用不纯度来衡量决策树能够找到最佳节点与最佳分枝方法的程度。通常,不纯度越低,决策树对训练集的拟合越好。决策树分枝算法的设计的核心就在于对与不纯度相关指标的优化上。

criterion是用来决定不纯度的计算方法,sklearn提供了两种计算方法:

  • (1) entropy:信息熵方法;
  • (2) gini:基尼系数方法(默认方法)。

两种方法的比较:比起基尼系数,信息熵对不纯度更加敏感,对不纯度的惩罚更强,但是实际应用,两种方法的效果基本相同。由于信息熵的计算涉及对数计算,所以信息熵的计算比基尼系数缓慢一些。另外,由于信息熵对不纯度更加敏感,所以基于信息熵的决策树通常会生长的更加精细,通常情况,对于高维度数据或者噪声很多的数据,其更容易出现过拟合现象。

3.2 splitter:决策树特征组合随机选择参数

splitter是用来控制决策树特征组合随机选择方法的参数,其包括两种方法:

  • (1) best:(默认方法)使用该方法时,决策树在分枝时虽然随机,但是其会优先选择最重要的特征进行分枝;
  • (2) random:决策树在分枝时更加随机,树相应的会更深,从而降低了对训练数据的过拟合程度。

3.3 剪枝策略控制参数

在不加限制的情况下,一颗决策树通常会生长到不纯度指标达到最优,或者没有更多的特征可用为止,这很容易导致决策树出现过拟合现象。此时我们就需要考虑如下一个关键问题:

!! 🤔 决策树对训练集的拟合程度如何控制,才能在测试集上表现出同样的预测效果?即如何对决策树进行合理剪枝,以防止过拟合线性和提高模型的泛化能力。

因此,剪枝策略对决策树的影响巨大,合理的剪枝策略是优化决策树算法的核心。

  • sklearn中提供的决策树包括的剪枝策略如下表所示:

剪枝策略参数

作用

max_dapth

(最常用的参数)用于限制决策树最大深度,超过设定深度的树枝全部剪掉

min_samples_leaf

用于设定一个节点在分枝后的每个子节点包含的最小的训练样本个数,小于该设定值则结束该节点的分枝

min_samples_split

用于设定一个节点必须包含的最小训练样本个数,小于该设定值则结束该节点的分枝

max_features

用于限制分枝时考虑的特征个数,超过限制个数的特征都会被舍弃

min_impurity_decrease

限制信息增益的大小,信息增益小于设定值时停止剪枝

!! 🫧 确定最优的剪枝参数编程技巧:

通常,我们可以使用超参数曲线法来确定最优的剪枝参数。超参数的学习曲线是一条以超参数的取值为横坐标,模型的衡量指标为纵坐标的曲线,通过超参数可以量化不同超参数取值下模型的表现曲线。

下面以max_depth参数的学习曲线为例,给出常用的代码模板:

代码语言:javascript
复制
import matplotlib.pyplot as plt

test = []       # 用于存放不同max_depth下的评价结果

# 通过for循环获得不同max_depth下的模型的评价指标结果
for i in range(10):
    clf = DecisionTreeClassifier(
        max_depth=i+1,
        criterion='gini',
        splitter="random"
    )                                   # 实例化分类决策树模型对象
    clf = clf.fit(X_train, y_train)     # 模型训练
    score = clf.score(X_test, y_test)   # 获取模型在测试集上的评价结果
    test.append(score)

# 绘制超参数的学习曲线
plt.plot(
    range(1,11), test,
    color="red", label="max_depth"
)
plt.legend()
plt.show()

代码执行结果如下图所示:


本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2023-03-13,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 人工智能技术栈 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、sklearn实现决策树简介
  • 二、通过sklearn实现一个分类决策树实例
  • 三、分类决策树实例化的主要参数
相关产品与服务
灰盒安全测试
腾讯知识图谱(Tencent Knowledge Graph,TKG)是一个集成图数据库、图计算引擎和图可视化分析的一站式平台。支持抽取和融合异构数据,支持千亿级节点关系的存储和计算,支持规则匹配、机器学习、图嵌入等图数据挖掘算法,拥有丰富的图数据渲染和展现的可视化方案。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档