sklearn.model_selection.learning_curve的详细介绍(包含ShuffleSplit()介绍)

提示:sklearn.model_selection.learning_curve的详细介绍

文章目录

  • 1、需求分析
  • 2、learning_curve主要输出参数
  • 3、learning_curve主要参数
  • 4、learning_curve作用
  • 5、learning_curve代码
  • 6、ShuffleSplit()


1、需求分析

通过参数train_size选取不同规模的数据集,再分别在不同规模的数据集上做交叉验证,通过参数cv选取交叉验证的类型;

例如:我们想选取含有1000个样本的数据集的10%,33%,55%,78%,100%的数据做实验,探究不同数据量下模型的预测准确度,选取不同规模的数据集后,我们又想分别在不同规模的数据集下做一下5折交叉验证。我们就可以设train_size=array([0.1, 0.33, 0.55, 0.78, 1.]), cv=5。

2、learning_curve主要输出参数

train_sizes_abs:返回生成的训练的样本数,如[ 80 , 260 , 620,800 ],80=0.1×1000×[(5-1)/5]即train_sizes_abs[i]=train_size[i]×样本数×(cv对应每一次交叉验证的训练样本占总样本的百分比)
train_scores:返回训练集分数,该矩阵为( len ( train_sizes_abs ) , cv分割数 )即(5,5)维的分数,每行数据代表该样本数对应不同折的分数。
test_scores:同train_scores,只不过是这个对应的是测试集分数

3、learning_curve主要参数

X:特征矩阵,包含输入样本的特征。
y:目标变量,包含与输入样本对应的真实标签。
train_sizes:一个数组或可迭代对象,表示训练集的不同大小的比例。每个比例都将生成一个学习曲线点。
cv:用于交叉验证的折数或交叉验证迭代器。
scoring:用于评估模型性能的指标。常见的指标包括准确率(accuracy)、均方误差(mean_squared_error)、R平方(r2_score)等。
shuffle:是否在每次迭代前对数据进行洗牌,默认为False。
random_state:随机数种子,用于控制随机性。
estimator:用于拟合数据的机器学习模型,例如分类器或回归器。

4、learning_curve作用

在这里插入图片描述
引用出处

5、learning_curve代码

代码如下(示例):

from sklearn.metrics import confusion_matrix as CM
import numpy as np
import matplotlib.pyplot as plt
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier as RFC
from sklearn.tree import DecisionTreeClassifier as DTC
from sklearn.linear_model import LogisticRegression as LR
from sklearn.datasets import load_digits
from sklearn.model_selection import learning_curve  # 画学习曲线的类
from sklearn.model_selection import ShuffleSplit  # 设定交叉验证模式的类
from time import time
import datetime
from sklearn.metrics import brier_score_loss as bsl
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression as LR
from sklearn.model_selection import train_test_split


def plot_learning_curve(estimator, title, X, y,
                        ax,  # 选择子图
                        ylim=None,  # 设置纵坐标的取值范围
                        cv=None,  # 交叉验证
                        n_jobs=None  # 设定索要使用的线程
                        ):
    train_sizes, train_scores, test_scores = learning_curve(estimator, X, y
                                                            , cv=cv, n_jobs=n_jobs)
    # learning_curve() 是一个可视化工具,用于评估机器学习模型的性能和训练集大小之间的关系。它可以帮助我们理解模型在不同数据规模下的训练表现,
    # 进而判断模型是否出现了欠拟合或过拟合的情况。该函数会生成一条曲线,横轴表示不同大小的训练集,纵轴表示训练集和交叉验证集上的评估指标(例如
    # 准确率、损失等)。通过观察曲线,我们可以得出以下结论:
    # 1,训练集误差和交叉验证集误差之间的关系:当训练集规模较小时,模型可能过度拟合,训练集误差较低,交叉验证集误差较高;当训练集规模逐渐增大时,
    #    模型可能更好地泛化,两者的误差逐渐趋于稳定。
    # 2,训练集误差和交叉验证集误差对训练集规模的响应:通过观察曲线的斜率,我们可以判断模型是否存在高方差(过拟合)或高偏差(欠拟合)的问题。如果
    #    训练集和交叉验证集的误差都很高,且二者之间的间隔较大,说明模型存在高偏差;如果训练集误差很低而交叉验证集误差较高,且二者的间隔也较大,说
    #    明模型存在高方差。
    # cv : int:交叉验证生成器或可迭代的可选项,确定交叉验证拆分策略。v的可能输入是:
    #            - 无,使用默认的3倍交叉验证,
    #            - 整数,指定折叠数。
    #            - 要用作交叉验证生成器的对象。
    #            - 可迭代的yielding训练/测试分裂。
    #      ShuffleSplit:我们这里设置cv,交叉验证使用ShuffleSplit方法,一共取得100组训练集与测试集,
    #      每次的测试集为20%,它返回的是每组训练集与测试集的下标索引,由此可以知道哪些是train,那些是test。
    # n_jobs : 整数,可选并行运行的作业数(默认值为1)。windows开多线程需要
    ax.set_title(title)
    if ylim is not None:
        ax.set_ylim(*ylim)
    # *是可以接受任意数量的参数
    # 而 ** 可以接受任意数量的指定键值的参数
    # def m(*args,**kwargs):
    # 	print(args)
    #     print(kwargs)
    # m(1,2,a=1,b=2)
    # #args:(1,2),kwargs:{'b': 2, 'a': 1}
    ax.set_xlabel("Training examples")
    ax.set_ylabel("Score")
    ax.grid()  # 显示网格作为背景,不是必须
    ax.plot(train_sizes, np.mean(train_scores, axis=1), 'o-'
            , color="r", label="Training score")
    ax.plot(train_sizes, np.mean(test_scores, axis=1), 'o-'
            , color="g", label="Test score")
    ax.legend(loc="best")
    return ax


digits = load_digits()
X, y = digits.data, digits.target
print(X.shape)
print(X)
title = ["Naive Bayes", "DecisionTree", "SVM, RBF kernel", "RandomForest", "Logistic"]
# model = [GaussianNB(), DTC(), SVC(gamma=0.001)
#     , RFC(n_estimators=50), LR(C=0.1, solver="lbfgs")]
model = [GaussianNB(), DTC(), SVC(gamma=0.001)
    , RFC(n_estimators=50), LR(C=0.1, solver="liblinear")]
cv = ShuffleSplit(n_splits=10, test_size=0.2, random_state=0)
# n_splits:
# 划分数据集的份数,类似于KFlod的折数,默认为10# test_size:
# 测试集所占总样本的比例,如test_size=0.2即将划分后的数据集中20%作为测试集
# random_state:
# 随机数种子,使每次划分的数据集不变
# train_sizes: 随着训练集的增大,选择在10%25%50%75%100%的训练集大小上进行采样。
#              比如(CV= 510%的意思是先在训练集上选取10%的数据进行五折交叉验证。
# train_sizes:数组类,形状(n_ticks),dtype floatint
# 训练示例的相对或绝对数量,将用于生成学习曲线。如果dtype为float,则视为训练集最大尺寸的一部分
# (由所选的验证方法确定),即,它必须在(01]之内,否则将被解释为绝对大小注意,为了进行分类,
# 样本的数量通常必须足够大,以包含每个类中的至少一个样本(默认值:np.linspace(0.11.05))
# 输出:
# train_sizes_abs:
# 返回生成的训练的样本数,如[ 10 , 100 , 1000 ]
# train_scores:
# 返回训练集分数,该矩阵为( len ( train_sizes_abs ) , cv分割数 )维的分数,
# 每行数据代表该样本数对应不同折的分数
# test_scores:
# 同train_scores,只不过是这个对应的是测试集分数

fig, axes = plt.subplots(1, 5, figsize=(30, 6))
for ind, title_, estimator in zip(range(len(title)), title, model):
    times = time()
    plot_learning_curve(estimator, title_, X, y,
                        ax=axes[ind], ylim=[0.7, 1.05], n_jobs=4, cv=5)
    print("{}:{}".format(title_, datetime.datetime.fromtimestamp(time() - times).strftime("%M:%S:%f")))
plt.show()

6、ShuffleSplit()

ShuffleSplit() 随机排列交叉验证,生成一个用户给定数量的独立的训练/测试数据划分。样例首先被打散然后划分为一对训练测试集合。

  • n_splits:划分训练集、测试集的次数,默认为10
  • test_size: 测试集比例或样本数量
  • random_state:随机种子值,默认为None,可以通过设定明确的random_state,使得伪随机生成器的结果可以重复
from sklearn.model_selection import ShuffleSplit
import numpy as np

X = np.arange(10)
print("所有数据:", X)
ss = ShuffleSplit(n_splits=5, test_size=0.25)
n_fold = 1
for train_indices, test_indices in ss.split(X):
    print('...............................')
    print("train_indices", train_indices)
    print("test_indices:", test_indices)

在这里插入图片描述
引用文献


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/459751.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

OJ_点菜问题(背包问题)

题干 C实现 #define _CRT_SECURE_NO_WARNINGS #include<stdio.h> #include<vector> using namespace std;int main() {int c, n;scanf("%d%d", &c, &n);int p[101];int v[101];for (int i 0; i < n; i){scanf("%d%d", &p[i],…

深入探讨MES管理系统与MOM系统之间的关系

在制造业的信息化浪潮中&#xff0c;各种系统与技术层出不穷&#xff0c;其中MES制造执行系统和MOM制造运营管理无疑是备受瞩目的两大主角。尽管它们都是制造业信息化不可或缺的部分&#xff0c;但许多人对它们之间的区别与联系仍感到困惑。本文将对MES管理系统和MOM系统进行深…

一键分割,瞬间转换!轻松驾驭视频的无限可能

在数字化的世界里&#xff0c;视频内容已成为我们日常生活与工作中不可或缺的一部分。然而&#xff0c;处理这些多媒体文件时&#xff0c;常常需要花费大量的时间和精力进行分割、转换和编辑。现在&#xff0c;有了这款强大的“一键分割与转换”工具&#xff0c;你将能够轻松驾…

细说C++反向迭代器:原理与用法

文章目录 一、引言二、反向迭代器的原理与实现细节三、模拟实现C反向迭代器反向迭代器模板类的设计反向迭代器的使用示例与测试 一、引言 迭代器与反向迭代器的概念引入 迭代器&#xff08;Iterator&#xff09;是C标准模板库&#xff08;STL&#xff09;中的一个核心概念&am…

大话设计模式——7.抽象工厂模式(Abstract Factory Pattern)

1.介绍 抽象工厂模式是工厂模式的进一步优化&#xff0c;提供一个创建一系列相关或相互依赖对象的接口&#xff0c;而无需指定它们具体的类。属于创建型模式。 UML图&#xff1a; 2.示例 车辆制造工厂&#xff0c;不仅可以制造轿车也可以用来生产自行车。 1&#xff09;Abs…

基于Java+SpringBoot+vue+element实现校园闲置物品交易网站

基于JavaSpringBootvueelement实现校园闲置物品交易网站 博主介绍&#xff1a;多年java开发经验&#xff0c;专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 ** 作者主页 央顺技术团队** 欢迎点赞 收藏 ⭐留言 文末获取源码联系方式 文章目录 基于…

掘根宝典之C++普通迭代器和反向迭代器详解

简介 迭代器是一种用于遍历容器元素的对象。它提供了一种统一的访问方式&#xff0c;使程序员可以对容器中的元素进行逐个访问和操作&#xff0c;而不需要了解容器的内部实现细节。 C标准库里每个容器都定义了迭代器&#xff0c;这迭代器的名字就叫容器迭代器 迭代器的作用类…

10、MongoDB -- MongoDB 的 MongoTemplate 的功能和用法介绍

目录 MongoTemplate 的功能和用法演示前提&#xff1a;登录单机模式的 mongodb 服务器命令登录【test】数据库的 mongodb 客户端命令登录【admin】数据库的 mongodb 客户端命令 为 MongoDB 提供的两个 Starterspring-boot-starter-data-mongodb&#xff08;为以同步方式操作 Mo…

Jmeter —— jmeter对图片验证码的处理!

jmeter对图片验证码的处理 在web端的登录接口经常会有图片验证码的输入&#xff0c;而且每次登录时图片验证码都是随机的&#xff1b;当通过jmeter做接口登录的时候要对图片验证码进行识别出图片中的字段&#xff0c;然后再登录接口中使用&#xff1b; 通过jmeter对图片验证码…

第N4周:中文文本分类-Pytorch实现

>- **&#x1f368; 本文为[&#x1f517;365天深度学习训练营](https://mp.weixin.qq.com/s/rbOOmire8OocQ90QM78DRA) 中的学习记录博客** >- **&#x1f356; 原作者&#xff1a;[K同学啊 | 接辅导、项目定制](https://mtyjkh.blog.csdn.net/)** # -*- coding: utf-8 -…

IDEA编译安卓源码TVBox(2)

一、项目结构&#xff1a;主要app和player app结构 二、增加遥控器按键选台 修改LivePlayActivity.java 1、声明变量 public String channelId "";public Timer timer new Timer();public Toast mToast;2、定义方法 private void mToastShow(String s){mToast …

攻防世界-misc-Make-similar

题目链接&#xff1a;攻防世界 (xctf.org.cn) 下载得到ogg文件。Olympic CTF 2014原题有提示120 LPM&#xff0c;对应Radiofax。需要将ogg格式文件转换成wav格式音频后&#xff0c;用OS X下的软件Multimode转换成单色传真图像&#xff1a; 文字部分为&#xff1a; section 1 of…

107. 如何使用Docker以及Docker Compose部署Go Web应用

文章目录 一、为什么需要Docker&#xff1f;二、Docker部署示例1. 准备代码2. 创建Docker镜像3. 编写Dockerfile4. Dockerfile解析5. 构建镜像6. 通过镜像创建容器运行 三、分阶段构建示例四、附带其他文件的部署示例五、关联其他容器六、Docker Compose模式七、总结 本文将介绍…

PlayBook 详解

4&#xff09;Playbook 4.1&#xff09;Playbook 介绍 PlayBook 与 ad-hoc 相比&#xff0c;是一种完全不同的运用 Ansible 的方式&#xff0c;类似与 Saltstack 的 state 状态文件。ad-hoc 无法持久使用&#xff0c;PlayBook 可以持久使用。 PlayBook 剧本是 由一个或多个 “…

5分钟上手Python爬虫:从干饭开始,轻松掌握技巧

很多人都听说过爬虫&#xff0c;我也不例外。曾看到别人编写的爬虫代码&#xff0c;虽然没有深入研究&#xff0c;但感觉非常强大。因此&#xff0c;今天我决定从零开始&#xff0c;花费仅5分钟学习入门爬虫技术&#xff0c;以后只需轻轻一爬就能查看所有感兴趣的网站内容。广告…

Docker 安装部署 ORACLE 11g数据库

Docker 安装部署 ORACLE 11g数据库 背景&#xff1a; ​ 最新在开发数据中台数据接入模块&#xff0c;其中设计很多数据类型&#xff0c;包括ORACLE &#xff0c;因为是测试使用&#xff0c;想着快速部署测试&#xff0c;于是使用Docker 部署 Oracle , 生产环境不建议使用Doc…

【LeetCode热题100】160. 相交链表(链表)

一.题目要求 给你两个单链表的头节点 headA 和 headB &#xff0c;请你找出并返回两个单链表相交的起始节点。如果两个链表不存在相交节点&#xff0c;返回 null。 图示两个链表在节点 c1 开始相交&#xff1a; 题目数据 保证 整个链式结构中不存在环。 注意&#xff0c;函数…

鸿蒙Next学习-Flex布局

Entry Component struct FlexCase {build() {//需要在构造参数上传Flex({ direction: FlexDirection.Row,justifyContent:FlexAlign.Center }) {//flex布局Row().width(100).height(100).backgroundColor(Color.Red)Row().width(100).height(100).backgroundColor(Color.Yellow…

【ubuntu】安装 Anaconda3

目录 一、Anaconda 说明 二、操作记录 2.1 下载安装包 2.1.1 官网下载 2.1.2 镜像下载 2.2 安装 2.2.1 安装必要的依赖包 2.2.2 正式安装 shell 和 base 的切换 2.2.3 检测是否安装成功 方法一 方法二 方法三 2.3 其他 三、参考资料 3.1 安装资料 3.2 验证是否…

C语言函数—递归理解和练习

练习&#xff1a; 编写函数不允许创建临时变量&#xff0c;求字符串的长度。 我们看到这道题&#xff0c;第一个想到的是不是strlen int main() {char[] "bit";//[b][i][t][\0]//里面一共4个字符&#xff08;包括结尾的、0&#xff09;但是我们的strlen函数并不会计…
最新文章