将自己的数据集加载到dataloader中

from torch.utils.data import Dataset
class YourDataset(Dataset):  # 继承Dataset类
    # 构造函数必须存在
    def __init__(self, root_dir, ann_file, transform=None):
        self.ann_file = ann_file
        self.root_dir = root_dir
        self.img_label = self.load_annotations()  # img_label是一个字典
        self.img = [os.path.join(self.root_dir, img) for img in list(self.img_label.keys())]
        self.label = [label for label in list(self.img_label.values())]
        self.transform = transform  # 数据需要做的预处理操作
 
    def __len__(self):
        return len(self.img)
 
    # 获取图像和标签交给模型,该函数必须存在
    # 不要修改参数,每次调用时会传入随机的idx
    # 一个batch的数据就是由__getitem__函数处理数据传入得到的
    def __getitem__(self, idx):
        image = Image.open(self.img[idx]).convert('RGB')  # img保存了图像的路径
        label = self.label[idx]
        if self.transform:
            image = self.transform(image)  # 对数据进行预处理操作
        label = torch.from_numpy(np.array(label))  # 转换label的数据类型,由list->numpy->tensor
        return image, label
 
    def load_annotations(self):
        data_infos = {}
        with open(self.ann_file) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for filename, gt_label in samples:
                data_infos[filename] = np.array(gt_label, dtype=np.int64)
        return data_infos

前言

使用开源算法进行小规模的训练,例如分类任务需要加载自己的数据集和类别,需要使用dataloader格式的数据集。现在通过以下的方式将自己的数据集制作为dataloader格式:

参考

参考文档:深度学习(17)--DataLoader自定义数据集制作_自定义dataloader-CSDN博客

首先感谢分享,但是文档中有较多不明白的地方,并且也没有引用,在此上做了相对的改进:

实现过程

1 从txt文件中读取图片文件名和对应label

import numpy as np
def load_annotation(ann_file): # 参数为文本文件的路径
    # 创建一个字典结构用于保存数据,key作为图像的名字,value作为图像的标签
    data_infos = {}
    with open(ann_file) as f:
        # strip()去除一些换行符等
        # split(' ')是以空格为分隔符
        # samples是一个list,格式为图像名字,图像标签
        # eg:[['image11.jpg,'0'],['image22.jpg,'1'],['image33.jpg,'3']]
        samples = [x.strip().split(' ') for x in f.readlines()]
        for filename, gt_label in samples:
            # filename是图像名字--'image11.jpg',gt_label--'0'是标签,加载到字典data_infos中去
            # value值设置为array(gt_label,dtype=int64)类型
            data_infos[filename] = np.array(gt_label, dtype=np.int64)
        # 得到的字典格式:{'image11.jpg':array(0,dtype=int64),'image22.jpg':array(1,dtype=int64)}
    return data_infos

文件格式如下:

c1.jpg 1
c2.jpg 1
c3.jpg 1
c4.jpg 1
c5.jpg 1
l1.jpg 2
l2.jpg 2
l3.jpg 2
l4.jpg 2
l5.jpg 2
l6.jpg 2
q1.jpg 3
q2.jpg 3
q3.jpg 3
q4.jpg 3
q5.jpg 3
q6.jpg 3

2 文件名和标签存入list

img_label = load_annotation('testload.txt')
image_name = list(img_label.keys())  # 取keys值
label = list(img_label.values())  # 取labels值
print(img_label.keys()) # 参看你的数据

3 整合为数据类

from torch.utils.data import Dataset
class YourDataset(Dataset):  # 继承Dataset类
    # 构造函数必须存在
    def __init__(self, root_dir, ann_file, transform=None):
        self.ann_file = ann_file
        self.root_dir = root_dir
        self.img_label = self.load_annotations()  # img_label是一个字典
        self.img = [os.path.join(self.root_dir, img) for img in list(self.img_label.keys())]
        self.label = [label for label in list(self.img_label.values())]
        self.transform = transform  # 数据需要做的预处理操作
 
    def __len__(self):
        return len(self.img)
 
    # 获取图像和标签交给模型,该函数必须存在
    # 不要修改参数,每次调用时会传入随机的idx
    # 一个batch的数据就是由__getitem__函数处理数据传入得到的
    def __getitem__(self, idx):
        image = Image.open(self.img[idx]).convert('RGB')  # img保存了图像的路径
        label = self.label[idx]
        if self.transform:
            image = self.transform(image)  # 对数据进行预处理操作
        label = torch.from_numpy(np.array(label))  # 转换label的数据类型,由list->numpy->tensor
        return image, label
 
    def load_annotations(self):
        data_infos = {}
        with open(self.ann_file) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for filename, gt_label in samples:
                data_infos[filename] = np.array(gt_label, dtype=np.int64)
        return data_infos

4 使用transform函数进行数据处理

# 创建一个字典结构的数据类型来进行图像预处理操作:key - value
import torchvision.transforms as transforms
data_transforms = {
    # 对训练集的预处理
    'train': transforms.Compose([
        transforms.Resize([256, 256]),  # 卷积神经网络处理的数据大小必须相同,通过Resize来设置
 
        # 数据增强
        transforms.RandomRotation(45),  # 随机旋转,-45到45度之间随机选
        #transforms.CenterCrop(64),  # 从中心开始裁剪,将原本96x96大小的图片数据裁剪为64x64大小的图片数据,可以获取更多的参数
        transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转 选择一个概率概率,50%的概率进行水平翻转
        transforms.RandomVerticalFlip(p=0.5),  # 随 机垂直翻转,50%的概率进行竖直翻转
 
        #transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),  # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        #transforms.RandomGrayscale(p=0.025),  # 概率转换成灰度率,3通道就是R=G=B(三颜色通道转为单一颜色通道,很少进行此处理)
 
        # 将数据转为Tensor类型
        transforms.ToTensor(),
 
        # 标准化
        transforms.Normalize([0.5, 0.5, 0.5], [0.224, 0.224, 0.225])  # 设置均值,标准差,分别对应R、G、B三个颜色通道的三个均值和标准差值,(x-μ)/σ
    ]),
 
    # 对验证集的预处理(不需要进行数据增强)
    'valid': transforms.Compose([transforms.Resize(256),
                                 transforms.CenterCrop(224),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                 # 均值和标准差数值的设置和训练集的相同(验证集的数据对我们来说是未知的,不能利用其中的数据再计算出相关的均值和标准差)
                                 ]),
}

5 进行实例化

import os
from torch.utils.data import DataLoader
# 训练集
train_dataset = YourDataset(root_dir='.../...', ann_file='xxx.txt', transform=data_transforms['train'])
# 测试集
#valid_dataset = YourDataset(root_dir=valid_dir, ann_file='./.../valid.txt', transform=data_transforms['valid'])
# 实例化DataLoader(使用封装好的DataLoader包)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
#valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=True)

6 调用使用你的数据

dataloaders = {'train': train_loader}
#dataloaders = {'train': train_loader, "valid": valid_loader}
for inputs, labels in dataloaders['train']:
    print("处理训练集")
#for inputs, labels in dataloaders['valid']:
#    print("处理验证集")

#或者使用枚举
#for i, (imgs, labels) in enumerate(dataloaders['train']):

#测试完成后你可以将以上6步写入你的程序,或封装成数据读取包

7 检查和展示你的数据

# 检查训练集
from PIL import Image
import torch
import matplotlib.pyplot as plt
image1, label1 = next(iter(train_loader))  # iter表示train_loader进行迭代,next取一个batch的数据
sample = image1[0].squeeze()  # 通过squeeze()压缩一个维度,有时候维度为1x3x64x64,去除这个1
# 此时的sample是3x64x64的结构,而需要图像展示则需要转换结构为64X64X3,同时需要转换为numpy数据结构
sample = sample.permute((1, 2, 0)).numpy()
# 标准化还原 x = (x-μ) / σ -> x = x*σ + μ (预处理中进行了标准化,需要还原)
sample *= [0.229, 0.224, 0.225]
sample += [0.485, 0.456, 0.406]
plt.imshow(sample)
plt.show()
print('Label is: {}'.format(label1[0].numpy()))
 
 
# 检查训练集
#image2, label2 = next(iter(valid_loader))  # iter表示train_loader进行迭代,next取一个batch的数据
#sample = image2[0].squeeze()  # 通过squeeze()压缩一个维度,有时候维度为1x3x64x64,去除这个1
# 此时的sample是3x64x64的结构,而需要图像展示则需要转换结构为64X64X3,同时需要转换为numpy数据结构
#sample = sample.permute((1, 2, 0)).numpy()
# 标准化还原 x = (x-μ) / σ -> x = x*σ + μ (预处理中进行了标准化,需要还原)
#sample *= [0.229, 0.224, 0.225]
#sample += [0.485, 0.456, 0.406]
#plt.imshow(sample)
#plt.show()
#print('Label is: {}'.format(label2[0].numpy()))
 
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/574001.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

展览模型一般怎么打灯vray---模大狮模型网

在展览模型的设计中,灯光的运用是至关重要的,它不仅能够增强展品的视觉效果,还可以营造出独特的氛围和情感。在利用V-Ray进行灯光设置时,有一些常用的技巧和方法可以帮助设计师实现理想的展览效果。在本文中,我们将介绍…

漏洞修复优先级考虑-不错的思路

权威说法: 漏洞利用预测评分系统 (EPSS) 是一项数据驱动的工作,用于估计软件漏洞在野外被利用的可能性(概率) https://www.first.org/epss/ GitHub - TURROKS/CVE_Prioritizer: Streamline vulnerability…

在windows上安装MySQL数据库全过程

1.首先在MySQL的官网找到其安装包 在下图中点击MySQL Community(gpl) 找到MySQL Community Server 选择版本进行安装包的下载 2.安装包(Windows (x86, 64-bit), MSI Installer)安装步骤 继续点击下一步 继续进行下一步,直到出现此界面&#…

基于小程序实现的惠农小店系统设计与开发

作者主页:Java码库 主营内容:SpringBoot、Vue、SSM、HLMT、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、小程序、安卓app等设计与开发。 收藏点赞不迷路 关注作者有好处 文末获取源码 技术选型 【后端】:Java 【框架】:spring…

leetcode-比较版本号-88

题目要求 思路 1.因为字符串比较大小不方便,并且因为需要去掉前导的0,这个0我们并不知道有几个,将字符串转换为数字刚好能避免。 2.当判断到符号位的时候加加,跳过符号位。 3.判断数字大小,来决定版本号大小 4.核心代…

探索直播+电商系统中台架构:连接消费者与商品的智能纽带

随着直播电商的崛起,电商行业进入了全新的智能时代。直播形式的互动性和即时性为消费者提供了全新的购物体验,而电商平台则为商品的展示、销售和配送提供了强大的支持。在这一背景下,直播电商系统中台架构成为了连接消费者与商品的智能纽带&a…

ABTest如何计算最小样本量-工具篇

如果是比例类指标,有一个可以快速计算最小样本量的工具: https://www.evanmiller.org/ab-testing/sample-size.html 计算样本量有4个要输入的参数:①一类错误概率,②二类错误概率 (一般是取固定取值)&…

【JavaScript】内置对象 ③ ( Math 内置对象 | Math 内置对象简介 | Math 内置对象的使用 )

文章目录 一、Math 内置对象1、Math 内置对象简介2、Math 内置对象的使用 二、代码示例1、代码示例 - Math 内置对象的使用2、代码示例 - 封装 Math 内置对象 一、Math 内置对象 1、Math 内置对象简介 JavaScript 中的 Math 内置对象 是一个 全局对象 , 该对象 提供了 常用的 数…

名家采访:国家级中国茶文化首席非遗传承人——罗大友

“崇高的理想是一个人心中的太阳,能照亮生活中的每一步。”罗大友,性别:男,国家级中国茶文化首席非遗传承人•中国茶文化研究院院长、美国巴拿马太平洋万国博览会终身评委兼中国区联合主席,大学文化,高级政工师。 “第…

【算法】删除有序数组中的重复项

本题来源---《删除有序数组中的重复项》 题目描述 给你一个 非严格递增排列 的数组 nums ,请你删除重复出现的元素,使每个元素 只出现一次 ,返回删除后数组的新长度。元素的 相对顺序 应该保持 一致 。然后返回 nums 中唯一元素的个数。 示…

ZDOCK linux 下载(无需安装)、配置、使用

ZDOCK 下载 使用 1. 下载1)教育邮箱提交申请,会收到下载密码2)选择相应的版本3)解压 2. 使用方法Step 1:将pdb文件处理为ZDOCK可接受格式Step 2:DockingStep 3:创建所有预测结构 1. 下载 1&…

【matlab】reshape函数介绍及应用

【matlab】reshape函数介绍及应用 【先赞后看养成习惯】求点赞关注收藏😀 在MATLAB中,reshape函数是一种非常重要的数组操作函数,它可以改变数组的形状而不改变其数据。本文将详细介绍reshape函数的使用方法和应用。 1. reshape函数的基本语…

个人博客系统的设计与实现

https://download.csdn.net/download/liuhaikang/89222885http://点击下载源码和论文 本 科 毕 业 设 计(论文) 题 目:个人博客系统的设计与实现 专题题目: 本 科 毕 业 设 计(论文)任 务 书 题 …

2.6设计模式——Flyweight 享元模式(结构型)

意图 运用共享技术有效地支持大量细粒度的对象。 结构 其中 Flyweight描述一个接口,通过这个接口Flyweight可以接受并作用于外部状态。ConcreteFlyweight实现Flyweight接口,并作为内部状态(如果有)增加存储空间。ConcreteFlywe…

快速入门基础控制台API

目录 一、什么是win32API 二、API基础函数介绍 2.1控制台基础命令 2.1.1标题修改 2.1.2长宽修改 2.1.3坐标 2.2GetStdHandle 2.3GetConsoleCursorInfo 2.4SetConsoleCursorInfo 2.5SetConsoleCursorPosition 2.6GetAsyncKeyState 三、API函数综合应用 3.1设置光标…

Facebook的魅力魔法:探访数字社交的奇妙世界

1. 社交媒体的演变与Facebook的角色 在数字化时代,社交媒体已经成为我们日常生活中不可或缺的一部分。而在众多的社交媒体平台中,Facebook 以其深厚的历史和广泛的影响力,成为了全球数亿用户沟通、分享和互动的主要场所。从其初创之时起&…

雅特力AT32F435学习——3.PWM实验

PWM实验 定时器浑身都是包其中PWM占大头,因为PWM应用太广了:呼吸灯、电机、蜂鸣器,生日火炬里的声音都是PWM干的,接下来就让我们学一下雅特力AT32F435单片机的PWM吧。 基础知识 老样子对于PWM的基础了解那肯定直接从数据手册学…

动手学深度学习14 数值稳定性+模型初始化和激活函数

动手学深度学习14 数值稳定性模型初始化和激活函数 1. 数值稳定性2. 模型初始化和激活函数3. QA **视频:**https://www.bilibili.com/video/BV1u64y1i75a/?spm_id_fromautoNext&vd_sourceeb04c9a33e87ceba9c9a2e5f09752ef8 **电子书:**https://zh-v…

azure云服务器学生认证优惠100刀续订永久必过方法记录

前面的话 前几天在隔壁网站搞了个美国edu邮箱,可以自定义用户名。今天就直接认证Azure,本来打算等GitHub学生包过期后用这个edu邮箱重新认证白嫖Azure的。在昨天无意中看到续期,就把原本那个Azure账号续了一年,所以这个美国edu邮…

25计算机考研院校数据分析 | 浙江大学

浙江大学(Zhejiang University),简称“浙大”,坐落于“人间天堂”杭州。前身是1897年创建的求是书院,是中国人自己最早创办的新式高等学校之一。 浙江大学由教育部直属、中央直管(副部级建制)&a…
最新文章