Pytorch建立MyDataLoader过程详解

简介

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device=‘’)

详细:DataLoader

自己基于DataLoader实现各个模块

代码实现

MyDataset基于torch中的Data实现对个人数据集的载入,例如图像和标签载入
SingleSampler基于torch中的Sampler实现对于数据的batch个数图像的载入,例如,Batch_Size=4,实现对所有数据中选取4个索引作为一组,然后在MyDataset中基于__getitem__根据图像索引去进行图像操作
MyBathcSampler基于torch的BatchSampler实现自己对于batch_size数据的处理。需要基于SingleSampler实现Sampler的处理,更为灵活。MyBatchSampler的存在会自动覆盖DataLoader中的batch_size参数
注:Sampler的实现,将会与shuffer冲突,shuffer是在没有实现sampler前提下去自动判断选择的sampler类型
collate_fn是实现将batch_size的图像数据进行打包,遍历过程中就可以实现batch_size的images和labels对应
在这里插入图片描述

sampler

from typing import Iterator, List
import torch
from torch.utils.data import BatchSampler
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import Sampler


class MyDataset(Dataset):
    def __init__(self) -> None:
        self.data = torch.arange(20)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]
    
    @staticmethod
    def collate_fn(batch):
        return torch.stack(batch, 0)

class MyBatchSampler(BatchSampler):
    def __init__(self, sampler: Sampler[int], batch_size: int) -> None:
        self._sampler = sampler
        self._batch_size = batch_size
    
    def __iter__(self) -> Iterator[List[int]]:
        batch = []
        for idx in self._sampler:
            batch.append(idx)
            if len(batch) == self._batch_size:
                yield batch
                batch = []
        yield batch
    
    def __len__(self):
        return len(self._sampler) // self._batch_size

class SingleSampler(Sampler):
    def __init__(self, data_source) -> None:
        self._data = data_source
        self.num_samples = len(self._data)
        
    def __iter__(self):
        # 顺序采样
        # indices = range(len(self._data))
        # 随机采样
        indices = torch.randperm(self.num_samples).tolist()
        return iter(indices)
    
    def __len__(self):
        return self.num_samples
        

train_set = MyDataset()
single_sampler = SingleSampler(train_set)
batch_sampler = MyBatchSampler(single_sampler, 8)
train_loader = DataLoader(train_set, batch_size=4, sampler=single_sampler, pin_memory=True, collate_fn=MyDataset.collate_fn)
for data in train_loader:
    print(data)

batch_sampler

from typing import Iterator, List
import torch
from torch.utils.data import BatchSampler
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import Sampler

class MyDataset(Dataset):
    def __init__(self) -> None:
        self.data = torch.arange(20)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]
    
    @staticmethod
    def collate_fn(batch):
        return torch.stack(batch, 0)

class MyBatchSampler(BatchSampler):
    def __init__(self, sampler: Sampler[int], batch_size: int) -> None:
        self._sampler = sampler
        self._batch_size = batch_size
    
    def __iter__(self) -> Iterator[List[int]]:
        batch = []
        for idx in self._sampler:
            batch.append(idx)
            if len(batch) == self._batch_size:
                yield batch
                batch = []
        yield batch
    
    def __len__(self):
        return len(self._sampler) // self._batch_size

class SingleSampler(Sampler):
    def __init__(self, data_source) -> None:
        self._data = data_source
        self.num_samples = len(self._data)
        
    def __iter__(self):
        # 顺序采样
        # indices = range(len(self._data))
        # 随机采样
        indices = torch.randperm(self.num_samples).tolist()
        return iter(indices)
    
    def __len__(self):
        return self.num_samples
        

train_set = MyDataset()
single_sampler = SingleSampler(train_set)
batch_sampler = MyBatchSampler(single_sampler, 8)
train_loader = DataLoader(train_set, batch_sampler=batch_sampler, pin_memory=True, collate_fn=MyDataset.collate_fn)
for data in train_loader:
    print(data)

参考

Sampler:https://blog.csdn.net/lidc1004/article/details/115005612

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/82896.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

TensorFlow2.1 模型训练使用

文章目录 1、环境安装搭建2、神经网络2.1、解决线性问题2.2、FAshion MNIST数据集使用 3、卷积神经网络3.1、卷积神经网络使用3.2、ImageDataGenerator使用3.3、猫狗识别案例3.4、参数优化 3.5、石头剪刀布案例4、词条化4.1、讽刺数据集的词条化和序列化4.2、词嵌入 1、环境安装…

[保研/考研机试] KY11 二叉树遍历 清华大学复试上机题 C++实现

题目链接: 二叉树遍历_牛客题霸_牛客网编一个程序,读入用户输入的一串先序遍历字符串,根据此字符串建立一个二叉树(以指针方式存储)。题目来自【牛客题霸】https://www.nowcoder.com/share/jump/43719512169254700747…

实数信号的傅里叶级数研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

(2)、将SpringCache扩展功能封装为starter

(2)、将SpringCache扩展功能封装为starter 1、准备工作 前面我们写了一个common-cache模块,尽可能的将自定义的RedisConnectionFactory, RedisTemplate, RedisCacheManager等Bean封装了起来。 就是为了方便我们将其封装为一个Starter。 我们这里直接《SpringCache+Redis实…

微信个人号开发,实现机器人辅助社群操作

微信 iPad 协议是指用于在 iPad 设备上使用微信应用的技术协议。一般来说,通过该协议可以将微信账号同步到 iPad 设备上,并且可以在 iPad 上发送和接收微信消息,查看好友列表、聊天记录等功能。微信 iPad 协议是通过私有API实现的。 需要一定…

嵌入式:ARM Day6

作业:完成cortex-A7核UART总线实验 目的:1.输入a,显示b,将输入的字符的ASCII码下一位字符输出 2.原样输出输入的字符串 源码: uart4.h #ifndef __UART4_H__ #define __UART4_H__#include "stm32mp1xx_rcc.h" #incl…

九、Linux下,如何在命令行进入文本编辑页面?

1、文本编辑基础 说到文本编辑页面,那就必须提到vi和vim,两者都是Linux系统中,常用的文本编辑器 2、三种工作模式 3、使用方法 (1)在进入Linux系统,在输入vim text.txt之后,会进入文本编辑中&…

Python多组数据三维绘图系统

文章目录 增添和删除坐标数据更改绘图逻辑源代码 Python绘图系统: 基础:将matplotlib嵌入到tkinter 📈简单的绘图系统 📈数据导入📈三维绘图系统自定义控件:坐标设置控件📉坐标列表控件 增添和…

【力扣】84. 柱状图中最大的矩形 <模拟、双指针、单调栈>

目录 【力扣】84. 柱状图中最大的矩形题解暴力求解双指针单调栈 【力扣】84. 柱状图中最大的矩形 给定 n 个非负整数,用来表示柱状图中各个柱子的高度。每个柱子彼此相邻,且宽度为 1 。求在该柱状图中,能够勾勒出来的矩形的最大面积。 示例…

使用ChatGPT进行创意写作的缺点

Open AI警告ChatGPT的使用者要明白此工具的局限性,更不应完全依赖。作为一位创作者,这一点非常重要,应尽可能地避免让版权问题或不必要的文体问题出现在自己的作品中。[1] 毕竟使用ChatGPT进行创意写作目前还有以下种种局限或缺点[2]&#xf…

5.5.webrtc的线程管理

今天呢,我们来介绍一下线程的管理与绑定,首先我们来看一下web rtc中的线程管理类,也就是thread manager。对于这个类来说呢,其实实现非常简单,对吧? 包括了几个重要的成员,第一个成员呢就是ins…

学习笔记230804---逻辑跳转this.$router.push在写法上的优化

今天和资深前端代码写重,同时写页面带参跳转,组长觉得他写的方式比我高端一点,我觉得确实是,像资深大佬学习。 我的写法: this.$router.push(/bdesign?applicationId${this.data.id}&appName${this.data.name})…

[VS/C++]如何更好的配置DLL项目中的成品输出

注意,解决方案与项目不放在同一个文件夹中,即不选中图中选项 直入主题 首先右键项目选择属性,或者选中项目然后AltEnter 选择配置属性下的常规 分别在四种配置中编辑输出目录如下 注意,四种配置要分别配置,一个个来…

一百六十一、Kettle——Linux上安装的kettle9.2开启carte服务(亲测、附流程截图)

一、目的 在Linux上安装好kettle9.2并且连接好各个数据库后,下面开启carte服务 二、实施步骤 (一)carte服务文件路径 kettle的Linux运行的carte服务文件是carte.sh (二)修改kettle安装路径下的pwd文件夹里的服务器…

Netty+springboot开发即时通讯系统笔记(四)终

实时性 1.线程池多线程,把消息同步给其他端和对方用户,其中数据持久化往往是最浪费时间的操作,可以使用mq异步存储,因为其他业务不需要拿着整条数据,只需要这条数据的id进行操作。 2。消息校验前置,放在t…

谷歌推出首款量子弹性 FIDO2 安全密钥

谷歌在本周二宣布推出首个量子弹性 FIDO2 安全密钥,作为其 OpenSK 安全密钥计划的一部分。 Elie Bursztein和Fabian Kaczmarczyck表示:这一开源硬件优化的实现采用了一种新颖的ECC/Dilithium混合签名模式,它结合了ECC抵御标准攻击的安全性和…

机器学习与模式识别3(线性回归与逻辑回归)

一、线性回归与逻辑回归简介 线性回归主要功能是拟合数据,常用平方误差函数。 逻辑回归主要功能是区分数据,找到决策边界,常用交叉熵。 二、线性回归与逻辑回归的实现 1.线性回归 利用回归方程对一个或多个特征值和目标值之间的关系进行建模…

java版本spring cloud 企业工程系统管理 工程项目管理系统源码em

工程项目管理软件(工程项目管理系统)对建设工程项目管理组织建设、项目策划决策、规划设计、施工建设到竣工交付、总结评估、运维运营,全过程、全方位的对项目进行综合管理 工程项目各模块及其功能点清单 一、系统管理 1、数据字典&#xff…

消息中间件的选择:RabbitMQ是一个明智的选择

💗wei_shuo的个人主页 💫wei_shuo的学习社区 🌐Hello World ! MQ(Message Queue) MQ(消息队列)是一种用于在应用程序之间进行异步通信的技术;允许应用程序通过发送和接收…

Vue3 用父子组件通信实现页面页签功能

一、大概流程 二、用到的Vue3知识 1、组件通信 (1)父给子 在vue3中父组件给子组件传值用到绑定和props 因为页签的数组要放在父页面中, data(){return {tabs: []}}, 所以顶部栏需要向父页面获取页签数组 先在页签页面中定义props用来接…
最新文章