如何使用pytorch的Dataset, 来定义自己的Dataset

Dataset与DataLoader的关系

在这里插入图片描述
在这里插入图片描述

  1. Dataset: 构建一个数据集,其中含有所有的数据样本
  2. DataLoader:将构建好的Dataset,通过shuffle、划分batch、多线程num_workers运行的方式,加载到可训练的迭代容器。
import torch
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    """创建自己的数据集"""
    def __init__(self):
        """初始化构建数据集所需要的参数"""
        pass

    def __getitem__(self, index):
        """来获取数据集中样本的索引"""
        pass

    def __len__(self):
        """获取数据集中的样本个数"""
        pass

# 实例化自定义的数据集
dataset = MyDataset()
# 将自定义的数据集加载到可训练的迭代容器
train_loader = DataLoader(dataset=dataset,  # 自定义的数据集
                          batch_size=32,  # 数据集中小批量的大小
                          shuffle=True,  # 是否要打乱数据集中样本的次序
                          num_workers=2)  # 是否要并行

实战1:CSV数据集(结构化数据集)

import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    """创建自己的数据集"""
    def __init__(self, filepath):
        """初始化构建数据集所需要的参数"""
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
        self.len = xy.shape[0]  # 查看数据集中样本的个数
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])
        print("数据已准备好......")

    def __getitem__(self, index):
        """为了支持下标操作, 即索引dataset[index]:来获取数据集中样本的索引"""
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        """为了使用len(dataset):获取数据集中的样本个数"""
        return self.len

file = "D:\\BaiduNetdiskDownload\\Dataset_Dataload\\diabetes1.csv"

""" 1.使用 MyDataset类 构建自己的dataset """
mydataset = MyDataset(file)
""" 2.使用 DataLoader 构建train_loader """
train_loader = DataLoader(dataset=mydataset,
                          batch_size=32,
                          shuffle=True,
                          num_workers=0)

class MyModel(torch.nn.Module):
    """定义自己的模型"""
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmooid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.sigmooid(self.linear1(x))
        x = self.sigmooid(self.linear2(x))
        x = self.sigmooid(self.linear3(x))
        return x

# 实例化模型
model = MyModel()

# 定义损失函数
criterion = torch.nn.BCELoss(size_average=True)
# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)


if __name__ == "__main__":
    for epoch in range(10):
        for i, data in enumerate(train_loader, 0):
            # 1. 准备数据
            inputs, labels = data

            # 2. 前向传播
            y_pred= model(inputs)
            loss = criterion(y_pred, labels)
            print(epoch, i, loss.item())

            # 3. 反向传播
            optimizer.zero_grad()
            loss.backward()

            # 4. 梯度更新
            optimizer.step()

在这里插入图片描述

实战2:图片数据集

├── flower_data
—├── flower_photos(解压的数据集文件夹,3670个样本)
—├── train(生成的训练集,3306个样本)
—└── val(生成的验证集,364个样本)

主函数文件main.py
import os

import torch
from torchvision import transforms

from my_dataset import MyDataSet
from utils import read_split_data, plot_data_loader_image

root = "../data/flower_data/flower_photos"  # 数据集所在根目录


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root)

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    train_data_set = MyDataSet(images_path=train_images_path,
                               images_class=train_images_label,
                               transform=data_transform["train"])

    batch_size = 8
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_data_set,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=nw,
                                               collate_fn=train_data_set.collate_fn)

    # plot_data_loader_image(train_loader)

    for epoch in range(100):
	    for step, data in enumerate(train_loader):
	        images, labels = data
	        # 然后在进行相应的训练操作即可


if __name__ == '__main__':
    main()

自定义数据集文件my_dataset.py
from PIL import Image
import torch
from torch.utils.data import Dataset


class MyDataSet(Dataset):
    """自定义数据集"""

    def __init__(self, images_path: list, images_class: list, transform=None):
        self.images_path = images_path
        self.images_class = images_class
        self.transform = transform

    def __len__(self):
        return len(self.images_path)

    def __getitem__(self, item):
        img = Image.open(self.images_path[item])
        # RGB为彩色图片,L为灰度图片
        if img.mode != 'RGB':
            raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
        label = self.images_class[item]

        if self.transform is not None:
            img = self.transform(img)

        return img, label

    @staticmethod
    def collate_fn(batch):
        # 官方实现的default_collate可以参考
        # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
        images, labels = tuple(zip(*batch))

        images = torch.stack(images, dim=0)
        labels = torch.as_tensor(labels)
        return images, labels


功能文件utils.py(训练集、验证集的划分与可视化)
import os
import json
import pickle
import random

import matplotlib.pyplot as plt


def read_split_data(root: str, val_rate: float = 0.2):
    random.seed(0)  # 保证随机结果可复现
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)  # 判断路径是否存在

    # 遍历文件夹,一个文件夹对应一个类别
    flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    # 排序,保证顺序一致
    flower_class.sort()
    # 生成类别名称以及对应的数字索引: 字典{’花名‘:0,’花名‘:1,···}
    class_indices = dict((k, v) for v, k in enumerate(flower_class))
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)  # 将花名与对应的序号分行保存
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    train_images_path = []  # 存储训练集的所有图片路径
    train_images_label = []  # 存储训练集图片对应索引信息
    val_images_path = []  # 存储验证集的所有图片路径
    val_images_label = []  # 存储验证集图片对应索引信息
    every_class_num = []  # 存储每个类别的样本总数
    supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型
    # 遍历每个文件夹下的文件
    for cla in flower_class:
        cla_path = os.path.join(root, cla)
        # 遍历获取supported支持的所有文件路径
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
                  if os.path.splitext(i)[-1] in supported]
        # 获取该类别对应的索引
        image_class = class_indices[cla]
        # 记录该类别的样本数量
        every_class_num.append(len(images))
        # 按比例随机采样验证样本
        val_path = random.sample(images, k=int(len(images) * val_rate))

        for img_path in images:
            if img_path in val_path:  # 如果该路径在采样的验证集样本中则存入验证集
                val_images_path.append(img_path)
                val_images_label.append(image_class)
            else:  # 否则存入训练集
                train_images_path.append(img_path)
                train_images_label.append(image_class)

    print("{} images were found in the dataset.".format(sum(every_class_num)))
    print("{} images for training.".format(len(train_images_path)))
    print("{} images for validation.".format(len(val_images_path)))

    plot_image = True
    if plot_image:
        # 绘制每种类别个数柱状图
        plt.bar(range(len(flower_class)), every_class_num, align='center')
        # 将横坐标0,1,2,3,4替换为相应的类别名称
        plt.xticks(range(len(flower_class)), flower_class)
        # 在柱状图上添加数值标签
        for i, v in enumerate(every_class_num):
            plt.text(x=i, y=v + 5, s=str(v), ha='center')
        # 设置x坐标
        plt.xlabel('image class')
        # 设置y坐标
        plt.ylabel('number of images')
        # 设置柱状图的标题
        plt.title('flower class distribution')
        plt.show()

    return train_images_path, train_images_label, val_images_path, val_images_label


def plot_data_loader_image(data_loader):
    batch_size = data_loader.batch_size
    plot_num = min(batch_size, 4)

    json_path = './class_indices.json'
    assert os.path.exists(json_path), json_path + " does not exist."
    json_file = open(json_path, 'r')
    class_indices = json.load(json_file)

    for data in data_loader:
        images, labels = data
        for i in range(plot_num):
            # [C, H, W] -> [H, W, C]
            img = images[i].numpy().transpose(1, 2, 0)
            # 反Normalize操作
            img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
            label = labels[i].item()
            plt.subplot(1, plot_num, i+1)
            plt.xlabel(class_indices[str(label)])
            plt.xticks([])  # 去掉x轴的刻度
            plt.yticks([])  # 去掉y轴的刻度
            plt.imshow(img.astype('uint8'))
        plt.show()


def write_pickle(list_info: list, file_name: str):
    with open(file_name, 'wb') as f:
        pickle.dump(list_info, f)


def read_pickle(file_name: str) -> list:
    with open(file_name, 'rb') as f:
        info_list = pickle.load(f)
        return info_list

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/339515.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

HYBBS 表白墙网站PHP程序源码 可封装成APP

源码介绍 PHP表白墙网站源码,可以做校园内的,也可以做校区间的,可封装成APP。告别QQ空间的表白墙吧。 安装PHP5.6以上随意 上传程序安装,然后设置账号密码,登陆后台切换模板手机PC都要换开启插件访问前台。 安装完…

IS-IS:01 ISIS基本配置

这是实验拓扑,下面是基本配置: R1: sys sysname R1 user-interface console 0 idle-timeout 0 0 int loop 0 ip add 1.1.1.1 24 int g0/0/0 ip add 192.168.12.1 24 qR2: sys sysname R2 user-interface console 0 idle-timeout 0 0 int loop 0 ip add …

Python Web 开发之 Flask 入门实践

导语:Flask 是一个轻量级的 Python Web 框架,广受开发者喜爱。本文将带领大家了解 Flask 的基本概念、搭建一个简单的 Web 项目以及如何进一步扩展功能。 一、Flask 简介 Flask 是一个基于 Werkzeug 和 Jinja2 的微型 Web 框架,它的特点是轻…

万物简单AIoT 端云一体实战案例学习 之 快速开始

学物联网,来万物简单IoT物联网!! 下图是本案的3步导学,每个步骤中实现的功能请参考图中的说明。 1、简介 物联网具有场景多且复杂、链路长且开发门槛高等特点,让很多想学习或正在学习物联网的学生或开发者有点不知所措,甚至直接就放弃了。    万物简单AIoT物联网教育…

批量修改拓展名的方法

新建一个文本文档 输入ren *(.你需要更改的拓展名)*.(更改后的拓展名) 注意:*前面要有空格, txt前面有一个 ". "如上图所示 注意:这个文件建在你需要更改拓展名的文件夹,此文件夹中的所有的txt…

深度学习记录--指数加权平均

指数加权移动平均(exponentially weighted moving averages) 如何对杂乱的数据进行拟合? 通过指数加权平均可以把数据图近似拟合成一条曲线 公式: 其中表示第t个平均数,表示第t-1个平均数,表示第t个数据,表示变化参数…

蚂蚁数科CTO王维首次公开亮相:进一步拓展数据相关技术布局

“AI与数据是相生相伴的共同体,高质量的行业数据才能使大模型在产业发挥更大价值。蚂蚁数科将进一步拓展数据相关技术的布局,以加速产业数字化迈入下一阶段。”1月19日,王维首次以蚂蚁数科CTO的身份亮相媒体沟通会。 数据是数字时代的“新石…

【优化技术专题】「性能优化系列」针对Java对象压缩及序列化技术的探索之路

针对Java对象压缩及序列化技术的探索之路 序列化和反序列化为何需要有序列化呢?Java实现序列化的方式二进制格式 指定语言层级二进制格式 跨语言层级JSON 格式化类JSON格式化:XML文件格式化 序列化的分类在速度的对比上一般有如下规律:Java…

如何优雅的实现前端国际化?

JavaScript 中每个常见问题都有许多成熟的解决方案。当然,国际化 (i18n) 也不例外,有很多成熟的 JavaScript i18n 库可供选择,下面就来分享一些热门的前端国际化库! i18next i18next 是一个用 JavaScript 编写的全面的国际化框架…

ubuntu安装vm和Linux,安装python环境,docker和部署项目(一篇从零到部署)

1、下载Ubuntu Index of /releaseshttps://old-releases.ubuntu.com/releases/ 2、下载VMware 官方正版VMware下载(16 pro):https://www.aliyundrive.com/s/wF66w8kW9ac 下载Linux系统镜像(阿里云盘不限速)&#xff…

[数据结构 - C++] 红黑树RBTree

文章目录 1、前言2、红黑树的概念3、红黑树的性质4、红黑树节点的定义5、红黑树的插入Insert6、红黑树的验证7、红黑树与AVL树的比较附录: 1、前言 我们在学习了二叉搜索树后,在它的基础上又学习了AVL树,知道了AVL树是靠平衡因子来调节左右高…

Mybatis之关联

一、一对多关联 eg:一个用户对应多个订单 建表语句 CREATE TABLE t_customer (customer_id INT NOT NULL AUTO_INCREMENT, customer_name CHAR(100), PRIMARY KEY (customer_id) ); CREATE TABLE t_order ( order_id INT NOT NULL AUTO_INCREMENT, order_name C…

git克隆/拉取报错过早的文件结束符(EOF)的原因及解决

近期使用git拉取仓库的时候,拉取了好几次都不行,总是反馈说过早的文件结束符 总是这样,当然我的报错信息并没有描述完整,因为在我检索此类问题的时候,我发现有好多种所谓的过早的文件结束符这样的报错,但是…

机器学习实验报告——EM算法

目录 一、算法介绍 1.1算法背景 1.2算法引入 1.3算法假设 1.4算法原理 1.5算法步骤 二、算法公式推导 2.1数学基础 2.2EM算法推导 三、算法实现 3.1关于EM聚类 3.2EM工具包的使用 3.3 实例测试 四、算法讨论 4.1EM算法的优缺点 4.2EM算法的应用 4.3对于EM算法…

RT Thread Stdio生成STM32L431RCT6无法启动问题

一、问题现象 使用RT thread Stdio生成STM32L431RCT6工程后,编译下载完成后系统无法启动,无法仿真debug; 二、问题原因 如果当前使用的芯片支持包版本为0.2.3,可能是这个版本问题,目前测试0.2.3存在问题&#xff0c…

【51单片机】外部中断

0、前言 参考&#xff1a;普中 51 单片机开发攻略 第16章 及17章 1、硬件 2、软件 #include <reg52.h> #include <intrins.h> #include "delayms.h"typedef unsigned char u8; typedef unsigned int u16;sbit led P2^0; sbit key3 P3^2;//外部中断…

晨控CK-FR03-EC与欧姆龙NX系列EtherCAT通讯连接说明手册

晨控CK-FR03-EC与欧姆龙NX系列EtherCAT通讯连接说明手册 晨控CK-FR03-EC是一款基于射频识别技术的高频RFID标签读卡器&#xff0c;读卡器工作频率为13.56MHZ&#xff0c;支持对I-CODE 2、I-CODE SLI等符合ISO15693国际标准协议格式标签的读取。 读卡器同时支持标准工业通讯协…

机器学习周记(第二十六周:文献阅读-DPGCN)2024.1.15~2024.1.21

目录 摘要 ABSTRACT 1 论文信息 1.1 论文标题 1.2 论文摘要 1.3 论文背景 2 论文模型 2.1 问题描述 2.2 论文模型 2.2.1 时间感知离散图结构估计&#xff08;Time-aware Discrete Graph Structure Estimation Module&#xff0c;TADG Module&#xff09; 2.2.2 时间…

HNU-数据挖掘-实验2-数据降维与可视化

数据挖掘课程实验实验2 数据降维与可视化 计科210X 甘晴void 202108010XXX 文章目录 数据挖掘课程实验<br>实验2 数据降维与可视化实验背景实验目标实验数据集说明实验参考步骤实验过程1.对数据进行初步降维2.使用无监督数据降维方法&#xff0c;比如PCA&#xff0c;I…

网络要素服务(WFS)详解

文章目录 1. 概述2. GetCapabilities3. DescribeFeatureType4. GetFeature4.1 Get访问方式4.2 Post访问方式 5. Transaction5.1 Insert5.2 Replace5.3 Update5.4 Delete 6 注意事项 1. 概述 前置文章&#xff1a; 地图服务器GeoServer的安装与配置 GeoServer发布地图服务&#…
最新文章