【深度学习笔记】3_5 图像分类数据集fashion-mnist

注:本文为《动手学深度学习》开源内容,仅为个人学习记录,无抄袭搬运意图

3.5 图像分类数据集(Fashion-MNIST)

在介绍softmax回归的实现前我们先引入一个多类图像分类数据集。它将在后面的章节中被多次使用,以方便我们观察比较算法之间在模型精度和计算效率上的区别。图像分类数据集中最常用的是手写数字识别数据集MNIST[1]。但大部分模型在MNIST上的分类精度都超过了95%。为了更直观地观察算法之间的差异,我们将使用一个图像内容更加复杂的数据集Fashion-MNIST[2](这个数据集也比较小,只有几十M,没有GPU的电脑也能吃得消)。

本节我们将使用torchvision包,它是服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。torchvision主要由以下几部分构成:

  1. torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
  2. torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
  3. torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
  4. torchvision.utils: 其他的一些有用的方法。

3.5.1 获取数据集

首先导入本节需要的包或模块。

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append("..") # 为了导入上层目录的d2lzh_pytorch
import d2lzh_pytorch as d2l

下面,我们通过torchvision的torchvision.datasets来下载这个数据集。第一次调用时会自动从网上获取数据。我们通过参数train来指定获取训练数据集或测试数据集(testing data set)。测试数据集也叫测试集(testing set),只用来评价模型的表现,并不用来训练模型。

另外我们还指定了参数transform = transforms.ToTensor()使所有数据转换为Tensor,如果不进行转换则返回的是PIL图片。transforms.ToTensor()将尺寸为 (H x W x C) 且数据位于[0, 255]的PIL图片或者数据类型为np.uint8的NumPy数组转换为尺寸为(C x H x W)且数据类型为torch.float32且位于[0.0, 1.0]的Tensor

注意: 由于像素值为0到255的整数,所以刚好是uint8所能表示的范围,包括transforms.ToTensor()在内的一些关于图片的函数就默认输入的是uint8型,若不是,可能不会报错但可能得不到想要的结果。所以,如果用像素值(0-255整数)表示图片数据,那么一律将其类型设置成uint8,避免不必要的bug。 详见传送门2.2.4节。

mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

上面的mnist_trainmnist_test都是torch.utils.data.Dataset的子类,所以我们可以用len()来获取该数据集的大小,还可以用下标来获取具体的一个样本。训练集中和测试集中的每个类别的图像数分别为6,000和1,000。因为有10个类别,所以训练集和测试集的样本数分别为60,000和10,000。

print(type(mnist_train))
print(len(mnist_train), len(mnist_test))

输出:

<class 'torchvision.datasets.mnist.FashionMNIST'>
60000 10000

我们可以通过下标来访问任意一个样本:

feature, label = mnist_train[0]
print(feature.shape, label)  # Channel x Height x Width

输出:

torch.Size([1, 28, 28]) tensor(9)

变量feature对应高和宽均为28像素的图像。由于我们使用了transforms.ToTensor(),所以每个像素的数值为[0.0, 1.0]的32位浮点数。需要注意的是,feature的尺寸是 (C x H x W) 的,而不是 (H x W x C)。第一维是通道数,因为数据集中是灰度图像,所以通道数为1。后面两维分别是图像的高和宽。

Fashion-MNIST中一共包括了10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。以下函数可以将数值标签转成相应的文本标签。

# 本函数已保存在d2lzh包中方便以后使用
def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

下面定义一个可以在一行里画出多张图像和对应标签的函数。

# 本函数已保存在d2lzh包中方便以后使用
def show_fashion_mnist(images, labels):
    d2l.use_svg_display()
    # 这里的_表示我们忽略(不使用)的变量
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()

现在,我们看一下训练数据集中前10个样本的图像内容和文本标签。

X, y = [], []
for i in range(10):
    X.append(mnist_train[i][0])
    y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))

在这里插入图片描述

3.5.2 读取小批量

我们将在训练数据集上训练模型,并将训练好的模型在测试数据集上评价模型的表现。前面说过,mnist_traintorch.utils.data.Dataset的子类,所以我们可以将其传入torch.utils.data.DataLoader来创建一个读取小批量数据样本的DataLoader实例。

在实践中,数据读取经常是训练的性能瓶颈,特别当模型较简单或者计算硬件性能较高时。PyTorch的DataLoader中一个很方便的功能是允许使用多进程来加速数据读取。这里我们通过参数num_workers来设置4个进程读取数据。

batch_size = 256
if sys.platform.startswith('win'):
    num_workers = 0  # 0表示不用额外的进程来加速读取数据
else:
    num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

我们将获取并读取Fashion-MNIST数据集的逻辑封装在d2lzh_pytorch.load_data_fashion_mnist函数中供后面章节调用。该函数将返回train_itertest_iter两个变量。随着本书内容的不断深入,我们会进一步改进该函数。它的完整实现将在5.6节中描述。

最后我们查看读取一遍训练数据需要的时间。

start = time.time()
for X, y in train_iter:
    continue
print('%.2f sec' % (time.time() - start))

输出:

1.57 sec

小结

  • Fashion-MNIST是一个10类服饰分类数据集,之后章节里将使用它来检验不同算法的表现。
  • 我们将高和宽分别为 h h h w w w像素的图像的形状记为 h × w h \times w h×w(h,w)

参考文献

[1] LeCun, Y., Cortes, C., & Burges, C. http://yann.lecun.com/exdb/mnist/

[2] Xiao, H., Rasul, K., & Vollgraf, R. (2017). Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. arXiv preprint arXiv:1708.07747.


注:本节除了代码之外与原书基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/407003.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

抖音视频评论数据提取软件|抖音数据抓取工具

一、开发背景&#xff1a; 在业务需求中&#xff0c;我们经常需要下载抖音视频。然而&#xff0c;在网上找到的视频通常只能通过逐个复制链接的方式进行抓取和下载&#xff0c;这种操作非常耗时。我们希望能够通过关键词自动批量抓取并选择性地下载抖音视频。因此&#xff0c;为…

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的犬种识别系统(附完整代码资源+UI界面+PyTorch代码)

摘要&#xff1a;本文介绍了一种基于深度学习的犬种识别系统系统的代码&#xff0c;采用最先进的YOLOv8算法并对比YOLOv7、YOLOv6、YOLOv5等算法的结果&#xff0c;能够准确识别图像、视频、实时视频流以及批量文件中的犬种。文章详细解释了YOLOv8算法的原理&#xff0c;并提供…

RabbitMQ-消息队列:发布确认高级

18、发布确认高级 在生产环境中由于一些不明原因&#xff0c;导致 RabbitMQ 重启&#xff0c;在 RabbitMQ 重启期间生产者消息投递失败&#xff0c; 导致消息丢失&#xff0c;需要手动处理和恢复。于是&#xff0c;我们开始思考&#xff0c;如何才能进行 RabbitMQ 的消息可靠投…

零基础手把手教你创建微信小程序(二)·创建第一个微信小程序以及了解小程序代码的构成

零基础手把手教你创建微信小程序&#xff08;一&#xff09;微信小程序开发账号的注册以及开发者工具的安装和使用-CSDN博客 目录 ​编辑 1. 创建微信小程序 1.1 基本信息 1.2 在模拟器上查看项目效果 1.3 在真机上预览项目效果 1.4 主界面的5个组成部分 1.4.1 菜单…

Android studio 下的APK打包失败问题解决办法

嗨&#xff0c;各位小伙伴们&#xff0c;我是你们的好朋友咕噜铁蛋&#xff01;作为移动应用开发者&#xff0c;在使用Android Studio进行APK打包时&#xff0c;有时候可能会遇到各种问题导致打包失败&#xff0c;这给我们的开发工作带来了一定的挑战。今天&#xff0c;我将和大…

Excel 面试题及答案(2)

一、VLOOKUP+IF案例: A1 :根据左侧数据源,按姓名匹配《职级》,仅限用函数,不能做任何辅助A2 :根据左侧数据源,按姓名匹配《部门》,仅限用函数,不能做任何辅助A3 :根据右侧考核规则,匹配《绩效比例》,用函数完成(可适当做辅助的单元格区域) =VLOOKUP(F8,IF({1,0},…

qt波位图

1&#xff0c;QPainter 绘制&#xff0c;先绘制这一堆蓝色的东西, 2&#xff0c;在用定时器&#xff1a;QTimer&#xff0c;配合绘制棕色的圆。用到取余&#xff0c;取整 #pragma once#include <QWidget> #include <QPaintEvent>#include <QTimer>QT_BEGIN_…

小程序配置服务器域名:一步步教你如何设置

小程序配置服务器域名&#xff1a;一步步教你如何设置 在当今数字化时代&#xff0c;小程序已经成为了连接用户与服务的重要桥梁。然而&#xff0c;为了让小程序能够正常地与服务器进行通信&#xff0c;我们需要对小程序进行服务器域名的配置。本文将为大家详细介绍小程序配置…

【黑马程序员】STL容器之string

string string 基本概念 string本质 string是c风格的字符串&#xff0c;而string本质上是一个类 string和char* 区别 char* 是一个指针string是一个类&#xff0c;类内部封装了char*,管理这个字符串&#xff0c;是一个char*型的容器 特点 string 内部封装了很多成员方法…

C# (WebApi)整合 Swagger

SpringBoot-整合Swagger_jboot整合swagger-CSDN博客 C# webapi 也可以整合Swagger webapi运行其实有个自带的HELP页面 但是如果觉得UI不好看&#xff0c;且没办法显示方法注释等不方便的操作&#xff0c;我们也可以整合Swagger 一、使用NuGet控制台安装Swagger 在菜单中选择…

从软硬件以及常见框架思考高并发设计

目录 文章简介 扩展方式 横向扩展 纵向扩展 站在软件的层面上看 站在硬件的层面上看 站在经典的单机服务框架上看 性能提升的思考方向 可用性提升的思考方向 扩展性提升的思考方向 文章简介 先从整体&#xff0c;体系认识&#xff0c;理解高并发的策略&#xff0c;方…

LeetCode 448.找到所有数组中消失的数字

目录 1.题目 2.代码及思路 3.进阶 3.1题目 3.2代码及思路 1.题目 给你一个含 n 个整数的数组 nums &#xff0c;其中 nums[i] 在区间 [1, n] 内。请你找出所有在 [1, n] 范围内但没有出现在 nums 中的数字&#xff0c;并以数组的形式返回结果。 示例 1&#xff1a; 输入&am…

shiro 整合 springboot 实战

序言 前面我们学习了如下内容&#xff1a; 5 分钟入门 shiro 安全框架实战笔记 shiro 整合 spring 实战及源码详解 这一节我们来看下如何将 shiro 与 springboot 进行整合。 spring 整合 maven 依赖 <?xml version"1.0" encoding"UTF-8"?> …

神经网络系列---常用梯度下降算法

文章目录 常用梯度下降算法随机梯度下降&#xff08;Stochastic Gradient Descent&#xff0c;SGD&#xff09;&#xff1a;随机梯度下降数学公式&#xff1a;代码演示 批量梯度下降&#xff08;Batch Gradient Descent&#xff09;批量梯度下降数学公式&#xff1a;代码演示 小…

- 工程实践 - 《QPS百万级的有状态服务实践》05 - 持久化存储

本文属于专栏《构建工业级QPS百万级服务》 继续上篇《QPS百万级的有状态服务实践》04 - 服务一致性。目前我们的系统如图1。现在我们虽然已经尽量把相同用户的请求转发到相同的机器&#xff0c;并且在客户端做了适配。但是因为成本&#xff0c;更极端的情况下&#xff0c;服务依…

【多线程】synchronized 关键字 - 监视器锁 monitor lock

synchronized 1 synchronized 的特性1) 互斥2) 可重入 2 synchronized 使用示例1) 修饰代码块: 明确指定锁哪个对象.2) 直接修饰普通方法: 锁的 SynchronizedDemo 对象3) 修饰静态方法: 锁的 SynchronizedDemo 类的对象 3 Java 标准库中的线程安全类 1 synchronized 的特性 1)…

信号通信与消息队列实现的通信:2024/2/23

作业1&#xff1a;将信号和消息队列的课堂代码敲一遍 1.1 信号 1.1.1 信号默认、捕获、忽略处理(普通信号) 代码&#xff1a; #include <myhead.h> void handler(int signo) {if(signoSIGINT){printf("用户键入 ctrlc\n");} } int main(int argc, const ch…

招聘APP开发实践:技术选型、架构设计与开发流程

时下&#xff0c;招聘APP成为了企业和求职者之间连接的重要纽带。本文将深入探讨招聘APP的开发实践&#xff0c;重点关注技术选型、架构设计以及开发流程等关键方面&#xff0c;带领读者走进这一充满挑战与机遇的领域。 一、技术选型 在开始招聘APP的开发之前&#xff0c;首…

单片机51 输入和输出

一、IO口基本概念介绍 单片机的IO口&#xff08;Input/Output口&#xff09;是连接单片机与外部电路或设备的接口。单片机的IO口可以分为输入口和输出口两种&#xff0c;用于控制和监测外部设备的状态。 1. 输入口&#xff1a;单片机的输入口用于接收外部电路或设备的信号。输…

Day20_网络编程(软件结构,网络编程三要素,UDP网络编程,TCP网络编程)

文章目录 Day20 网络编程学习目标1 软件结构2 网络编程三要素2.1 IP地址和域名1、IP地址2、域名3、InetAddress类 2.2 端口号2.3 网络通信协议1、OSI参考模型和TCP/IP参考模型2、UDP协议3、TCP协议 2.4 Socket编程 3 UDP网络编程3.1 DatagramSocket和DatagramPacket1、Datagram…
最新文章