PyTorch中定义自己的数据集

文章目录

    • 1. 简介
    • 2. 查看PyTorch自带的数据集(可视化)
    • 3. 准备材料
      • 3.1 图片数据
      • 3.2 标签数据
    • 4. 方法

1. 简介

尽管PyTorch提供了许多自带的数据集,如MNIST、CIFAR-10、ImageNet等,但它们对于没有经验的用户来说,理解数据加载器的工作原理以及如何正确地配置数据加载器可能会有一定难度。 用户需要了解所使用的数据集,包括数据集的内容、结构、标签等信息。对于一些复杂的数据集,用户可能需要理解数据集的结构和标签的含义。通过定义自己的数据集类,您可以更好地控制数据的加载和处理过程,提高代码的灵活性、可读性和可维护性,同时更好地满足模型训练的需求。

2. 查看PyTorch自带的数据集(可视化)

为了更好的定义自己的数据集,我们首先查看PyTorch自带的数据集的内容,代码如下

# 导入所需的库
import matplotlib.pyplot as plt  # 导入Matplotlib库,用于可视化
import torch  # 导入PyTorch库
from torchvision.datasets import MNIST  # 从torchvision中导入MNIST数据集
from torchvision import transforms  # 导入transforms模块,用于数据预处理
import numpy as np  # 导入NumPy库

# 加载MNIST数据集
train_mnist_data = MNIST(root='./data',  # 数据集存储路径
                         train=True,  # 加载训练集
                         transform=transforms.Compose([transforms.Resize(size=(28, 28)), transforms.ToTensor()]),  # 数据预处理操作
                         download=True)  # 如果数据集不存在,则自动下载

# 设置要显示的样本数量
num_samples = 10

# 创建包含多个子图的大图窗口
fig, axes = plt.subplots(1, num_samples, figsize=(10, 6))

# 遍历选择要显示的样本
for i in range(num_samples):
    # 从数据集中获取图像数据和标签
    image, label = train_mnist_data[i]
    
    # 在子图中显示图像
    axes[i].imshow(image.squeeze().numpy(), cmap='gray')  # 使用imshow函数显示图像,将张量转换为NumPy数组
    axes[i].set_title(f"Label: {label}")  # 设置子图标题,显示图像对应的标签
    axes[i].axis('off')  # 关闭坐标轴显示
    
    # 将图像保存为PNG格式的图片文件,文件名以图像的标签命名
    plt.imsave(f"./data/mnist_images/{label}.png", image.squeeze().numpy(), cmap='gray')

# 显示图形窗口
plt.show()

这里,我们使用MNIST类加载MNIST数据集。在加载数据集时,通过transform参数指定了数据预处理操作,包括将图像大小调整为28x28像素,并将图像转换为张量。train=True表示加载训练集,download=True表示如果数据集不存在则自动下载到指定的路径。

接下来,我们选择一些样本进行可视化。我们在一个子图中显示了10个样本,每个样本对应一个数字图像和其对应的标签。通过循环遍历这些样本,从数据集中获取图像数据和标签,并使用Matplotlib的imshow()函数将图像显示在子图中。
在这里插入图片描述

同时,使用imsave()函数将每个图像保存为PNG格式的图片文件,文件名以标签命名。最后,使用plt.show()显示图形窗口,显示图像的同时也会将图像保存到指定的路径中。这段代码的执行结果是显示10张MNIST数据集中的数字图像,并将这些图像保存到指定路径下。保存的图片如下所示

在这里插入图片描述

通过上面程序可以看到,数据集主要是由图片数据和对应的标签构成,那么我们就可以用这两个主要构成成分来构建自己的数据集。

3. 准备材料

3.1 图片数据

这里我们就用刚才保存的十张图片,即

在这里插入图片描述

当然,你也可以准备其它的图片,并给图片分别命名为“0.png, 1.png, …”。

这里,十张图片的相对路径为

imgs_path = "./data/mnist_images"

注:你们要根据自己存储的路径来给定。

3.2 标签数据

创建一个txt文件,为每一幅图片指定标签数据,如下所示

在这里插入图片描述

这里,txt文件的相对路径为

labels_path = "labels.txt"

4. 方法

在PyTorch中,您可以通过创建一个自定义的数据集类来定义自己的数据集。这个自定义类需要继承自torch.utils.data.Dataset类,并且实现两个主要的方法:__len____getitem____len__方法应该返回数据集的长度,而__getitem__方法则根据给定的索引返回数据集中的样本。

下面我们展示如何创建一个自定义的数据集类:

import os  # 导入os模块,用于操作文件路径
from PIL import Image  # 导入PIL库中的Image模块,用于图像处理
import torch  # 导入PyTorch库
from torch.utils.data import Dataset  # 从torch.utils.data模块导入Dataset类,用于定义自定义数据集
from torchvision import transforms  # 导入transforms模块,用于数据预处理
import numpy as np  # 导入NumPy库,用于数值处理
import matplotlib.pyplot as plt  # 导入Matplotlib库,用于可视化


class CustomDataset(Dataset):
    def __init__(self, image_dir, label_file, transform=None):
        super().__init__()  # 调用父类的构造函数
        self.image_dir = image_dir  # 图像数据的路径
        self.label_file = label_file  # 标签文本的路径
        self.transform = transform  # 数据预处理操作
        self.samples = self._load_samples()  # 加载数据集样本信息

    def _load_samples(self):
        samples = []  # 存储样本信息的列表
        with open(self.label_file, 'r') as f:  # 打开标签文本文件
            for line in f:  # 逐行读取标签文本文件中的内容
                image_name, label = line.strip().split(',')  # 根据逗号分隔每行内容,获取图像文件名和标签
                image_path = os.path.join(self.image_dir, image_name)  # 拼接图像文件的完整路径
                samples.append((image_path, int(label)))  # 将图像路径和标签组成元组,加入样本列表
        return samples  # 返回样本列表

    def __len__(self):
        return len(self.samples)  # 返回数据集样本的数量

    def __getitem__(self, index):
        image_path, label = self.samples[index]  # 获取指定索引处的图像路径和标签
        image = Image.open(image_path).convert('L')  # 打开图像文件并将其转换为灰度图像
        if self.transform:  # 如果定义了数据预处理操作
            image = self.transform(image)  # 对图像进行预处理操作
        return image, label  # 返回预处理后的图像和标签


# 设置图片数据路径和标签文本路径
image_dir = './data/mnist_images'  # 图像数据的路径
label_file = 'labels.txt'  # 标签文本的路径

# 定义数据预处理操作,根据需要添加其他预处理操作
transform = transforms.Compose([
    transforms.Resize((28, 28)),  # 调整图像大小
    transforms.ToTensor(),  # 将图像转换为张量
])

# 创建自定义数据集实例
custom_dataset = CustomDataset(image_dir, label_file, transform=transform)

# 创建数据加载器
data_loader = torch.utils.data.DataLoader(custom_dataset, batch_size=1, shuffle=False)

# 遍历数据加载器中的每个批次数据
for batch_images, batch_labels in data_loader:
    # 使用squeeze()函数去除图像张量中的单维度,将图像数据转换为NumPy数组,并存储在变量image中
    image = batch_images.squeeze().numpy()

    # 使用imshow()函数显示图像,cmap='gray'指定使用灰度色彩映射
    plt.imshow(image, cmap='gray')

    # 设置图像标题,显示图像对应的标签,使用f-string格式化字符串,将batch_labels转换为Python标量并获取其值
    plt.title(f"Label: {batch_labels.item()}")

    # 关闭坐标轴显示,即不显示坐标轴
    plt.axis('off')

    # 显示图形窗口
    plt.show()


这段代码实现了加载自定义数据集,并使用 PyTorch 的 DataLoader 将数据加载成批次,然后逐批次地展示图像。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/609862.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【数据结构】栈的实现以及数组和链表的优缺点

个人主页:一代… 个人专栏:数据结构 1.栈 1.1栈的概念及结构 栈:一种特殊的线性表,其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端 称为栈顶,另一端称为栈底。栈中的数据元素遵守后进…

批量自定义重命名,一键添加顺序编号,文件夹管理更高效!

我们经常需要对文件夹进行管理和整理。然而,当面对大量需要改名的文件夹时,手动逐个修改不仅效率低下,还容易出错。那么,有没有一种方法能够批量自定义重命名文件夹,并在名称后自动添加顺序编号呢?答案是肯…

C++反汇编——多态,面试题01

文章目录 1.C的三大特性1.1封装1.2继承1.3多态1.3.1 虚函数1.3.2 多态代码反汇编分析。反汇编分析1——基类指针指向子类对象,构造过程。反汇编分析2——基类指针指向子类对象,调用虚函数getPrice()过程。反汇编分析3——基类对象,调用虚函数…

版本控制工具之Git的基础使用教程

Git Git是一个分布式版本控制系统,由Linux之父Linus Torvalds 开发。它既可以用来管理和追踪计算机文件的变化,也是开发者协作编写代码的工具。 本文将介绍 Git 的基础原理、用法、操作等内容。 一、基础概念 1.1 版本控制系统 版本控制系统&#x…

PSoc™62开发板之IoT应用

实验目的 使用PSoc62™开发板驱动OLED模块,实时监控室内的光照强度、温度信息 实验准备 PSoc62™开发板SSD1309 OLED模块DS18B20温度传感器BH1750光照传感器 模块电路 SSD1309 OLED模块的电路连接和模块配置教程请参考之前的文章,这里不详细展开描…

汽车EDI:IAC Elmdon EDI 对接指南

近期收到客户C公司的需求,需要与其合作伙伴IAC Elmdon建立EDI连接,本文将主要为大家介绍IAC Elmdon EDI 对接指南,了解EDI项目的对接流程。 项目需求 传输协议:OFTP2 IAC Elmdon 与其供应商之间使用的传输协议为OFTP2。OFTP2是…

云南区块链商户平台优化开发

背景 云南区块链商户平台是全省统一区块链服务平台。依托于云南省发改委、阿里云及蚂蚁区块链的国内首个省级区块链平台——云南省区块链平台同步上线,助力数字云南整体升级。 网页版并不适合妈妈那辈人使用,没有记忆功能,于是打算自己开发…

科技查新中医学科研项目查新点如何确立与提炼?案例讲解

一、前言 医学科技查新包括立项查新和成果查新两个部分,其中医学立项查新,它是指在医学科研项目申报开题之前,通过在一定范围内进行该课题的相关文献检索 ( 可以根据项目委托人的具体要求,进行国内检索或者进行国外检索 ) &#x…

媲美Suno、Udio!AI铁了心,要砸音乐人的饭碗

5月10日凌晨,著名语音生成式AI平台ElevenLabs在社交平台宣布,推出文本生成歌曲产品ElevenLabs Music。 从其展示的效果来看,音乐的节奏感、和声、乐器的搭配、情感表达、创意性、风格的多样性、高/低音,可媲美该领域的两款头部产…

k8s StatefulSet

Statefulset 一个 Statefulset 创建的每个pod都有一个从零开始的顺序索引,这个会体现在 pod 的名称和主机名上,同样还会体现在 pod 对应的固定存储上。这些 pod 的名称是可预知的,它是由 Statefulset 的名称加该实例的顺序索引值组成的。不同…

【元对象系统概述】

元对象系统概述 🌟 元对象🌟 元对象系统🌟 QT官方文档中给出的定义🌟《Qt5.9 C开发指南》中给出的定义 🌟 元对象 元对象是一个描述类的信息的数据结构,在qt中常常与QObject的类相关联。 可以通过QObject::…

这些企业注意!推荐使用OVSSL证书

JoySSL官网 注册码230918 SSL证书作为一种重要的安全措施,对于确保网站数据传输的安全性至关重要。而在众多SSL证书类型中,OV(Organization Validation,组织验证)SSL证书以其独特的功能和适用范围,成为众多…

夸克网盘免费扩容N次20T的方法

上文我们用:夸克网盘免费领取1TB空间的方法使自己的网盘扩容到1TB,但只有三个月还不够大。 所以用下面的方法那个免费的把自己的网盘扩容到20TB。 一、 登录任推邦 APP 需要借助这个平台,这是夸克网盘的第三方服务商,完善注册信…

2024年自动驾驶、车辆工程与智能交通国际会议(ICADVEIT2024)

2024年自动驾驶、车辆工程与智能交通国际会议(ICADVEIT2024) 会议简介 2024年自动驾驶、车辆工程和智能交通国际会议(ICADVEIT 2024)将在中国深圳举行。会议主要聚焦自动驾驶、车辆工程和智能交通等研究领域,旨在为从…

智慧便民小程序源码系统 求职招聘+房产出租+相亲交友 带完整的安装代码包以及系统搭建教程

在数字化、智能化的今天,我们的生活节奏越来越快,对于各种服务的需求也越发多元化和个性化。为了满足广大市民对于便捷、高效、全面的服务需求,罗峰给大家分享一款智慧便民小程序源码系统,集求职招聘、房产出租、相亲交友三大功能…

【全开源】Java U U跑腿同城跑腿小程序源码快递代取帮买帮送源码小程序+H 5+公众号跑腿系统

特色功能: 智能定位与路线规划:UU跑腿小程序能够利用定位技术,为用户提供附近的跑腿服务,并自动规划最佳路线,提高配送效率。订单管理:包括订单查询、订单状态更新、订单评价等功能,全行业覆盖…

Mac YOLO V9本地训练(命令行模式)

环境: Mac M1 (MacOS Sonoma 14.3.1) Python 3.11PyTorch 2.1.2 一、YOLO v9工程及模型准备 详见:Mac YOLO V9推理测试-CSDN博客 二、数据集准备 Roboflow Universe上有许多小规模的数据集,很适合用来进行目标检测。 首先安装依赖 pip …

NVIDIA 配置 Jetson 扩展针座

系列文章目录 前言 每个 Jetson 开发套件包括多个扩展接头和连接器(统称 "接头"): 40 针扩展接头: 可让您将 Jetson 开发套件连接到现成的 Raspberry Pi HAT(顶部附加硬件),如 Seee…

echarts-gl 离线3D地图

1、安装依赖 echarts-gl 与 echarts 版本关系: "echarts": "^5.2.0", "echarts-gl": "^2.0.8"# 执行安装 yarn add echarts-gl2、下载离线地图 免费下载实时更新的geoJson数据、行政区划边界数据、区划边界坐标集合_…

笨方法自学python(一)

我觉得python和c语言有很多相似之处,如果有c语言基础的话学习python也不是很难。这一系列主要是学习例题来学习python;我用的python版本是3.12 代码编辑器我用的是notepad,运行py程序用cmd 现在开始写第一个程序: print ("…
最新文章