机器学习与深度学习——通过决策树算法分类鸢尾花数据集iris求出错误率画出决策树并进行可视化

什么是决策树?

决策树是一种常用的机器学习算法,它可以对数据集进行分类或回归分析。决策树的结构类似于一棵树,由节点和边组成。每个节点代表一个特征或属性,每个边代表一个判断或决策。从根节点开始,根据特征的不同取值,不断向下遍历决策树,直到达到叶子节点,即最终的分类或回归结果。

在分类问题中,决策树通过将数据集分成不同的类别来进行分类。在回归问题中,决策树通过将数据集分成不同的区域来进行回归分析。

决策树的优点包括易于理解和解释、能够处理具有非线性关系的数据、对缺失数据具有容忍性等。然而,决策树也存在一些缺点,例如容易过拟合、对噪声数据敏感等。为了解决这些问题,常常需要对决策树进行剪枝或使用集成学习算法如随机森林来提高预测准确性。

两个目标

1、通过决策树算法对iris数据集前两个维度的数据进行模型训练并求出错误率,最后进行可视化展示数据区域划分。
2、通过决策树算法对iris数据集总共四个维度的数据进行模型训练并求出错误率。

基本思路:

1、先载入iris数据集 Load Iris data
2、分离训练集和设置测试集split train and test sets
3、使用决策树模型进行训练Train using clf
4、画出决策树
5、然后二维进行可视化处理Visualization
6、最后通过绘图决策平面plot decision plane

程序代码

1、通过决策树算法对iris数据集前两个维度的数据进行模型训练并求出错误率,最后进行可视化展示数据区域划分:

from sklearn import datasets
import numpy as np
### Load Iris data 加载数据集
iris = datasets.load_iris()
x = iris.data[:,:2]#前2个维度
# x = iris.data
y = iris.target
print("class labels: ", np.unique(y))
x.shape
y.shape

### split train and test sets
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, stratify=y)
x_train.shape
print("Labels count in y:", np.bincount(y))
print("Labels count in y_train:", np.bincount(y_train))
print("Labels count in y_test:", np.bincount(y_test))

### Train using clf 决策树模型
from sklearn import tree # 决策树算法
clf = tree.DecisionTreeClassifier() # 决策树分类器 # 设置决策树分类器  选择用信息熵作为测量标准
clf = clf.fit(x_train,y_train)  #训练模型
pred_test=clf.predict(x_test)#预测模型
err_num = (pred_test != y_test).sum()
rate = err_num/y_test.size
print("Misclassfication num: {}\nError rate: {}".format(err_num, rate))#计算错误率

#画出决策树
import matplotlib.pyplot as plt
plt.figure(dpi=200)
# feature_names=iris.feature_names设置决策树中显示的特征名称
tree.plot_tree(clf,feature_names=iris.feature_names,class_names=iris.target_names)

### Visualization
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
def plot_decision_regions(x, y, classifier, test_idx=None, resolution=0.02):
    markers = ('s', 'x', 'o', '^', 'v')
    colors = ('red', 'blue', 'lightgreen', 'gray', 'cyan')
    cmap = ListedColormap(colors[:len(np.unique(y))])
    x1_min, x1_max = x[:, 0].min() - 1, x[:, 0].max() + 1
    x2_min, x2_max = x[:, 1].min() - 1, x[:, 1].max() + 1
    xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, resolution),
    np.arange(x2_min, x2_max, resolution))
    Z = classifier.predict(np.array([xx1.ravel(), xx2.ravel()]).T)
    Z = Z.reshape(xx1.shape)
    plt.contourf(xx1, xx2, Z, alpha=0.3, cmap=cmap)
    plt.xlim(xx1.min(), xx1.max())
    plt.ylim(xx2.min(), xx2.max())
    for idx, cl in enumerate(np.unique(y)):
        plt.scatter(x=x[y == cl, 0], y=x[y == cl, 1], alpha=0.8, c=colors[idx], marker=markers[idx], label=cl, edgecolor='black')
    if test_idx:
        x_test, y_test = x[test_idx, :], y[test_idx]
        plt.scatter(x_test[:, 0], x_test[:, 1], c=colors[4], edgecolor='black', alpha=1.0, linewidth=1, marker='.', s=100, label='test set')
        
#### plot decision plane
x_combined_std = np.vstack((x_train, x_test))
y_combined = np.hstack((y_train, y_test))
plot_decision_regions(x_combined_std, y_combined,
classifier=clf, test_idx=range(105,150))
plt.xlabel('petal length [standardized]')
plt.ylabel('petal width [standardized]')
plt.legend(loc='upper left')
plt.tight_layout()
plt.show()

运行截图
在这里插入图片描述

在这里插入图片描述

2、通过决策树算法对iris数据集总共四个维度的数据进行模型训练并求出错误率:

from sklearn import datasets
import numpy as np
iris = datasets.load_iris()
x = iris.data[:,:4]#前4个维度
# x = iris.data
y = iris.target
print("class labels: ", np.unique(y))
x.shape
y.shape

### split train and test sets 分割数据集和设置测试集
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, stratify=y)
x_train.shape
print("Labels count in y:", np.bincount(y))
print("Labels count in y_train:", np.bincount(y_train))
print("Labels count in y_test:", np.bincount(y_test))

### Train using clf 决策树模型
from sklearn import tree # 决策树算法
clf = tree.DecisionTreeClassifier(criterion="entropy") # 决策树分类器 # 设置决策树分类器  选择用信息熵作为测量标准
clf = clf.fit(x_train,y_train)  #训练模型
pred_test=clf.predict(x_test)#预测模型
err_num = (pred_test != y_test).sum()
rate = err_num/y_test.size
print("Misclassfication num: {}\nError rate: {}".format(err_num, rate))

#画决策树图
import matplotlib.pyplot as plt
plt.figure(dpi=200)
# feature_names=iris.feature_names设置决策树中显示的特征名称
tree.plot_tree(clf,feature_names=iris.feature_names,class_names=iris.target_names)

代码运行截图

在这里插入图片描述

在这里插入图片描述

决策树算法既可以做分类也可以做回归,决策树也存在一些缺点,例如容易过拟合、对噪声数据敏感等。为了解决这些问题,我们可以对决策树进行剪枝或使用集成学习算法如随机森林来提高预测准确性。

过拟合是指模型在训练数据上表现很好,但在测试数据上表现不佳的现象。以下是一些解决决策树算法过拟合问题的方法:

  1. 剪枝:决策树剪枝是一种常用的方法,用于减少决策树的复杂度并避免过拟合。常见的剪枝方法包括预剪枝和后剪枝。预剪枝是在构建树时停止拆分某些节点,后剪枝是在构建完整的树之后,再去掉一些子树。

  2. 正则化:与其他机器学习算法一样,决策树也可以使用正则化技术来防止过拟合。例如,可以使用L1或L2正则化来约束树的复杂度。

  3. 随机化:通过引入随机化,可以减少决策树的方差并提高模型的鲁棒性。例如,可以在构建树时随机选择特征,而不是根据特定的规则进行选择。

  4. 集成学习:将多个决策树组合成一个集成模型,例如随机森林和梯度提升树等。这种方法可以通过减少单个决策树的方差来提高整体模型的泛化性能。

  5. 数据扩充:通过增加数据样本数量或改变数据样本的特征值,可以减少过拟合。例如,可以使用数据增强技术,如旋转、平移和缩放等来生成更多的训练样本。

祝各位五一劳动节快乐!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/16649.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

vue3的props和defineProps

文章目录 1. Props 声明1.1 props用字符串数组来声明Blog.vueBlogPost.vue 1.2 props使用对象来声明Blog.vueBlogPost.vue 2. 传递 prop 的细节2.1 Prop 名字格式2.1 静态Prop & 动态 Prop静态prop动态prop示例Blog.vueBlogPost.vue 2.3 传递不同的值类型NumberBooleanArra…

基于YOLOv4的目标检测系统(附MATLAB代码+GUI实现)

摘要:本文介绍了一种MATLAB实现的目标检测系统代码,采用 YOLOv4 检测网络作为核心模型,用于训练和检测各种任务下的目标,并在GUI界面中对各种目标检测结果可视化。文章详细介绍了YOLOv4的实现过程,包括算法原理、MATLA…

C++知识点 -- 异常

C知识点 – 异常 文章目录 C知识点 -- 异常一、异常概念二、异常的使用1.异常的抛出和捕获2.异常的重新抛出3.异常安全4.异常规范 三、自定义异常体系四、C标准库的异常体系五、C异常的优缺点 一、异常概念 当一个函数发现自己无法处理错误时,就可以抛出异常&#…

14-3-进程间通信-消息队列

前面提到的管道pipe和fifo是半双工的,在某些场景不能发挥作用; 接下来描述的是消息队列(一种全双工的通信方式); 比如消息队列可以实现两个进程互发消息(不像管道,只能1个进程发消息&#xff…

kali: kali工具-Ettercap

kali工具-Ettercap ettercap工具: 用来进行arp欺骗,可以进行ARP poisoning(arp投毒),除此之外还可以其他功能: ettercap工具的arp投毒可以截取web服务器、FTP服务器账号密码等信息,简略后打印出…

前端学习之使用JavaScript

前情回顾:网页布局 JavaScript 简介 avaScript诞生于1995年,它的出现主要是用于处理网页中的前端验证。所谓的前端验证,就是指检查用户输入的内容是否符合一定的规则。比如:用户名的长度,密码的长度,邮箱的…

SQL中去除重复数据的几种方法,我一次性都告你​

使用SQL对数据进行提取和分析时,我们经常会遇到数据重复的场景,需要我们对数据进行去重后分析。 以某电商公司的销售报表为例,常见的去重方法我们用到distinct 或者group by 语句, 今天介绍一种新的方法,利用窗口函数…

Github 的使用

3. Github 在版本控制系统中,大约90%的操作都是在本地仓库中进行的:暂存,提交,查看状态或者历史记录等等。除此之外,如果仅仅只有你一个人在这个项目里工作,你永远没有机会需要设置一个远程仓库。只有当你…

2001-2021年全国30省就业人数数据

2001-2021年全国30省就业人数数据/各省就业人数数据 1、时间:2001-2021年 2、范围:包括30个省市不含西藏 3、指标:就业人数 4、来源:各省NJ、社会统计NJ 5、缺失情况说明:无缺失 6、指标说明: 就业人…

实在智能出席第六届数字中国建设峰会,入围2022年信息技术应用创新优秀解决方案榜单

最美榕城四月天,山海之间尽显数字澎湃。这一周来,实在智能来到了“有福之州”,为数字中国建设增添实在色彩。 4月25日,实在华夏行抵达福州站,与众多生态合作伙伴携手共话数字发展新未来; 4月26日&#xff…

分布式事务 --- Seata事务模式、高可用

一、事务模式 1.1、XA模式 XA 规范 是 X/Open 组织定义的分布式事务处理(DTP,Distributed Transaction Processing)标准,XA 规范 描述了全局的TM与局部的RM之间的接口,几乎所有主流的数据库都对 XA 规范 提供了支持。…

ContextCapture Master 倾斜摄影测量实景三维建模技术应用

查看原文>>>ContextCapture Master 倾斜摄影测量实景三维建模技术应用 目录 第一部分、倾斜摄影测量原理及应用领域 第二部分、倾斜摄影测量数据采集方法 第三部分、CC支持数据类型及导入数据方法 第四部分、CC空三计算参数设置及数据处理方法 第五部分、CC控制…

电气电工相关专业知识及名词解释

一、电流电压 火线、零线、地线:火线和零线的区别就是:火线带电,零线不带电。火线是传电流的,而零线是回流的。 红色是火线,零线一般是绿色的,通常可用电笔来测。电笔一头亮了是火线,不亮的则…

Python使用CV2库捕获、播放和保存摄像头视频

Python使用CV2库捕获、播放和保存摄像头视频 特别提示:CV2指的是OpenCV2(Open Source Computer Vision Library),安装的时候是 opencv_python,但在导入的时候采用 import cv2。 若想使用cv2库必须先安装,P…

InnoDB 引擎 底层逻辑

目录 0 课程视频 1 逻辑存储结构 1.1 结构图 1.2 表空间 -> 记录 索引 存储记录 等数据 1.2.1 储存在 cd/var/lib/mysql -> ll -> 目录 mysql.ibd 1.3 段 -> 索引 存储记录 具体存储 1.3.1 数据段 b树 叶子节点 1.3.2 索引段 b树的 非叶子节点 1.3.3 回滚段…

ChatGPT来了不用慌,广告人还有这个神器在手

#ChatGPT能取代广告人吗,#ChatGPT会抢走你的工作吗?#ChatGPT火了,会让营销人失业吗?自ChatGPT爆火以来,各种专业or非专业文章不停给广告人强加焦虑,但工具出现的意义,更多在于提效而非替代&…

【技术分享】防止根据IP查域名,防止源站IP泄露

有的人设置了禁止 IP 访问网站,但是别人用 https://ip 的形式,会跳到你服务器所绑定的一个域名网站上 直接通过 https://IP, 访问网站,会出现“您的连接不是私密连接”,然后点高级,会出现“继续前往 IP”,…

简单分享微信小程序上的招聘链接怎么做

招聘小程序的主要用户就是企业招聘端和找工作人员的用户端,下面从这两个端来对招聘小程序开发的功能进行介绍。 企业端功能 1、岗位发布:企业根据自身岗位需求,在招聘app上发布招聘岗位及所需技能。 2.简历筛选:根据求职者提交的简历选择合适的简历,并对公开发布的简历进行筛…

【五一创作】【Simulink】采用延时补偿的三相并网逆变器FCS-MPC

👉 【Simulink】基于FCS-MPC的三相并网逆变器控制 上一篇博客介绍了FCS-MPC的基本操作,并且以三相并网逆变器为控制对象进行了Simulink仿真。 但实际仿真中没有考虑补偿延时。本篇博客将讨论为什么要考虑延时并进行补偿,最后对此仿真验证。 …

【Java数据结构】顺序表、队列、栈、链表、哈希表

顺序表 定义 存放数据使用数组但是可以编写一些额外的操作来强化为线性表&#xff0c;底层依然采用顺序存储实现的线性表&#xff0c;称为顺序表 代码实现 创建类型 先定义一个新的类型 public class ArrayList<E> {int capacity 10; //顺序表的最大容量int size …
最新文章