Pytorch实现扩散模型【DDPM代码解读篇1】

本篇内容属于对DDPM 原理-代码 项目的解读。

具体内容参考一篇推文,里面对DDPM讲解相对细致:

扩散模型的原理及实现(Pytorch)

下面主要是对其中源码的细致注解,帮助有需要的朋友更好理解代码。

目录

ConvNext块

 正弦时间戳嵌入

时间多层感知器

注意力

整合


ConvNext块

class ConvNextBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        mult=2,  # 输出通道数相对于输入通道数的倍数
        time_embedding_dim=None,  # 表示时间嵌入的维度
        norm=True,  # 是否使用归一化层
        group=8,  # 卷积操作的分组数量,默认为8
    ):
        super().__init__()
        # 多层感知机(MLP),用于处理时间嵌入。
        # 如果time_embedding_dim不为None,则创建一个包含GELU激活函数和线性层的序列;否则为None。
        self.mlp = (
            nn.Sequential(nn.GELU(), nn.Linear(time_embedding_dim, in_channels))
            if time_embedding_dim
            else None
        )
 		# in_conv是一个输入卷积层,对输入进行卷积操作。
        self.in_conv = nn.Conv2d(
            in_channels, in_channels, 7, padding=3, groups=in_channels
        )
 		# block是一个序列模块,包含一系列卷积操作和归一化层。这里使用了GELU作为激活函数。
        self.block = nn.Sequential(
            nn.GroupNorm(1, in_channels) if norm else nn.Identity(),
            nn.Conv2d(in_channels, out_channels * mult, 3, padding=1),
            nn.GELU(),
            nn.GroupNorm(1, out_channels * mult),
            nn.Conv2d(out_channels * mult, out_channels, 3, padding=1),
        )
 		# residual_conv是一个残差连接的卷积层,用于调整输入和输出的通道数,
 		# 如果输入通道数和输出通道数不同,则使用1x1卷积进行调整;否则为恒等映射。
        self.residual_conv = (
            nn.Conv2d(in_channels, out_channels, 1)
            if in_channels != out_channels
            else nn.Identity()
        )
 
    def forward(self, x, time_embedding=None):
        h = self.in_conv(x)  # 首先对输入x进行输入卷积操作。
        # 如果mlp不为None且time_embedding不为None,则对时间嵌入进行处理并与输入相加。
        if self.mlp is not None and time_embedding is not None:
            assert self.mlp is not None, "MLP is None"
            h = h + rearrange(self.mlp(time_embedding), "b c -> b c 1 1")  # 然后将处理后的特征输入到块中进行卷积操作。
        h = self.block(h)
        return h + self.residual_conv(x)  # 最后将卷积结果与输入进行残差连接,并返回。

 正弦时间戳嵌入

# 通常用于为序列数据添加位置信息。
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim, theta=10000):
        super().__init__()
        self.dim = dim  # 位置编码的维度。
        self.theta = theta  # theta是用于计算位置编码的参数,默认值为10000。
 
    def forward(self, x):
        device = x.device  # 首先获取输入x的设备信息。
        half_dim = self.dim // 2  # 然后计算位置编码的维度一半的值half_dim
        emb = math.log(self.theta) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)  # 生成位置编码矩阵emb,其中每一行对应一个位置的编码,使用正弦和余弦函数计算。
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


# DownSample & UpSample 上下采样
class DownSample(nn.Module):
    def __init__(self, dim, dim_out=None):
        super().__init__()
        self.net = nn.Sequential(
        	# 用于对输入进行重新排列,将2x2的空间块转换为通道数的维度,从而将空间维度减小为原来的四分之一。
        	'''
        		Rearrange层用于对输入的张量进行重新排列,将其从四维张量(batch size、通道数、高度、宽度)转换为新的形状
        		b:表示batch size,保持不变。
				c:表示通道数,保持不变。
				(h p1)和(w p2):表示对高度和宽度进行的操作。p1和p2是两个额外的参数,用于指定在高度和宽度上的扩展倍数。这意味着将输入的高度和宽度分别扩展为原来的p1倍和p2倍。
				b (c p1 p2) h w:表示输出张量的形状,其中通道数乘以高度和宽度。这样做的效果是将原始的空间维度拼接到通道维度后面,使得输出的张量变为三维(batch size、新的通道数、高度、宽度)。
        	'''
            Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
            # 接着是一个1x1的卷积层,用于将输入通道数变换为dim_out或者保持不变
            nn.Conv2d(dim * 4, default(dim_out, dim), 1),
        )
 
    def forward(self, x):
        return self.net(x)  # 将输入通过Sequential网络进行前向传播,返回处理后的结果。
 
 
 class Upsample(nn.Module):
    def __init__(self, dim, dim_out=None):
        super().__init__()
        self.net = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),  # 用于对输入进行上采样,采用最近邻插值的方式,并将图像沿着两个维度放大两倍。
            nn.Conv2d(dim, dim_out or dim, kernel_size=3, padding=1), # 接着是一个3x3的卷积层,用于将输入通道数变换为dim_out或者保持不变
        )
 
    def forward(self, x):
        return self.net(x)

时间多层感知器

sinu_pos_emb = SinusoidalPosEmb(dim, theta=10000)  # 用于生成正弦和余弦位置编码
 
  time_dim = dim * 4  # 四倍维度,通常用于增加时间信息的表示能力
 
  time_mlp = nn.Sequential(
      sinu_pos_emb,  # 将输入的时间信息进行正弦和余弦位置编码
      nn.Linear(dim, time_dim),  # 一个线性层,将输入维度dim映射为time_dim,以增加时间信息的表示能力。
      nn.GELU(),  # 激活函数,用于引入非线性。
      nn.Linear(time_dim, time_dim),  # 另一个线性层,将time_dim映射回time_dim,以保持输出维度不变。
  )

注意力

class BlockAttention(nn.Module):
	# gate_in_channel:门输入的通道数, residual_in_channel:残差输入的通道数, scale_factor:尺度因子,用于初始化门和残差卷积层的权重。

    def __init__(self, gate_in_channel, residual_in_channel, scale_factor):
        super().__init__()
        self.gate_conv = nn.Conv2d(gate_in_channel, gate_in_channel, kernel_size=1, stride=1)
        self.residual_conv = nn.Conv2d(residual_in_channel, gate_in_channel, kernel_size=1, stride=1)
        self.in_conv = nn.Conv2d(gate_in_channel, 1, kernel_size=1, stride=1)  # 输入卷积层,将门和残差的输出进行卷积处理,将结果映射为范围在0到1之间的值
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
 	
 	# 前向传播方法接受两个张量作为输入:x表示残差输入,g表示门输入。
    def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
        in_attention = self.relu(self.gate_conv(g) + self.residual_conv(x))
        in_attention = self.in_conv(in_attention)
        in_attention = self.sigmoid(in_attention)
        return in_attention * x

整合

将前面讨论的所有块(不包括注意力块)整合到一个Unet中。
每个块都包含两个残差连接,而不是一个。
这个修改是为了解决潜在的过度拟合问题。

# 这个模块实现了一个双重残差 U 型网络,用于图像处理任务,如图像去噪、超分辨率等。
class TwoResUNet(nn.Module):
    def __init__(
        self,
        dim,
        init_dim=None,
        out_dim=None,
        dim_mults=(1, 2, 4, 8),
        channels=3,
        sinusoidal_pos_emb_theta=10000,
        convnext_block_groups=8,  # 卷积块中的分组数
    ):
        super().__init__()
        self.channels = channels
        input_channels = channels
        # init_dim 不为 None,则返回 init_dim;否则返回 dim。这样做的目的是提供了一种灵活的方式,允许用户在初始化模型时选择是否指定初始维度,如果未指定,则使用输入的维度作为初始维度。
        self.init_dim = default(init_dim, dim)
        self.init_conv = nn.Conv2d(input_channels, self.init_dim, 7, padding=3)
 
        dims = [self.init_dim, *map(lambda m: dim * m, dim_mults)]
        '''
            使用 map 函数对 dim_mults 中的每个值 m 进行操作,将其乘以 dim。
            这样可以得到一个新的列表,其中的每个值都是 dim 与 dim_mults 中的相应值相乘得到的结果。
            * 运算符用于解包操作,将 map 函数生成的结果解包成单独的元素。
        '''
        in_out = list(zip(dims[:-1], dims[1:]))
        # 使用 zip 函数将 dims 中的相邻两个元素组合成一个元组。这样可以得到一个列表,其中每个元素都是一个包含相邻两个阶段的输入通道数和输出通道数的元组。

        sinu_pos_emb = SinusoidalPosEmb(dim, theta=sinusoidal_pos_emb_theta)
 
        time_dim = dim * 4
 
        self.time_mlp = nn.Sequential(
            sinu_pos_emb,
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim),
        )
 
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)  # 计算了图像的分辨率数,存储在 num_resolutions 中
        
        # 下面的循环用于创建下采样部分(downs):
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)
 
            self.downs.append(
                nn.ModuleList(
                    [
                        ConvNextBlock(
                            in_channels=dim_in,
                            out_channels=dim_in,
                            time_embedding_dim=time_dim,
                            group=convnext_block_groups,
                        ),
                        ConvNextBlock(
                            in_channels=dim_in,
                            out_channels=dim_in,
                            time_embedding_dim=time_dim,
                            group=convnext_block_groups,
                        ),
                        DownSample(dim_in, dim_out)
                        if not is_last
                        else nn.Conv2d(dim_in, dim_out, 3, padding=1),
                    ]
                )
            )
        
        # 创建了中间残差块 mid_block1 和 mid_block2:
        mid_dim = dims[-1]  # 通常会将最后一个阶段的输出通道数作为中间残差块的输入通道数
        self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, time_embedding_dim=time_dim)
        self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, time_embedding_dim=time_dim)
 
        # 下面的循环用于创建上采样部分(ups):
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)
            is_first = ind == 0
 
            self.ups.append(
                nn.ModuleList(
                    [
                        ConvNextBlock(
                            in_channels=dim_out + dim_in,
                            out_channels=dim_out,
                            time_embedding_dim=time_dim,
                            group=convnext_block_groups,
                        ),
                        ConvNextBlock(
                            in_channels=dim_out + dim_in,
                            out_channels=dim_out,
                            time_embedding_dim=time_dim,
                            group=convnext_block_groups,
                        ),
                        Upsample(dim_out, dim_in)
                        if not is_last
                        else nn.Conv2d(dim_out, dim_in, 3, padding=1)
                    ]
                )
            )
 
        default_out_dim = channels
        self.out_dim = default(out_dim, default_out_dim)
 
        # 创建了最终的残差块 final_res_block 和输出卷积层 final_conv:
        self.final_res_block = ConvNextBlock(dim * 2, dim, time_embedding_dim=time_dim)
        self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
 
    def forward(self, x, time):
        b, _, h, w = x.shape
        x = self.init_conv(x)  # 对输入张量进行初始卷积操作
        r = x.clone()  # 克隆 x
 
        t = self.time_mlp(time)  # 使用时间多层感知器处理时间信息
 
        unet_stack = []  # 创建一个空列表 unet_stack,用于存放下采样阶段的特征。

        # 对每个下采样模块进行操作:先执行两个卷积块,然后执行下采样,并将特征存储在 unet_stack 中。
        for down1, down2, downsample in self.downs:
            x = down1(x, t)
            unet_stack.append(x)
            x = down2(x, t)
            unet_stack.append(x)
            x = downsample(x)
        
        # 中间残差块
        x = self.mid_block1(x, t)
        x = self.mid_block2(x, t)
        
        # 对每个上采样模块进行操作:从 unet_stack 中取出特征,与当前特征拼接后执行两个卷积块,然后执行上采样。
        for up1, up2, upsample in self.ups:
            x = torch.cat((x, unet_stack.pop()), dim=1)
            x = up1(x, t)
            x = torch.cat((x, unet_stack.pop()), dim=1)
            x = up2(x, t)
            x = upsample(x)
 
        # 将初始特征 r 与最终的特征拼接后,执行最终的残差块和输出卷积层,得到最终的输出。
        x = torch.cat((x, r), dim=1)
        x = self.final_res_block(x, t)
 
        return self.final_conv(x)

 Life is a journey. We pursue love and light with purity.

你的 “三连” 是小曦持续更新的动力!
下期将推出
扩散的代码实现,零距离解读扩散是如何实现的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/592381.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

快速构建vscode pytest 开发测试环境

如果不想用 heavy 的pycharm vscode 也是1个很好的选择 安装python SDK pacman -S python [gatemanmanjaro-x13 tmp]$ pacman -Q python python 3.11.8-1安装Vscode 很多中方法 yay -S visual-studio-code-bin [gatemanmanjaro-x13 tmp]$ pacman -Q | grep -i visual visua…

关于怎么计算重复路段的算法问题

这是代码&#xff1a; #define _CRT_SECURE_NO_WARNINGS #define MAX_TEST 100 #include<stdio.h> int main() {int l, m;int u0, v0,j0;int arr[MAX_TEST][10] {0};int count0;scanf("%d", &l);scanf("%d", &m);if (l < 500 &…

性能优化(一):ArrayList还是LinkedList?

引言 集合作为一种存储数据的容器&#xff0c;是我们日常开发中使用最频繁的对象类型之一。JDK为开发者提供了一系列的集合类型&#xff0c;这些集合类型使用不同的数据结构来实现。因此&#xff0c;不同的集合类型&#xff0c;使用场景也不同。 很多同学在面试的时候&#x…

短视频矩阵系统ai剪辑 矩阵 文案 无人直播四合一功能核心独家源头saas开发

抖去推矩阵AI小程序是一款针对短视频平台的智能创作和运营工具&#xff0c;它具有以下功能特点&#xff1a; 1.批量视频生成&#xff1a;抖去推可以在短时间内生成大量视频&#xff0c;帮助商家快速制作出适合在短视频平台上推广的内容 2.全行业覆盖&#xff1a;适用于多个行业…

环形链表面试题详解

A. 环形链表1 给你一个链表的头节点 head &#xff0c;判断链表中是否有环. 如果链表中有某个节点&#xff0c;可以通过连续跟踪 next 指针再次到达&#xff0c;则链表中存在环。 为了表示给定链表中的环&#xff0c;评测系统内部使用整数 pos 来表示链表尾连接到链表中的位置…

每日OJ题_贪心算法二⑦_力扣942. 增减字符串匹配

目录 力扣942. 增减字符串匹配 解析代码 力扣942. 增减字符串匹配 942. 增减字符串匹配 难度 简单 由范围 [0,n] 内所有整数组成的 n 1 个整数的排列序列可以表示为长度为 n 的字符串 s &#xff0c;其中: 如果 perm[i] < perm[i 1] &#xff0c;那么 s[i] I 如果 …

奈氏准则和香农定理

一、奈奎斯特和香农 哈里奈奎斯特&#xff08;Harry Nyquist&#xff09;(左) 克劳德艾尔伍德香农&#xff08;Claude Elwood Shannon&#xff09;(右) 我们应该在心里记住他们&#xff0c;记住所有为人类伟大事业做出贡献的人&#xff0c;因为他们我们的生活变得越来越精彩&…

计算机毕业设计Python+Spark考研预测系统 考研推荐系统 考研数据分析 考研大数据 大数据毕业设计 大数据毕设

安顺学院本科毕业论文(设计)题目申请表 院别&#xff1a;数学与计算机科学 专业&#xff1a;数据科学与大数据 时间&#xff1a;2022年 5月26日 题 目 情 况 题目名称 基于hive数据仓库的考研信息离线分析系统的设计与实现 学生姓名 杨娣荧 学号 201903144042 …

【Hadoop】--基于hadoop和hive实现聊天数据统计分析,构建聊天数据分析报表[17]

目录 一、需求分析 1、背景介绍 2、目标 3、需求 4、数据内容 5、建库建表 二、ETL数据清洗 1、数据问题 2、需求 3、实现 4、扩展概念&#xff1a;ETL 三、指标计算 1、指标1&#xff1a;统计今日消息总量 2、指标2&#xff1a;统计每小时消息量、发送量和接收用…

shpfile转GeoJSON;控制shp转GeoJSON的精度;如何获取GeoJSON;GeoJSON是什么有什么用;GeoJSON结构详解(带数据示例)

目录 一、GeoJSON是什么 二、GeoJSON的结构组成 2.1、点&#xff08;Point&#xff09;数据示例 2.2、线&#xff08;LineString&#xff09;数据示例 2.3、面&#xff08;Polygon&#xff09;数据示例 2.4、特征&#xff08;Feature&#xff09;数据示例 2.5、特征集合&…

Leetcode—1056. 易混淆数【简单】Plus

2024每日刷题&#xff08;126&#xff09; Leetcode—1056. 易混淆数 &#x1f4a9;山实现代码 class Solution { public:bool confusingNumber(int n) {int arr[10] {0};int notNum 0;int arr2[12] {0};int size 0;while(n) {int x n % 10;arr[x] 1;arr2[size] x;if(…

自动化测试适用场景

日常大家都用自动化去写测试脚本。但是自动化可不仅仅可以工作写脚本&#xff0c;还可以应用到如下领域&#xff1a; 1. 自动化测试脚本&#xff1a;自动化测试是软件测试 领域中最常见的自动化应用领域。它可以通过 自动化测试工具和脚本来自动化执行测试用例 &#xff0c…

水仙花数问题

问题描述&#xff1a; 求出0&#xff5e;100000之间的所有“水仙花数”并输出。 “水仙花数”是指一个n位数&#xff0c;其各位数字的n次方之和确好等于该数本身&#xff0c;如:153&#xff1d;1^3&#xff0b;5^3&#xff0b;3^3&#xff0c;则153是一个“水仙花数”。 #in…

VS Code 保存+格式化代码

在 VSCode 中&#xff0c;使用 Ctrl S 快捷键直接保存并格式化代码&#xff1a; 打开 VSCode 的设置界面&#xff1a;File -> Preferences -> Settings在设置界面搜索框中输入“format on save”&#xff0c;勾选“Editor: Format On Save”选项&#xff0c;表示在保存…

Java 【数据结构】常见排序算法实用详解(下) 冒泡排序/快速排序/归并排序/非基于比较排序【贤者的庇护】

登神长阶 上古神器-常见排序算法 冒泡排序/快速排序/归并排序/非基于比较排序 &#x1f4b0;一.前言 为保障知识获取的可读性&#xff0c;以及连贯性&#xff0c;再开始可以适当的重新温习前文内容 &#xff1a;Java 【数据结构】常见排序算法实用详解&#xff08;上&#xf…

TWS 蓝牙耳机 ESD EOS保护方案

1. TWS 蓝牙耳机 TWS&#xff08;True Wireless Stereo&#xff09;蓝牙耳机是指没有传统连接线的完全无线耳机&#xff0c;通常由两个分别放置在耳朵中的独立耳机组成&#xff0c;提供立体声音效。这类耳机在近年来越来越受欢迎&#xff0c;因为它们提供了更自由、更便捷的音…

有限单元法-编程与软件应用(崔济东、沈雪龙)【PDF下载】

专栏导读 作者简介&#xff1a;工学博士&#xff0c;高级工程师&#xff0c;专注于工业软件算法研究本文已收录于专栏&#xff1a;《有限元编程从入门到精通》本专栏旨在提供 1.以案例的形式讲解各类有限元问题的程序实现&#xff0c;并提供所有案例完整源码&#xff1b;2.单元…

MLP手写数字识别(3)-使用tf.data.Dataset模块制作模型输入(tensorflow)

1、tensorflow版本查看 import tensorflow as tfprint(Tensorflow Version:{}.format(tf.__version__)) print(tf.config.list_physical_devices())2、MNIST数据集下载与预处理 (train_images,train_labels),(test_images,test_labels) tf.keras.datasets.mnist.load_data()…

JSON.toJSONString() 输出 “$ref“:“$[0]“问题解决及原因分析

一、背景 在构建一个公共的批处理方法类的时候&#xff0c;在测试输出的时候&#xff0c;打印了" r e f " : " ref":" ref":"[0][0]"的内容&#xff0c;这让我比较疑惑。不由得继续了下去… 二、问题分析 首先&#xff0c;我们需要…
最新文章