Pytorch:Dataset类和DataLoader类

文章目录

  • 一、Dataset 类
    • 1、定义
    • 2、示例
  • 二、DataLoader 类
    • 1、定义
    • 2、参数
    • 3、示例:使用 DataLoader
  • 三、总结
  • 四、实战
    • 1、load_data函数:
    • 2、IrisDataset类
    • 3、DataLoader 的使用

  在机器学习和深度学习框架中,尤其是在 PyTorch 中,DatasetDataLoader 是处理和加载数据的重要工具。这里我们详细探讨这两个类的结构、用途和如何实际使用它们。
  数据集(Dataset)是指存储和表示数据的类或接口。它通常用于封装数据,以便能够在机器学习任务中使用。数据集可以是任何形式的数据,比如图像、文本、音频等。数据集的主要目的是提供对数据的标准访问方法,以便可以轻松地将其用于模型训练、验证和测试。
  数据加载器(DataLoader)是一个提供批量加载数据的工具。它通过将数据集分割成小批量,并按照一定的顺序加载到内存中,以提高训练效率。数据加载器常用于训练过程中的数据预处理、批量化操作和数据并行处理等。

  • 他俩都在torch.utils.data
  • from torch.utils.data import Dataset,DataLoader

一、Dataset 类

1、定义

Dataset 是一个抽象类,用于表示一个数据集的全部内容。在 PyTorch 中,任何继承自 torch.utils.data.Dataset 的自定义数据集需要实现两个必须的方法:

  • __getitem__(self, index)
    • 这个方法应该返回一个索引处的数据点和其对应的标签。例如,在图像数据集中,这可能是一对(图像,标签)。
  • __len__(self)
    • 这个方法返回数据集中的数据点的总数,即数据集的大小。

2、示例

下面是一个简单的形象化例子,展示如何创建一个用于加载图像数据集的自定义 Dataset 类:

import torch
from torch.utils.data import Dataset
class IceCreamDataset(Dataset):
    def __init__(self):
        self.flavors = ["vanilla", "chocolate", "strawberry"]

    def __len__(self):
        return len(self.flavors)

    def __getitem__(self, index):
        return f"One scoop of {self.flavors[index]} ice cream"
ice_cream_menu = IceCreamDataset()

在这个例子中,IceCreamDataset 类定义了一个冰激凌数据。

二、DataLoader 类

1、定义

DataLoader 是一个迭代器,用于将 Dataset 封装成易于访问的数据流,支持批量加载和多进程数据加载等操作。

2、参数

  • dataset: 要加载的 Dataset 对象。
  • batch_size(可选): 每个批次加载的样本数量。即对Dataset数据集进行等分,分成的份数(每份叫作一个batch)为len(dataset)/batch_sizebatch_size通常是单次训练使用的数据量,默认为1。
  • shuffle(可选): 是否在每个训练周期开始时打乱数据。
  • num_workers(可选): 用于数据加载的进程数。

3、示例:使用 DataLoader

一旦定义了 Dataset,就可以使用 DataLoader 来有效地加载数据:

from torch.utils.data import DataLoader

# 创建 DataLoader,每批三份不同口味的冰激凌
ice_cream_loader = DataLoader(ice_cream_menu)#等价于ice_cream_loader = DataLoader(ice_cream_menu,batch_size=1)

for batch in ice_cream_loader:
    print(batch)

在这个例子中,data_loader 会自动管理从 dataset 中加载数据的复杂性,如批量加载、打乱顺序和多进程加载。
输出:

['One scoop of vanilla ice cream']
['One scoop of chocolate ice cream']
['One scoop of strawberry ice cream']
ice_cream_loader = DataLoader(ice_cream_menu,batch_size=2)

输出:

['One scoop of vanilla ice cream', 'One scoop of chocolate ice cream']
['One scoop of strawberry ice cream']
ice_cream_loader = DataLoader(ice_cream_menu,batch_size=3)#大于等于3的输出一样,因为就三个数据了
['One scoop of vanilla ice cream', 'One scoop of chocolate ice cream', 'One scoop of strawberry ice cream']

三、总结

通过组合使用 DatasetDataLoader,PyTorch 用户可以高效、灵活地处理大规模数据集。Dataset 提供了一个清晰的接口来访问单个数据点__getitem__),而 DataLoader 管理整个数据集的批量处理和并行加载,这两者的结合极大地简化了在训练深度学习模型时的数据处理工作。

为了简单说明,以下我们将继承Dataset类的类,说成Dataset
根据上述简单的例子,我们可以知道,Dataset可以用来导入数据集,并规定整个数据集的长度是如何计算的,并规定单个数据点的格式;而DataLoader配合Dataset使用,可以导入数据集,并规定该数据集划分的批次数量和批次大小,以及导入数据集时是否打乱数据等。


对于:

for batch in dataloader:
	pass

为了理解batchbatch_size,可以这样去想:
  假设有512个箱子,将这些箱子,每16个分成一份,一共有32份,每一份叫作一个batch,而每个batch里面一共16个箱子。每16个箱子为一批,一批一批进行拆箱,即一个batch一个batch进行处理。遍历dataloader,每次取出的是一个batch,从上面的例子可以发现,batch里面的元素是通过列表组织在一起的。

  每一个batch实际上就是DataLoaderDataset划分成的一个批次,每个batch的大小就是batch_size(除非数据集不是它的整数倍,上面也有体现)。所有batch加起来才构成整个Dataset
  如果是图片数据集,batch_size可以认为,一个batchbatch_size张图片(如果该数据集规定单个数据点是一张图片的话。)(因为DataLoader访问数据时,会按照Dataset规定的数据点规格访问)。

四、实战

以上是一个简单的实例,方便理解,现在我们进行实战。

import torch
from sklearn.datasets import load_iris
from torch.utils.data import Dataset, DataLoader
 
# 此函数用于加载鸢尾花数据集
def load_data(shuffle=True):
    x = torch.tensor(load_iris().data)
    y = torch.tensor(load_iris().target)
 
    # 数据归一化
    x_min = torch.min(x, dim=0).values
    x_max = torch.max(x, dim=0).values
    x = (x - x_min) / (x_max - x_min)
 
    if shuffle:
        idx = torch.randperm(x.shape[0])
        x = x[idx]
        y = y[idx]
    return x, y
 
# 自定义鸢尾花数据类
class IrisDataset(Dataset):
    def __init__(self, mode='train', num_train=120, num_dev=15):
        super(IrisDataset, self).__init__()
        x, y = load_data(shuffle=True)
        if mode == 'train':
            self.x, self.y = x[:num_train], y[:num_train]
        elif mode == 'dev':
            self.x, self.y = x[num_train:num_train + num_dev], y[num_train:num_train + num_dev]
        else:
            self.x, self.y = x[num_train + num_dev:], y[num_train + num_dev:]
 
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]
 
    def __len__(self):
        return len(self.x)
 
batch_size = 16
 
# 分别构建训练集、验证集和测试集
train_dataset = IrisDataset(mode='train')
dev_dataset = IrisDataset(mode='dev')
test_dataset = IrisDataset(mode='test')
 
train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

这段代码涉及到使用 PyTorch 加载和处理著名的鸢尾花(Iris)数据集,并将其分成训练集、验证集和测试集。下面逐部分详细解释:

1、load_data函数:

  1. 加载数据:

    • 使用 load_iris() 函数从 scikit-learn 库中加载鸢尾花数据集。这个函数返回包含特征(data)和目标(target)的数据结构。
    • 数据转换成 PyTorch 张量,方便后续使用 PyTorch 进行操作。
  2. 归一化:

    • 对特征进行归一化处理,使得每个特征的值范围都缩放到 [0, 1] 区间。这是通过从每个特征中减去最小值,然后除以其范围(最大值 - 最小值)来实现的。
    • 归一化有助于模型训练,因为它确保了所有特征都在相同的尺度上,从而加速学习过程。
  3. 打乱数据:

    • 如果启用 shuffle,则通过生成一个随机排列的索引并重新排序数据来打乱数据集。这通常用于训练数据集,以保证每次训练的随机性和泛化能力。
    • 这里使用的方法:
      • idx = torch.randperm(x.shape[0])x.shape[0]是二维张量的行数。 torch.randperm即随机打乱(生成一个 0 到样本数量减一的随机排列),得到一个随机排列。
      • x = x[idx];y = y[idx],使用的是高级索引:使用多个整数索引访问多个元素

2、IrisDataset类

  • IrisDataset 类继承自 Dataset。它用于封装鸢尾花数据,使其可以通过 PyTorch DataLoader 使用。
  • 在构造函数中,根据 mode(训练、验证或测试)来划分数据:
    • 训练集 (train): 使用数据集的前 num_train 个样本。
    • 验证集 (dev): 紧随训练集之后的 num_dev 个样本。
    • 测试集 (test): 剩余的样本。
  • 这种方式的好处是简单易实现,但在实际应用中可能需要更复杂的交叉验证策略来更好地评估模型。

3、DataLoader 的使用

  • 对于每种数据集(训练、验证、测试),通过创建 DataLoader 实例来进行封装。这允许以批量方式加载数据,可选择是否打乱。
  • 批量大小 (batch_size):
    • 对于训练数据,使用较大的批量(例如 16),有助于稳定和加速训练过程。
    • 对于验证数据,也采用同样大小的批量,以保持一致性。
    • 对于测试数据,每批只有一个样本,这常用于评估模型时逐个样本进行处理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/579942.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

项目进度规划全攻略,助你成为项目管理高手

项目进度管理作为项目管理中的核心环节,对于确保项目按时交付、控制成本和提高质量至关重要。本文将详细介绍项目进度管理的基本步骤,帮助项目经理和团队成员更好地理解和执行进度管理工作。 项目进度管理的基本步骤主要包括以下几个方面: …

236基于matlab的三维比例导引法仿真

基于matlab的三维比例导引法仿真,可以攻击静止/机动目标。1.三维空间内的比例导引程序,采用龙哥库塔积分法;2.文件名为bili3dnew的.m文件是主函数,执行时需调用目标机动子函数、导引律子函数、数值积分法子函数;3.文件…

微服务之并行与分布式计算

一、概述 1.1集中式系统vs分布式系统 集中式系统 集中式系统完全依赖于一台大型的中心计算机的处理能力,这台中心计算机称为主机(Host 或 mainframe ),与中心计算机相连的终端设备具有各不相同非常低的计算能力。实际上大多数终…

《ESP8266通信指南》8-连接WIFI(Arduino开发)(非常简单)

往期 《ESP8266通信指南》7-Arduino 开发8266的环境配置与示例代码烧录-CSDN博客 《ESP8266通信指南》6-创建TCP服务器(AT指令)-CSDN博客 《ESP8266通信指南》5-TCP通信透传模式(AT指令)-CSDN博客 《ESP8266通信指南》4-以Client进行TCP通信&#xf…

弹性盒之主轴侧轴对齐方式

弹性盒设置主侧轴对齐方式 1.默认 justify-content: flex-start 2.justify-content: flex-end 3.justify-content: center 4.justify-content: space-between; 两端对齐 5.justify-content: space-around; 距离环绕 调整侧轴上中下 1.默认align-items: flex-start; …

论机器学习(ML)在网络安全中的重要性

机器学习是什么? 机器学习(ML)是人工智能的一个分支,它使用算法来使计算机系统能够自动地从数据和经验中进行学习,并改进其性能,而无需进行明确的编程。机器学习涉及对大量数据的分析,通过识别数据中的模式来做出预测…

在MySQL中isnull()函数不能作为替代null值!

在MySQL中isnull()函数不能作为替代null值! 如下: 首先有个名字为business的表: SELECT ISNULL(business_name,no business_name) AS bus_isnull FROM business WHERE id2 直接运行就会报错: 错误代码: 1582 Incor…

5本On Hold,6本预警被踢,学术诚信高风险期刊被踢9本,还剩1本你还敢投吗?

本周投稿推荐 SSCI • 2/4区经管类,2.5-3.0(录用率99%) SCIE(CCF推荐) • 计算机类,2.0-3.0(最快18天录用) SCIE(CCF-C类) • IEEE旗下,1/2…

禅道项目管理系统身份认证绕过漏洞

禅道项目管理系统身份认证绕过漏洞 1.漏洞描述 禅道项目管理软件是国产的开源项目管理软件,专注研发项目管理,内置需求管理、任务管理、bug管理、缺陷管理、用例管理、计划发布等功能,完整覆盖了研发项目管理的核心流程。 禅道项目管理系统…

如何使用 Internet Download Manager (IDM) 来加速和优化你的下载体验 IDM 6.41下载神器

在当今信息爆炸的时代,下载文件和媒体内容已成为我们日常生活的一部分。无论是工作学习还是娱乐休闲,我们都需要从互联网上下载各种资源。为了提高下载效率和确保文件完整性,选择一款优秀的下载管理软件至关重要。Internet Download Manager …

PotatoPie 4.0 实验教程(21) —— FPGA实现摄像头图像二值化(RGB2Gray2Bin)

PotatoPie 4.0开发板教程目录(2024/04/21) 为什么要进行图像的二值化? 当我们处理图像时,常常需要将其转换为二值图像。这是因为在很多应用中,我们只对图像中的某些特定部分感兴趣,而不需要考虑所有像素的…

JavaScript 中的IF判断竟然可以这样写,效率更高

当然,它们是创建控制流的一种简单而方便的方式,但你可以写下数十亿行条件性的 JavaScript 代码,而不需要一个 if 语句。 而且有很多情况下,使用不同的结构会更清晰地展示你想要做的事情 —— 只要我们还在为人类编写代码&#xf…

深度学习系列65:数字人openHeygen详解

1. 主流程分析 从inference.py函数进入,主要流程包括: 1) 使用cv2获取视频中所有帧的列表,如下: 2)定义Croper。核心代码为69行:full_frames_RGB, crop, quad croper.crop(full_frames_RGB)。…

基于springboot+vue+Mysql的乐校园二手书交易管理系统

开发语言:Java框架:springbootJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:…

OPPO手机支持深度测试+免深度测试解锁BL+ROOT权限机型整理-2024年3月更新

绿厂OPPO手机线上线下卖的都很不错,目前市场份额十分巨大,用户自然也非常多,而近期ROM乐园后台受到很多关于OPPO手机的私信,咨询哪些机型支持解锁BL,ROOT刷机,今天ROM乐园正式盘点当前市场上可以解BL刷root…

Android图片压缩、Drawable和Bitmap转换、bitmap和base64转换

1. Android图片压缩、Drawable和Bitmap转换、bitmap和base64转换 1.1. Drawable和Bitmap之间的转化 1.1.1. bitmap和Drawable间的区别 Bitmap - 称作位图,一般位图的文件格式后缀为bmp,当然编码器也有很多如RGB565、RGB888。作为一种逐像素的显示对象执…

【YesPMP】众包平台,最新项目

YesPMP平台专注于软件开发领域,是专业的一站式互联网众包平台,目前平台汇聚了上万个解决方案,覆盖全国,拥有众多专业优质的H5开发服务商,专为企业提供软件H5开发解决方案,提高企业的知名度。优秀的H5能为用…

云仓酒庄北京发布会与《综合品酒师》培训的延伸层次分享

原标题:云仓酒庄北京发布会与《综合品酒师》培训近日,云仓酒庄在北京举办了一场盛大的发布会,并近期举行了首届《综合品酒师》培训活动。这一事件不仅引起了业内的广泛关注,更成为了酒类行业专业化、规范化发展的重要里程碑。大世…

[移动端] “viewport“ content=“width=device-width, initial-scale=1.0“ 什么意思

布局视口, 代码如下 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><title>Document</title><style>body,html {margin: 0;padding: 0;}.box {width: 200px;height: 200px;background-color: pi…

“无媒体,不活动”,这句话怎么理解?

传媒如春雨&#xff0c;润物细无声&#xff0c;大家好&#xff0c;我是51媒体网胡老师。 “无媒体&#xff0c;不活动”通常指的是在现代社会中&#xff0c;媒体对于各种活动&#xff0c;尤其是公共活动和事件的推广、宣传和影响力是至关重要的。它强调了媒体在塑造公众意识、…