交叉验证之KFold和StratifiedKFold的使用(附案例实战)

3f6a7ab0347a4af1a75e6ebadee63fc1.gif

🤵‍♂️ 个人主页:@艾派森的个人主页

✍🏻作者简介:Python学习者
🐋 希望大家多多支持,我们一起进步!😄
如果文章对你有帮助的话,
欢迎评论 💬点赞👍🏻 收藏 📂加关注+


846afc070de74e37936f1257fcab9122.png

 一、交叉验证简介

        交叉验证是在机器学习建立模型和验证模型参数时常用的办法。交叉验证,顾名思义,就是重复的使用数据,把得到的样本数据进行切分,组合为不同的训练集和测试集,用训练集来训练模型,用测试集来评估模型预测的好坏。在此基础上可以得到多组不同的训练集和测试集,某次训练集中的某样本在下次可能成为测试集中的样本,即所谓“交叉”。

  那么什么时候才需要交叉验证呢?交叉验证用在数据不是很充足的时候。通常情况下,如果数据样本量小于一万条,我们就会采用交叉验证来训练优化选择模型。如果样本大于一万条的话,我们一般随机的把数据分成三份,一份为训练集(Training Set),一份为验证集(Validation Set),最后一份为测试集(Test Set)。用训练集来训练模型,用验证集来评估模型预测的好坏和选择模型及其对应的参数。把最终得到的模型再用于测试集,最终决定使用哪个模型以及对应参数。

        学习预测函数的参数,并在相同数据集上进行测试是一种错误的做法: 一个仅给出测试用例标签的模型将会获得极高的分数,但对于尚未出现过的数据它则无法预测出任何有用的信息。 这种情况称为 overfitting(过拟合).。为了避免这种情况,在进行机器学习实验时,通常取出部分可利用数据作为 test set(测试数据集) X_test, y_test。下面是模型训练中典型的交叉验证工作流流程图。通过网格搜索可以确定最佳参数。

3937883524934850800ff95749dcf35a.png

         k-折交叉验证得出的性能指标是循环计算中每个值的平均值。 该方法虽然计算代价很高,但是它不会浪费太多的数据(如固定任意测试集的情况一样), 在处理样本数据集较少的问题(例如,逆向推理)时比较有优势。

9e94e7738704429b896e25bf6f959dd4.png

k-折交叉验证步骤

 

  • 第一步,不重复抽样将原始数据随机分为 k 份。
  • 第二步,每一次挑选其中 1 份作为测试集,剩余 k-1 份作为训练集用于模型训练。
  • 第三步,重复第二步 k 次,这样每个子集都有一次机会作为测试集,其余机会作为训练集。
  • 在每个训练集上训练后得到一个模型,
  • 用这个模型在相应的测试集上测试,计算并保存模型的评估指标,
  • 第四步,计算 k 组测试结果的平均值作为模型精度的估计,并作为当前 k 折交叉验证下模型的性能指标。
     

例如:

十折交叉验证

  • 将训练集分成十份,轮流将其中9份作为训练数据,1份作为测试数据,进行试验。每次试验都会得出相应的正确率。
  • 10次的结果的正确率的平均值作为对算法精度的估计,一般还需要进行多次10折交叉验证(例如10次10折交叉验证),再求其均值,作为对算法准确性的估计
  • 模型训练过程的所有步骤,包括模型选择,特征选择等都是在单个折叠 fold 中独立执行的。
  • 此外:
    • 多次 k 折交叉验证再求均值,例如:10 次10 折交叉验证,以求更精确一点。
    • 数据量大时,k设置小一些 / 数据量小时,k设置大一些。
       

adf57d0cc7bd4840a675f17f53c0fd39.png

KFold和StratifiedKFold的使用

        StratifiedKFold用法类似Kfold,但是它是分层采样,确保训练集,测试集中各类别样本的比例与原始数据集中相同。这一区别在于当遇到非平衡数据时,StratifiedKFold() 各个类别的比例大致和完整数据集中相同,若数据集有4个类别,比例是2:3:3:2,则划分后的样本比例约是2:3:3:2;但是KFold可能存在一种情况:数据集有5类,抽取出来的也正好是按照类别划分的5类,也就是说第一折全是0类,第二折全是1类等等,这样的结果就会导致模型训练时没有学习到测试集中数据的特点,从而导致模型得分很低,甚至为0。

 

Parameters

  • n_splits : int, default=3   也就是K折中的k值,必须大于等于2
  • shuffle : boolean  True表示打乱顺序,False反之
  • random_state :int,default=None 随机种子,如果设置值了,shuffle必须为True
# KFold
from sklearn.model_selection import KFold
kfolds = KFold(n_splits=3)
for train_index, test_index in kfolds.split(X,y):
    print('X_train:%s ' % X[train_index])
    print('X_test: %s ' % X[test_index])

# StratifiedKFold
from sklearn.model_selection import StratifiedKFold
skfold = StratifiedKFold(n_splits=3)
for train_index, test_index in skfold.split(X,y):
    print('X_train:%s ' % X[train_index])
    print('X_test: %s ' % X[test_index])

KFold和StratifiedKFold实战案例

首先导入数据集,本数据集为员工离职数据,属于二分类任务

import pandas as pd
import warnings
warnings.filterwarnings('ignore')

data = pd.read_excel('data.xlsx')
data['薪资情况'].replace(to_replace={'低':0,'中':1,'高':2},inplace=True)
data.head()

5109f3a08a384ac19a08bb8406c8a858.png

 拆分数据集为训练集和测试集,测试集比例为0.2

from sklearn.model_selection import train_test_split
X = data.drop('是否离职',axis=1)
y = data['是否离职']
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.2)

初始化一个分类模型,这里用逻辑回归模型举例。方法1使用cross_val_score()可以直接得到k折训练的模型效果,比如下面使用3折进行训练,得分评估使用准确率,关于scoring这个参数我会在文末介绍。

# 初始化一个分类模型,比如逻辑回归
from sklearn.linear_model import LogisticRegression
lg = LogisticRegression()
# 方法1
from sklearn.model_selection import cross_val_score
scores = cross_val_score(lg,X_train,y_train,cv=3,scoring='accuracy')
print(scores)
print("Accuracy: %0.2f (+/- %0.2f)" % (scores.mean(), scores.std() * 2))

3107a95839c7430aa0ec7a3614dfad73.png

 接下来分别使用KFold和StratifiedKFold,其实两者代码非常类似,只是前面的方法不同。

KFold

# 方法2-KFold和StratifiedKFold
import numpy as np
from sklearn.model_selection import KFold,StratifiedKFold
from sklearn.metrics import accuracy_score,recall_score,f1_score
# KFold
kfolds = KFold(n_splits=3)
accuracy_score_list,recall_score_list,f1_score_list = [],[],[]
for train_index,test_index in kfolds.split(X_train,y_train):
    # 准备交叉验证的数据
    X_train_fold = X_train.iloc[train_index]
    y_train_fold = y_train.iloc[train_index]
    X_test_fold = X_train.iloc[test_index]
    y_test_fold = y_train.iloc[test_index]
    # 训练模型
    lg.fit(X_train_fold,y_train_fold)
    y_pred = lg.predict(X_test_fold)
    # 评估模型
    AccuracyScore = accuracy_score(y_test_fold,y_pred)
    RecallScore = recall_score(y_test_fold,y_pred)
    F1Score = f1_score(y_test_fold,y_pred)
    # 将评估指标存放对应的列表中
    accuracy_score_list.append(AccuracyScore)
    recall_score_list.append(RecallScore)
    f1_score_list.append(F1Score)
    # 打印每一次训练的正确率、召回率、F1值
    print('accuracy_score:',AccuracyScore,'recall_score:',RecallScore,'f1_score:',F1Score)
# 打印各指标的平均值和95%的置信区间
print("Accuracy: %0.2f (+/- %0.2f)" % (np.average(accuracy_score_list), np.std(accuracy_score_list) * 2))
print("Recall: %0.2f (+/- %0.2f)" % (np.average(recall_score_list), np.std(recall_score_list) * 2))
print("F1_score: %0.2f (+/- %0.2f)" % (np.average(f1_score_list), np.std(f1_score_list) * 2))

8c46118c7f9c4dac8b3e24b2ece08182.png

StratifiedKFold

# StratifiedKFold
skfolds = StratifiedKFold(n_splits=3)
accuracy_score_list,recall_score_list,f1_score_list = [],[],[]
for train_index,test_index in skfolds.split(X_train,y_train):
    # 准备交叉验证的数据
    X_train_fold = X_train.iloc[train_index]
    y_train_fold = y_train.iloc[train_index]
    X_test_fold = X_train.iloc[test_index]
    y_test_fold = y_train.iloc[test_index]
    # 训练模型
    lg.fit(X_train_fold,y_train_fold)
    y_pred = lg.predict(X_test_fold)
    # 评估模型
    AccuracyScore = accuracy_score(y_test_fold,y_pred)
    RecallScore = recall_score(y_test_fold,y_pred)
    F1Score = f1_score(y_test_fold,y_pred)
    # 将评估指标存放对应的列表中
    accuracy_score_list.append(AccuracyScore)
    recall_score_list.append(RecallScore)
    f1_score_list.append(F1Score)
    # 打印每一次训练的正确率、召回率、F1值
    print('accuracy_score:',AccuracyScore,'recall_score:',RecallScore,'f1_score:',F1Score)
# 打印各指标的平均值和95%的置信区间
print("Accuracy: %0.2f (+/- %0.2f)" % (np.average(accuracy_score_list), np.std(accuracy_score_list) * 2))
print("Recall: %0.2f (+/- %0.2f)" % (np.average(recall_score_list), np.std(recall_score_list) * 2))
print("F1_score: %0.2f (+/- %0.2f)" % (np.average(f1_score_list), np.std(f1_score_list) * 2))

74873fa67e7c464cb2cceba748a9d07f.png

补充

scoring 参数: 定义模型评估规则

Model selection (模型选择)和 evaluation (评估)使用工具,例如 model_selection.GridSearchCV 和 model_selection.cross_val_score ,采用 scoring 参数来控制它们对 estimators evaluated (评估的估计量)应用的指标。

常见场景: 预定义值

        对于最常见的用例, 可以使用 scoring 参数指定一个 scorer object (记分对象); 下表显示了所有可能的值。 所有 scorer objects (记分对象)遵循惯例 higher return values are better than lower return values(较高的返回值优于较低的返回值)。因此,测量模型和数据之间距离的 metrics (度量),如 metrics.mean_squared_error 可用作返回 metric (指数)的 negated value (否定值)的 neg_mean_squared_error 。

Scoring(得分)Function(函数)Comment(注解)
Classification(分类)  
‘accuracy’metrics.accuracy_score 
‘average_precision’metrics.average_precision_score 
‘f1’metrics.f1_scorefor binary targets(用于二进制目标)
‘f1_micro’metrics.f1_scoremicro-averaged(微平均)
‘f1_macro’metrics.f1_scoremacro-averaged(宏平均)
‘f1_weighted’metrics.f1_scoreweighted average(加权平均)
‘f1_samples’metrics.f1_scoreby multilabel sample(通过 multilabel 样本)
‘neg_log_loss’metrics.log_lossrequires predict_proba support(需要 predict_proba 支持)
‘precision’ etc.metrics.precision_scoresuffixes apply as with ‘f1’(后缀适用于 ‘f1’)
‘recall’ etc.metrics.recall_scoresuffixes apply as with ‘f1’(后缀适用于 ‘f1’)
‘roc_auc’metrics.roc_auc_score 
Clustering(聚类)  
‘adjusted_mutual_info_score’metrics.adjusted_mutual_info_score 
‘adjusted_rand_score’metrics.adjusted_rand_score 
‘completeness_score’metrics.completeness_score 
‘fowlkes_mallows_score’metrics.fowlkes_mallows_score 
‘homogeneity_score’metrics.homogeneity_score 
‘mutual_info_score’metrics.mutual_info_score 
‘normalized_mutual_info_score’metrics.normalized_mutual_info_score 
‘v_measure_score’metrics.v_measure_score 
Regression(回归)  
‘explained_variance’metrics.explained_variance_score 
‘neg_mean_absolute_error’metrics.mean_absolute_error 
‘neg_mean_squared_error’metrics.mean_squared_error 
‘neg_mean_squared_log_error’metrics.mean_squared_log_error 
‘neg_median_absolute_error’metrics.median_absolute_error 
‘r2’metrics.r2_score

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/372020.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

云计算、Docker、K8S问题

1 云计算 云计算作为一种新兴技术,已经在现代社会中得到了广泛应用。它以其高效、灵活和可扩展特性,成为了许多企业和组织在数据处理和存储方面的首选方案。 1.1 什么是云计算?它有哪些特点? 云计算是一种通过网络提供计算资源…

项目02《游戏-06-开发》Unity3D

基于 项目02《游戏-05-开发》Unity3D , 接下来做 背包系统的 存储框架 , 首先了解静态数据 与 动态数据,静态代表不变的数据,比如下图武器Icon, 其中,武器的名称,描述&#xff…

全网第一篇把Nacos配置中心客户端讲明白的

入口 我们依旧拿ConfigExample作为入口 public class ConfigExample {public static void main(String[] args) throws NacosException, InterruptedException {String serverAddr "localhost";String dataId "test";String group "DEFAULT_GROU…

搭建frp

1.frp 是什么? frp 是一款高性能的反向代理应用,专注于内网穿透。它支持多种协议,包括 TCP、UDP、HTTP、HTTPS 等,并且具备 P2P 通信功能。使用 frp,您可以安全、便捷地将内网服务暴露到公网,通过拥有公网…

解决nvrtc: error: invalid value for --gpu-architecture (-arch)

问题描述 在使用pytorch3d的时候,可以正常的import,但是在执行错误的使用就会报,nvrtc: error: invalid value for --gpu-architecture (-arch),的错误,图片如下: 我的环境是: 显卡&#xff1…

精细管理药厂设备,制药机械设备管理平台系统助力生产提效

制药行业的复杂性要求对药品的品质和安全性进行严格控制,而这离不开高效管理各类机械设备。然而,随着制药企业规模的不断扩大和技术的迅猛进步,如何有效管理这些设备成为一个亟待解决的问题。在这一挑战面前,PreMaint制药机械设备…

Antd+React+react-resizable实现表格拖拽功能

1、先看效果 2、环境准备 "dependencies": {"antd": "^5.4.0","react-resizable": "^3.0.4",},"devDependencies": {"types/react": "^18.0.33","types/react-resizable": "^…

前端面试题——Vue的双向绑定

前言 双向绑定机制是Vue中最重要的机制之一,甚至可以说是Vue框架的根基,它将数据与视图模板相分离,使得数据处理和页面渲染更为高效,同时它也是前端面试题中的常客,接下来让我们来了解什么是双向绑定以及其实现原理。…

Python的包安装工具——pip命令大全

对于大多数使用Python的人来说,一定知道pip这个包安装工具,但是对pip可能还不是很了解,今天作者给大家介绍一下pip的命令,以方便灵活使用pip。 一、pip工具使用方法 pip的语法如下: pip [options] 式中&#xff1a…

InverseMatrix3D

InverseMatrixVT3D: An Efficient Projection Matrix-Based Approach for 3D Occupancy Prediction https://github.com/DanielMing123/InverseMatrixVT3D InverseMatrix3D过程总结如下: 1. 用2D backbone提取N个视角的多尺度图像特征,表示如下&#xf…

机器学习聚类算法

聚类算法是一种无监督学习方法,用于将数据集中的样本划分为多个簇,使得同一簇内的样本相似度较高,而不同簇之间的样本相似度较低。在数据分析中,聚类算法可以帮助我们发现数据的内在结构和规律,从而为进一步的数据分析…

Centos 内存和硬盘占用情况以及top作用

目录 只查看内存使用情况: 内存使用排序取前5个: 硬盘占用情况 定位占用空间最大目录 top查看cpu及内存使用信息 前言-与正文无关 生活远不止眼前的苦劳与奔波,它还充满了无数值得我们去体验和珍惜的美好事物。在这个快节奏的世界中&…

Python 潮流周刊#38:Django + Next.js 构建全栈项目

△△请给“Python猫”加星标 ,以免错过文章推送 你好,我是猫哥。这里每周分享优质的 Python、AI 及通用技术内容,大部分为英文。本周刊开源,欢迎投稿[1]。另有电报频道[2]作为副刊,补充发布更加丰富的资讯,…

protoc结合go完成protocol buffers协议的序列化与反序列化

下载protoc编译器 下载 https://github.com/protocolbuffers/protobuf/releases ps: 根据平台选择需要的编译器,这里选择windows 解压 加入环境变量 安装go专用protoc生成器 https://blog.csdn.net/qq_36940806/article/details/135017748?spm1001.2014.3001.…

canvas图片上设置镂空文字效果

查看专栏目录 canvas实例应用100专栏,提供canvas的基础知识,高级动画,相关应用扩展等信息。canvas作为html的一部分,是图像图标地图可视化的一个重要的基础,学好了canvas,在其他的一些应用上将会起到非常重…

VR全景技术可以应用在哪些行业,VR全景技术有哪些优势

引言: VR全景技术(Virtual Reality Panorama Technology)是一种以虚拟现实技术为基础,通过360度全景影像、立体声音、交互元素等手段,创造出沉浸式的虚拟现实环境。该技术不仅在娱乐领域有着广泛应用,还可…

方案分享:F5怎么样应对混合云网络安全?

伴随着云计算走入落地阶段,企业的云上业务规模增长迅猛。具有部署灵活、成本低、最大化整合现有资产、促进业务创新等优点的混合云逐渐成为企业选择的部署方式。与此同时,安全运营的复杂度进一步提高。比如安全堆栈越来越复杂、多云基础设施和应用添加网…

小白Linux学习笔记-Linux开机启动流程

Linux 开机启动流程 文章目录 Linux 开机启动流程启动流程概览详细讲解开机软件 —— BIOS、Grub名词解释流程解释BIOS 开机文档 —— menu.lst、grub.confGrub 配置文档流程解释 init 程序流程解释init 执行的相关文件 run-level(启动等级) 相关的命令实验rhel6 单用户模式修改…

机器学习数据预处理方法(数据重编码)

文章目录 [TOC]基于Kaggle电信用户流失案例数据(可在官网进行下载)一、离散字段的数据重编码1.OrdinalEncoder自然数排序2.OneHotEncoder独热编码3.ColumnTransformer转化流水线 二、连续字段的特征变换1.标准化(Standardization)…

数字人客服技术预研

技术洞察 引言 在当今数字化时代,不断进步和创新的人工智能(AI)技术已经渗透到各行各业中。随着AI技术、大模型技术逐步发展,使得数字人的广泛应用成为可能,本文将跟大家一起探讨AI数字人客服的概念、优势、应用场景…
最新文章