深度学习——(生成模型)DDPM

前置数学知识

1、先验概率和后验概率

先验概率:根据以往经验和分析得到的概率,它往往作为“由因求果”问题中的“因”出现,如 q ( x t ∣ x t − 1 ) q(x_t|x_{t-1}) q(xtxt1)

后验概率:指在得到“结果”的信息后重新修正的概率,是“执果寻因”问题中的“因", 如 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt)

2、条件概率:设 A A A B B B为任意两个事件,若 P ( A ) > 0 P(A)>0 P(A)>0,称在已知事件 A A A发生的条件下,事件 B B B发生的概率为条件概率,记为 P ( B ∣ A ) P(B|A) P(BA)
P ( B ∣ A ) = P ( A , B ) P ( A ) P(B|A)=\frac{P(A,B)} {P(A)} P(BA)=P(A)P(A,B)

3、乘法公式:
P ( A , B ) = P ( B ∣ A ) P ( A ) P(A,B)=P(B|A)P(A) P(A,B)=P(BA)P(A)

4、乘法公式一般形式:
P ( A , B , C ) = P ( C ∣ B , A ) P ( B , A ) = P ( C ∣ B , A ) P ( B ∣ A ) P ( A ) P(A,B,C)=P(C|B,A)P(B,A)=P(C|B,A)P(B|A)P(A)\\ P(A,B,C)=P(CB,A)P(B,A)=P(CB,A)P(BA)P(A)

5、贝叶斯公式:
P ( A ∣ B ) = P ( B ∣ A ) P ( A ) P ( B ) P(A|B)=\frac{P(B|A)P(A)}{P(B)} P(AB)=P(B)P(BA)P(A)
6、多元贝叶斯公式:
P ( A ∣ B , C ) = P ( A , B , C ) P ( B , C ) = P ( B ∣ A , C ) P ( A , C ) P ( B , C ) = P ( B ∣ A , C ) P ( A ∣ C ) P ( C ) P ( B ∣ C ) P ( C ) = P ( B ∣ A , C ) P ( A ∣ C ) ) P ( B ∣ C ) P(A|B,C)=\frac{P(A,B,C)}{P(B,C)}=\frac{P(B|A,C)P(A,C)}{P(B,C)}=\frac{P(B|A,C)P(A|C)P(C)}{P(B|C)P(C)}=\frac{P(B|A,C)P(A|C))}{P(B|C)} P(AB,C)=P(B,C)P(A,B,C)=P(B,C)P(BA,C)P(A,C)=P(BC)P(C)P(BA,C)P(AC)P(C)=P(BC)P(BA,C)P(AC))

7、正态分布的叠加性:当有两个独立的正态分布变量 N 1 N_{1} N1 N 2 N_{2} N2,它们的均值和方差分别为 μ 1 \mu_{1} μ1, μ 2 \mu_{2} μ2 σ 1 2 \sigma_{1}^2 σ12, σ 2 2 \sigma_{2}^2 σ22它们的和为 N = a N 1 + b N 2 N=a N_{1}+b N_{2} N=aN1+bN2的均值和方差可以表示如下:
E ( N ) = E ( a N 1 + b N 2 ) = a μ 1 + b μ 2 V a r ( N ) = V a r ( a N 1 + b N 2 ) = a 2 σ 1 2 + b 2 σ 2 2 E(N)=E(aN_{1}+bN_{2})=a\mu_{1}+b\mu_{2}\\ Var(N)=Var(aN_{1}+bN_{2})=a^2\sigma_{1}^2+b^2\sigma_{2}^2 E(N)=E(aN1+bN2)=aμ1+bμ2Var(N)=Var(aN1+bN2)=a2σ12+b2σ22
相减时:
E ( N ) = E ( a N 1 − b N 2 ) = a μ 1 − b μ 2 V a r ( N ) = V a r ( a N 1 − b N 2 ) = a 2 σ 1 2 + b 2 σ 2 2 E(N)=E(aN_{1}-bN_{2})=a\mu_{1}-b\mu_{2}\\ Var(N)=Var(aN_{1}-bN_{2})=a^2\sigma_{1}^2+b^2\sigma_{2}^2 E(N)=E(aN1bN2)=aμ1bμ2Var(N)=Var(aN1bN2)=a2σ12+b2σ22

8、重参数化:从 N ( μ , σ 2 ) N(\mu,\sigma^2) N(μ,σ2) 采样等价于从 N ( 0 , 1 ) N(0,1) N(0,1)采样一个 ϵ \epsilon ϵ, ϵ ⋅ σ + μ \epsilon\cdot\sigma+\mu ϵσ+μ

9、高斯分布的概率密度函数
f ( x ) = 1 2 π σ e − ( x − μ ) 2 2 σ 2 f(x)=\frac{1}{\sqrt{2\pi}\sigma}e^{-\frac{(x-\mu)^2}{2\sigma^2}} f(x)=2π σ1e2σ2(xμ)2
10、高斯分布的KL散度公式
K L ( p ∣ q ) = l o g σ 2 σ 1 + σ 2 + ( μ 1 − μ 2 ) 2 2 σ 2 2 − 1 2 KL(p|q)=log\frac{\sigma_2}{\sigma_1}+\frac{\sigma^2+(\mu_1-\mu_2)^2}{2\sigma_2^2}-\frac{1}{2} KL(pq)=logσ1σ2+2σ22σ2+(μ1μ2)221
11、二次函数配方
a x 2 + b x = a ( x + b 2 a ) 2 + c ax^2+bx=a(x+\frac{b}{2a})^2+c ax2+bx=a(x+2ab)2+c
12、随机变量的期望公式
X X X是随机变量, Y = g ( X ) Y=g(X) Y=g(X),则:
E ( Y ) = E [ g ( X ) ] = { ∑ k = 1 ∞ g ( x k ) p k ∫ − ∞ ∞ g ( x ) p ( x ) d x E(Y)=E[g(X)]= \begin{cases} \displaystyle\sum_{k=1}^\infty g(x_k)p_k\\ \displaystyle\int_{-\infty}^{\infty}g(x)p(x)dx \end{cases} E(Y)=E[g(X)]= k=1g(xk)pkg(x)p(x)dx

13、KL散度公式
K L ( p ( x ) ∣ q ( x ) ) = E x ∼ p ( x ) [ p ( x ) q ( x ) ] = ∫ p ( x ) p ( x ) q ( x ) d x KL(p(x)|q(x))=E_{x \sim p(x)}[\frac{p(x)}{q(x)}]=\int p(x) \frac{p(x)}{q(x)}dx KL(p(x)q(x))=Exp(x)[q(x)p(x)]=p(x)q(x)p(x)dx

介绍DDPM

2020年Berkeley提出DDPM(Denoising Diffusion Probabilistic Models),简称扩散模型,是AIGC的核心算法,在生成图像的真实性和多样性方面均超越了GAN,而且训练过程稳定。缺点是计算成本较高,实时推理比较困难,但也有相关技术在时间和空间维度上降低计算量。

扩散模型包括两个过程:前向扩散过程(前向加噪过程)反向去噪过程

img

前向过程和反向过程都是马尔可夫链,全过程大约需要1000步,其中反向过程用来生成数据,它的推导过程可以描述成:

img

前向扩散的过程

前向扩散过程是对原始数据逐渐增加高斯噪声,直至变成标准高斯分布的过程。

img

从原始数据集采样 x 0 ∼ q ( x 0 ) x_0\sim q(x_0) x0q(x0),按照预定义的noise schedule策略添加随机噪声,得到一系列噪声图像 x 1 , x 2 , … , x T x_1,x_2,\dots,x_T x1,x2,,xT,用概率表示为:
q ( x 1 : T ∣ x 0 ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) q ( x t ∣ x t − 1 ) = N ( x t ; α t x t − 1 , β t I ) \begin{aligned} q(x_{1:T}|x_{0})&=\prod_{t=1}^{T}q(x_t|x_{t-1}) \\q(x_{t}|x_{t-1})&=\mathcal{N}(x_t;\sqrt{\alpha_t}x_{t-1},\beta_{t}I)\\ \end{aligned} q(x1:Tx0)q(xtxt1)=t=1Tq(xtxt1)=N(xt;αt xt1,βtI)
进行重参数化(前置知识数学知识8),得到
x t = α t x t − 1 + β t ϵ t      ϵ t ∼ N ( 0 , I ) α t = 1 − β t \begin{aligned} x_{t}&=\sqrt{\alpha_{t}}x_{t-1}+\sqrt{\beta_{t}}\epsilon_{t} \space \space \space \space \epsilon_{t}\sim \mathcal{N}(0,I) \\ \alpha_{t}&=1-\beta_{t} \end{aligned} xtαt=αt xt1+βt ϵt    ϵtN(0,I)=1βt

利用上述公式进行迭代推导
x t = α t x t − 1 + β t ϵ t = α t ( α t − 1 x t − 2 + β t − 1 ϵ t − 1 ) + β t ϵ t = ( α t … α 1 ) x 0 + ( α t … α 2 ) β 1 ϵ 1 + ( α t … α 3 ) β 2 ϵ 2 + ⋯ + α t β t − 1 ϵ t − 1 + β t ϵ t \begin{aligned} x_{t}&=\sqrt{\alpha_{t}} x_{t-1}+\sqrt{\beta_{t}}\epsilon_{t}\\ &=\sqrt{\alpha_{t}}(\sqrt{\alpha_{t-1}}x_{t-2}+\sqrt{\beta_{t-1}}\epsilon_{t-1})+\sqrt{\beta_{t}}\epsilon_{t}\\ &=\sqrt{(\alpha_{t}\dots\alpha_{1})}x_{0}+\sqrt{(\alpha_{t}\dots\alpha_{2})\beta_{1}}\epsilon_{1}+\sqrt{(\alpha_{t}\dots\alpha_{3})\beta_{2}}\epsilon_{2}+\dots+\sqrt{\alpha_{t}\beta_{t-1}}\epsilon_{t-1}+\sqrt{\beta_{t}}\epsilon_{t} \end{aligned} xt=αt xt1+βt ϵt=αt (αt1 xt2+βt1 ϵt1)+βt ϵt=(αtα1) x0+(αtα2)β1 ϵ1+(αtα3)β2 ϵ2++αtβt1 ϵt1+βt ϵt

设: α t ˉ = α 1 α 2 … α t \bar{\alpha_{t}}=\alpha_{1}\alpha_{2}\dots\alpha_{t} αtˉ=α1α2αt

根据正态分布的叠加性得到

x t = α t ˉ x 0 + 1 − α t ˉ ϵ     ϵ ∼ N ( 0 , I ) q ( x t ∣ x 0 ) = N ( x t ; α t ˉ x 0 , 1 − α t ˉ I ) x_{t}=\sqrt{\bar{\alpha_{t}}}x_{0}+\sqrt{1-\bar{\alpha_{t}}}\epsilon \space \space\space \epsilon\sim \mathcal{N}(0,I)\\ \textcolor{REd}{q(x_{t}|x_{0})=\mathcal{N}(x_{t};\sqrt{\bar{\alpha_{t}}}x_{0},\sqrt{1-\bar{\alpha_{t}}}I)} xt=αtˉ x0+1αtˉ ϵ   ϵN(0,I)q(xtx0)=N(xt;αtˉ x0,1αtˉ I)
这个公式表示任意步骤 t t t的噪声图像 x t x_t xt ,都可以通过 x 0 x_0 x0直接加噪得到,后面需要用到。

注:上述前向过程在代码实现时是一步到位的!!!!!

反向去噪过程,神经网络拟合过程

反向去噪过程就是数据生成过程,它首先是从标准高斯分布中采样得到一个噪声样本,再一步步地迭代去噪,最后得到数据分布中的一个样本。

img

如果知道反向过程的每一步真实的条件分布 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt1xt),那么从一个随机噪声开始,逐步采样就能生成一个真实的样本。但是真实的条件分布利用贝叶斯公式
q ( x t − 1 ∣ x t ) = q ( x t ∣ x t − 1 ) q ( x t − 1 ) q ( x t ) q(x_{t-1}|x_{t}) =\frac{q(x_{t}|x_{t-1})q(x_{t-1})}{q(x_{t})} q(xt1xt)=q(xt)q(xtxt1)q(xt1)
无法直接求解,原因是其中 q ( x t − 1 ) q(x_{t-1}) q(xt1) , q ( x t ) q(x_{t}) q(xt) 未知,因此无法从 x t x_{t} xt 推导到 x t − 1 {x_{t-1}} xt1,所以必须通过神经网络** p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1}|x_t) pθ(xt1xt)来近似。为了简化起见,将反向过程也定义为一个马尔卡夫链,且服从高斯分布**,建模如下:

p θ ( x 0 : T ) = p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , ∑ θ ( x t , t ) ) p_\theta(x_{0:T})=p(x_T)\prod_{t=1}^Tp_\theta(x_{t-1}|x_t)\\ p_\theta(x_{t-1}|x_t)=N(x_{t-1};\mu_\theta(x_t,t),\sum_\theta(x_t,t)) pθ(x0:T)=p(xT)t=1Tpθ(xt1xt)pθ(xt1xt)=N(xt1;μθ(xt,t),θ(xt,t))

--------------------下面这段讲解与上面有些跳脱,是为损失函数做铺垫------------------------------

虽然真实条件分布 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt1xt)无法直接求解,但是加上已知条件 x 0 x_0 x0的后验分布$q(x_{t-1}|x_{t},x_{0}) $却可以通过贝叶斯公式求解,再结合前向马尔科夫性质可得
q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) = q ( x t ∣ x t − 1 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) q(x_{t-1}|x_{t},x_{0}) =\frac{q(x_{t}|x_{t-1},x_{0})q(x_{t-1}|x_{0})}{q(x_{t}|x_{0})}=\frac{q(x_{t}|x_{t-1})q(x_{t-1}|x_{0})}{q(x_{t}|x_{0})} q(xt1xt,x0)=q(xtx0)q(xtxt1,x0)q(xt1x0)=q(xtx0)q(xtxt1)q(xt1x0)
因此可以得到:
q ( x t − 1 ∣ x 0 ) = α ˉ t − 1 x 0 + 1 − α ˉ t − 1 ϵ ∼ N ( α ˉ t − 1 x 0 , ( 1 − α ˉ t − 1 ) I ) q ( x t ∣ x 0 ) = α ˉ t x 0 + 1 − α ˉ t ϵ ∼ N ( α ˉ t x 0 , ( 1 − α ˉ t ) I ) q ( x t ∣ x t − 1 ) = α t x t − 1 + β t ϵ ∼ N ( α t x t − 1 , β t I ) \begin{aligned} q(x_{t-1}|x_{0})&=\sqrt{\bar{\alpha}_{t-1}}x_{0}+\sqrt{1-\bar{\alpha}_{t-1}}\epsilon\sim \mathcal{N}(\sqrt{\bar{\alpha}_{t-1}}x_{0},(1-\bar{\alpha}_{t-1})I)\\ q(x_{t}|x_{0})&=\sqrt{\bar{\alpha}_{t}}x_{0}+\sqrt{1-\bar{\alpha}_{t}}\epsilon\sim \mathcal{N}(\sqrt{\bar{\alpha}_{t}}x_{0},(1-\bar{\alpha}_{t})I)\\ q(x_{t}|x_{t-1})&=\sqrt{\alpha}_{t}x_{t-1}+\beta_{t}\epsilon\sim \mathcal{N}(\sqrt{\alpha}_{t}x_{t-1},\beta_{t}I) \end{aligned} q(xt1x0)q(xtx0)q(xtxt1)=αˉt1 x0+1αˉt1 ϵN(αˉt1 x0,(1αˉt1)I)=αˉt x0+1αˉt ϵN(αˉt x0,(1αˉt)I)=α txt1+βtϵN(α txt1,βtI)
所以
q ( x t − 1 ∣ x t , x 0 ) ∝ e x p ( − 1 2 ( ( x t − α t x t − 1 ) 2 β t ) + ( x t − 1 − α ˉ t − 1 x 0 ) 2 1 − α ˉ t − 1 − ( x t − α ˉ t x 0 ) 2 1 − α ˉ t ) = e x p ( − 1 2 ( α t β t + 1 1 − α ˉ t − 1 ) x t − 1 2 − ( 2 α t β t x t + 2 α t ˉ 1 − α t ˉ x 0 ) x t − 1 + C ( x t , x 0 ) ) \begin{aligned} q(x_{t-1}|x_{t},x_{0}) &\propto exp(-\frac{1}{2}(\frac{(x_{t}-\sqrt{\alpha_{t}}x_{t-1})^2}{\beta_{t}})+\frac{(x_{t-1}-\sqrt{\bar{\alpha}}_{t-1}x_{0})^2}{1-\bar{\alpha}_{t-1}}-\frac{(x_{t}-\sqrt{\bar{\alpha}_{t}}x_{0})^2}{1-\bar{\alpha}_{t}})\\ &=exp(-\frac{1}{2}(\frac{\alpha_{t}}{\beta_{t}}+\frac{1}{1-\bar{\alpha}_{t-1}})x_{t-1}^2-(\frac{2\sqrt{\alpha_{t}}}{\beta_{t}}x_{t}+\frac{2\sqrt{\bar{\alpha_{t}}}}{1-\bar{\alpha_{t}}}x_{0})x_{t-1}+C(x_{t},x_{0})) \end{aligned} q(xt1xt,x0)exp(21(βt(xtαt xt1)2)+1αˉt1(xt1αˉ t1x0)21αˉt(xtαˉt x0)2)=exp(21(βtαt+1αˉt11)xt12(βt2αt xt+1αtˉ2αtˉ x0)xt1+C(xt,x0))

通过配方就可以得到
β ~ t = 1 / ( α t β t + 1 1 − α ˉ t − 1 ) = 1 − α ˉ t − 1 1 − α ˉ t β t μ ~ t = ( α t β t x t + α ˉ t 1 − α t ˉ x 0 ) / ( α t β t + 1 1 − α ˉ t − 1 ) = α t ( 1 − α ˉ t − 1 ) 1 − α t ˉ x t + α ˉ t − 1 β t 1 − α ˉ t x 0 \widetilde{\beta}_t=1/(\frac{\alpha_{t}}{\beta_{t}}+\frac{1}{1-\bar{\alpha}_{t-1}})=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_{t}}\beta_{t}\\ \widetilde{\mu}_t=(\frac{\sqrt\alpha_{t}}{\beta_{t}}x_{t}+\frac{\sqrt{\bar{\alpha}_{t}}}{1-\bar{\alpha_{t}}}x_{0})/(\frac{\alpha_{t}}{\beta_{t}}+\frac{1}{1-\bar{\alpha}_{t-1}})=\frac{\sqrt{\alpha_{t}}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha_{t}}}x_{t}+\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_{t}}{1-\bar{\alpha}_{t}}x_{0} β t=1/(βtαt+1αˉt11)=1αˉt1αˉt1βtμ t=(βtα txt+1αtˉαˉt x0)/(βtαt+1αˉt11)=1αtˉαt (1αˉt1)xt+1αˉtαˉt1 βtx0

又因为
x 0 = 1 α ˉ t ( x t − β t 1 − α ˉ t ϵ ) x_0= \frac{1}{\sqrt{\bar\alpha_t}}(x_t- \frac{\beta_t}{\sqrt{1-\bar \alpha_t} }\epsilon)\\ x0=αˉt 1(xt1αˉt βtϵ)
可以得
μ ~ t = 1 α t ( x t − β t ( 1 − α t ) ϵ ) \widetilde{\mu}_t=\frac{1}{\sqrt{\alpha_t} }(x_t-\frac{\beta_t}{\sqrt{(1-\alpha_t)}}\epsilon) μ t=αt 1(xt(1αt) βtϵ)

----------------------------------------------------------------------------------------------

采样过程(模型训练完后的预测过程)

μ θ ( x t , t ) = 1 α t ( x t − β t ( 1 − α t ) ϵ θ ( x t , t ) ) x t − 1 ∼ p θ ( x t − 1 ∣ x t ) x t − 1 = 1 α t ( x t − β t ( 1 − α t ) ϵ θ ( x t , t ) ) + β ~ t z      z ∼ N ( 0 , I ) \mu_\theta(x_t,t)=\frac{1}{\sqrt{\alpha_t} }(x_t-\frac{\beta_t}{\sqrt{(1-\alpha_t)}}\epsilon_\theta(x_t,t))\\ x_{t-1}\sim p_\theta(x_{t-1}|x_t)\\ x_{t-1}=\frac{1}{\sqrt{\alpha_t} }(x_t-\frac{\beta_t}{\sqrt{(1-\alpha_t)}}\epsilon_\theta(x_t,t))+\sqrt{\widetilde{\beta}_t}z \space \space\space\space z\sim N(0,I) μθ(xt,t)=αt 1(xt(1αt) βtϵθ(xt,t))xt1pθ(xt1xt)xt1=αt 1(xt(1αt) βtϵθ(xt,t))+β t z    zN(0,I)
这里用z是为了和之前的 ϵ \epsilon ϵ区别开

损失函数

https://blog.csdn.net/weixin_45453121/article/details/131223653

Code

import torch
import torchvision
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
import numpy as np
from torch.optim import Adam
from torch import nn
import math
from torchvision.utils import save_image


def show_images(data, num_samples=20, cols=4):
    """ Plots some samples from the dataset """
    plt.figure(figsize=(15,15))
    for i, img in enumerate(data):
        if i == num_samples:
            break
        plt.subplot(int(num_samples/cols) + 1, cols, i + 1)
        plt.imshow(img[0])


def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)


def get_index_from_list(vals, t, x_shape):
    """
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    #print("out:",out)
    #print("out.shape:",out.shape)
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion_sample(x_0, t, device="cpu"):
    """
    Takes an image and a timestep as input and
    returns the noisy version of it
    """
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    )
    # mean + variance
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
    + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)






def load_transformed_dataset(IMG_SIZE):
    data_transforms = [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(), # Scales data into [0,1]
        transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1]
    ]
    data_transform = transforms.Compose(data_transforms)

    train = torchvision.datasets.MNIST(root="./Data",transform=data_transform,train=True)
    test = torchvision.datasets.MNIST(root="./Data", transform=data_transform, train=False)

    return torch.utils.data.ConcatDataset([train, test])

def show_tensor_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    #Take first image of batch
    if len(image.shape) == 4:
        image = image[0, :, :, :]
    plt.imshow(reverse_transforms(image))





class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()

    def forward(self, x, t):
        #print("ttt:",t.shape)
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        # TODO: Double check the ordering here
        return embeddings


class SimpleUnet(nn.Module):
    """
    A simplified variant of the Unet architecture.
    """
    def __init__(self):
        super().__init__()
        image_channels =1   #灰度图为1,彩色图为3
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 1   #灰度图为1 ,彩色图为3
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_emb_dim),
                nn.Linear(time_emb_dim, time_emb_dim),
                nn.ReLU()
            )

        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
                                    time_emb_dim) \
                    for i in range(len(down_channels)-1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
                                        time_emb_dim, up=True) \
                    for i in range(len(up_channels)-1)])

        # Edit: Corrected a bug found by Jakub C (see YouTube comment)
        self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

    def forward(self, x, timestep):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)
        return self.output(x)




def get_loss(model, x_0, t):
    x_noisy, noise = forward_diffusion_sample(x_0, t, device)
    noise_pred = model(x_noisy, t)
    return F.l1_loss(noise, noise_pred)



@torch.no_grad()
def sample_timestep(x, t):
    """
    Calls the model to predict the noise in the image and returns
    the denoised image.
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )

    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)

    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )

    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)

    if t == 0:
        # As pointed out by Luis Pereira (see YouTube comment)
        # The t's are offset from the t's in the paper
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise

@torch.no_grad()
def sample_plot_image(IMG_SIZE):
    # Sample noise
    img_size = IMG_SIZE
    img = torch.randn((1, 1, img_size, img_size), device=device)   #生成第T步的图片
    plt.figure(figsize=(15,15))
    plt.axis('off')
    num_images = 10
    stepsize = int(T/num_images)

    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        #print("t:",t)
        img = sample_timestep(img, t)
        # Edit: This is to maintain the natural range of the distribution
        img = torch.clamp(img, -1.0, 1.0)
        if i % stepsize == 0:
            plt.subplot(1, num_images, int(i/stepsize)+1)
            plt.title(str(i))
            show_tensor_image(img.detach().cpu())
    plt.show()


if __name__ =="__main__":

    # Define beta schedule
    T = 300
    betas = linear_beta_schedule(timesteps=T)

    # Pre-calculate different terms for closed form
    alphas = 1. - betas
    alphas_cumprod = torch.cumprod(alphas, axis=0)
    # print(alphas_cumprod.shape)
    alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
    # print(alphas_cumprod_prev)
    # print(alphas_cumprod_prev.shape)
    sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
    posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
    # print(posterior_variance.shape)


    IMG_SIZE = 32
    BATCH_SIZE = 16

    data = load_transformed_dataset(IMG_SIZE)
    dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

    model = SimpleUnet()
    print("Num params: ", sum(p.numel() for p in model.parameters()))

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    optimizer = Adam(model.parameters(), lr=0.001)
    epochs = 1 # Try more!

    for epoch in range(epochs):
        for step, batch in enumerate(dataloader):  #由于batch 是包含标签的所以取batch[0]
            #print(batch[0].shape)
            optimizer.zero_grad()

            t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
            loss = get_loss(model, batch[0], t)
            loss.backward()
            optimizer.step()

            if epoch % 1 == 0 and step %5== 0:
                print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
                sample_plot_image(IMG_SIZE)

参考文献

https://zhuanlan.zhihu.com/p/630354327](https://zhuanlan.zhihu.com/p/630354327)

https://blog.csdn.net/weixin_45453121/article/details/131223653

https://www.cnblogs.com/risejl/p/17448442.html

https://zhuanlan.zhihu.com/p/569994589?utm_id=0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/163198.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

LeetCo

题目描述如下: 罗马数字包含以下七种字符: I, V, X, L,C,D 和 M。 字符 数值 I 1 V 5 X 10 L 50 C 100 D 500 M …

component 动态组件的用法

一&#xff1a;前言 <component></component> 标签是Vue框架自定义的标签&#xff0c;它的用途就是可以动态绑定我们的组件&#xff0c;根据数据的不同需求来更换使用不同的组件。 在最上方的图片中&#xff0c;就是使用的 Element Plus 的 Tags 组件&#xff0c;根…

golang学习笔记——接口

文章目录 Go 语言接口例子空接口空接口的定义空接口的应用空接口作为函数的参数空接口作为map的值 类型断言接口值 类型断言例子001类型断言例子002 Go 语言接口 接口&#xff08;interface&#xff09;定义了一个对象的行为规范&#xff0c;只定义规范不实现&#xff0c;由具…

Codeforces Round #909 (Div. 3)

A. Game with Integers 签到题&#xff0c;但是本蒟蒻11分钟才AC&#xff0c;主要还是英文题面不熟练&#xff0c;题目中加粗了after&#xff0c;只有下一步操作之后能被整除才胜利。 英文题面的加粗单词很重要&#xff0c;注意提高签到题速度。 B. 250 Thousand Tons of TNT…

C语言的由来与发展历程

C语言的起源可以追溯到上世纪70年代&#xff0c;由Dennis Ritchie在贝尔实验室开发出来。C语言的设计目标是提供一种简洁、高效、可移植的编程语言&#xff0c;以便于开发底层的系统软件。在那个时代&#xff0c;计算机技术正在迅速发展&#xff0c;出现了多种高级编程语言&…

05-Spring Boot工程中简化开发的方式Lombok和dev-tools

简化开发的方式Lombok和dev-tools Lombok常用注解 Lombok用标签方式代替构造器、getter/setter、toString()等重复代码, 在程序编译的时候自动生成这些代码 注解名功能NoArgsConstructor生成无参构造方法AllArgsConstructor生产含所有属性的有参构造方法,如果不希望含所有属…

Pycharm中添加Python库指南

一、介绍 Pycharm是一款为Python开发者提供的集成开发环境&#xff08;IDE&#xff09;&#xff0c;支持执行、调试Python代码&#xff0c;并提供了许多有用的工具和功能&#xff0c;其中之一就是在Pycharm中添加Python库。 添加Python库有许多好处&#xff0c;比如能够增加开…

C/C++字符判断 2021年12月电子学会青少年软件编程(C/C++)等级考试一级真题答案解析

目录 C/C字符判断 一、题目要求 1、编程实现 2、输入输出 二、算法分析 三、程序编写 四、程序说明 五、运行结果 六、考点分析 C/C字符判断 2021年12月 C/C编程等级考试一级编程题 一、题目要求 1、编程实现 对于给定的字符&#xff0c;如果该字符是大小写字母或…

Typecho用宝塔面板建站(保姆级教程)

提前准备&#xff1a; 1 已备案域名 注意:在腾讯云备案的域名部署阿里云服务器的话还需要在阿里云备案&#xff0c;反之亦然 2 服务器 服务器操作系统设置为windows 服务器实例设置&#xff1a;依次开放8888/888/443/3000-4000/21/22端口 个人用的阿里云&#xff0c;到安全组配…

代码随想录算法训练营第五十五天|392. 判断子序列、115. 不同的子序列

第九章 动态规划 part15 392. 判断子序列 给定字符串 s 和 t &#xff0c;判断 s 是否为 t 的子序列。 字符串的一个子序列是原始字符串删除一些&#xff08;也可以不删除&#xff09;字符而不改变剩余字符相对位置形成的新字符串。&#xff08;例如&#xff0c;"ace&q…

实验(二):存储器实验

一、实验内容与目的 实验要求&#xff1a; 利用 CP226 实验仪上的 K16..K23 开关做为 DBUS 的数据&#xff0c;其它开关做为控制信号&#xff0c;实现主存储器 EM 的读写操作&#xff1b;利用 CP226 实验仪上的小键盘将程序输入主存储器 EM&#xff0c;实现程序的自动运行。 实…

leetcoe刷题日志-6N字形变换

将一个给定字符串 s 根据给定的行数 numRows &#xff0c;以从上往下、从左到右进行 Z 字形排列。 比如输入字符串为 “PAYPALISHIRING” 行数为 3 时&#xff0c;排列如下&#xff1a; 之后&#xff0c;你的输出需要从左往右逐行读取&#xff0c;产生出一个新的字符串&#…

openwrt配置ipv6

废话部分&#xff08;可跳过&#xff09; 历经多天&#xff0c;经过各种测试&#xff0c;终于把openwrt的ipv6配置成功了&#xff0c;这篇我将尽我所能详尽的描述一下可能遇到的问题和解决办法。这篇文章致力于让你完成整个openwrt的ipv6配置&#xff0c;希望对你有所帮助。在…

(Matalb回归预测)PSO-BP粒子群算法优化BP神经网络的多维回归预测

目录 一、程序及算法内容介绍&#xff1a; 基本内容&#xff1a; 亮点与优势&#xff1a; 二、实际运行效果&#xff1a; 三、部分程序&#xff1a; 四、完整程序数据说明文档下载&#xff1a; 一、程序及算法内容介绍&#xff1a; 基本内容&#xff1a; 本代码基于Matalb…

、如何在企业签名、超级签名、tf签名之间做选择

企业签名 (Enterprise Signing): 用途&#xff1a; 适用于企业内部发布应用&#xff0c;不需要经过App Store审核&#xff0c;可以通过企业内部渠道直接分发给员工或内部用户。限制&#xff1a; 仅限于企业内部使用&#xff0c;无法在App Store上发布或向外部用户分发。 超级签…

记一次解决Pyqt6/Pyside6添加QTreeView或QTreeWidget导致窗口卡死(未响应)的新路历程,打死我都想不到是这个原因

文章目录 💢 问题 💢🏡 环境 🏡📄 代码💯 解决方案 💯⚓️ 相关链接 ⚓️💢 问题 💢 我在窗口中添加了一个 QTreeWidget控件 ,但是程序在运行期间,只要鼠标进入到 QTreeWidget控件 内进行操作,时间超过几秒中就会出现窗口 未响应卡死的 状态 🏡 环境 �…

机器视觉工程师吐槽的常见100个名场面

学了后发现真没用&#xff0c;只能越干越多 德创跑的快&#xff0c;苏映视裁的快&#xff0c;上帝说&#xff0c;要有光&#xff0c;我是凌云光。 这群里面有多少从德创跑路的 去年我辛辛苦苦干一年顶两年了&#xff0c;单双休变单休或者无休&#xff0c;节假日全部对半砍。加班…

多聚焦图像融合算法

# File : PerfectFusion.py # Author : ShawnWang # Desc : 多焦点图像融合 # Time : 2023/9/24 08:25 import cv2 import matplotlib.pyplot as plt import numpy as np import pywt from PIL import Image# 基于小波变换的多聚焦图像融合…

基于SSM的古董拍卖系统

基于SSM的古董拍卖系统的设计与实现~ 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringMyBatisSpringMVC工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 系统展示 主页 拍卖界面 管理员界面 摘要 古董拍卖系统是一个基于SSM框架&#xff08;Spring …

两数之和 II - 输入有序数组

给你一个下标从 1 开始的整数数组 numbers &#xff0c;该数组已按 非递减顺序排列 &#xff0c;请你从数组中找出满足相加之和等于目标数 target 的两个数。如果设这两个数分别是 numbers[index1] 和 numbers[index2] &#xff0c;则 1 < index1 < index2 < numbers.…
最新文章