如何创建一个pytorch的训练数据加载器(train_loader)用于批量加载训练数据

Talk is cheap,show me the code!  哈哈,先上几段常用的代码,以语义分割的DRIVE数据集加载为例:

DRIVE数据集的目录结构如下,下载链接DRIVE,如果官网下不了,到Kaggle官网可以下到:

1. 定义DriveDataset类,每行代码都加了注释,其中collate_fn()看不懂没关系:

import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset


class DriveDataset(Dataset):  # 继承Dataset类
    def __init__(self, root: str, train: bool, transforms=None):
        super(DriveDataset, self).__init__()
        self.flag = "training" if train else "test"  # 根据train这个布尔类型确定需要处理的是训练集还是测试集
        data_root = os.path.join(root, "DRIVE", self.flag)  # 得到数据集根目录
        assert os.path.exists(data_root), f"path '{data_root}' does not exists."   # 判断路径是否存在
        self.transforms = transforms   # 初始化图像变换操作
        img_names = [i for i in os.listdir(os.path.join(data_root, "images")) if i.endswith(".tif")]   # 遍历图像文件夹获取每个图像的文件名
        self.img_list = [os.path.join(data_root, "images", i) for i in img_names]  # 获取图像路径
        self.manual = [os.path.join(data_root, "1st_manual", i.split("_")[0] + "_manual1.gif")  # 获取手动标签的路径
                       for i in img_names]
        # 检查手动标签文件是否存在
        for i in self.manual:
            if os.path.exists(i) is False:
                raise FileNotFoundError(f"file {i} does not exists.")
        # 获取分割的ROI区域掩码
        self.roi_mask = [os.path.join(data_root, "mask", i.split("_")[0] + f"_{self.flag}_mask.gif")
                         for i in img_names]
        # check files
        for i in self.roi_mask:
            if os.path.exists(i) is False:
                raise FileNotFoundError(f"file {i} does not exists.")

    def __getitem__(self, idx):
        img = Image.open(self.img_list[idx]).convert('RGB')  # 加载图像,并转换为RGB模式
        manual = Image.open(self.manual[idx]).convert('L')   # 加载手动标注图像,并转换为灰度模式
        manual = np.array(manual) / 255   # 进行归一化操作
        roi_mask = Image.open(self.roi_mask[idx]).convert('L')  # 加载ROI图像,并转换为灰度模式
        roi_mask = 255 - np.array(roi_mask)   # 对图像数组取反,使用这个方法将背景和前景颜色反转,白色是255,黑色是0,反转后ROI变成了内黑外白
        mask = np.clip(manual + roi_mask, a_min=0, a_max=255)   # 将手动标注图像和反转后的ROI图像相加,使用np.clip()将像素值控制在0-255范围,

        # 这里转回PIL的原因是,transforms中是对PIL数据进行处理
        mask = Image.fromarray(mask)

        if self.transforms is not None:
            img, mask = self.transforms(img, mask)

        return img, mask

    # 获取图像数据集长度
    def __len__(self):
        return len(self.img_list)

    # 用于将批量的图像和标签数据合并为一个批张量。
    @staticmethod
    def collate_fn(batch):
        images, targets = list(zip(*batch))  # 将批量数据拆分为图像和标签两个列表
        batched_imgs = cat_list(images, fill_value=0)  # 使用 cat_list() 函数将图像和标签列表合并成张量。用于将列表中的 PIL 图像数据堆叠成张量,
        batched_targets = cat_list(targets, fill_value=255)
        return batched_imgs, batched_targets


def cat_list(images, fill_value=0):
    max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))  # 找到图像中最大的形状,以元组形式返回给max_size
    batch_shape = (len(images),) + max_size   # 计算出堆叠后的张量形状,包括批量大小和图像大小两个维度
    batched_imgs = images[0].new(*batch_shape).fill_(fill_value)  # 创建一个新的空白张量 batched_imgs,其形状与 batch_shape 相同,并将其填充为指定的填充值 fill_value
    for img, pad_img in zip(images, batched_imgs):  # 使用 zip() 函数将输入列表中的每个图像与其对应的空白张量进行拼接,以得到一个完整的张量。
        pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)  # 将每个图像按照其实际大小插入到空白张量的左上角,以保持图像的相对位置不变。
    return batched_imgs

2. 构建训练集和验证集对象。调用上述自定义的DriveDataset数据集类,通过传入不同的参数来区分训练集和验证集。arg.data_path表示数据集所在的路径,transforms参数则是表示对数据进行预处理的操作,包括图像增强和归一化等,mean和std是规范化处理时用到的均值和标准差。get_transform这个类下面会介绍。train=True表示构建训练集对象,train=False表示构建验证集对象。

train_dataset = DriveDataset(args.data_path,
                                 train=True,
                                 transforms=get_transform(train=True, mean=mean, std=std))
val_dataset = DriveDataset(args.data_path,
                               train=False,
                               transforms=get_transform(train=False, mean=mean, std=std))

3.这是定义图像预处理方式,包括训练集和测试集的图像和标签的预处理方式,每行代码的具体作用注释有介绍。

# 定义训练集图像的预处理方式
class SegmentationPresetTrain:
    def __init__(self, base_size, crop_size, hflip_prob=0.5, vflip_prob=0.5,
                 mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        min_size = int(0.5 * base_size)
        max_size = int(1.2 * base_size)

        trans = [T.RandomResize(min_size, max_size)]   # 对图像的短边(长和宽中最短的)进行随机缩放以适应不同图像输入尺寸,缩放范围为【min_size, max_size】
        if hflip_prob > 0:
            trans.append(T.RandomHorizontalFlip(hflip_prob))  # 加入水平翻转
        if vflip_prob > 0:
            trans.append(T.RandomVerticalFlip(vflip_prob))  # 加入垂直翻转
        trans.extend([
            T.RandomCrop(crop_size),   # 对图像进行随机裁剪
            T.ToTensor(),  # 将数组矩阵转换为tensor类型,规范化到【0,1】范围
            T.Normalize(mean=mean, std=std),  # 加入图像归一化,并定义均值和标准差,RGB三通道的
        ])
        # trans是一个列表类型,包含各种了变换,将这些变换组成一个compose变换,注意transforms.Compose()函数需要接收一个列表类型
        self.transforms = T.Compose(trans)

    # 使用__call__()函数来调用transforms变换
    def __call__(self, img, target):
        return self.transforms(img, target)  # target是指标签图像,img是指待分割图像


# 定义验证集的图像预处理组合类,比较简单,只有张量化和规范化两个操作,这里规范化使用的是ImageNet推荐的参数,注意这种做法是针对彩色图像
class SegmentationPresetEval:
    def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        self.transforms = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])

    def __call__(self, img, target):
        return self.transforms(img, target)

# 定义一个函数根据数据集的类型来调用对应的数据集处理类
def get_transform(train, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    base_size = 565
    crop_size = 480
    # 检查train是否为True
    if train:
        return SegmentationPresetTrain(base_size, crop_size, mean=mean, std=std)
    else:
        return SegmentationPresetEval(mean=mean, std=std)

4. 定义训练时所使用的线程数目,这里如果时windows系统训练出错的话,建议把num_workers直接设置为0就可以解决。定义训练数据加载器train_loader,用于批量加载训练数据。在代码中,使用torch.utils.data.DataLoader类来创建数据加载器,构造函数的参数包括:

  • train_dataset:训练数据集,应该是一个符合 PyTorch Dataset 接口的对象。
  • batch_size:每个批次的样本数量。
  • num_workers:用于数据加载的线程数。
  • shuffle:是否在每个时期(epoch)重新洗牌数据。一般只在训练集用
  • pin_memory:是否将数据加载到固定的内存区域,可以加速数据传输。
  • collate_fn:用于将样本列表转换为批次张量的函数。

num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])    # 如果batch_size>1, 线程数num_workers取min(cpu核数,batch_size)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               shuffle=True,
                                               pin_memory=True,
                                               collate_fn=train_dataset.collate_fn)

val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             num_workers=num_workers,
                                             pin_memory=True,
                                             collate_fn=val_dataset.collate_fn)

至此数据集加载器创建完成!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/319221.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Kibana:使用反向地理编码绘制自定义区域地图

Elastic 地图(Maps)附带预定义区域,可让你通过指标快速可视化区域。 地图还提供了绘制你自己的区域地图的功能。 你可以使用任何您想要的区域数据,只要你的源数据包含相应区域的标识符即可。 但是,当源数据不包含区域…

Spring集成

目录 概述1 声朋一个简单的集成流1.1 使用XML定义集成流1.2 使用Java配置集成流1.3 使用Spring lntegration 的 DSL 配置 2 Spring integration 功能概览2.1 消息通道2.2 过滤器2.3 转换器2.4 路由器2.5 切分器2.6 服务激活器2.7 网关2.8 通道适配器2.9 端点模块 概述 就像我们…

RibbonGroup 添加QLineEdit

RibbonGroup添加QLineEdit: QLineEdit* controlEdit new QLineEdit(); controlEdit->setToolTip(tr("Edit")); controlEdit->setText(tr("Edit")); controlEdit->setMinimumWidth(150); …

vue知识-04

计算属性computed 注意: 1、计算属性是基于它们的依赖进行缓存的 2、计算属性只有在它的相关依赖发生改变时才会重新求值 3、计算属性就像Python中的property,可以把方法/函数伪装成属性 4、computed: { ... } 5、计算属性必须要有…

Windows python下载

1、下载 进入到地址:https://www.python.org/dowmloads 可以直接下载最新的版本 或者点击Windows,查看下方更多的版本 点击文档就下载下来啦 2、安装 双击下载的文件,勾选Add python.exe to PATH添加到环境变量中,选择Coutomi…

【数据结构】树和二叉树堆(基本概念介绍)

🌈个人主页:秦jh__https://blog.csdn.net/qinjh_?spm1010.2135.3001.5343🔥 系列专栏:《数据结构》https://blog.csdn.net/qinjh_/category_12536791.html?spm1001.2014.3001.5482 ​​ 目录 前言 树的概念 树的常见名词 树与…

Linux工具-搭建文件服务器

当我们使用linux系统作为开发环境时,经常需要在Linux系统之间、Linux和Windows之间传输文件。 对少量文件进行传输时,可以使用scp工具在两台主机之间实现文件传输: rootubuntu:~$ ssh --help unknown option -- - usage: ssh [-46AaCfGgKkMN…

【面试突击】Java面试底层逻辑(HashMap、ConcurrentHashMap面试实战)

🌈🌈🌈🌈🌈🌈🌈🌈 欢迎关注公众号(通过文章导读关注:【11来了】),及时收到 AI 前沿项目工具及新技术 的推送 发送 资料 可领取 深入理…

HashMap集合万字源码详解(面试常考)

文章目录 HashMap集合1.散列2.hashMap结构3.继承关系4.成员变量5.构造方法6.成员方法6.1增加方法6.2将链表转换为红黑树的treeifyBin方法6.3扩容方法_resize6.3.1扩容机制6.3.2源码resize方法的解读 6.4 删除方法(remove)6.5查找元素方法(get)6.6遍历HashMap集合几种方式 7.初始…

12.2内核空间基于SPI总线的OLED驱动

在内核空间编写SPI设备驱动的要点 在SPI总线控制器的设备树节点下增加SPI设备的设备树节点,节点中必须包含 reg 属性、 compatible 属性、 spi-max-frequency 属性, reg 属性用于描述片选索引, compatible属性用于设备和驱动的匹配&#xff…

pyinstaller,一个超酷的 Python 库!

更多Python学习内容:ipengtao.com 大家好,今天为大家分享一个超酷的 Python 库 - pyinstaller。 Github地址:https://github.com/pyinstaller/pyinstaller PyInstaller是一个用于将Python应用程序打包成独立可执行文件的工具。这可以将Python…

Leading Dimension是什么

在LAPACK中频繁出现Leading Dimension(中文翻译为“主维度”),那么它是什么呢? 首先了解行主序(Row-Major)和列主序(Column-Major)的概念: Given a matrix A of shape …

宝塔nginx部署前端页面刷新报404

问题: 当我们使用脚手架打包前端项目的时候,如果前端项目并没有静态化的配置,如以下 当我们刷新页面,或进行路由配置访问的时候就会报404的错误 原因: 这是因为通常我们做的vue项目属于单页面开发。所以只有index.html…

Cocos 使用VsCode调试-跨域问题

解决方案: 在添加完debug配置后 在项目文件夹中打开vscode然后找到launch.json 这个runtimeArgs参数在原本的配置中是没有的,给他添加上去 "runtimeArgs": ["--disable-web-security" ] 原理: 禁用浏览器跨域检查(仅用于调试&…

格密码基础:SIS问题的定义与理解

目录 一. 介绍 二. SIS问题定义 2.1 直观理解 2.2 数学定义 2.3 基本性质 三. SIS与q-ary格 四. SIS问题的推广 五. Hermite标准型 六. 小结 一. 介绍 short interger solution problem短整数解问题,简称SIS问题。 1996年,Ajtai首次提出SIS问…

物联网介绍

阅读引言: 本文从多方面叙述物联网的定义以及在物联网当中的各种通信的介绍。 一、物联网的定义 1.1 通用的定义 物联网(Internet of Things,IOT;也称为Web of Things)是指通过各种信息传感设 备,如传感器、…

基于51单片机的智能热水器设计

需要全部文件请私信关注我!!! 基于51单片机的智能热水器设计 摘要一、绪论1.1 选题背景及意义1.2 完成目标与功能设计 二、硬件系统设计2.1 硬件完成要求2.2 方案选择2.3 电源电路设计2.4 键盘电路2.5 蜂鸣器报警电路2.6 温度检测电路2.7 红…

Web前端-移动web开发——flex布局

移动web开发——flex布局 1.0传统布局和flex布局对比 1.1传统布局 兼容性好布局繁琐局限性,不能再移动端很好的布局 1.2 flex布局 操作方便,布局极其简单,移动端使用比较广泛pc端浏览器支持情况比较差IE11或更低版本不支持flex或仅支持部…

AWS-SAA-C03认证——之基础知识扫盲

文章目录 前言一、Amazon ECS二、 Amazon EKS三、Amazon EC2四、Elastic Beanstalk五、AWS Fargate六、 AWS Lambda (serverless)七、Amazon EBS7.1 EBS生命周期 八、Amazon Elastic File System (EFS) -共享文件系统九、什么是 Amazon S3?9.…