CIFAR-10数据集详析:使用卷积神经网络训练图像分类模型

1.数据集介绍

CIFAR-10 数据集由 10 个类的 60000 张 32x32 彩色图像组成,每类 6000 张图像。有 50000 张训练图像和 10000 张测试图像。
数据集分为5个训练批次和1个测试批次,每个批次有10000张图像。测试批次正好包含从每个类中随机选择的 1000 张图像。训练批次以随机顺序包含剩余的图像,但某些训练批次可能包含来自一个类的图像多于另一个类的图像。在它们之间,训练批次正好包含来自每个类的 5000 张图像。

总结:

Size(大小): 32×32 RGB图像 ,数据集本身是 BGR 通道
Num(数量): 训练集 50000 和 测试集 10000,一共60000张图片
Classes(十种类别): plane(飞机), car(汽车),bird(鸟),cat(猫),deer(鹿),dog(狗),frog(蛙类),horse(马),ship(船),truck(卡车)
在这里插入图片描述

下载链接

来自博主(Dream是个帅哥)的分享:
链接: https://pan.baidu.com/s/1gKazlkk108V_1nrc68VoSQ 提取码: 0213

数据集文件夹

在这里插入图片描述

CIFAR-100数据集(拓展)

这个数据集与CIFAR-10类似,只不过它有100个类,每个类包含600个图像。每个类有500个训练图像和100个测试图像。CIFAR-100中的100个子类被分为20个大类。每个图像都有一个“fine”标签(它所属的子类)和一个“coarse”标签(它所属的大类)。

CIFAR-10数据集与MNIST数据集对比

  • 维度不同:CIFAR-10数据集有4个维度,MNIST数据集有3个维度(CIRAR-10的四维: 一次的样本数量, 图片高, 图片宽, 图通道数 -> N H W C;MNIST的三维: 一次的样本数量, 图片高, 图片宽 -> N H W)
  • 图像类型不同:CIFAR-10数据集是RGB图像(有三个通道),MNIST数据集是灰度图像,这也是为什么CIFAR-10数据集比MNIST数据集多出一个维度的原因。
  • 图像内容不同:CIFAR-10数据集展示的是各种不同的物体(猫、狗、飞机、汽车…),MNIST数据集展示的是不同人的手写0~9数字。

2.数据集读取

读取数据集

选取data_batch_1可视化其中一张图:

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict
dict = unpickle('D:\PycharmProjects\model-fuxian\CIFAR\cifar-10-batches-py\data_batch_1')
print(dict)

输出结果:
一批次的数据集中有4个字典键,我们需要用到的就是 数据标签 和 数据内容(10000×32×32×3,10000张32×32大小为rgb三通道的图片)
在这里插入图片描述
输出的是一个字典:

{
b’batch_label’: b’training batch 1 of 5’,
b’labels’: [6, 9 … 1,5],
b’data’: array([[ 59, 43, …, 84, 72],…[ 62, 61, 60, …, 130, 130, 131]], dtype=uint8),
b’filenames’: [b’leptodactylus_pentadactylus_s_000004.png’,…b’cur_s_000170.png’]

}

其中,各个代表的意思如下:
b’batch_label’ : 所属文件集
b’labels’ : 图片标签
b’data’ :图片数据
b’filename’ :图片名称

读取类型

print(type(dict[b'batch_label']))
print(type(dict[b'labels']))
print(type(dict[b'data']))
print(type(dict[b'filenames']))

输出结果:

<class ‘bytes’>
<class ‘list’>
<class ‘numpy.ndarray’>
<class ‘list’>

读取图片

img = dict[b'data']
print(img.shape)

输出结果:(10000, 3072),其中 3072 = 32 * 32 * 3 (图片 size)

3.数据集调用

TensorFlow 调用

from tensorflow.keras.datasets import cifar10

(x_train,y_train), (x_test, y_test) = cifar10.load_data()

本地调用

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict
dict = unpickle('D:\PycharmProjects\model-fuxian\CIFAR\cifar-10-batches-py\data_batch_1')

4.卷积神经网络训练

此处参考:传送门

1.指定GPU

gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0],True)
#初始化
plt.rcParams['font.sans-serif'] = ['SimHei']

2.加载数据

cifar10 = tf.keras.datasets.cifar10
(train_x,train_y),(test_x,test_y) = cifar10.load_data()
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s'%(train_x.shape,train_y.shape,test_x.shape,test_y.shape))

3.数据预处理

X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32)     #归一化
y_train,y_test = tf.cast(train_y,tf.int16),tf.cast(test_y,tf.int16)

4.建立模型

adam算法参数采用keras默认的公开参数,损失函数采用稀疏交叉熵损失函数,准确率采用稀疏分类准确率函数

model = tf.keras.Sequential()
##特征提取阶段
#第一层
model.add(tf.keras.layers.Conv2D(16,kernel_size=(3,3),padding='same',activation=tf.nn.relu,data_format='channels_last',input_shape=X_train.shape[1:]))  #卷积层,16个卷积核,大小(3,3),保持原图像大小,relu激活函数,输入形状(28,28,1)
model.add(tf.keras.layers.Conv2D(16,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2)))   #池化层,最大值池化,卷积核(2,2)
#第二层
model.add(tf.keras.layers.Conv2D(32,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.Conv2D(32,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2)))
##分类识别阶段
#第三层
model.add(tf.keras.layers.Flatten())    #改变输入形状
#第四层
model.add(tf.keras.layers.Dense(128,activation='relu'))     #全连接网络层,128个神经元,relu激活函数
model.add(tf.keras.layers.Dense(10,activation='softmax'))   #输出层,10个节点
print(model.summary())      #查看网络结构和参数信息

#配置模型训练方法
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])

5.训练模型

批量训练大小为64,迭代5次,测试集比例0.2(48000条训练集数据,12000条测试集数据)

history = model.fit(X_train,y_train,batch_size=64,epochs=5,validation_split=0.2)

6.评估模型

model.evaluate(X_test,y_test,verbose=2)     #每次迭代输出一条记录,来评价该模型是否有比较好的泛化能力

#保存整个模型
model.save('CIFAR10_CNN_weights.h5')

7.结果可视化

print(history.history)
loss = history.history['loss']          #训练集损失
val_loss = history.history['val_loss']  #测试集损失
acc = history.history['sparse_categorical_accuracy']            #训练集准确率
val_acc = history.history['val_sparse_categorical_accuracy']    #测试集准确率

plt.figure(figsize=(10,3))

plt.subplot(121)
plt.plot(loss,color='b',label='train')
plt.plot(val_loss,color='r',label='test')
plt.ylabel('loss')
plt.legend()

plt.subplot(122)
plt.plot(acc,color='b',label='train')
plt.plot(val_acc,color='r',label='test')
plt.ylabel('Accuracy')
plt.legend()

8.使用模型

plt.figure()
for i in range(10):
    num = np.random.randint(1,10000)
    plt.subplot(2,5,i+1)
    plt.axis('off')
    plt.imshow(test_x[num],cmap='gray')
    demo = tf.reshape(X_test[num],(1,32,32,3))
    y_pred = np.argmax(model.predict(demo))
    plt.title('标签值:'+str(test_y[num])+'\n预测值:'+str(y_pred))
plt.show()

输出结果:
在这里插入图片描述
在这里插入图片描述
上面的内容分别是训练样本的损失函数值和准确率、测试样本的损失函数值和准确率,可以看到它每次训练迭代时损失函数和准确率的变化,从最后一次迭代结果上看,测试样本的损失函数值达到0.9123,准确率仅达到0.6839。
这个结果并不是很好,我尝试过增加迭代次数,发现训练样本的损失函数值可以达到0.04,准确率达到0.98;但实际上训练模型却产生了越来越大的泛化误差,这就是训练过度的现象,经过尝试泛化能力最好时是在迭代第5次的状态,故只能选择迭代5次。
在这里插入图片描述

训练好的模型文件——直接用

CIFAR10数据集介绍,并使用卷积神经网络训练图像分类模型——附完整代码训练好的模型文件——直接用:https://download.csdn.net/download/weixin_51390582/88788820

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/354364.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

如何在AirPods Pro中使用降噪功能?这里提供几个方法

本文介绍了如何在AirPods Pro上使用降噪功能&#xff0c;如何关闭它&#xff0c;以及该功能的工作原理。 注意&#xff1a;AirPods Pro和AirPods Max支持噪音消除。你的设备必须运行iOS 13.2或iPadOS 13.2或更高版本才能使用降噪功能。 如何在AirPods Pro上打开降噪功能 Air…

CSS color探索

CSS 颜色探索 在 CSS 的世界里&#xff0c;颜色为网页元素赋予了丰富的视觉效果。通过预定义的颜色名称、RGB、HEX、HSL&#xff0c;以及支持透明度的 RGBA 和 HSLA&#xff0c;我们可以创造出各种吸引人的设计。接下来&#xff0c;我们将通过示例代码来深入了解这些颜色应用。…

重构改善既有代码的设计-学习(六):处理继承关系

1、函数上移&#xff08;Pull Up Method&#xff09; 无论何时&#xff0c;只要系统内出现重复&#xff0c;你就会面临“修改其中一个却未能修改另一个”的风险。通常&#xff0c;找出重复也有一定的难度。 所以&#xff0c;某个函数在各个子类中的函数体都相同&#xff08;它们…

MYSQL中group by分组查询的用法详解(where和having的区别)!

文章目录 前言一、数据准备二、使用实例1.如何显示每个部门的平均工资和最高工资2.显示每个部门的每种岗位的平均工资和最低工资3.显示平均工资低于2000的部门和它的平均工资4.having 和 where 的区别5.SQL查询中各个关键字的执行先后顺序 前言 在前面的文章中&#xff0c;我们…

指针的深入了解2

1.const修饰指针 在这之前我们还学过static修饰变量&#xff0c;那我们用const来修饰一下变量会有什么样的效果呢&#xff1f; 我们来看看&#xff1a; 我们可以看到编译器报错告诉我们a变成了一个不可修改的值&#xff0c;我们在变量前加上了const进行限制&#xff0c;但是我…

深入理解与防范C语言中的栈溢出问题

一、引言 栈溢出是计算机安全领域中一个常见的漏洞&#xff0c;特别是在C语言编程中。由于C语言的灵活性和对内存管理的直接操作性&#xff0c;如果程序员在编写代码时不注意&#xff0c;就可能导致栈溢出的发生。本文将全面解析栈溢出的概念、原因、影响以及防范措施。 二、…

绘制太极图 - 使用 PyQt

大家好&#xff01;今天我们将一起来探讨一下如何使用PyQt&#xff0c;这是一个强大的Python库&#xff0c;来绘制一个传统的太极图。这个图案代表着古老的阴阳哲学&#xff0c;而我们的代码将以大白话的方式向你揭示它的奥秘。 PyQt&#xff1a;是什么鬼&#xff1f; 首先&a…

嵌入式——窗口看门狗(WWDG)补充

目录 一、独立看门狗与窗口看门狗 1.功能描述 2.两者区别 二、WWDG功能描述 1.窗口看门狗时钟 2.计数器时钟 3. 计数器 4.窗口值 三、WWDG超时时间 一、独立看门狗与窗口看门狗 1.功能描述 STM32有两个看门狗&#xff1a;一个是独立看门狗&#xff08;IWDG&#xff0…

【GPU】GPU 硬件与 CUDA 程序开发工具

GPU 硬件与 CUDA 程序开发工具 笔记内容来自&#xff1a;《CUDA 编程&#xff1a;基础与实践》—樊哲勇 著 本文目录 GPU 硬件简介CUDA 程序开发工具CUDA 开发环境搭建用 nvidia-smi 检查与设置设备CUDA 的官方手册 GPU 硬件简介 GPU 是英文 graphics processing unit 的首字母…

【GitHub项目推荐--基于 AI 的口语训练平台】【转载】

Polyglot Polyglot 是一个开源的基于 AI 的口语训练平台客户端&#xff0c;可以在 Windows、Mac 上使用。 比如你想练习英语口语&#xff0c;只需在该平台配置一个虚拟的 AI 国外好友&#xff0c;你可以通过发语音的方式和 AI 好友交流&#xff0c;通过聊天的方式提升你的口…

黑马程序员——html css基础——day05——盒子模型

目录&#xff1a; 选择器 结构伪类选择器:nth-child(公式)伪元素选择器PxCook盒子模型 盒子模型-组成边框线 四个方向单方向边框线内边距尺寸计算外边距版心居中清除默认样式元素溢出外边距问题 合并现象外边距塌陷行内元素–内外边距问题圆角盒子阴影&#xff08;拓展&#x…

【Java】Spring的APO及事务

今日目标 能够理解AOP的作用 能够完成AOP的入门案例 能够理解AOP的工作流程 能够说出AOP的五种通知类型 能够完成"测量业务层接口万次执行效率"案例 能够掌握Spring事务配置 一、AOP 1 AOP简介 问题导入 问题1&#xff1a;AOP的作用是什么&#xff1f; 问题2&am…

java设计模式:工厂模式

1&#xff1a;在平常的开发工作中&#xff0c;我们可能会用到不同的设计模式&#xff0c;合理的使用设计模式&#xff0c;可以提高开发效率&#xff0c;提高代码质量&#xff0c;提高系统的可拓展性&#xff0c;今天来简单聊聊工厂模式。 2&#xff1a;工厂模式是一种创建对象的…

HCIP-交换机实验

实验拓扑 实验需求 实验思路 配置IP地址 配置vlan 实验步骤 配置IP地址 以pc1为例&#xff1a; 配置vlan 以sw1为例&#xff1a; <Huawei>sys Enter system view, return user view with CtrlZ. [Huawei]sys sw1 [sw1]vlan 3 Jan 28 2024 15:39:45-08:00 sw1 DS/…

HPE ProLiant MicroServer Gen8安装windows server 2019

按照《HP MicroServer Gen8使用官方工具安装Windows Sverver 2016教程》安装时&#xff0c;安装系统选项并没有server 2019可选&#xff0c;依然只是server 2012&#xff0c;我还以为是从单位拿回来的镜像有误&#xff0c;从官方下载了server 2019评估版&#xff0c;但依然只有…

【算法专题】二分查找(进阶)

&#x1f4d1;前言 本文主要是二分查找&#xff08;进阶&#xff09;的文章&#xff0c;如果有什么需要改进的地方还请大佬指出⛺️ &#x1f3ac;作者简介&#xff1a;大家好&#xff0c;我是青衿&#x1f947; ☁️博客首页&#xff1a;CSDN主页放风讲故事 &#x1f304;每日…

【UEFI实战】Redfish的BIOS实现——生成EDK数据

生成Redfish文件 Redfish数据的表示形式&#xff0c;最常用的是JSON。将JSON表示的数据转换成C语言可以操作的结构体&#xff0c;是必不可少的步骤。当然如果手动转换的话&#xff0c;需要浪费大量的时间&#xff0c;因此DMTF组织开发了一个工具&#xff0c;用于将JSON数据快速…

实验6:循环与子程序设计

1、实验目的&#xff1a; 通过完成将字节内存单元存储的8个数依次显示在屏幕上的程序设计&#xff0c;掌握循环与子程序设计的方法。 2、实验内容&#xff1a; 将内存单元存储的8个两位16进制数&#xff1a;01H, 25H, 38H, 62H, 8DH, 9AH, BAH, CEH依次显示在屏幕上。 3、实…

CSS设置单行文字水平垂直居中的方法

<!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>单行文字水平垂直居中</title><style>div {/* 给div设置宽高 */width: 400px;height: 200px;margin: 100px auto;background-color: red;/…

4.列表选择弹窗(CenterListPopup)

愿你出走半生,归来仍是少年&#xff01; 环境&#xff1a;.NET 7、MAUI 在屏幕中间弹窗的列表选择弹窗。 1.布局 <?xml version"1.0" encoding"utf-8" ?> <toolkit:Popup xmlns"http://schemas.microsoft.com/dotnet/2021/maui"x…
最新文章