用Python进行AI数据分析进阶教程49:
决策树的剪枝
关键词:决策树、剪枝、预剪枝、后剪枝、过拟合
摘要:本文介绍了决策树剪枝的基本概念及其重要性,重点讲解了防止决策树过拟合的两种主要方法:预剪枝和后剪枝。预剪枝在树的构建过程中提前停止划分,通过设置最大深度、最小样本数等参数控制树的复杂度,但存在欠拟合风险;后剪枝则是在树完全生成后进行自底向上的剪枝操作,常用方法为代价复杂度剪枝(CCP),利用验证集评估剪枝效果,虽计算成本较高但泛化能力更强。文中还提供了基于Scikit-Learn库实现预剪枝与后剪枝的具体Python代码示例,并通过鸢尾花数据集展示了模型训练与评估过程,突出了剪枝对提升分类准确率的实际意义。
👉 欢迎订阅🔗
《用Python进行AI数据分析进阶教程》专栏
《AI大模型应用实践进阶教程》专栏
《Python编程知识集锦》专栏
《字节跳动旗下AI制作抖音视频》专栏
《智能辅助驾驶》专栏
《工具软件及IT技术集锦》专栏
决策树剪枝是为了防止决策树过拟合,提高其泛化能力的重要手段。主要分为预剪枝和后剪枝两种方法,下面将详细讲解这两种剪枝方法。
一、预剪枝
1、原理
预剪枝是在决策树构建过程中,在每个节点划分之前,先评估划分是否能带来模型性能的提升,如果不能则停止划分。常用的评估指标包括信息增益、基尼系数等,同时也可以设置一些限制条件,如树的最大深度、最小样本数等。
2、关键点
- 评估指标:使用信息增益、基尼系数等指标来评估划分的优劣。
- 限制条件:设置树的最大深度、最小样本数、最小信息增益等限制条件。
3、注意点
- 欠拟合风险:预剪枝可能会过早停止树的生长,导致模型欠拟合。
- 参数选择:限制条件的参数选择需要根据具体数据集进行调整,不同的参数可能会导致不同的模型性能。
4、示例及代码
Python脚本
# 从 sklearn 库的 datasets 模块中导入 load_iris 函数,用于加载鸢尾花数据集
from sklearn.datasets import load_iris
# 从 sklearn 库的 model_selection 模块中导入 train_test_split 函数,
# 用于将数据集划分为训练集和测试集
from sklearn.model_selection import train_test_split
# 从 sklearn 库的 tree 模块中导入 DecisionTreeClassifier 类,用于创建决策树分类器
from sklearn.tree import DecisionTreeClassifier
# 从 sklearn 库的 metrics 模块中导入 accuracy_score 函数,用于计算分类准确率
from sklearn.metrics import accuracy_score
# 调用 load_iris 函数加载鸢尾花数据集,并将其赋值给变量 iris
# 鸢尾花数据集是一个经典的分类数据集,包含 150 个样本,分为 3 个类别
iris = load_iris()
# 从 iris 数据集中提取特征数据,赋值给变量 X
# 特征数据包含了鸢尾花的一些测量值,如花瓣长度、花瓣宽度等
X = iris.data
# 从 iris 数据集中提取标签数据,赋值给变量 y
# 标签数据表示每个样本所属的类别
y = iris.target
# 使用 train_test_split 函数将特征数据 X 和标签数据 y 划分为训练集和测试集
# test_size=0.3 表示将 30% 的数据作为测试集,70% 的数据作为训练集
# random_state=42 是随机数种子,保证每次划分的结果一致,方便结果复现
X_train, X_test, y_train, y_test = (train_test_split(X, y,
test_size=0.3, random_state=42))
# 创建一个 DecisionTreeClassifier 决策树分类器对象,并设置预剪枝参数
# max_depth=3 表示决策树的最大深度为 3,防止树生长过深导致过拟合
# min_samples_split=5 表示一个节点要进行划分时,
# 至少需要包含 5 个样本,同样是为了防止过拟合
clf = DecisionTreeClassifier(max_depth=3, min_samples_split=5)
# 调用决策树分类器的 fit 方法,使用训练集数据 X_train 和对应的标签 y_train 对模型进行训练
# 训练过程就是让决策树学习特征和标签之间的关系
clf.fit(X_train, y_train)
# 调用训练好的决策树分类器的 predict 方法,使用测试集数据 X_test 进行预测
# 预测结果存储在变量 y_pred 中
y_pred = clf.predict(X_test)
# 调用 accuracy_score 函数,计算测试集真实标签 y_test 和预测标签 y_pred 之间的准确率
# 准确率是分类正确的样本数占总样本数的比例
accuracy = accuracy_score(y_test, y_pred)
# 使用 f-string 格式化输出准确率
# 打印出训练好的决策树模型在测试集上的分类准确率
print(f"Accuracy: {accuracy}")
输出 / 打印结果及注释
代码的输出结果会类似如下形式:
plaintext
Accuracy: 0.9555555555555556
这里的输出值是一个近似值,实际的准确率会因为浮点数的精度问题显示为多位小数。
- Accuracy:表示模型在测试集上的分类准确率。这个值越接近 1,说明模型在测试集上的分类效果越好。在这个例子中,准确率约为 0.956,意味着模型在测试集上大约 95.6% 的样本分类正确。不过,每次运行代码时,由于数据集的划分可能会有细微差异(虽然设置了 random_state 尽量保证结果一致,但不同环境可能仍有极微小差别),准确率可能会在一定范围内波动。
重点语句解读
- DecisionTreeClassifier(max_depth=3, min_samples_split=5):创建决策树分类器,max_depth=3 表示树的最大深度为 3,min_samples_split=5 表示节点划分所需的最小样本数为 5。这些参数的设置可以限制树的生长,避免过拟合。
- clf.fit(X_train, y_train):使用训练集数据对决策树模型进行训练。
- clf.predict(X_test):使用训练好的模型对测试集数据进行预测。
- accuracy_score(y_test, y_pred):计算预测结果的准确率。
二、后剪枝
1、原理
后剪枝是在决策树构建完成后,自底向上地对非叶子节点进行评估,判断将该节点对应的子树替换为叶子节点是否能提升模型的泛化能力。如果能,则进行剪枝操作。常用的后剪枝方法有代价复杂度剪枝(CCP)。
2、关键点
- 剪枝评估:使用验证集来评估剪枝前后模型的性能。
- 剪枝策略:根据评估结果决定是否进行剪枝。
3、注意点
- 计算成本:后剪枝需要在决策树构建完成后进行,计算成本相对较高。
- 验证集选择:验证集的选择会影响剪枝的效果,需要合理划分验证集。
4、示例及代码
Python脚本
# 从 sklearn 库的 datasets 模块中导入 load_iris 函数,用于加载经典的鸢尾花数据集
from sklearn.datasets import load_iris
# 从 sklearn 库的 model_selection 模块中导入 train_test_split 函数,
# 该函数可将数据集划分为不同子集
from sklearn.model_selection import train_test_split
# 从 sklearn 库的 tree 模块中导入 DecisionTreeClassifier 类,用于创建决策树分类模型
from sklearn.tree import DecisionTreeClassifier
# 从 sklearn 库的 metrics 模块中导入 accuracy_score 函数,用于计算分类模型的准确率
from sklearn.metrics import accuracy_score
# 调用 load_iris 函数加载鸢尾花数据集,并将其赋值给变量 iris
# 鸢尾花数据集包含了 150 个样本,每个样本有 4 个特征,分为 3 个类别
iris = load_iris()
# 从 iris 数据集中提取特征数据,存储在变量 X 中
# 这些特征是鸢尾花的一些属性,如花瓣长度、花瓣宽度等
X = iris.data
# 从 iris 数据集中提取标签数据,存储在变量 y 中
# 标签代表每个样本所属的类别
y = iris.target
# 使用 train_test_split 函数将数据集划分为训练集和临时集
# test_size=0.4 表示将 40% 的数据作为临时集,60% 的数据作为训练集
# random_state=42 是随机数种子,保证每次划分的结果一致,方便结果复现
X_train,X_temp,y_train,y_temp=train_test_split(X,y,test_size=0.4,random_state=42)
# 再次使用 train_test_split 函数将临时集划分为验证集和测试集
# test_size=0.5 表示将临时集的 50% 作为测试集,另外 50% 作为验证集
# 这样整体上训练集、验证集、测试集的比例大致为 60%、20%、20%
X_val,X_test,y_val,y_test=(train_test_split(X_temp,y_temp,
test_size=0.5,random_state=42))
# 创建一个 DecisionTreeClassifier 类的实例 clf,即一个决策树分类器
# 这里没有设置额外参数,将使用默认的参数来构建决策树
clf = DecisionTreeClassifier()
# 调用决策树分类器的 fit 方法,
# 使用训练集数据 X_train 和对应的标签 y_train 对模型进行训练
# 训练过程就是让决策树学习特征和标签之间的关系,构建决策树结构
clf.fit(X_train, y_train)
# 调用决策树分类器的 cost_complexity_pruning_path 方法,进行代价复杂度剪枝路径的计算
# 该方法会计算不同剪枝强度(由 ccp_alpha 参数控制)下的子树及其对应的不纯度
# 结果存储在 path 对象中
path = clf.cost_complexity_pruning_path(X_train, y_train)
# 从 path 对象中提取不同的 ccp_alpha 值,存储在 ccp_alphas 变量中
# ccp_alpha 是代价复杂度剪枝的参数,值越大,剪枝越严重
ccp_alphas, impurities = path.ccp_alphas, path.impurities
# 初始化最优的 ccp_alpha 值为 None,用于后续存储找到的最优值
best_alpha = None
# 初始化最优准确率为 0,用于后续比较和更新
best_accuracy = 0
# 遍历所有的 ccp_alpha 值
for ccp_alpha in ccp_alphas:
# 创建一个新的决策树分类器 pruned_clf,并设置当前的 ccp_alpha 值
# 这样就得到了一个使用该剪枝强度的决策树模型
pruned_clf = DecisionTreeClassifier(ccp_alpha=ccp_alpha)
# 使用训练集数据 X_train 和对应的标签 y_train 对剪枝后的决策树模型进行训练
pruned_clf.fit(X_train, y_train)
# 调用剪枝后决策树模型的 predict 方法,使用验证集数据 X_val 进行预测
# 预测结果存储在 y_val_pred 中
y_val_pred = pruned_clf.predict(X_val)
# 调用 accuracy_score 函数,
# 计算验证集真实标签 y_val 和预测标签 y_val_pred 之间的准确率
accuracy = accuracy_score(y_val, y_val_pred)
# 如果当前的准确率大于之前记录的最优准确率
if accuracy > best_accuracy:
# 更新最优准确率为当前准确率
best_accuracy = accuracy
# 更新最优的 ccp_alpha 值为当前的 ccp_alpha 值
best_alpha = ccp_alpha
# 创建一个新的决策树分类器 final_clf,并使用最优的 ccp_alpha 值
# 这样就得到了一个使用最优剪枝强度的决策树模型
final_clf = DecisionTreeClassifier(ccp_alpha=best_alpha)
# 使用训练集数据 X_train 和对应的标签 y_train 对使用最优剪枝强度的决策树模型进行训练
final_clf.fit(X_train, y_train)
# 调用使用最优剪枝强度的决策树模型的 predict 方法,使用测试集数据 X_test 进行预测
# 预测结果存储在 y_test_pred 中
y_test_pred = final_clf.predict(X_test)
# 调用 accuracy_score 函数,
# 计算测试集真实标签 y_test 和预测标签 y_test_pred 之间的准确率
test_accuracy = accuracy_score(y_test, y_test_pred)
# 使用 f-string 格式化输出测试集上的准确率
print(f"Test Accuracy: {test_accuracy}")
输出 / 打印结果及注释
代码运行后的输出可能类似如下:
plaintext
Test Accuracy: 0.9666666666666667
- Test Accuracy:这是使用最优剪枝强度的决策树模型在测试集上的分类准确率。它反映了经过代价复杂度剪枝后,模型在未参与训练和验证的数据上的泛化能力。该值越接近 1 越好,这里约为 0.967,意味着模型在测试集上大约 96.7% 的样本分类正确。不同的运行可能会因为随机种子的影响(虽然设置了 random_state 尽量保证一致,但环境差异等因素仍可能有极小波动)导致准确率略有不同。
重点语句解读
- clf.cost_complexity_pruning_path(X_train, y_train):计算决策树的代价复杂度剪枝路径,返回不同 ccp_alpha 值对应的子树的不纯度。
- ccp_alphas, impurities = path.ccp_alphas, path.impurities:获取不同 ccp_alpha 值和对应的不纯度。
- DecisionTreeClassifier(ccp_alpha=ccp_alpha):创建一个使用指定 ccp_alpha 值的决策树分类器,ccp_alpha 是一个用于控制剪枝强度的参数,值越大,剪枝越严重。
- final_clf = DecisionTreeClassifier(ccp_alpha=best_alpha):使用最优的 ccp_alpha 值重新训练决策树模型。
通过预剪枝和后剪枝,可以有效地防止决策树过拟合,提高模型的泛化能力。在实际应用中,可以根据具体情况选择合适的剪枝方法。
——The END——
🔗 欢迎订阅专栏
序号 | 专栏名称 | 说明 |
---|---|---|
1 | 用Python进行AI数据分析进阶教程 | 《用Python进行AI数据分析进阶教程》专栏 |
2 | AI大模型应用实践进阶教程 | 《AI大模型应用实践进阶教程》专栏 |
3 | Python编程知识集锦 | 《Python编程知识集锦》专栏 |
4 | 字节跳动旗下AI制作抖音视频 | 《字节跳动旗下AI制作抖音视频》专栏 |
5 | 智能辅助驾驶 | 《智能辅助驾驶》专栏 |
6 | 工具软件及IT技术集锦 | 《工具软件及IT技术集锦》专栏 |
👉 关注我 @理工男大辉郎 获取实时更新
欢迎关注、收藏或转发。
敬请关注 我的
微信搜索公众号:cnFuJH
CSDN博客:理工男大辉郎
抖音号:31580422589