Pytorch实现扩散模型【DDPM代码解读篇2】

扩散的代码实现

本文介绍“扩散是如何实现的”。代码逻辑清晰,可快速上手学习。

# 扩散的代码实现
# 扩散过程是训练部分的模型。它打开了一个采样接口,允许我们使用已经训练好的模型生成样本。
class DiffusionModel(nn.Module):
	# 类变量,用于将字符串调度器名称映射到相应的调度函数
    SCHEDULER_MAPPING = {
        "linear": linear_beta_schedule,
        "cosine": cosine_beta_schedule,
        "sigmoid": sigmoid_beta_schedule,
    }
 
    def __init__(
        self,
        model: nn.Module,
        image_size: int,
        *,
        beta_scheduler: str = "linear",  # 调度器类型,默认为线性
        timesteps: int = 1000,
        schedule_fn_kwargs: dict | None = None,  # 调度函数的关键字参数,默认为 None
        auto_normalize: bool = True,
    ) -> None:
        super().__init__()
        self.model = model
 
        self.channels = self.model.channels
        self.image_size = image_size
 
 		# 从 SCHEDULER_MAPPING 字典中获取与 beta_scheduler 字符串相对应的调度函数
 		# 如果 beta_scheduler 字符串不存在于字典中,则返回 None
        self.beta_scheduler_fn = self.SCHEDULER_MAPPING.get(beta_scheduler)
        # 检查获取到的调度函数是否为 None,即检查是否成功选择了β调度函数
        # 如果调度函数为 None,则说明指定的 beta_scheduler 字符串不在预定义的调度函数列表中,于是抛出 ValueError 异常
        if self.beta_scheduler_fn is None:
            raise ValueError(f"unknown beta schedule {beta_scheduler}")
 		# 检查是否提供了调度函数的关键字参数。若未提供,将schedule_fn_kwargs 设置为空字典。
        if schedule_fn_kwargs is None:
            schedule_fn_kwargs = {}
 		
 		# 用于计算扩散模型中的β调度函数,以及与β相关的其他参数,如α和后验方差:
        betas = self.beta_scheduler_fn(timesteps, **schedule_fn_kwargs)  # 生成一个包含β值的张量 betas
        alphas = 1.0 - betas

        # 对α值进行累积乘积,得到一个新的张量 alphas_cumprod,其形状与 betas 相同,包含了从0到 timesteps-1 时间步的所有α值的乘积。
        alphas_cumprod = torch.cumprod(alphas, dim=0)

        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
        '''
        	对 alphas_cumprod 进行填充操作,将其第一个元素用 1.0 填充,以确保在计算后验方差时不会出现除以零的情况。
			F.pad 函数用于在张量的指定维度上进行填充,这里在维度 0 上进行填充,向左填充一个元素。
        '''

        # 计算后验方差
        posterior_variance = (
            betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )
 
 		# 注册缓冲区(buffer),并将每个相关的张量转换为 torch.float32 类型
        register_buffer = lambda name, val: self.register_buffer(
            name, val.to(torch.float32)
        )
 
        register_buffer("betas", betas)  # 包含 β 值的张量
        register_buffer("alphas_cumprod", alphas_cumprod)  # 包含 α 累积乘积的张量
        register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)  # 包含 α 累积乘积的前一个时间步的张量
        register_buffer("sqrt_recip_alphas", torch.sqrt(1.0 / alphas))  # α 的倒数的平方根的张量
        register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))  # α 累积乘积的平方根的张量
        register_buffer(
            "sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)
        )
        register_buffer("posterior_variance", posterior_variance)  # 后验方差的张量
 
        timesteps, *_ = betas.shape
        '''
        	这里使用了“*”操作符,它的作用是在变量解构(destructuring)中丢弃不需要的部分。因为 betas 张量是一维的,所以这里的“*”操作符实际上没有起到什么作用,只是为了让代码更具通用性。
        '''
        self.num_timesteps = int(timesteps)  # 将时间步数转换为整数
 
        self.sampling_timesteps = timesteps
 
 		# 归一化
 		# auto_normalize 为 True,则选择 normalize_to_neg_one_to_one 函数进行归一化;否则选择 identity 函数,即不进行归一化操作。
        self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity

        #  auto_normalize 为 True,则选择 unnormalize_to_zero_to_one 函数进行反归一化;否则选择 identity 函数,即不进行反归一化操作。
        self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity
 
    @torch.inference_mode()
    '''
    	可以将下面的函数或代码块置于推断模式中。这意味着,在装饰器声明的范围内,PyTorch 将禁用梯度计算,不会跟踪梯度,也不会进行任何与梯度相关的操作。这有助于提高推断速度,并且可以确保模型在进行推断时不会意外地进行训练相关的计算。
    '''

    def p_sample(self, x: torch.Tensor, timestamp: int) -> torch.Tensor:
    	# 使用了解构语法 *,将张量形状中的除最后一维之外的所有维度忽略掉,并将结果赋值给一个名为 _ 的临时变量,最后一个维度被赋值给 device
        b, *_, device = *x.shape, x.device

        batched_timestamps = torch.full(
            (b,), timestamp, device=device, dtype=torch.long
        )
        # 创建了一个形状为 (b,) 的张量 batched_timestamps,用于存储批次中每个样本的时间戳。
        # timestamp,其数据类型为 torch.long,并且张量存储在与输入张量相同的设备上
 		
 		# 将输入张量 x 和时间戳张量 batched_timestamps 传递给模型 self.model,以获取预测值 preds
        preds = self.model(x, batched_timestamps)
 		
 		# 使用函数 extract 从预先计算的参数 self.betas 中提取与批次时间戳对应的β值
        betas_t = extract(self.betas, batched_timestamps, x.shape)
        sqrt_recip_alphas_t = extract(
            self.sqrt_recip_alphas, batched_timestamps, x.shape
        )  # 提取与批次时间戳对应的α倒数的平方根
        sqrt_one_minus_alphas_cumprod_t = extract(
            self.sqrt_one_minus_alphas_cumprod, batched_timestamps, x.shape
        )  # 提取与批次时间戳对应的1减去α累积乘积的平方根
 		
 		# 计算预测的样本均值
        predicted_mean = sqrt_recip_alphas_t * (
            x - betas_t * preds / sqrt_one_minus_alphas_cumprod_t
        )
 		
 		#如果时间戳为零,直接返回预测的样本均值;否则,计算样本的后验方差并添加噪声,然后返回结果。
        if timestamp == 0:
            return predicted_mean
        else:
            posterior_variance = extract(
                self.posterior_variance, batched_timestamps, x.shape
            )
            noise = torch.randn_like(x)
            return predicted_mean + torch.sqrt(posterior_variance) * noise
 
    @torch.inference_mode()
    def p_sample_loop(
        self, shape: tuple, return_all_timesteps: bool = False
    ) -> torch.Tensor:
        batch, device = shape[0], "mps"  # 从形状元组中获取批量大小 batch,并设置设备为 "mps"(多处理器尺寸)
 
        img = torch.randn(shape, device=device) # 函数生成一个具有指定形状的随机张量 img,其值服从标准正态分布
        # This cause me a RunTimeError on MPS device due to MPS back out of memory
        # No ideas how to resolve it at this point
 
        # imgs = [img]
 		
 		'''
 			使用 tqdm 函数创建一个迭代进度条,迭代范围是从 0 到 self.num_timesteps 的逆序。
 			每个时间步长 t 都会调用 p_sample 方法进行样本采样,并更新 img 的值。
 		'''
        for t in tqdm(reversed(range(0, self.num_timesteps)), total=self.num_timesteps):
            img = self.p_sample(img, t)
            # imgs.append(img)
            '''
            	将每个时间步长的采样结果添加到一个列表 imgs 中。在循环中,每次迭代会生成一个新的采样结果,并将其添加到列表中
            	允许在函数结束后返回所有时间步长的采样结果,以便进一步分析或处理
            '''
 		
 		# 最终的采样结果
        ret = img # if not return_all_timesteps else torch.stack(imgs, dim=1)
 		
 		# 调用 unnormalize 方法将最终的采样结果反归一化,使其返回到原始数据范围内。
        ret = self.unnormalize(ret)
        return ret
 
 	# return_all_timesteps指定是否返回所有时间步长的样本,默认为 False,表示只返回最终时间步长的样本
    def sample(
        self, batch_size: int = 16, return_all_timesteps: bool = False  
    ) -> torch.Tensor:
        shape = (batch_size, self.channels, self.image_size, self.image_size)
        return self.p_sample_loop(shape, return_all_timesteps=return_all_timesteps)
 
 	# 用于在给定时间步长 t 上生成样本
    def q_sample(
        self, x_start: torch.Tensor, t: int, noise: torch.Tensor = None
    ) -> torch.Tensor:
    	# 首先检查是否提供了噪声
        if noise is None:
            noise = torch.randn_like(x_start)

 		# 接着根据 t 从预先计算的参数中提取相应的系数
        sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(
            self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
        )
 		
 		# 最后根据扩散过程的定义,计算并返回生成的样本
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
 
    def p_loss(
        self,
        x_start: torch.Tensor,
        t: int,
        noise: torch.Tensor = None,
        loss_type: str = "l2",
    ) -> torch.Tensor:
        if noise is None:
            noise = torch.randn_like(x_start)
        x_noised = self.q_sample(x_start, t, noise=noise)  # 在给定时间步长 t 上生成经过噪声处理的样本 x_noised

        # 使用生成的 x_noised 作为输入,调用模型 self.model,并传入时间步长 t,以获取预测的噪声 predicted_noise。
        predicted_noise = self.model(x_noised, t)
 
        if loss_type == "l2":  # 均方误差损失函数
            loss = F.mse_loss(noise, predicted_noise)
        elif loss_type == "l1":  # 绝对值误差损失函数
            loss = F.l1_loss(noise, predicted_noise)
        else:
            raise ValueError(f"unknown loss type {loss_type}")
        return loss
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, c, h, w, device, img_size = *x.shape, x.device, self.image_size
        assert h == w == img_size, f"image size must be {img_size}"
 		# 解析输入 x 的形状,并确保输入的图像是正方形且大小与 image_size 相同。

 		# 生成一个随机的时间步长 timestamp,范围在 [0, num_timesteps) 内
        timestamp = torch.randint(0, self.num_timesteps, (1,)).long().to(device)
        x = self.normalize(x)
        return self.p_loss(x, timestamp)

Life is a journey. We pursue love and light with purity.

你的 “三连” 是小曦持续更新的动力!
下期将推出
扩散的代码实现,零距离解读扩散是如何实现的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/592733.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

c++set和map

目录 一、set的使用 1、set对象的创建 2、multiset 二、map的使用 1、map对象的创建 2、map的operator[] 序列式容器:vector、list、deque....单纯的存储数据,数据和数据之间没有关联 关联式容器:map、set.....不仅仅是存储数据&#x…

2000-2020年县域创业活跃度数据

2000-2020年县域创业活跃度数据 1、时间:2000-2020年 2、指标:地区名称、年份、行政区划代码、经度、纬度、所属城市、所属省份、年末总人口万人、户籍人口数万人、当年企业注册数目、县域创业活跃度1、县域创业活跃度2、县域创业活跃3 3、来源&#…

python数据可视化:显示两个变量间的关系散点图scatterplot()

【小白从小学Python、C、Java】 【计算机等考500强证书考研】 【Python-数据分析】 python数据可视化: 显示两个变量间的关系 散点图 scatterplot() [太阳]选择题 请问关于以下代码表述错误的选项是? import seaborn as sns import matplotlib.pyplot …

VISO流程图之子流程的使用

子流程的作用 整个流程图的框图多而且大,进行分块;让流程图简洁对于重复使用的流程,可以归结为一个子流程图,方便使用,避免大量的重复性工作; 新建子流程 方法1: 随便布局 框选3 和4 &#…

SQL:NOT IN与NOT EXISTS不等价

在对SQL语句进行性能优化时,经常用到一个技巧是将IN改写成EXISTS,这是等价改写,并没有什么问题。问题在于,将NOT IN改写成NOT EXISTS时,结果未必一样。 目录 一、举例验证二、三值逻辑简述三、附录:用到的S…

3.3Java全栈开发前端+后端(全栈工程师进阶之路)-前端框架VUE3框架-企业级应用-Vue组合式API

为什么要使用Composition API 一个Options API实例 在前面的课程中&#xff0c;我们都是采用 Options API&#xff08;基于选项的 API &#xff09; 来写一个组件的。下面是一个实例&#xff1a; <template> Count is: {{ count }}, doubleCount is: {{ doubleCount…

深入理解网络原理3----TCP核心特性介绍(上)【面试高频考点】

文章目录 前言TCP协议段格式一、确认应答【保证可靠性传输的机制】二、超时重传【保证可靠性传输的机制】三、连接管理机制【保证可靠性传输的机制】3.1建立连接&#xff08;TCP三次握手&#xff09;---经典面试题3.2断开连接&#xff08;四次挥手&#xff09;3.3TCP状态转换 四…

【skill】onedrive的烦人问题

Onedrive的迷惑行为 安装Onedrive&#xff0c;如果勾选了同步&#xff0c;会默认把当前用户的数个文件夹&#xff08;桌面、文档、图片、下载 等等&#xff09;移动到安装时提示的那个文件夹 查看其中的一个文件的路径&#xff1a; 这样一整&#xff0c;原来的文件收到严重影…

政安晨:【Keras机器学习示例演绎】(三十五)—— 使用 LayerScale 的类注意图像变换器

目录 简介 导入 层刻度层 随机深度层 类注意力 会说话的头注意力 前馈网络 其他模块 拼凑碎片&#xff1a;CaiT 模型 定义模型配置 模型实例化 加载预训练模型 推理工具 加载图像 获取预测 关注层可视化 结论 政安晨的个人主页&#xff1a;政安晨 欢迎 &#…

Topaz Video AI 5.0.3激活版 AI视频无损缩放增强

Topaz Video AI专注于很好地完成一些视频增强任务&#xff1a;去隔行&#xff0c;放大和运动插值。我们花了五年时间制作足够强大的人工智能模型&#xff0c;以便在真实世界的镜头上获得自然的结果。 Topaz Video AI 还将充分利用您的现代工作站&#xff0c;因为我们直接与硬件…

【数学建模】矩阵微分方程

一、说明 我相信你们中的许多人都熟悉微分方程&#xff0c;或者至少知道它们。微分方程是数学中最重要的概念之一&#xff0c;也许最著名的微分方程是布莱克-斯科尔斯方程&#xff0c;它控制着任何股票价格。 ​​ 股票价格的布莱克-斯科尔斯模型 微分方程可以由数学中的许多…

MidJourney提示词大全

大家好&#xff0c;我是无界生长。 这篇文章分享一下MidJourney提示词&#xff0c;篇幅内容有限&#xff0c;关注公众号&#xff1a;无界生长&#xff0c;后台回复&#xff1a;“MJ”&#xff0c;获取全部内容。 我是无界生长&#xff0c;如果你觉得我分享的内容对你有帮助&…

ArcGIS软件:地图投影的认识、投影定制

这一篇博客介绍的主要是如何在ArcGIS软件中查看投影数据&#xff0c;如何定制投影。 1.查看地图坐标系、投影数据 首先我们打开COUNTIES.shp数据&#xff08;美国行政区划图&#xff09;&#xff0c;并点击鼠标右键&#xff0c;再点击数据框属性就可以得到以下的界面。 我们从…

【Mac】graphpad prism for Mac(专业医学绘图工具) v10.2.3安装教程

软件介绍 GraphPad Prism for Mac是一款专业的科学数据分析和绘图软件&#xff0c;广泛用于生物医学和科学研究领域。它具有强大的统计分析功能&#xff0c;可以进行各种数据分析&#xff0c;包括描述性统计、生存分析、回归分析、方差分析等。同时&#xff0c;它还提供了丰富…

C++奇迹之旅:string类接口详解(上)

文章目录 &#x1f4dd;为什么学习string类&#xff1f;&#x1f309; C语言中的字符串&#x1f309;string考察 &#x1f320;标准库中的string类&#x1f309;string类的常用接口说明&#x1f320;string类对象的常见构造 &#x1f6a9;总结 &#x1f4dd;为什么学习string类…

FFmpeg学习记录(二)—— ffmpeg多媒体文件处理

1.日志系统 常用的日志级别&#xff1a; AV_LOG_ERRORAV_LOG_WARNINGAV_LOG_INFOAV_LOG_DEBUG #include <stdio.h> #include <libavutil/log.h>int main(int argc, char *argv[]) {av_log_set_level(AV_LOG_DEBUG);av_log(NULL, AV_LOG_DEBUG, "hello worl…

Cisco Nexus Dashboard 3.1(1k) - 云和数据中心网络管理软件

Cisco Nexus Dashboard 3.1(1k) - 云和数据中心网络管理软件 跨数据中心和云实现集中配置、运行和分析。 请访问原文链接&#xff1a;https://sysin.org/blog/cisco-nexus-dashboard/&#xff0c;查看最新版。原创作品&#xff0c;转载请保留出处。 作者主页&#xff1a;sys…

根据docker部署nginx并且实现https

目录 一、Docker中启用HTTPS有几个重要的原因 二、https介绍 三、https过程 四、安装docker-20.10.18 五、如何获取证书 通过阿里云获取证书 六、docker部署nginx并且实现https 6.1准备证书 6.2准备nginx.conf 和 index.html文件 6.3生成容器 6.4浏览器验证证书 一、…

文章解读与仿真程序复现思路——电力自动化设备EI\CSCD\北大核心《考虑碳捕集和电转气的综合能源系统优化调度》

本专栏栏目提供文章与程序复现思路&#xff0c;具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 论文与完整源程序_电网论文源程序的博客-CSDN博客https://blog.csdn.net/liang674027206/category_12531414.html 电网论文源程序-CSDN博客电网论文源…

STM32标准库控制一盏LED闪烁

实物连接&#xff1a; ## 软件编程&#xff1a;默认已经有一个工程模板&#xff0c;代码实现逻辑&#xff1a; 1、使用RCC开启GPIO的时钟&#xff1b; 2、使用GPIO初始化函数实现初始化GPIO 3、使用输入或输出的函数控制GPIO口 #include "stm32f10x.h" …
最新文章