diffusion 之 cifar/mnist 数据集

diffusion 之 mnist 数据集

  • mnist数据集
    • ddpm/script_utils.py
    • scripts/train_mnist.py
    • 展示采样结果

代码出处:https://github.com/abarankab/DDPM

wandb的问题解决方法:

step1: 按照这个https://blog.csdn.net/weixin_43164054/article/details/124156206一步步走 step2: 修改project_name=“cifar”,然后执行python train_cifar.py 若出现报错"wandb: ERROR It appears that you do not have permission to access the requested resource.",参看这个https://blog.csdn.net/weixin_43835996/article/details/126955917

mnist数据集

ddpm/script_utils.py

line 90:img_channel=1,因为cifar图片为3通道,而mnist图片为1通道
line 101: initial_pad=2, 是因为cifar数据集的图片大小为32,为2的指数倍,降采样过程中除以2的话一直能整除;而mnist的图片大小为28,所以要padding为32,即设置initial_pad=2
line 120:cifar10 的图片大小为3232, mnist的图片大小为2828,

import argparse
import torchvision
import torch.nn.functional as F

from .unet import UNet
from .diffusion import (
    GaussianDiffusion,
    generate_linear_schedule,
    generate_cosine_schedule,
)


def cycle(dl):
    """
    https://github.com/lucidrains/denoising-diffusion-pytorch/
    """
    while True:
        for data in dl:
            yield data

def get_transform():
    class RescaleChannels(object):
        def __call__(self, sample):
            return 2 * sample - 1

    return torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        RescaleChannels(),
    ])


def str2bool(v):
    """
    https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
    """
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("boolean value expected")


def add_dict_to_argparser(parser, default_dict):
    """
    https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/script_util.py
    """
    for k, v in default_dict.items():
        v_type = type(v)
        if v is None:
            v_type = str
        elif isinstance(v, bool):
            v_type = str2bool
        parser.add_argument(f"--{k}", default=v, type=v_type)


def diffusion_defaults():
    defaults = dict(
        num_timesteps=1000,
        schedule="linear",
        loss_type="l2",
        use_labels=False,

        base_channels=128,
        channel_mults=(1, 2, 2, 2),
        num_res_blocks=2,
        time_emb_dim=128 * 4,
        norm="gn",
        dropout=0.1,
        activation="silu",
        attention_resolutions=(1,),

        ema_decay=0.9999,
        ema_update_rate=1,
    )

    return defaults


def get_diffusion_from_args(args):
    activations = {
        "relu": F.relu,
        "mish": F.mish,
        "silu": F.silu,
    }
    # base_channels=128
    model = UNet(
        img_channels=1,

        base_channels=args.base_channels,
        channel_mults=args.channel_mults,
        time_emb_dim=args.time_emb_dim,
        norm=args.norm,
        dropout=args.dropout,
        activation=activations[args.activation],
        attention_resolutions=args.attention_resolutions,

        num_classes=None if not args.use_labels else 10,
        initial_pad=2,
    )
    # line102  在cifar中为initial_pad=0,  

    if args.schedule == "cosine":
        betas = generate_cosine_schedule(args.num_timesteps)
    else:
        betas = generate_linear_schedule(
            args.num_timesteps,
            args.schedule_low * 1000 / args.num_timesteps,
            args.schedule_high * 1000 / args.num_timesteps,
        )

    # 本py文件共修改了3处:line 90 ; line 101 ;line 120.
    # model, (32, 32), 3, 10,    
    # cifar10 的图片大小为32*32,3channel, mnist的图片大小为28*28,1channel
    
    
    diffusion = GaussianDiffusion(
        model, (28, 28), 1, 10,
        betas,
        ema_decay=args.ema_decay,
        ema_update_rate=args.ema_update_rate,
        ema_start=2000,
        loss_type=args.loss_type,
    )

    return diffusion

scripts/train_mnist.py

把entity=‘treaptofun’,给去掉

import argparse
import datetime
import torch
import wandb

from torch.utils.data import DataLoader
from torchvision import datasets
from ddpm import script_utils


def main():
    args = create_argparser().parse_args()
    device = args.device

    try:
        diffusion = script_utils.get_diffusion_from_args(args).to(device)
        optimizer = torch.optim.Adam(diffusion.parameters(), lr=args.learning_rate)


        # 接着上次中断保存的参数继续训练
        if args.model_checkpoint is not None:
            diffusion.load_state_dict(torch.load(args.model_checkpoint))
        if args.optim_checkpoint is not None:
            optimizer.load_state_dict(torch.load(args.optim_checkpoint))

        if args.log_to_wandb:
            if args.project_name is None:
                raise ValueError("args.log_to_wandb set to True but args.project_name is None")

            # wandb.init(project="ddpm_cifar")

            run = wandb.init(
                project=args.project_name,
                
                config=vars(args),
                name=args.run_name,
            )
            # entity='treaptofun',

            wandb.watch(diffusion)

        batch_size = args.batch_size

        train_dataset = datasets.MNIST(
            root='../dataset/mnist/mnist_train',
            train=True,
            download=True,
            transform=script_utils.get_transform(),
        )

        test_dataset = datasets.MNIST(
            root='../dataset/mnist/mnist_test',
            train=False,
            download=True,
            transform=script_utils.get_transform(),
        )

        train_loader = script_utils.cycle(DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=2,
        ))
        test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True, num_workers=2)
        
        acc_train_loss = 0

        for iteration in range(1, args.iterations + 1):
            diffusion.train()

            x, y = next(train_loader)
            x = x.to(device)
            y = y.to(device)

            if args.use_labels:
                loss = diffusion(x, y)
            else:
                loss = diffusion(x)

            acc_train_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            diffusion.update_ema()
            
            if iteration % args.log_rate == 0:
                test_loss = 0
                with torch.no_grad():
                    diffusion.eval()
                    for x, y in test_loader:
                        x = x.to(device)
                        y = y.to(device)

                        if args.use_labels:
                            loss = diffusion(x, y)
                        else:
                            loss = diffusion(x)

                        test_loss += loss.item()
                
                if args.use_labels:
                    samples = diffusion.sample(10, device, y=torch.arange(10, device=device))
                else:
                    samples = diffusion.sample(10, device)
                
                samples = ((samples + 1) / 2).clip(0, 1).permute(0, 2, 3, 1).numpy()

                test_loss /= len(test_loader)
                acc_train_loss /= args.log_rate

                wandb.log({
                    "test_loss": test_loss,
                    "train_loss": acc_train_loss,
                    "samples": [wandb.Image(sample) for sample in samples],
                })

                acc_train_loss = 0
            
            if iteration % args.checkpoint_rate == 0:
                model_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-model.pth"
                optim_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-optim.pth"

                torch.save(diffusion.state_dict(), model_filename)
                torch.save(optimizer.state_dict(), optim_filename)
        
        if args.log_to_wandb:
            run.finish()
    except KeyboardInterrupt:
        if args.log_to_wandb:
            run.finish()
        print("Keyboard interrupt, run finished early")


def create_argparser():
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    run_name = datetime.datetime.now().strftime("ddpm-%Y-%m-%d-%H-%M")
    defaults = dict(
        learning_rate=2e-4,
        batch_size=128,
        iterations=80000,

        log_to_wandb=True,
        log_rate=1000,
        checkpoint_rate=1000,
        log_dir="./ddpm_logs_mnist",
        project_name="mnist",
        run_name=run_name,

        model_checkpoint=None,
        optim_checkpoint=None,

        schedule_low=1e-4,
        schedule_high=0.02,

        device=device,
    )
    defaults.update(script_utils.diffusion_defaults())

    parser = argparse.ArgumentParser()
    script_utils.add_dict_to_argparser(parser, defaults)
    return parser


if __name__ == "__main__":
    main()

命令行执行的训练命令:
python train.py
命令行执行的采样命令
python sample_images.py --model_path "your model path" --save_dir "your save img path" --schedule cosine

展示采样结果

import matplotlib.pyplot as plt
import numpy as np
import os

def show(num_imgs, dir_path):
    ''' 
    num_imgs: 要展示的图片的张数
    dir_path:图片的路径
    '''
    img_names=os.listdir (dir_path)
    img_names.sort(key=lambda x:int(x.split('.')[0]))

    plt.figure(figsize=(20,5)) # 画布大小
    N=2
    M=10
    #形成NxM大小的画布
    for i in range(num_imgs):#有张图片
        path = dir_path + img_names[i]
        img = plt.imread(path)
        plt.subplot(N,M,i+1)#表示第i张图片,下标只能从1开始,不能从0,
        plt.imshow(img)
        plt.title(img_names[i],color='black')
        #下面两行是消除每张图片自己单独的横纵坐标,不然每张图片会有单独的横纵坐标,影响美观
        plt.xticks([])
        plt.yticks([])
    plt.show()

print("mnist generation results:")
show(20, './scripts/save_dir_mnist/')  # 模型训练出来的保存的结果

这里的名字只是预测出来的图片的序号,并不是预测的label!
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/4110.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于kubernetes部署gitlab

目录前提下载镜像部署服务前提 已经搭建完kubernets集群并可提供服务。 下载镜像 去docker hub 下载具体版本镜像,当使用最新版本时,也建议具体制定版本号,而不是使用latest. 如 gitlab/gitlab-ce:15.10.0-ce.0 当然可以pull到本地&#x…

Linux拒绝俄罗斯开发者合入

最近在Linux社区看到这样的信息https://lore.kernel.org/all/20230314103316.313e5f61kernel.org/我们不愿意接受你们的补丁。关于上面的内容,看到有一篇这样的文章https://www.phoronix.com/news/Linux-STMAC-Russian-Sanctions由于美国对俄罗斯实施制裁&#xff0…

一次内存泄露排查

前因: 因为测试 长时间压测导致 接口反应越来越慢,甚至 导致服务器 崩溃 排查过程 1、top 查看是 哪个进程 占用 内存过高 2、根据 进程 id 去查找 具体是哪个 程序的问题 ps -ef| grep 41356 可以看到 具体的 容器位置 排查该进程 对象存活 状态…

大数据学习路线图(2023完整高清版超详细)

送福利了!超详细的大数据学习路线图来啦,2023版是首发哟!大数据学习路线图分为7个阶段,包含: 数据仓库基础-->Linux &Hadoop生态-->Hadoop-->数据仓库与ETL技术-->BI数据分析与可视化-->自研数据仓…

计算机科学与技术专业-大三-学年设计-题目

大三-学年设计题目 西南大学 计算机与信息科学学院 周竹荣 课程概述 学年设计是重要的综合性设计训练,安排在修完相关专业平台课后进行。旨在培养学生综合运用所学的基础理论和专业知识,分析、解决实际问题的能力,理论联系实际。是一次系统的…

HTML 标签和属性

一些标签 单双标签 双标签。双标签指标签是成对出现的&#xff0c;也就是有一个开始标签和一个结束标签&#xff0c;开始标签用 <标签名> 表示&#xff0c;结束标签用 </标签名> 表示&#xff0c;只有一对标签一起使用才能表示一个具体的含义。例如 <html>&…

关于线程池你了解些什么?

前言学习线程池的思维导图线程池是什么?它有什么用?虽然线程比进程更轻量级,但是每个进程所占的资源空间是有限,如果我们频繁创建和销毁线程也会消耗很多CPU资源,那么我们该如何解决这个问题呢?官方解释:线程池是一种多线程处理形式,其处理过程可以将多个任务添加到阻塞队列…

电子招标采购系统:营造全面规范安全的电子招投标环境,促进招投标市场健康可持续发展

营造全面规范安全的电子招投标环境&#xff0c;促进招投标市场健康可持续发展 传统采购模式面临的挑战 一、立项管理 1、招标立项申请 功能点&#xff1a;招标类项目立项申请入口&#xff0c;用户可以保存为草稿&#xff0c;提交。 2、非招标立项申请 功能点&#xff1a;非招标…

Google巨大漏洞让Win10、11翻车,小姐姐马赛克白打了

早年间电脑截图这项技能未被大多数人掌握时&#xff0c;许多人应该都使用过手机拍屏幕这个原始的方式。 但由于较低的画面质量极其影响其他用户的观感&#xff0c;常常受到大家的调侃。 但到了 Win10、11 &#xff0c;预装的截图工具让门槛大幅降低。 WinShiftS 就能快速打开…

专业护眼灯什么牌子好?分享最专业护眼灯品牌排行

在日常生活中&#xff0c;照明灯具为人们提供了很多便利&#xff0c;但是对于长时间在灯光下用眼的学生跟办公族来说&#xff0c;容易导致用眼疲劳&#xff0c;甚至头晕等现象&#xff0c;所以现在普遍许多家庭都必备有护眼灯&#xff0c;能对眼睛起到缓解疲劳的作用&#xff0…

Docker安装Mysql集群(主从复制)

Docker安装Mysql集群(主从复制) 配置阿里云镜像 sudo vim /etc/docker/daemon.json插入如下镜像 {"registry-mirrors": ["https://sdiz8d27.mirror.aliyuncs.com"] }重启docker sudo systemctl daemon-reloadsudo systemctl restart docker保证images有…

python

好处 相对其他编程语言 比较简洁丰富的第三方库【做爬虫、机器学习、深度学习】 numpypandasmatplotlit用处 数据分析web开发游戏开发AI【比较广泛】安装部署python环境 1.官网下载python安装包【原生部署】 官网&#xff1a;python.org2.安装anaconda 1.自带python环境2.有丰富…

Mybatis-Plus进阶使用

一、逻辑删除 曾经我们写的删除代码都是物理删除。 逻辑删除&#xff1a;删除转变为更新 update user set deleted1 where id 1 and deleted0 查找: 追加 where 条件过滤掉已删除数据,如果使用 wrapper.entity 生成的 where 条件也会自动追加该字段 查找: select id,name,dele…

JConsole使用教程

JConsole是一个Java虚拟机的监控和管理工具&#xff0c;可以监控Java应用程序的内存使用、线程和类信息等。 以下是JConsole的使用教程&#xff1a; 1.启动JConsole JConsole是一个Java自带的工具&#xff0c;可以在bin目录下找到jconsole.exe文件。双击运行该文件即可启动JC…

正则表达式-元字符

文章目录一、正则表达式-元字符总结一、正则表达式-元字符 正则表达式 - 元字符 下表包含了元字符的完整列表以及它们在正则表达式上下文中的行为&#xff1a; 字符描述\将下一个字符标记为一个特殊字符、或一个原义字符、或一个 向后引用、或一个八进制转义符。例如&#xf…

强人工智能时代,区块链还有戏吗?

最近很多人都在问我&#xff0c;ChatGPT 把 AI 又带火了&#xff0c;区块链和 Web3 被抢了风头&#xff0c;以后还有戏吗&#xff1f;还有比较了解我的朋友问&#xff0c;当年你放弃 AI 而选择区块链&#xff0c;有没有后悔&#xff1f;这里有一个小背景。2017 年初我离开 IBM …

银行数字化转型导师坚鹏:如何制定银行数字化转型年度培训规划

如何制定银行数字化转型年度培训规划 ——以推动银行数字化转型战略落地为核心&#xff0c;实现知行果合一课程背景&#xff1a; 很多银行都在开展银行数字化转型培训工作&#xff0c;目前存在以下问题急需解决&#xff1a; 缺少针对性的银行数字化转型年度培训规划 不清楚如…

Spring --- 声明式事务

一、JdbcTemplate 1.1、简介 Spring 框架对 JDBC 进行封装&#xff0c;使用 JdbcTemplate 方便实现对数据库操作 1.2、准备工作 ① 加入依赖 <dependencies><!-- 基于Maven依赖传递性&#xff0c;导入spring-context依赖即可导入当前所需所有jar包 --><depen…

openAi ChatGPT调用性能优化的一些小妙招

参考的demo:GitHub - ddiu8081/chatgpt-demo: A demo repo based on OpenAI API. 扭曲调教&#xff1a; openai提供的chat接口&#xff08;https://api.openai.com/v1/chat/completions&#xff09;由于其模型很大&#xff08;什么1750亿个参数啥的&#xff09;&#xff0c;单…

Redis之底层数据结构

一 Redis数据结构 Redis底层数据结构有三层意思&#xff1a; 从Redis本身数据存储的结构层面来看&#xff0c;Redis数据结构是一个HashMap。从使用者角度来看&#xff0c;Redis的数据结构是String&#xff0c;List&#xff0c;Hash&#xff0c;Set&#xff0c;Sorted Set。从…