2024不可不会的StableDiffusion之扩散模型(四)

1. 引言

这是我关于StableDiffusion学习系列的第四篇文章,如果之前的文章你还没有阅读,强烈推荐大家翻看前篇内容。在本文中,我们将学习构成StableDiffusion的第三个基础组件基于Unet的扩散模型,并针该组件的功能进行详细的阐述。

闲话少说,我们直接开始吧!

2. 概览

通常来说一个U-Net包含两个输入:
Noisy latent/Noise : 该Noisy latent主要是由VAE编码器产生并在其基础上添加了噪声;或者如果我们想仅根据文本描述来创建随机的新图像,则可以采用纯噪声作为输入。
Text embeddings: 基于CLIP的将文本输入提示转化为文本语义嵌入(embedding)
在这里插入图片描述

U-Net模型的输出是从包含输入噪声的Noisy Latents中预测其所包含的噪声。换句话说,它预测输出的为Noisy Latents减去de-noised latents后的结果。

3. 导入所需的库

让我们接着通过代码来了解 U-Net。我们将首先导入所需的库并加载我们的U-Net模型。

from diffusers import UNet2DConditionModel,LMSDiscreteScheduler

## Initializing a scheduler
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
## Initializing the U-Net model
sd_path = r'/media/stable_diffusion/stable-diffusion-v1-4'
unet = UNet2DConditionModel.from_pretrained(sd_path,subfolder="unet",
                                            local_files_only=True,
                                            torch_dtype=torch.float16).to("cuda")

4. 可视化Scheduler

正如大家从上面的代码中观察到的那样,我们不仅导入了 unet,还导入了一个scheduler库。scheduler的目的是确定在扩散过程中的给定的步骤中向latent 添加多少噪声。我们使用以下代码来可视化 scheduler 函数:

## Setting number of sampling steps
scheduler.set_timesteps(51)
plt.plot(scheduler.sigmas)
plt.xlabel("Sampling step")
plt.ylabel("sigma")
plt.title("Schedular routine")
plt.show()

得到结果如下:
在这里插入图片描述

可以看到,随着sample step的增大,我们添加噪声的权重在逐渐减小。

5. 可视化扩散过程

扩散过程遵循上述采样Scheduler,我们从高噪声开始,然后逐渐按照schedulerlatent_img添加对应权重的噪声。让我们可视化一下这个过程:

img_path = r'/home/VAEs/Scarlet-Macaw-2.jpg'
img = Image.open(img_path).convert("RGB").resize((512, 512))
latent_img = pil_to_latents(img, vae)
# Random noise
noise = torch.randn_like(latent_img) 
fig, axs = plt.subplots(2, 3, figsize=(16, 12))
for c, sampling_step in enumerate(range(0,51,10)):
    encoded_and_noised = scheduler.add_noise(latent_img, noise, 
                   timesteps=torch.tensor([scheduler.timesteps[sampling_step]]))
    axs[c//3][c%3].imshow(latents_to_pil(encoded_and_noised,vae)[0])
    axs[c//3][c%3].set_title(f"Step - {sampling_step}")
plt.show()

得到结果如下:
在这里插入图片描述

6. 生成Unet输入

接着我们需要对latent_img, 随机添加一些噪声,用以作为我们Unet输入的Noisy Latent , 代码如下:

encoded_and_noised = scheduler.add_noise(latent_img, noise, 
                     timesteps=torch.tensor([scheduler.timesteps[40]])) 
latents_to_pil(encoded_and_noised,vae)[0]

得到结果如下:

在这里插入图片描述

7. 调用Unet去噪

接着我们就可以使用Unet对其进行去噪了,相应的调用代码如下:

prompt = [""]
text_input = tokenizer(prompt, padding="max_length",
                       max_length=tokenizer.model_max_length,
                       truncation=True,return_tensors="pt")
with torch.no_grad():
    text_embeddings = text_encoder(text_input.input_ids.to("cuda"))[0]

latent_model_input = torch.cat( [encoded_and_noised.to("cuda").float()]).half()
with torch.no_grad():
    noise_pred = unet(latent_model_input,40,encoder_hidden_states=text_embeddings
        )["sample"]
out_images = latent_to_pil((encoded_and_noised - noise_pred),vae)
plt.imshow(out_images[0])
plt.show()

得到去噪后结果如下:

在这里插入图片描述

我们也可以使用以下代码来将Unet预测出来的噪声进行可视化,代码如下:

out_noises = latent_to_pil( noise_pred,vae)
plt.imshow(out_noises[0])
plt.show()

得到结果如下:
在这里插入图片描述

8. Unet 在SD中的用途

潜在扩散模型使用U-Net分几个步骤通过逐渐减去潜伏空间中的噪声,来达到所需要的输出。在每一步中,添加到latents中的噪声量都会减少,直到我们达到最终的去噪输出。

我们知道,在深度学习领域U-Nets最先被引入用于生物医学图像分割任务。U-Net 有一个编码器和解码器,它们由 ResNet blocks组成。在stable diffusion中的U-Net还具有交叉注意力层,使它们能够根据提供的文本描述来控制调整输出。交叉注意力层通常添加到U-Net的编码器和解码器之间。可以在此处了解有关此 U-Net 架构的更多信息。

9. 总结

本文重点介绍了SD模型中的Unet组件的相关功能和具体工作原理,并详细介绍了其去噪过程;至此,我们完成了稳定扩散模型的三个关键组件,即 CLIP 文本编码器、VAE U-Net。在下一篇文章中,我们将研究使用这些组件的扩散过程。

您学废了嘛!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/357468.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

RK3568平台开发系列讲解(Linux系统篇)platform 设备的注册

🚀返回专栏总目录 文章目录 一、platform_device_register 注册函数二、platform_device_unregister 反注册函数三、platform_device 结构体四、resource 结构体沉淀、分享、成长,让自己和他人都能有所收获!😄 一、platform_device_register 注册函数 platform_device_re…

海外云手机运营Instagram攻略

Instagram是世界著名的社交媒体平台,有着10亿实时用户,是跨境电子商务的优质流量来源。平台以女性用户为主,购物倾向高,转化率好。它被公认为外贸行业的优质社交媒体流量池。那么,如何使用海外云手机吸引Instagram上的…

【论文阅读】Long-Tailed Recognition via Weight Balancing(CVPR2022)

目录 论文使用方法weight decayMaxNorm 如果使用原来的代码报错的可以看下面这个 论文 问题:真实世界中普遍存在长尾识别问题,朴素训练产生的模型在更高准确率方面偏向于普通类,导致稀有的类别准确率偏低。 key:解决LTR的关键是平衡各方面&a…

力扣题集(第一弹)

一日练,一日功;一日不练十日空。 学编程离不开刷题,接下来让我们来看几个力扣上的题目。 1. 242. 有效的字母异位词 题目描述 给定两个字符串 s 和 t ,编写一个函数来判断 t 是否是 s 的字母异位词。 注意:若 s 和 t 中每个字符出现的次数…

JS图片二维码识别

前言 js识别QR图片&#xff0c;基于jsQR.js 代码 <!DOCTYPE html> <html> <head><meta charset"utf-8" /><title>图片二维码识别</title><script src"https://cdn.bootcss.com/jquery/3.4.1/jquery.min.js">…

什么是消息队列?

消息用队列的模式发送&#xff0c; 把要传输的数据放在队列中&#xff0c; 产生消息的叫做生产者&#xff0c; 从队列里取出消息的叫做消费者。 一、组成 生产者&#xff1a;Producer 消息的产生者与调用端 主要负责消息所承载的业务信息的实例化 是一个队列的发起方 代理…

网站小程序分类目录网源码系统+会员注册登录功能 附带完整的搭建教程

随着互联网的发展&#xff0c;小程序分类目录网站已经成为了人们获取各类信息的重要渠道。而在这个领域中&#xff0c;罗峰给大家分享一款网站小程序分类目录网源码系统以其强大的功能和易用性&#xff0c;脱颖而出。本系统集成了会员注册登录功能&#xff0c;让用户能够更加便…

uniapp H5 实现上拉刷新 以及 下拉加载

uniapp H5 实现上拉刷新 以及 下拉加载 1. 先上图 下拉加载 2. 上代码 <script>import DragableList from "/components/dragable-list/dragable-list.vue";import {FridApi} from /api/warn.jsexport default {data() {return {tableList: [],loadingHi…

Redis核心技术与实战【学习笔记】 - 6.Redis 的统计操作处理

1.前言 在 Web 业务场景中&#xff0c;我们经常保存这样一种信息&#xff1a;一个 key 对应了一个数据集合。比如&#xff1a; 手机 APP 中的每天用户登录信息&#xff1a;一天对应一系列用户 ID。电商网站上商品的用户评论列表&#xff1a;一个商品对应了一些列的评论。用户…

12 数据仓库理论

数仓基本概述 数据仓库基本概念 数据仓库是一个为数据分析而设计的企业级数据管理系统。数据仓库可集中 、整合多个信息源的大量数据。 数仓核心架构 数据仓库建模概述 数据仓库建模意义 数据模型就是数据组织和存储方法&#xff0c;它强调从业务、数据存取和使用角度合理…

Django配置websocket时的错误解决

基于移动群智感知的网络图谱构建系统需要手机app不断上传数据到服务器并把数据推到前端标记在百度地图上&#xff0c;由于众多手机向同一服务器发送数据&#xff0c;如果使用长轮询&#xff0c;则实时性差、延迟高且服务器的负载过大&#xff0c;而使用websocket则有更好的性能…

链表与二叉树-数据结构

链表与二叉树-数据结构 创建叶子node节点建立二叉树三元组&#xff1a;只考虑稀疏矩阵中非0的元素&#xff0c;并且存储到一个类&#xff08;三元组&#xff09;的数组中。 创建叶子node节点 class Node{int no;Node next;public Node(int no){this.nono;} } public class Lb…

YOLOv8改进 | 可视化热力图 | 支持YOLOv8最新版本密度热力图,和视频热力图

一、本文介绍 本文给大家带来的机制是集成了YOLOv8最新版本的可视化热力图功能,热力图作为我们论文当中的必备一环,可以展示出我们呈现机制的有效性,本文的内容支持YOLOv8最新版本的根据密度呈现的热力图,同时支持视频检测,根据视频中的密度来绘画热力图。 在开始之前给…

薅运营商羊毛?封杀!

最近边小缘在蓝点网上看到一则消息 “浙江联通也开始严格排查PCDN和PT等大流量行为 被检测到可能会封停宽带”。 此前中国联通已经在四川和上海等多个省市严查家庭宽带 (部分企业宽带也被查) 使用 PCDN 或 PT&#xff0c;当用户的宽带账户存在大量上传数据的情况&#xff0c;中…

数据库管理-第141期 DG PDB - Oracle DB 23c(20240129)

数据库管理141期 2024-01-29 第141期 DG PDB - Oracle DB 23c&#xff08;20240129&#xff09;1 概念2 环境说明3 操作3.1 数据库配置3.2 配置tnsname3.3 配置强制日志3.4 DG配置3.5 DG配置建立联系3.6 启用所有DG配置3.7 启用DG PDB3.8 创建源PDB的DG配置3.9 拷贝pdbprod1文件…

【C++】I/O多路转接详解(一)

目录 1. 背景引入1.1 IO的过程1.2 五种IO模型1.2.1 阻塞IO1.2.2 非阻塞IO1.2.3 信号驱动IO1.2.4 IO多路转接1.2.5 异步IO 1.3 同步通信 与 异步通信1.4 阻塞 与 非阻塞1.4.1 阻塞与非阻塞区别1.4.2 设置非阻塞IO 2. select2.1 接口使用2.2 select执行过程2.3 select代码实践 3.…

<网络安全>《9 入侵防御系统IPS》

1 概念 IPS&#xff08; Intrusion Prevention System&#xff09;是电脑网络安全设施&#xff0c;是对防病毒软件&#xff08;Antivirus Programs&#xff09;和防火墙&#xff08;Packet Filter, Application Gateway&#xff09;的补充。 入侵预防系统&#xff08;Intrusio…

JS第一课简单看看这是啥东西

1.什么是JavaScript JS是一门编程语言&#xff0c;是一种运行在客户端(浏览器)的编程语言&#xff0c;主要是让前端的画面动起来&#xff0c;注意HTML和CSS不是编程语言&#xff0c;他俩是一种标记语言。JS只要有浏览器就能运行不用跟Python或者Java一样上来装一个jdk或者Pyth…

2023年算法SAO-CNN-BiLSTM-ATTENTION回归预测(matlab)

2023年算法SAO-CNN-BiLSTM-ATTENTION回归预测&#xff08;matlab&#xff09; SAO-CNN-BiLSTM-Attention雪消融优化器优化卷积-长短期记忆神经网络结合注意力机制的数据回归预测 Matlab语言。 雪消融优化器( SAO) 是受自然界中雪的升华和融化行为的启发&#xff0c;开发了一种…

Docker入门篇(二)—— 命令

Docker入门篇&#xff08;二&#xff09;—— 命令 插播&#xff01;插播&#xff01;插播&#xff01;亲爱的朋友们&#xff0c;我们的Cmake/Makefile/Shell这三个课程上线啦&#xff01;感兴趣的小伙伴可以去下面的链接学习哦~ 构建工具大师-CSDN程序员研修院 一、引言 当…
最新文章