ccc-pytorch-宝可梦自定义数据集实战-加载数据部分(9)

文章目录

      • 第一步:构建路径与种类的映射关系
      • 第二步:载入所有的宝可梦图像
      • 第三步:打散顺序并通过路径名提取映射关系构建映射文件
      • 第四步:完善选取、获取图片信息功能并可视化
      • 第五步:对数据进行预处理
      • 第六步:批量读取图片

文件/数据结构:
image-20230313185704642
在这里插入图片描述

第一步:构建路径与种类的映射关系

import os
from torch.utils.data import Dataset

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize
        self.name2label={}
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)

    def __len__(self):
        pass
    def __getitem__(self, idx):
        pass

def main():
    db =Pokeman('D:\pythonProject\pythonProject39\pokeman',224,'train')

if __name__ == '__main__':
    main()

image-20230311204936302

第二步:载入所有的宝可梦图像

import os,glob

from torch.utils.data import Dataset

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize
        self.name2label={}
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)
        self.load_csv('images.csv')
    def load_csv(self,filename):
        images = []
        for name in self.name2label.keys():
            images +=glob.glob(os.path.join(self.root,name,'*.png'))
            images += glob.glob(os.path.join(self.root, name, '*.jpg'))
            images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
        #1167,'D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png'
        print(len(images),images)

    def __len__(self):
        pass
    def __getitem__(self, idx):
        pass

def main():

    db =Pokeman('D:\pythonProject\pythonProject39\pokeman',224,'train')

if __name__ == '__main__':
    main()

image-20230311210101708

第三步:打散顺序并通过路径名提取映射关系构建映射文件

import csv
import os,glob
import random

from torch.utils.data import Dataset

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize
        self.name2label={}
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)
        self.images,self.labels = self.load_csv('images.csv')

    def load_csv(self,filename):
        if not os.path.exists(os.path.join(self.root,filename)):

            images = []
            for name in self.name2label.keys():
                images +=glob.glob(os.path.join(self.root,name,'*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            #1167,'D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png'
            print(len(images),images)

            random.shuffle(images)
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                writer = csv.writer(f)
                for img in images :
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([img,label])
                print('writen into csv file',filename)

            images,labels = [],[]
            with open(os.path.join(self.root,filename)) as f:
                reader = csv.reader(f)
                for row in reader:
                    img , label = row
                    label = int (label)
                    images.append(img)
                    labels.append(label)
            assert  len(images) == len(labels)
            return images,labels

    def __len__(self):
        pass
    def __getitem__(self, idx):
        pass

def main():

    db =Pokeman('D:\pythonProject\pythonProject39\pokeman',224,'train')

if __name__ == '__main__':
    main()

image-20230311211317832
image-20230311211306693

第四步:完善选取、获取图片信息功能并可视化

import csv
import os,glob
import random
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize

        self.name2label={}
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)
        self.images,self.labels = self.load_csv('images.csv')

        if model == 'train':
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif model == 'val':
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.images))]
        else :
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.images)):]

    def load_csv(self,filename):
        if not os.path.exists(os.path.join(self.root,filename)):

            images = []
            for name in self.name2label.keys():
                images +=glob.glob(os.path.join(self.root,name,'*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            #1167,'D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png'
            print(len(images),images)

            random.shuffle(images)
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                writer = csv.writer(f)
                for img in images :
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([img,label])
                print('writen into csv file',filename)

        images,labels = [],[]
        with open(os.path.join(self.root,filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img , label = row
                label = int (label)
                images.append(img)
                labels.append(label)
        assert  len(images) == len(labels)
        return images,labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        #img:D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png
        img , label = self.images[idx],self.labels[idx]
        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'),
            transforms.Resize((self.resize,self.resize)),
            transforms.ToTensor()
        ])

        img = tf(img)
        label = torch.tensor(label)

        return img,label

def main():
    import visdom
    viz = visdom.Visdom()

    db =Pokeman('D:\pythonProject\pythonProject39\pokeman',224,'train')
    # 得到迭代器第一个样本
    x,y = next(iter(db))
    print('sample:',x.shape,y.shape)
    viz.images(x,win='sample_x',opts=dict(title='sample_x'))

if __name__ == '__main__':
    main()

在这里插入图片描述

image-20230312210040000

第五步:对数据进行预处理

import csv
import os,glob
import random
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize

        self.name2label={}
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)
        self.images,self.labels = self.load_csv('images.csv')

        if model == 'train':
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif model == 'val':
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.images))]
        else :
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.images)):]

    def load_csv(self,filename):
        if not os.path.exists(os.path.join(self.root,filename)):

            images = []
            for name in self.name2label.keys():
                images +=glob.glob(os.path.join(self.root,name,'*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            #1167,'D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png'
            print(len(images),images)

            random.shuffle(images)
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                writer = csv.writer(f)
                for img in images :
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([img,label])
                print('writen into csv file',filename)

        images,labels = [],[]
        with open(os.path.join(self.root,filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img , label = row
                label = int (label)
                images.append(img)
                labels.append(label)
        assert  len(images) == len(labels)
        return images,labels

    def __len__(self):
        return len(self.images)

    def denormalize(self,x_hat):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # x_hat = (x-mean)/std
        # x = x_hat*std + mean
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        x = x_hat * std + mean
        return x

    def __getitem__(self, idx):
        #img:D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png
        img , label = self.images[idx],self.labels[idx]
        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'),
            transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),#大小放缩
            transforms.RandomRotation(15),#随机旋转
            transforms.CenterCrop(self.resize),#中心裁剪
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],#通用数据
                                 std=[0.229, 0.224, 0.225])
        ])

        img = tf(img)
        label = torch.tensor(label)

        return img,label

def main():
    import visdom
    import time
    viz = visdom.Visdom()

    db =Pokeman('D:\pythonProject\pythonProject39\pokeman',224,'train')
    # 得到迭代器第一个样本
    x,y = next(iter(db))
    print('sample:',x.shape,y.shape)
    viz.images(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))

if __name__ == '__main__':
    main()

image-20230312211459931
如果没有denormalize生成图片如下:
image-20230312210836559

第六步:批量读取图片

import csv
import os,glob
import random
import torch
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image

class Pokeman(Dataset):
    def __init__(self,root,resize,model):
        super(Pokeman,self).__init__()
        self.root=root
        self.resize=resize

        self.name2label={}
        print(root)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        print(self.name2label)
        self.images,self.labels = self.load_csv('images.csv')

        if model == 'train':
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif model == 'val':
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.images))]
        else :
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.images)):]

    def load_csv(self,filename):
        if not os.path.exists(os.path.join(self.root,filename)):

            images = []
            for name in self.name2label.keys():
                images +=glob.glob(os.path.join(self.root,name,'*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            #1167,'D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png'
            print(len(images),images)

            random.shuffle(images)
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                writer = csv.writer(f)
                for img in images :
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([img,label])
                print('writen into csv file',filename)

        images,labels = [],[]
        with open(os.path.join(self.root,filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img , label = row
                label = int (label)
                images.append(img)
                labels.append(label)
        assert  len(images) == len(labels)
        return images,labels

    def __len__(self):
        return len(self.images)

    def denormalize(self,x_hat):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # x_hat = (x-mean)/std
        # x = x_hat*std + mean
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        x = x_hat * std + mean
        return x

    def __getitem__(self, idx):
        #img:D:\\pythonProject\\pythonProject39\\pokeman\\bulbasaur\\00000000.png
        img , label = self.images[idx],self.labels[idx]
        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'),
            transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),#大小放缩
            transforms.RandomRotation(15),#随机旋转
            transforms.CenterCrop(self.resize),#中心裁剪
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],#通用数据
                                 std=[0.229, 0.224, 0.225])
        ])

        img = tf(img)
        label = torch.tensor(label)

        return img,label

def main():
    import visdom
    import time
    viz = visdom.Visdom()

    db =Pokeman('D:\pythonProject\pythonProject39\pokeman',64,'train')
    # 得到迭代器第一个样本
    x,y = next(iter(db))
    print('sample:',x.shape,y.shape)
    viz.images(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))
    loader = DataLoader(db,batch_size=32,shuffle=True,num_workers=8)

    for x ,y in loader:
        viz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch'))
        viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))
        time.sleep(10)


if __name__ == '__main__':
    main()

image-20230312212317555
对于分类分类有序的结构可以更简单的调用API

tf = transforms.Compose([
                transforms.Resize((64,64)),
                transforms.ToTensor(),
])
db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
loader = DataLoader(db, batch_size=32, shuffle=True)

print(db.class_to_idx)

for x,y in loader:
    viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
    viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))

    time.sleep(10)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/7977.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【李宏毅】深度学习——HW4-Speaker Identification

Speaker Identification 1.Goal 根据给定的语音内容,识别出说话者是谁 2.Data formats 2.1data directory 目录下有三个json文件和很多pt文件,三个json文件作用标注在下图中,pt文件就是语音内容。 mapping文件 metadata文件 n_mels:Th…

飞桨EasyDL落地三大工业场景,工业AI赋能产业升级

数智化时代,如何利用人工智能实现传统生产方式的转型升级,成为摆在每个工业制造企业的一道必答题。工业生产、质检、管理等环节,持续产生海量数据。以机器视觉为代表的AI技术,广泛应用在3C电子、快消品制造、汽车零部件制造等多个…

指令系统和寻址方式

文章目录指令系统指令的基本格式扩展码指令格式指令的操作类型指令的寻址方式指令寻址数据寻址隐含寻址立即寻址直接寻址间接寻址寄存器寻址寄存器间接寻址相对寻址基址寻址变址寻址堆栈寻址使用场景PSW小结程序的机器级代码表示CISC和RISC刷题小结指令系统 指令:计…

Revit插件 | 建模助手2023年度版本大更新,就是这么懂你

​大家好,本期是懂你的建模助手。 从去年开始,建模助手几乎每个月都会有大大小小的活动,目的是让大家用最低的成本尝试极棒的建模体验!强行挽尊ing 但作为一支很pro的团队,单一地搞活动肯定不行滴,还得在…

Python SMTP发送邮件和线程

文章目录一、Python SMTP发送邮件二、Python3 多线程总结一、Python SMTP发送邮件 SMTP(Simple Mail Transfer Protocol)即简单邮件传输协议,它是一组用于由源地址到目的地址传送邮件的规则,由它来控制信件的中转方式。 python的smtplib提供…

使用ChatGPT帮助我们编码的10种场景

文章目录1、技术搜索2、生成常用工具函数3、帮助解读代码4、添加注释5、优化代码6、Vue2 转 Vue37、Vue 转 React8、补充 TypeScript 类型9、生成文档10、工具配置总结ChatGPT 的出现,彻底改变的很多代码开发的方式,特别是通用型的代码,使用它…

基于公私密钥的单点登录

目前已知的单点登陆方式有: 多个系统集群 建立一个SSO认证中心,用户只需要登录一次就可以访问所有相互信任的应用系统。 1、可以通过session广播机制实现:在一个集群中的一个模块登录后,然后把这个session复制n份,发…

JUC-01 线程的创建和状态转换

本次我们主要讲三个问题 线程是什么?线程有哪些状态?各状态间的转换了解吗?创建线程的3种方法你都了解吗? 1. 线程是什么?(了解即可) 进程: 进程是一个具有一定独立功能的程序在一…

计算机网络考试复习——第一章 1.7

1.7 计算机网络体系结构 两台计算机要互相传送文件需解决很多问题: (1) 必须有一条传送数据的通路。 (2) 发起方必须激活通路。 (3) 要告诉网络如何识别接收方。 (4) 发起方要清楚对方是否已开机,且与网络连接正常。 (5) 发起方要清楚对方是否准备好接收和存储文件…

数据结构——栈与队列相关题目

数据结构——栈与队列相关题目232. 用栈实现队列思路225. 用队列实现栈1.两个队列实现栈2.一个队列实现栈20. 有效的括号思路1047. 删除字符串中的所有相邻重复项思路155. 最小栈150. 逆波兰表达式求值思路239. 滑动窗口最大值单调队列347. 前 K 个高频元素思路232. 用栈实现队…

2023版Postman接口测试使用全指南(原来使用 Postman测试API接口如此简单)

下面是一篇详细介绍postman接口测试的文章,如果文章内容不太明白的话, 我建议看看视频版本,更加清洗,更加直观! 最详细的postman接口测试实战教程_哔哩哔哩_bilibili最详细的postman接口测试实战教程共计129条视频&am…

ToBeWritten之ARM汇编基础铺垫

也许每个人出生的时候都以为这世界都是为他一个人而存在的,当他发现自己错的时候,他便开始长大 少走了弯路,也就错过了风景,无论如何,感谢经历 转移发布平台通知:将不再在CSDN博客发布新文章,敬…

FPGA解码4line MIPI视频 IMX291/IMX290摄像头采集 提供工程源码和技术支持

目录1、前言2、Xilinx官方主推的MIPI解码方案3、我已有的MIPI解码方案4、纯Vhdl代码解码MIPI5、vivado工程介绍6、上板调试验证7、福利:工程代码的获取1、前言 FPGA图像采集领域目前协议最复杂、技术难度最高的应该就是MIPI协议了,MIPI解码难度之高&…

关键词采集软件在SEO优化中的应用与效果

搜索引擎的优化被广泛认为是提高网站排名和在线可见性的重要方法之一。SEO人员需要进行大量的工作以确保网站的内容和标签可以被搜索引擎正确地解析和索引。在这项任务中,使用搜索引擎关键词采集软件可以帮助SEO人员完成许多繁琐的任务并简化他们的工作流程。在本文…

Linux 基础IO(Input与output)学习

进程间通信:讲的是操作系统为用户提供的几种进程间的通信方式概念:进程间通信其实就是多个进程之间进行数据交互问题:进程间通信为什么不能直接进行数据交互,需要使用系统提供的方式?原因:进程之间是具有独…

电动力学问题中的Matlab可视化

电磁场的经典描述 小说一则 电磁场的经典描述就是没有啥玩意量子力学的经典电动力学下对电磁场的描述,以后有空写个科幻小说,写啥呢,就写有天张三遇见了一个外星人,外星人来自这样一个星球,星球上的物质密度特别低,导致外星人的测量会明显的影响物质的运动,外星人不能同时得到…

JNI 调用

简介 JNI是Java Native Interface的缩写,通过使用 Java本地接口书写程序,可以确保代码在不同的平台上方便移植。从Java1.1开始,JNI标准成为java平台的一部分,它允许Java代码和其他语言写的代码进行交互。 本地代码与 Java 虚拟机…

【ChatGPT】ChatGPT-5 强到什么地步?

Yan-英杰的主页 悟已往之不谏 知来者之可追 C程序员,2024届电子信息研究生 目录 ChatGPT-5 强到什么地步? 技术 深度学习模型的升级 更好的预测能力 自适应学习能力 特点 语言理解能力更强 自我修正和优化 更广泛的应用领域 应用 对话系统 智能写作…

【机器学习】吴恩达机器学习Deeplearning.ai

机器学习已经强大到可以独立成为人工智能的一个子领域。 可以通过对机器编程实现比如执行网络搜索、理解人类语言、通过x光诊断疾病,或制造自动驾驶汽车。 机器学习定义 一般来说,给一个算法学习的机会越多,它的表现就越好。 机器学习的两种…