对模型性能进行评估(Machine Learning 研习十五)

在上一篇我们已然训练了一个用于对数字图像识别的模型,但我们目前还不知道该模型在识别数字图像效率如何?所以,本文将对该模型进行评估。

使用交叉验证衡量准确性

评估模型的一个好方法是使用交叉验证,让我们使用cross_val_score() 函数来评估我们的 SGDClassifier模型,使用三折的 k 折交叉验证。k-fold 交叉验证意味着将训练集分成 k 个折叠(在本例中是三个),然后训练模型 k 次,每次取出一个不同的折叠进行评估:

在这里插入图片描述

当您看到这组数字,是不是感到很兴奋?毕竟所有交叉验证折叠的准确率(预测准确率)均超过了 95%。然而,在您兴奋于这组数字前,还是让我们来看看一个假分类器,它只是将每张图片归入最常见的类别,在本例中就是负类别(即非 5):

from sklearn.dummy import DummyClassifier

dummy_clf = DummyClassifier() 
dummy_clf.fit(X_train, y_train_5) 
print(any(dummy_clf.predict(X_train)))  # prints False: no 5s detected

您能猜出这个模型的准确度吗?让我们一探究竟:

在这里插入图片描述

没错,它的准确率超过 90%!这只是因为只有大约 10% 的图片是 5,所以如果你总是猜测图片不是 5,你就会有大约 90% 的时间是正确的。比诺斯特拉达穆斯还准。

这说明了为什么准确率通常不是分类器的首选性能指标,尤其是在处理偏斜``````数据集时(即某些类别的出现频率远高于其他类别)。评估分类器性能的更好方法是查看混淆矩阵(CM)。

实施交叉验证

Scikit-Learn现成提供的功能相比,您有时需要对交叉验证过程进行更多控制。在这种情况下,你可以自己实现交叉验证。下面的代码与 Scikit-Learn cross_val_score() 函数做了大致相同的事情,并会打印出相同的结果:

from sklearn.model_selection import StratifiedKFold 
from sklearn.base import clone

skfolds = StratifiedKFold(n_splits=3)  # add shuffle=True if the dataset is                                                # not already shuffled 
for train_index, test_index in skfolds.split(X_train, y_train_5):    
    clone_clf = clone(sgd_clf)    
    X_train_folds = X_train[train_index]    
    y_train_folds = y_train_5[train_index]    
    X_test_fold = X_train[test_index]    
    y_test_fold = y_train_5[test_index]
    clone_clf.fit(X_train_folds, y_train_folds)    
    y_pred = clone_clf.predict(X_test_fold)    
    n_correct = sum(y_pred == y_test_fold)    
    print(n_correct / len(y_pred))  # prints 0.95035, 0.96035, and 0.9604 

StratifiedKFold 类执行分层抽样,生成的折叠数包含每个类别的代表性比例。每次迭代时,代码都会创建分类器的克隆,在训练折叠上训练该克隆,并在测试折叠上进行预测。然后计算正确预测的次数,并输出正确预测的比例。

混淆矩阵

混淆矩阵的一般概念是计算在所有 A/B 对中,A 类实例被分类为 B 类的次数。例如,要知道分类器将 8 和 0 的图像混淆的次数,可以查看混淆矩阵的第 8 行第 0 列。

要计算混淆矩阵,首先需要有一组预测结果,以便与实际目标进行比较。你可以在测试集上进行预测,但最好暂时不要使用测试集(记住,只有在项目的最后阶段,也就是分类器准备好启动时,才会使用测试集)。相反,你可以使用 cross_val_predict() 函数:

from sklearn.model_selection import cross_val_predict

y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3) 

cross_val_score() 函数一样,cross_val_predict()也会执行 k 折交叉验证,但它返回的不是评估分数,而是在每个测试折上做出的预测。这意味着你可以得到训练集中每个实例的准确预测(我说的 "准确 "是指 “样本外”:模型对训练期间从未见过的数据进行预测)。

现在可以使用 confusion_matrix()函数获取混淆矩阵了。只需将目标类 (y_train_5) 和预测类 (y_train_pred) 传递给它即可:

在这里插入图片描述

混淆矩阵的每一行代表一个实际类别,每一列代表一个预测类别。矩阵的第一行是非 5 图像(负类): 其中 53 892 幅图像被正确分类为非 5 图像(称为真阴性图像),其余 687 幅图像被错误分类为 5 图像(称为假阳性图像,也称为 I 类错误)。第二行是 5 的图像(正类): 有 1 891 张图片被错误地归类为非 5(假阴性,也称为 II 类错误),而其余 3 530 张图片被正确地归类为 5(真阳性)。一个完美的分类器只有真阳性和真阴性,因此其混淆矩阵只有在主对角线上(从左上角到右下角)才有非零值:

在这里插入图片描述

混淆矩阵提供了大量信息,但有时您可能更喜欢更简洁的指标。一个有趣的指标是正向预测的准确度;这被称为分类器的精度(公式 见下图)。

在这里插入图片描述

TP 是正面的数量,FP是反面的数量。

要想获得完美的精度,一个简单的方法就是创建一个分类器,除了对它最有信心的实例进行一次正向预测外,它总是进行负向预测。如果这一个预测是正确的,那么分类器的精度就是 100%(精度 = 1/1 = 100%)。显然,这样的分类器用处不大,因为它会忽略除了一个正向实例之外的所有实例。因此,精度通常与另一个名为召回率的指标一起使用,召回率也称为灵敏度或真阳性率(TPR):这是分类器正确检测到的阳性实例的比率(公式见下图)。

在这里插入图片描述

FN当然是假不良的数量。

在这里插入图片描述

精确度和召回率

Scikit-Learn提供多种函数来计算分类器指标,包括精度和召回率:

在这里插入图片描述

现在,我们的 "5-检测器 "看起来不像我们观察它的准确性时那么闪亮了。当它声称一幅图像代表 5 时,正确率只有 83.7%。而且,它只能检测到 65.1% 的 5。

通常情况下,将精确度和召回率合并为一个称为 F1 分数的指标会比较方便,尤其是在需要用一个指标来比较两个分类器时。F1 分数是精确度和召回率的调和平均数(公式 见下图)。普通均值对所有值一视同仁,而调和均值对低值的权重要大得多。因此,分类器只有在召回率和精确率都很高的情况下才能获得较高的 F1 分数。

在这里插入图片描述

要计算 F1 分数,只需调用f1_score() 函数即可:

在这里插入图片描述

F1 分数有利于精确度和召回率相似的分类器。这并不总是你想要的:在某些情况下,你主要关心精度,而在另一些情况下,你真正关心的是召回率。例如,如果您训练了一个分类器来检测对儿童安全的视频,那么您可能更倾向于选择一个剔除了许多好视频(召回率低)但只保留安全视频(高精度)的分类器,而不是一个召回率高得多但却让一些非常糟糕的视频出现在您的产品中的分类器(在这种情况下,您甚至可能想要添加一个人工管道来检查分类器的视频选择)。另一方面,假设您训练了一个分类器来检测监控图像中的偷窃者:只要您的分类器的召回率达到 99%,即使它只有 30% 的精度也没有问题(当然,保安会收到一些错误警报,但几乎所有的偷窃者都会被抓住)。

不幸的是,鱼和熊掌不可兼得:提高精度会降低召回率,反之亦然。这就是所谓的精度/召回权衡。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/460398.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux从0到1——Linux第一个小程序:进度条

Linux从0到1——Linux第一个小程序&#xff1a;进度条 1. 输出缓冲区2. 回车和换行的本质3. 实现进度条3.1 简单原理版本3.2 实际工程版本 1. 输出缓冲区 1. 小实验&#xff1a; 编写一个test.c文件&#xff0c;&#xff1a; #include <stdio.h> #include <unistd.h…

js【详解】ajax (含XMLHttpRequest、 同源策略、跨域、JSONP)

ajax 的核心API – XMLHttpRequest get 请求 // 新建 XMLHttpRequest 对象的实例 const xhr new XMLHttpRequest(); // 发起 get 请求&#xff0c;open 的三个参数为&#xff1a;请求类型&#xff0c;请求地址&#xff0c;是否异步请求&#xff08; true 为异步&#xff0c;f…

数据结构(二)顺序表和链表

1.线性表 线性表&#xff08;linear list&#xff09;是n个具有相同特性的数据元素的有限序列。 线性表是一种在实际中广泛使 用的数据结构&#xff0c;常见的线性表&#xff1a;顺序表、链表、栈、队列、字符串... 线性表在逻辑上是线性结构&#xff0c;也就说是连续的一条直…

fiddle连接mumu模拟器到adb连接成功,保姆级

前言: 在现代的移动应用程序开发中&#xff0c;模拟器成为了一个必不可少的工具。而Mumu模拟器是一个非常受欢迎的选择&#xff0c;它提供了稳定的性能和丰富的功能。然而&#xff0c;要在模拟器上进行调试和测试&#xff0c;你需要将它与ADB连接起来。 首先&#xff0c;我将解…

网络层_IP

传输层解决的是传输控制&#xff0c;而实际真正决定数据能否发送到对端的是网络层。网络层是有概率传输&#xff0c;而传输层是可靠性传输。所以传输层网络层就可以做到将数据可靠发送到对端。网络层的常见协议有&#xff1a;IP、ICMP等&#xff0c;其中最重要的是IP协议&#…

el-table的border属性失效问题解决方案

目录 问题&#xff1a; 使用的代码&#xff1a; 官方文档的说明&#xff1a; 可能的问题所在&#xff1a; 关于使用了作用域插槽&#xff1a; a.自定义内容的样式覆盖&#xff1a; b.表格结构的改变&#xff1a; 解决方案&#xff1a; 通过css样式解决&#xff1a; 下面…

Unity AI Navigation插件快速使用方法

AI Navigation插件使您能够创建能够在游戏世界中智能移动的角色。这些角色利用的是根据场景几何结构自动生成的导航网格。障碍物可以让您在运行时改变角色的导航路径。 演示使用的Unity版本为Tuanjie 1.0.0,团结引擎是Unity中国的引擎研发团队基于Unity 2022 LTS版本为中国开发…

HTML、XHTML和HTML5系列对比

目录 HTML HTML的优点&#xff1a; HTML的缺点&#xff1a; 应用场景&#xff1a; XHTML XHTML的优点&#xff1a; XHTML的缺点&#xff1a; 应用场景&#xff1a; HTML5 HTML5的优点&#xff1a; HTML5的缺点&#xff1a; 应用场景&#xff1a; 回首发现&#xff0…

面试经典-31-随机链表的复制

题目 给你一个长度为 n 的链表&#xff0c;每个节点包含一个额外增加的随机指针 random &#xff0c;该指针可以指向链表中的任何节点或空节点。 构造这个链表的 深拷贝。 深拷贝应该正好由 n 个 全新 节点组成&#xff0c;其中每个新节点的值都设为其对应的原节点的值。新节…

C语言- strcat(拼接函数的使用和模拟)

strcat&#xff08;拼接函数的使用和模拟&#xff09; strcat的语法 strcat 是 C 语言标准库中的一个字符串拼接函数&#xff0c;它用于将一个字符串&#xff08;source&#xff09;拼接到另一个字符串&#xff08;destination&#xff09;的末尾。该函数定义在 <string.h…

Java开发从入门到精通(九):Java的面向对象OOP:成员变量、成员方法、类变量、类方法、代码块、单例设计模式

Java大数据开发和安全开发 &#xff08;一)Java的变量和方法1.1 成员变量1.2 成员方法1.3 static关键字1.3.1 static修饰成员变量1.3.1 static修饰成员变量的应用场景1.3.1 static修饰成员方法1.3.1 static修饰成员方法的应用场景1.3.1 static的注意事项1.3.1 static的应用知识…

基于暗通道的图像去雾算法,Matlab实现

博主简介&#xff1a; 专注、专一于Matlab图像处理学习、交流&#xff0c;matlab图像代码/项目合作可以联系&#xff08;QQ:3249726188&#xff09; 个人主页&#xff1a;Matlab_ImagePro-CSDN博客 原则&#xff1a;代码均由本人编写完成&#xff0c;非中介&#xff0c;提供有偿…

MongoDB简单CRUD操作(含GO中的库操作)

MongoDBCRUD操作&#xff08;含GO中的库操作&#xff09; 这周开始尝试做新项目&#xff0c;涉及到了文章的存储&#xff0c;查了查MongoDB在这方面用的比较多&#xff0c;因此对MongoDB和他在Golang中的用法进行了学习&#xff0c;以下是我的整理 文章目录 MongoDBCRUD操作&a…

IDEA中的Project工程、Module模块的概念及创建导入

1、IDEA中的层级关系&#xff1a; project(工程) - module(模块) - package(包) - class(类)/接口具体的&#xff1a; 一个project中可以创建多个module一个module中可以创建多个package一个package中可以创建多个class/接口2、Project和Module的概念&#xff1a; 在 IntelliJ …

vue3模块化引用组件和引用ts,调用ts中的接口

以简单的登录功能为例子 1.在util中创建loginValidators.ts import { ref, reactive } from vueinterface User{email: string;password: string; }export const loginUserreactive<User>({email: ,password: })interface Rules{email: {required: boolean;message: …

Html提高——HTML5 新增的语义化标签

引入&#xff1a; 以前布局&#xff0c;我们基本用 div 来做。div 对于搜索引擎来说&#xff0c;是没有语义的。 但是在html5里增加了语义化标签&#xff0c;如 <header>&#xff1a;头部标签 <nav>&#xff1a;导航标签 <article>&#xff1a;内容标签 &…

鸿蒙Harmony应用开发—ArkTS声明式开发(容器组件:ListItem)

用来展示列表具体item&#xff0c;必须配合List来使用。 说明&#xff1a; 该组件从API Version 7开始支持。后续版本如有新增内容&#xff0c;则采用上角标单独标记该内容的起始版本。该组件的父组件只能是List或者ListItemGroup。 子组件 可以包含单个子组件。 接口 从API…

[江苏工匠杯]easyphp

先看源码 <?php highlight_file(__FILE__); $key1 0; $key2 0; ​ $a $_GET[a]; $b $_GET[b]; ​ if(isset($a) && intval($a) > 6000000 && strlen($a) < 3){if(isset($b) && 8b184b substr(md5($b),-6,6)){$key1 1;}else{die("…

遗传算法及基于该算法的典型问题的求解实践

说明 遗传算法是一个很有用的工具&#xff0c;它可以帮我们解决生活和科研中的诸多问题。最近在看波束形成相关内容时了解到可以用这个算法来优化阵元激励以压低旁瓣&#xff0c;于是特地了解和学习了一下这个算法&#xff0c;觉得蛮有意思的&#xff0c;于是把这两天关于该算法…

大模型训练准备工作

一、目录 1 大模型训练需要多少算力&#xff1f; 2. 大模型训练需要多少显存&#xff1f; 3. 大模型需要多少数据量训练&#xff1f; 4. 训练时间估计 5. epoch 选择经验 6. 浮点计算性能测试 二、实现 1 大模型训练需要多少算力&#xff1f; 训练总算力&#xff08;Flops&…