Python实战开发及案例分析(9)—— 决策树

        决策树是一种用于分类和回归的机器学习模型。它通过学习一系列的决策规则将数据分成不同的类别或预测数值。决策树在构建时依赖于属性选择度量,如信息增益、基尼系数等。

        在Python中,我们可以使用scikit-learn库来快速构建和使用决策树模型。下面是一个基于决策树的分类和回归的案例分析。

案例分析:决策树分类

        我们将使用scikit-learn的决策树分类器对鸢尾花数据集进行分类。鸢尾花数据集包含了三种鸢尾花的四个特征(花萼和花瓣的长度和宽度),并需要根据这些特征对鸢尾花的种类进行分类。

Python 实现:

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
from sklearn import tree

# 加载鸢尾花数据集
iris = datasets.load_iris()
X, y = iris.data, iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建决策树分类器并训练
clf = DecisionTreeClassifier(random_state=42)
clf.fit(X_train, y_train)

# 预测测试集结果
y_pred = clf.predict(X_test)

# 输出混淆矩阵和分类报告
print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred))

# 绘制决策树
plt.figure(figsize=(12, 8))
tree.plot_tree(clf, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)
plt.show()

解释:

  • DecisionTreeClassifier:用于创建决策树分类模型。
  • plot_tree:绘制决策树,展示决策路径。

案例分析:决策树回归

        决策树也可以用于回归问题。在这个案例中,我们将使用波士顿房价数据集来预测房屋的价格。

Python 实现:

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt

# 加载波士顿房价数据集
boston = datasets.load_boston()
X, y = boston.data, boston.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建决策树回归模型并训练
regressor = DecisionTreeRegressor(random_state=42)
regressor.fit(X_train, y_train)

# 预测测试集结果
y_pred = regressor.predict(X_test)

# 计算均方误差
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse:.2f}")

# 绘制特征重要性
plt.figure(figsize=(8, 6))
plt.barh(boston.feature_names, regressor.feature_importances_)
plt.xlabel("Feature Importance")
plt.ylabel("Feature Name")
plt.title("Feature Importance in Decision Tree Regression")
plt.show()

解释:

  • DecisionTreeRegressor:用于创建决策树回归模型。
  • mean_squared_error:计算模型预测的均方误差,用于评估回归模型的性能。

结论

        决策树模型通过构建一系列决策规则为分类和回归问题提供了强大的模型能力。它的优势包括可解释性强和适用于处理类别或数值型数据。

  • 分类问题:通过分割不同特征空间,可以有效地分类鸢尾花数据集。
  • 回归问题:通过预测连续数值,为房价预测提供了简单有效的方法。

        然而,决策树模型容易过拟合,需要通过剪枝、设置深度和样本数量等参数进行调节。在实际应用中,结合交叉验证和其他技术可以提高模型的泛化能力。

        继续深入探讨决策树模型,我们可以讨论更多的决策树相关技术,如剪枝、特征重要性以及基于集成学习的随机森林和梯度提升树。

决策树剪枝

        决策树容易过拟合,为了解决这一问题,可以进行预剪枝或后剪枝。

  • 预剪枝:在构建过程中通过设置参数(如max_depth, min_samples_split等)限制树的生长。
  • 后剪枝:先生成完整的树,然后在验证集上进行剪枝。
预剪枝示例

        通过设置max_depth限制树的最大深度,避免过拟合。

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
from sklearn import tree

# 加载鸢尾花数据集
iris = datasets.load_iris()
X, y = iris.data, iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 设置预剪枝参数
clf = DecisionTreeClassifier(random_state=42, max_depth=3, min_samples_split=4, min_samples_leaf=2)
clf.fit(X_train, y_train)

# 预测测试集结果
y_pred = clf.predict(X_test)

# 输出混淆矩阵和分类报告
print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred))

# 绘制决策树
plt.figure(figsize=(12, 8))
tree.plot_tree(clf, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)
plt.show()

决策树特征重要性

        通过查看特征的重要性,可以了解哪些特征对分类结果影响最大。

# 使用鸢尾花数据集和决策树分类器
clf = DecisionTreeClassifier(random_state=42, max_depth=3, min_samples_split=4, min_samples_leaf=2)
clf.fit(X_train, y_train)

# 打印特征重要性
print("Feature importances:", clf.feature_importances_)

# 可视化特征重要性
plt.figure(figsize=(8, 6))
plt.barh(iris.feature_names, clf.feature_importances_)
plt.xlabel("Feature Importance")
plt.ylabel("Feature Name")
plt.title("Feature Importance in Decision Tree Classification")
plt.show()

案例分析:使用随机森林进行分类

        随机森林是由多个决策树构成的集成模型,可以通过组合多个决策树的预测结果来提高模型性能和稳定性。它还可以帮助减小单个决策树的过拟合风险。

Python 实现:

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# 加载鸢尾花数据集
iris = datasets.load_iris()
X, y = iris.data, iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建随机森林分类器并训练
rf_clf = RandomForestClassifier(n_estimators=100, random_state=42)
rf_clf.fit(X_train, y_train)

# 预测测试集结果
y_pred = rf_clf.predict(X_test)

# 输出混淆矩阵和分类报告
print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred))

# 可视化特征重要性
plt.figure(figsize=(8, 6))
sns.barplot(x=rf_clf.feature_importances_, y=iris.feature_names)
plt.xlabel("Feature Importance")
plt.ylabel("Feature Name")
plt.title("Feature Importance in Random Forest Classification")
plt.show()

案例分析:使用梯度提升树(Gradient Boosting Trees)进行分类

        梯度提升树是一种集成学习方法,通过逐步构建多个弱模型(决策树)来提高预测精度。

Python 实现:

from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# 加载鸢尾花数据集
iris = datasets.load_iris()
X, y = iris.data, iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建梯度提升分类器并训练
gb_clf = GradientBoostingClassifier(n_estimators=100, random_state=42)
gb_clf.fit(X_train, y_train)

# 预测测试集结果
y_pred = gb_clf.predict(X_test)

# 输出混淆矩阵和分类报告
print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred))

# 可视化特征重要性
plt.figure(figsize=(8, 6))
sns.barplot(x=gb_clf.feature_importances_, y=iris.feature_names)
plt.xlabel("Feature Importance")
plt.ylabel("Feature Name")
plt.title("Feature Importance in Gradient Boosting Classification")
plt.show()

结论

  • 剪枝:通过预剪枝或后剪枝技术可以减小决策树的过拟合风险,得到更稳健的模型。
  • 特征重要性:决策树模型可以用于评估不同特征对分类结果的重要性。
  • 集成模型
    • 随机森林:通过组合多个决策树,减少了单一决策树的过拟合风险,通常能获得更好的预测效果。
    • 梯度提升树:通过逐步构建多个弱模型来提高预测精度,适用于复杂的分类和回归问题。

        这些技术使得决策树及其变种模型在实际机器学习问题中具有广泛的应用。

        继续深入探讨更多与决策树相关的技术和案例,我们可以学习基于决策树的其他集成方法,如极端随机树(ExtraTrees),并探讨决策树的具体应用案例。

案例分析:使用极端随机树(ExtraTrees)进行分类

        极端随机树是一种随机森林的变种,与随机森林不同,极端随机树在构建树时不是选择最佳分割点,而是随机选择分割点,从而增加多样性并加速计算。

Python 实现:

from sklearn.ensemble import ExtraTreesClassifier
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# 加载鸢尾花数据集
iris = datasets.load_iris()
X, y = iris.data, iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建极端随机树分类器并训练
et_clf = ExtraTreesClassifier(n_estimators=100, random_state=42)
et_clf.fit(X_train, y_train)

# 预测测试集结果
y_pred = et_clf.predict(X_test)

# 输出混淆矩阵和分类报告
print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred))

# 可视化特征重要性
plt.figure(figsize=(8, 6))
sns.barplot(x=et_clf.feature_importances_, y=iris.feature_names)
plt.xlabel("Feature Importance")
plt.ylabel("Feature Name")
plt.title("Feature Importance in Extra Trees Classification")
plt.show()

案例分析:决策树在客户流失预测中的应用

项目背景:客户流失是指客户停止使用某种产品或服务。为了保留更多的客户,可以通过决策树模型对客户进行分类,预测哪些客户更可能流失。

数据集:使用著名的 Telco Customer Churn 数据集。

Python 实现:

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# 加载数据集
url = 'https://raw.githubusercontent.com/IBM/telco-customer-churn-on-icp4d/master/data/Telco-Customer-Churn.csv'
df = pd.read_csv(url)

# 删除不必要的列
df.drop(['customerID'], axis=1, inplace=True)

# 处理分类数据
label_encoders = {}
for column in df.select_dtypes(include='object').columns:
    le = LabelEncoder()
    df[column] = le.fit_transform(df[column])
    label_encoders[column] = le

# 定义特征和标签
X = df.drop(['Churn'], axis=1)
y = df['Churn']

# 标准化数据
scaler = StandardScaler()
X = scaler.fit_transform(X)

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建随机森林模型并训练
rf_clf = RandomForestClassifier(n_estimators=100, random_state=42)
rf_clf.fit(X_train, y_train)

# 预测测试集结果
y_pred = rf_clf.predict(X_test)

# 输出混淆矩阵和分类报告
print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred))

# 可视化特征重要性
plt.figure(figsize=(8, 6))
feature_names = df.drop(['Churn'], axis=1).columns
sns.barplot(x=rf_clf.feature_importances_, y=feature_names)
plt.xlabel("Feature Importance")
plt.ylabel("Feature Name")
plt.title("Feature Importance in Customer Churn Prediction")
plt.show()

案例分析:使用LightGBM进行分类

项目背景:LightGBM(Light Gradient Boosting Machine)是由微软开发的梯度提升框架,具有更快的训练速度和更好的准确性。

Python 实现:

安装 lightgbm

pip install lightgbm

代码实现:

import lightgbm as lgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# 加载数据集
url = 'https://raw.githubusercontent.com/IBM/telco-customer-churn-on-icp4d/master/data/Telco-Customer-Churn.csv'
df = pd.read_csv(url)

# 删除不必要的列
df.drop(['customerID'], axis=1, inplace=True)

# 处理分类数据
label_encoders = {}
for column in df.select_dtypes(include='object').columns:
    le = LabelEncoder()
    df[column] = le.fit_transform(df[column])
    label_encoders[column] = le

# 定义特征和标签
X = df.drop(['Churn'], axis=1)
y = df['Churn']

# 标准化数据
scaler = StandardScaler()
X = scaler.fit_transform(X)

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建LightGBM模型并训练
lgb_train = lgb.Dataset(X_train, y_train)
params = {
    'objective': 'binary',
    'boosting_type': 'gbdt',
    'metric': 'binary_logloss',
    'num_leaves': 31,
    'learning_rate': 0.05,
    'verbose': 0
}
gbm = lgb.train(params, lgb_train, num_boost_round=100)

# 预测测试集结果
y_pred_prob = gbm.predict(X_test)
y_pred = (y_pred_prob > 0.5).astype(int)

# 输出混淆矩阵和分类报告
print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred))

# 可视化特征重要性
plt.figure(figsize=(8, 6))
feature_importance = gbm.feature_importance()
sns.barplot(x=feature_importance, y=feature_names)
plt.xlabel("Feature Importance")
plt.ylabel("Feature Name")
plt.title("Feature Importance in Customer Churn Prediction using LightGBM")
plt.show()

结论

        在这几种案例分析中,我们展示了不同的集成学习方法在决策树上的扩展和应用:

  • 极端随机树:通过随机分割点和数据采样构建,增加模型多样性并提高效率。
  • 客户流失预测:决策树模型在预测分类问题中具有较好的表现,适用于客户流失预测等场景。
  • LightGBM:一种高效的梯度提升方法,能显著提高训练速度和预测性能。

        通过这些不同的集成方法和决策树扩展模型,可以更有效地解决分类和回归问题,特别是对大型数据集和复杂特征的预测。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/596221.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数据分析之Tebleau可视化:树状图、日历图、气泡图

树状图(适合子分类比较多的) 1.基本树状图的绘制 同时选择产品子分类和销售金额----选择智能推荐----选择树状图 2.双层树状图的绘制 将第二个维度地区拖到产品分类的下面---大的划分区域是上面的维度(产品分类),看着…

cmake进阶:文件操作

一. 简介 前面几篇文章学习了 cmake的文件操作,写文件,读文件。文章如下: cmake进阶:文件操作之写文件-CSDN博客 cmake进阶:文件操作之读文件-CSDN博客 本文继续学习文件操作。主要学习 文件重命名,删…

C++引用2

什么是引用变量? 引用实际上是已定义变量的别名,使一个变量拥有多个名字 c给&符号赋予了另一个意义,将其用来声明引用 int a9;int&ba; 此时b成为a的一个别名,a就是b,b就是a.它们均指向同一片内存 int a99; in…

虚拟键代码

虚拟键代码 虚拟键码 (Winuser.h) - Win32 apps | Microsoft Learn 在Windows操作系统中,虚拟键代码(Virtual-Key Codes)是一组用来表示键盘上按键的数值。这些代码通常用于Windows API函数,以便程序能够识别和处理键盘输入。 虚拟…

OSEK任务管理

1 前言 RTOS通过任务(task)来组织应用层程序框架(framework),支持任务的并发和同步执行(concurrent and asynchronous execution of tasks),并通过调度器(scheduler&…

[方法] Unity 实现仿《原神》第三人称跟随相机 v1.1

参考网址:【Unity中文课堂】RPG战斗系统Plus 在Unity游戏引擎中,实现类似《原神》的第三人称跟随相机并非易事,但幸运的是,Unity为我们提供了强大的工具集,其中Cinemachine插件便是实现这一目标的重要工具。Cinemachi…

超分辨率重建——BSRN网络训练自己数据集并推理测试(详细图文教程)

目录 一、BSRN网络总结二、源码包准备三、环境准备3.1 报错KeyError: "No object named BSRN found in arch registry!"3.2 安装basicsr源码包3.3 参考环境 四、数据集准备五、训练5.1 配置文件参数修改5.2 启动训练5.2.1 命令方式训练5.2.2 配置Configuration方式训…

vivado UltraScale 比特流设置

下表所示 UltraScale ™ 器件的器件配置设置可搭配 set_property <Setting> <Value> [current_design] Vivado 工具 Tcl 命令一起使用。

翻译《The Old New Thing》 - What’s the point of DeferWindowPos?

Whats the point of DeferWindowPos? - The Old New Thing (microsoft.com)https://devblogs.microsoft.com/oldnewthing/20050706-26/?p35023 Raymond Chen 在 2005年7月6日 DeferWindowPos 的作用是什么&#xff1f; 简要 文章讨论了 DeferWindowPos 函数的用途&#xff…

LangChain框架学习总结

目录 一、简介 二、概念 三、组件具体介绍 3.1 Models 3.1.1 LLMs 3.1.2 Chat Models 3.1.3 Text Embedding Modesl 3.1.4 总结 3.2 Prompts 3.2.1 LLM Prompt Template 3.2.1.1 自定义PromptTemplate 3.2.1.2 partial PromptTemplate 3.2.1.3 序列化PromptTemplat…

IMEI引起的无法驻网问题

这篇内容没什么意思&#xff0c;仅仅是做个简单记录。 问题不复杂&#xff0c;场景很简单&#xff0c;如上图&#xff0c;UE在进行LTE attach过程时&#xff0c;在送完NAS security mode complete后&#xff0c;就立刻收到了网络attach reject 带cause 6 Illegal ME&#xff0c…

Chrome浏览器安装React工具

一、如果网络能访问Google商店&#xff0c;直接安装官方插件即可 二、网络不能访问Google商店&#xff0c;使用安装包进行安装 1、下载react工具包 链接&#xff1a;https://pan.baidu.com/s/1qAeqxSafOiNV4CG3FVVtTQ 提取码&#xff1a;vgwj 2、chrome浏览器安装react工具…

设置定位坐标+请按任意键继续

设置定位坐标 目的 在编程和游戏开发中&#xff0c;设置定位坐标的目的是为了确定对象在屏幕或游戏世界中的具体位置。坐标通常由一对数值表示&#xff0c;例如 (x, y)&#xff0c;其中 x 表示水平位置&#xff0c;y 表示垂直位置。设置定位坐标的目的包括&#xff1a; 1. **精…

【云原生】Pod 的生命周期(二)

【云原生】Pod 的生命周期&#xff08;一&#xff09;【云原生】Pod 的生命周期&#xff08;二&#xff09; Pod 的生命周期&#xff08;二&#xff09; 6.容器探针6.1 检查机制6.2 探测结果6.3 探测类型 7.Pod 的终止7.1 强制终止 Pod7.2 Pod 的垃圾收集 6.容器探针 probe 是…

MATLAB 变换

MATLAB 变换&#xff08;Transforms&#xff09; MATLAB提供了用于处理诸如Laplace和Fourier变换之类的变换的命令。转换在科学和工程中用作简化分析和从另一个角度查看数据的工具。 例如&#xff0c;傅立叶变换允许我们将表示为时间函数的信号转换为频率函数。拉普拉斯变换使…

Linux驱动开发——(十一)INPUT子系统

目录 一、input子系统简介 二、input驱动API 2.1 input字符设备 2.2 input_dev结构体 2.3 上报输入事件 2.4 input_event结构体 三、代码 3.1 驱动代码 3.2 测试代码 四、平台测试 一、input子系统简介 input子系统是管理输入的子系统&#xff0c;和pinctrl、gpio子…

#9松桑前端后花园周刊-React19beta、TS5.5beta、Node22.1.0、const滥用、jsDelivr、douyin-vue

行业动态 Mozilla 提供 Firefox 的 ARM64 Linux二进制文件 此前一直由发行版开发者或其他第三方提供&#xff0c;目前Mozilla提供了nightly版本&#xff0c;正式版仍需要全面测试后再推出。 发布 React 19 Beta 此测试版用于为 React 19 做准备的库。React团队概述React 19…

【driver4】锁,错误码,休眠唤醒,中断,虚拟内存,tasklet

文章目录 1.互斥锁和自旋锁选择&#xff1a;自旋锁&#xff08;开销少&#xff09;的自旋时间和被锁住的代码执行时间成正比关系2.linux错误码&#xff1a;64位系统内核空间最后一页地址为0xfffffffffffff000~0xffffffffffffffff&#xff0c;这段地址是被保留的&#xff0c;如果…

全新桥隧坡安全监测解决方案,24h监测效率提升30%

4月26日&#xff0c;交通运输部党组书记、部长李小鹏在部务会上强调&#xff0c;要高度重视公路桥梁隧道结构监测工作&#xff0c;抓紧推进公路桥梁隧道结构监测系统建设&#xff0c;进一步健全完善公路桥梁隧道结构监测长效运行机制。 中海达积极参与公路桥梁隧道结构监测工作…

触摸OpenNJet,感悟云原生

小程一言 云原生使得应用充分利用云计算、容器化和微服务架构等现代技术来构建和运行应用程序。 云原生技术的用处在于提高应用程序的可靠性、可伸缩性和灵活性&#xff0c;加快开发和部署速度&#xff0c;降低成本&#xff0c;提升整体的效率和竞争力。通过采用云原生技术&a…