【Diffusion实战】训练一个diffusion模型生成蝴蝶图像(Pytorch代码详解)

  上一篇Diffusion实战是确确实实一步一步走的公式,这回采用一个更方便的库:diffusers,来实现Diffusion模型训练。


Diffusion实战篇:
  【Diffusion实战】训练一个diffusion模型生成S曲线(Pytorch代码详解)
Diffusion综述篇:
  【Diffusion综述】医学图像分析中的扩散模型(一)
  【Diffusion综述】医学图像分析中的扩散模型(二)


0、所需安装

pip install diffusers  # diffusers库
pip install datasets  

1、数据集下载

  下载地址:蝴蝶数据集
  下载好后的文件夹中包括以下文件,放在当前目录下就可以了。

在这里插入图片描述
加载数据集,并对一批数据进行可视化:

import torch
import torchvision
from datasets import load_dataset
from torchvision import transforms
import numpy as np
import torch.nn.functional as F
from matplotlib import pyplot as plt
from PIL import Image

def show_images(x):
    """Given a batch of images x, make a grid and convert to PIL"""
    x = x * 0.5 + 0.5  # Map from (-1, 1) back to (0, 1)
    grid = torchvision.utils.make_grid(x)
    grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
    grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
    return grid_im

def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# 数据加载
dataset = load_dataset("./smithsonian_butterflies_subset", split='train')

image_size = 32
batch_size = 64

# 数据增强
preprocess = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),  # Resize
        transforms.RandomHorizontalFlip(),  # Randomly flip (data augmentation)
        transforms.ToTensor(),  # Convert to tensor (0, 1)
        transforms.Normalize([0.5], [0.5]),  # Map to (-1, 1)
    ]
)

dataset.set_transform(transform)

# 数据装载
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 抽取一批数据可视化
xb = next(iter(train_dataloader))["images"].to(device)[:8]
print("X shape:", xb.shape)
show_images(xb).resize((8 * 64, 64), resample=Image.NEAREST)

输出可视化结果:

在这里插入图片描述


2、加噪调度器

  即DDPM论文中需要预定义的 β t {\beta_t } βt ,可使用DDPMScheduler类来定义,其中num_train_timesteps参数为时间步 t {t} t

from diffusers import DDPMScheduler

# βt值
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

plt.figure(dpi=300)
plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r"${\sqrt{\bar{\alpha}_t}}$")
plt.plot((1 - noise_scheduler.alphas_cumprod.cpu()) ** 0.5, label=r"$\sqrt{(1 - \bar{\alpha}_t)}$")
plt.legend(fontsize="x-large");

根据定义的 β t {\beta_t } βt ,可视化 α ˉ t {\sqrt {{{\bar \alpha }_t}}} αˉt 1 − α ˉ t {\sqrt {1 - {{\bar \alpha }_t}}} 1αˉt

在这里插入图片描述

  通过设置beta_start、beta_end和beta_schedule三个参数来控制噪声调度器的超参数 β t {\beta_t } βt

noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_start=0.001, beta_end=0.004)

在这里插入图片描述

  beta_schedule可以通过一个函数映射来为模型推理的每一步生成一个 β t {\beta_t } βt值。

noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')

在这里插入图片描述

x t = α ˉ t x 0 + 1 − α ˉ t ε {{x_t} = \sqrt {{{\bar \alpha }_t}} {x_0} + \sqrt {1 - {{\bar \alpha }_t}} \varepsilon } xt=αˉt x0+1αˉt ε 加噪前向过程可视化:

timesteps = torch.linspace(0, 999, 8).long().to(device)  # 随机采样时间步
noise = torch.randn_like(xb)
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)  # 加噪
print("Noisy X shape", noisy_xb.shape)
show_images(noisy_xb).resize((8 * 64, 64), resample=Image.NEAREST)

输出为:

在这里插入图片描述


3、扩散模型定义

  diffusers库中模型的定义也非常简洁:

# 创建模型
from diffusers import UNet2DModel

model = UNet2DModel(
    sample_size=image_size,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(64, 128, 128, 256),  # More channels -> more parameters
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "AttnDownBlock2D",
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",  # a regular ResNet upsampling block
    ),
)

model.to(device)
with torch.no_grad():
    model_prediction = model(noisy_xb, timesteps).sample
model_prediction.shape  # 验证输出与输出尺寸相同

4、扩散模型训练

  定义优化器,和传统模型一样的训练写法:

# 定义噪声调度器
noise_scheduler = DDPMScheduler(
    num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2"
)

# 优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)

losses = []

for epoch in range(30):
    for step, batch in enumerate(train_dataloader):
        clean_images = batch["images"].to(device)
        
        # 为图像添加随机噪声
        noise = torch.randn(clean_images.shape).to(clean_images.device)  # eps
        bs = clean_images.shape[0]

        # 为每一张图像随机选择一个时间步
        timesteps = torch.randint(
            0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device
        ).long()  

        # 根据时间步,向清晰的图像中加噪声, 前向过程:根号下αt^ * x0 + 根号下(1-αt^) * eps
        noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

        # 获得模型预测结果
        noise_pred = model(noisy_images, timesteps, return_dict=False)[0]

        # 计算损失, 损失回传
        loss = F.mse_loss(noise_pred, noise)  
        loss.backward(loss)
        losses.append(loss.item())

        # 更新模型参数
        optimizer.step()
        optimizer.zero_grad()

    if (epoch + 1) % 5 == 0:
        loss_last_epoch = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
        print(f"Epoch:{epoch+1}, loss: {loss_last_epoch}")

30个epoch训练过程如下所示:

在这里插入图片描述

可用以下代码查看损失曲线:

# 损失曲线可视化
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].plot(losses)
axs[1].plot(np.log(losses))  # 对数坐标
plt.show()

损失曲线可视化:

在这里插入图片描述


5、图像生成

  (1)通过建立pipeline生成图像:

# 图像生成
# 方法一:建立一个pipeline, 打包模型和噪声调度器
from diffusers import DDPMPipeline
image_pipe = DDPMPipeline(unet=model, scheduler=noise_scheduler)

pipeline_output = image_pipe()
plt.figure()
plt.imshow(pipeline_output.images[0])
plt.axis('off')
plt.show()

# 保存pipeline
image_pipe.save_pretrained("my_pipeline")  # 在当前目录下保存了一个 my_pipeline 的文件夹

生成的蝴蝶图像如下:

在这里插入图片描述

生成的my_pipeline文件夹如下:

在这里插入图片描述

  (2)通过随机采样循环生成图像:

# 方法二:模型调用, 写采样循环 
# 随机初始化8张图像:
sample = torch.randn(8, 3, 32, 32).to(device)

for i, t in enumerate(noise_scheduler.timesteps):

    # 获得模型预测结果
    with torch.no_grad():
        residual = model(sample, t).sample

    # 根据预测结果更新图像
    sample = noise_scheduler.step(residual, t, sample).prev_sample

show_images(sample)

8张生成图像如下:
在这里插入图片描述


6、代码汇总

import torch
import torchvision
from datasets import load_dataset
from torchvision import transforms
import numpy as np
import torch.nn.functional as F
from matplotlib import pyplot as plt
from PIL import Image


def show_images(x):
    """Given a batch of images x, make a grid and convert to PIL"""
    x = x * 0.5 + 0.5  # Map from (-1, 1) back to (0, 1)
    grid = torchvision.utils.make_grid(x)
    grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
    grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
    return grid_im


def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}

# --------------------------------------------------------------------------------
# 1、数据集加载与可视化
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# 数据加载
dataset = load_dataset("./smithsonian_butterflies_subset", split='train')

image_size = 32
batch_size = 64

# 数据增强
preprocess = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),  # Resize
        transforms.RandomHorizontalFlip(),  # Randomly flip (data augmentation)
        transforms.ToTensor(),  # Convert to tensor (0, 1)
        transforms.Normalize([0.5], [0.5]),  # Map to (-1, 1)
    ]
)

dataset.set_transform(transform)

# 数据装载
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
# --------------------------------------------------------------------------------

# --------------------------------------------------------------------------------
# 抽取一批数据可视化
xb = next(iter(train_dataloader))["images"].to(device)[:8]
print("X shape:", xb.shape)
show_images(xb).resize((8 * 64, 64), resample=Image.NEAREST)
# --------------------------------------------------------------------------------

# --------------------------------------------------------------------------------
# 2、噪声调度器
from diffusers import DDPMScheduler

# 加噪声的系数βt
# noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
# noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_start=0.001, beta_end=0.004)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')

plt.figure(dpi=300)
plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r"${\sqrt{\bar{\alpha}_t}}$")
plt.plot((1 - noise_scheduler.alphas_cumprod.cpu()) ** 0.5, label=r"$\sqrt{(1 - \bar{\alpha}_t)}$")
plt.legend(fontsize="x-large");
# --------------------------------------------------------------------------------

# --------------------------------------------------------------------------------
# 加噪声可视化
timesteps = torch.linspace(0, 999, 8).long().to(device)  # 随机采样时间步
noise = torch.randn_like(xb)
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)  # 加噪
print("Noisy X shape", noisy_xb.shape)
show_images(noisy_xb).resize((8 * 64, 64), resample=Image.NEAREST)
# --------------------------------------------------------------------------------

# --------------------------------------------------------------------------------
# 3、创建模型
from diffusers import UNet2DModel

model = UNet2DModel(
    sample_size=image_size,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(64, 128, 128, 256),  # More channels -> more parameters
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "AttnDownBlock2D",
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",  # a regular ResNet upsampling block
    ),
)

model.to(device)
with torch.no_grad():
    model_prediction = model(noisy_xb, timesteps).sample
model_prediction.shape  # 验证输出与输出尺寸相同
# --------------------------------------------------------------------------------

# --------------------------------------------------------------------------------
# 4、扩散模型训练
# 定义噪声调度器
noise_scheduler = DDPMScheduler(
    num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2"
)

# 优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)

losses = []

for epoch in range(30):
    for step, batch in enumerate(train_dataloader):
        clean_images = batch["images"].to(device)
        
        # 为图像添加随机噪声
        noise = torch.randn(clean_images.shape).to(clean_images.device)  # eps
        bs = clean_images.shape[0]

        # 为每一张图像随机选择一个时间步
        timesteps = torch.randint(
            0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device
        ).long()  

        # 根据时间步,向清晰的图像中加噪声, 前向过程:根号下αt^ * x0 + 根号下(1-αt^) * eps
        noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

        # 获得模型预测结果
        noise_pred = model(noisy_images, timesteps, return_dict=False)[0]

        # 计算损失, 损失回传
        loss = F.mse_loss(noise_pred, noise)  
        loss.backward(loss)
        losses.append(loss.item())

        # 更新模型参数
        optimizer.step()
        optimizer.zero_grad()

    if (epoch + 1) % 5 == 0:
        loss_last_epoch = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
        print(f"Epoch:{epoch+1}, loss: {loss_last_epoch}")
# --------------------------------------------------------------------------------

# --------------------------------------------------------------------------------
# 损失曲线可视化
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].plot(losses)
axs[1].plot(np.log(losses))  # 对数坐标
plt.show()
# --------------------------------------------------------------------------------

# --------------------------------------------------------------------------------
# 5、图像生成
# 方法一:建立一个pipeline, 打包模型和噪声调度器
from diffusers import DDPMPipeline
image_pipe = DDPMPipeline(unet=model, scheduler=noise_scheduler)

pipeline_output = image_pipe()

plt.figure()
plt.imshow(pipeline_output.images[0])
plt.axis('off')
plt.show()

image_pipe.save_pretrained("my_pipeline")  # 在当前目录下保存了一个 my_pipeline 的文件夹

# 方法二:模型调用, 写采样循环 
# 随机初始化8张图像:
sample = torch.randn(8, 3, 32, 32).to(device)

for i, t in enumerate(noise_scheduler.timesteps):

    # 获得模型预测结果
    with torch.no_grad():
        residual = model(sample, t).sample

    # 根据预测结果更新图像
    sample = noise_scheduler.step(residual, t, sample).prev_sample

show_images(sample)

grid_im = show_images(sample).resize((8 * 64, 64), resample=Image.NEAREST)
plt.figure(dpi=300)
plt.imshow(grid_im)
plt.axis('off')
plt.show()
# --------------------------------------------------------------------------------

  参考资料:扩散模型从原理到实践. 人民邮电出版社. 李忻玮, 苏步升等.

  diffusers确实很方便使用,有点子PyCaret的感觉了~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/577865.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

锁,数据同步

目录 原子操作中断控制自旋锁信号量小结 经常遇到数据同步的问题,具体有哪些情况呢?来聊一聊。我知道的,应该有以下四种,原子操作,中断控制,自旋锁和信号量。 原子操作 适用情况和场景 原子操作经常用在单…

vue实现录音并转文字功能,包括PC端web,手机端web

vue实现录音并转文字功能,包括PC端,手机端和企业微信自建应用端 不止vue,不限技术栈,vue2、vue3、react、.net以及原生js均可实现。 原理 浏览器实现录音并转文字最快捷的方法是通过Web Speech API来实现,这是浏览器…

【Linux系统编程】第八弹---权限管理操作(中)

✨个人主页: 熬夜学编程的小林 💗系列专栏: 【C语言详解】 【数据结构详解】【C详解】【Linux系统编程】 目录 1、修改文件权限的做法(二) 2、文件类型 3、可执行权限 4、创建文件/目录的默认权限 4.1、权限掩码 总结 前面一弹我们学…

CH4INRULZ-v1靶机练习实践报告

CH4INRULZ-v1靶机练习实践报告 1 安装靶机 靶机是.ova文件,需要用VirtualBox打开,但我习惯于使用VMWare,因此修改靶机文件,使其适用于VMWare打开。 解压ova文件,得到.ovf文件和.vmdk文件。直接用VMWare打开.ovf文件即可。 2 夺…

社区新零售:重构邻里生活圈,赋能美好未来

新时代的邻里脉动 在城市的肌理中,社区作为生活的基本单元,正经历一场由新零售引领的深刻变革。社区新零售,以其独特的商业模式、创新的技术手段和以人为本的服务理念,重新定义了社区商业的边界,重构了邻里生活的形态…

[C++ QT项目实战]----C++ QT系统实现多线程通信

前言 在C QT中,多线程通信原理主要涉及到信号与槽机制和事件循环机制。 1、信号与槽机制: 在QT中,信号与槽是一种用于对象间通信的机制。对象可以通过发送信号来通知其他对象,其他对象通过连接槽来接收信号并进行相应的处…

软件物料清单(SBOM)生成指南 .pdf

如今软件安全攻击技术手段不断升级,攻击数量显著增长。尤其是针对软件供应链的安全攻击,具有高隐秘性、追溯难的特点,对企业软件安全威胁极大。 同时,软件本身也在不断地更新迭代,软件内部成分安全性在持续变化浮动。…

web题目实操 5(备份文件和关于MD5($pass,true)注入的学习)

1.[ACTF2020 新生赛]BackupFile (1)打开页面后根据提示是备份文件 (2)查看源码发现啥都没有 (3)这里啊直接用工具扫描,可以扫描到一个文件名为:/index.php.bak的文件 (…

json解析大全

JSON解析案例1 将String转为JSONObject JSONObject res JSONObject.parseObject(result);获取documents JSONArray array res.getJSONObject("result").getJSONArray("documents");遍历JSONArray for (int i 0; i < array.size(); i) {JSONObject…

IDEA pom.xml依赖警告

IDEA中&#xff0c;有时 pom.xml 中会出现如下提示&#xff1a; IDEA 2022.1 升级了检测易受攻击的 Maven 和 Gradle 依赖项&#xff0c;并建议修正&#xff0c;通过插件 Package Checker 捆绑到 IDE 中。 这并不是引用错误&#xff0c;不用担心。如果实在强迫症不想看到这个提…

稳态视觉诱发电位 (SSVEP) 分类学习系列 (4) :Temporal-Spatial Transformer

稳态视觉诱发电位分类学习系列:Temporal-Spatial Transformer 0. 引言1. 主要贡献2. 提出的方法2.1 解码的主要步骤2.2 网络的主要结构 3. 结果和讨论3.1 在两个数据集下的分类效果3.2 与基线模型的比较3.3 消融实验3.4 t-SNE 可视化 4. 总结欢迎来稿 论文地址&#xff1a;http…

Hive——DDL(Data Definition Language)数据定义语句用法详解

1.数据库操作 1.1创建数据库 CREATE DATABASE [IF NOT EXISTS] database_name [COMMENT database_comment] [LOCATION hdfs_path] [WITH DBPROPERTIES (property_nameproperty_value, ...)];IF NOT EXISTS&#xff1a;可选参数&#xff0c;表示如果数据库已经存在&#xff0c…

软考-系统分析师-精要2

5、可行性分类 经济可行性&#xff1a;成本收益分析&#xff0c;包括建设成本、运行成本和项目建设后可能的经济收益。 技术可行性&#xff1a;技术风险分析&#xff0c;现有的技术能否支持系统目标的实现&#xff0c;现有资源&#xff08;员工&#xff0c;技术积累&#xff0…

GEM TSU Interface Details and IEEE 1588 Support

摘要&#xff1a;Xilinx ZNYQ ULTRASCALE MPSOC的GEM和1588的使用 对于FPGA来说&#xff0c;只需要勾选一些znyq的配置就行了&#xff0c;其余的都是软件的工作&#xff1b; 所有配置都勾选之后&#xff0c;最终会露出来的接口如下&#xff1a; GEM需要勾选的配置如下&#xf…

泰坦尼克号乘客生存情况预测分析2

泰坦尼克号乘客生存情况预测分析1 泰坦尼克号乘客生存情况预测分析2 泰坦尼克号乘客生存情况预测分析3 泰坦尼克号乘客生存情况预测分析总 背景描述 Titanic数据集在数据分析领域是十分经典的数据集&#xff0c;非常适合刚入门的小伙伴进行学习&#xff01; 泰坦尼克号轮船的…

AI新闻速递:揭秘本周科技界最热的AI创新与发展

兄弟朋友们&#xff0c;本周的AI领域又迎来了一系列激动人心的进展。在这个快速变化的时代&#xff0c;不会利用AI的人&#xff0c;就像在数字化高速公路上步行的旅行者&#xff0c;眼看着同行者驾驶着智能汽车绝尘而去&#xff0c;而自己却束手无策。 1. Adobe Firefly 3&…

【基础算法总结】双指针算法二

双指针 1.有效三角形的个数2.和为S的两个数字3.和为S的两个数字4.四数之和 点赞&#x1f44d;&#x1f44d;收藏&#x1f31f;&#x1f31f;关注&#x1f496;&#x1f496; 你的支持是对我最大的鼓励&#xff0c;我们一起努力吧!&#x1f603;&#x1f603; 1.有效三角形的个数…

深度学习运算:CUDA 编程简介

一、说明 如今&#xff0c;当我们谈论深度学习时&#xff0c;通常会将其实现与利用 GPU 来提高性能联系起来。GPU&#xff08;图形处理单元&#xff09;最初设计用于加速图像、2D 和 3D 图形的渲染。然而&#xff0c;由于它们能够执行许多并行操作&#xff0c;因此它们的实用性…

Python游戏工具包pygame

当你涉及游戏开发时&#xff0c;Pygame是一个强大的工具包&#xff0c;它提供了一系列功能丰富的模块和工具&#xff0c;让你可以轻松地创建各种类型的游戏。在本文中&#xff0c;我将介绍Pygame的依赖以及其详细属性&#xff0c;同时提供一些示例代码来说明其用法。 目录 一…

Github 2024-04-27 开源项目日报 Top9

根据Github Trendings的统计,今日(2024-04-27统计)共有9个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Python项目6TypeScript项目2C++项目1JavaScript项目1Open-Sora: 构建自己的视频生成模型 创建周期:17 天开发语言:Python协议类型:Apache Lic…