【论文阅读】Consistency Models

文章目录

  • Introduction
  • Diffusion Models
  • Consistency Models
    • Definition
    • Parameterization
    • Sampling
  • Training Consistency Models via Distillation
  • Training Consistency Models in Isolation
  • Experiment

Introduction

  • 相比于单步生成的模型(例如 GANs, VAEs, normalizing flows),扩散模型的迭代式生成过程需要 10 到 2000 步计算来采样,导致推理速度低,实时性应用受限.

  • 本文的目的是创造高效、单步的生成,同时不牺牲迭代采样的优势。在数据到噪声的 PF-ODE 轨迹上,学习轨迹上任意点到轨迹起点的映射,对这些映射的建模成为 consistency model.
    在这里插入图片描述

  • 两种训练 consistency model的方法

    1. 使用 numerical ODE solver 和预训练的扩散模型在 PF-ODE 轨迹上生成若干相邻点对,通过最小化模型输出点对间的距离(相似度),蒸馏出 consistency model.
    2. 不依赖预训练扩散模型,独立训练一个 consistency model.
  • 在一些数据集上测试.

Diffusion Models

使用 p d a t a ( x ) p_{data}(\mathrm{x}) pdata(x)表示数据分布,扩散模型使用如下随机微分公式对服从原分布的数据进行扩散:

d x t = μ ( , x t , t ) + σ ( t ) d w t \large \mathrm{dx}_t = \mu(\mathrm,{x}_t, t) + \sigma(t)\mathrm{dw}_t dxt=μ(,xt,t)+σ(t)dwt

其中 t t t为时间步,范围是 0 0 0 T T T μ ( ⋅ , ⋅ ) \mu(·,·) μ(⋅,⋅) σ ( ⋅ ) \sigma(·) σ()分别是布朗运动中的漂移系数和扩散系数, x t \mathbf{x}_t xt服从分布 p t ( x ) p_{t}(\mathrm{x}) pt(x) x 0 \mathrm{x}_0 x0服从分布 p d a t a ( x ) p_{data}(\mathrm{x}) pdata(x). 该方程的一个重要属性是,其存在一个 PF-ODE 方程:

d x t = [ μ ( x t , t ) − 1 2 σ ( t ) 2 ∇ log ⁡ p t ( x t ) ] d t \large\mathrm{dx}_t = \left[ \mu(\mathrm{x}_t, t)-\frac{1}{2}\sigma(t)^2 \nabla\log{p_t(\mathrm{x}_t)} \right]\mathrm{d}t dxt=[μ(xt,t)21σ(t)2logpt(xt)]dt

其中 ∇ log ⁡ p t ( x ) \nabla\log{p_t(\mathrm{x})} logpt(x) p t ( x ) p_t(\mathrm{x}) pt(x)的 score function.
在 SDE 中,令漂移系数 μ ( x , t ) = 0 \mu(\mathrm{x}, t) = 0 μ(x,t)=0, 扩散系数 σ ( t ) = 2 t \sigma(t) = \sqrt{2t} σ(t)=2t . 使用得分匹配的方式训练模型 s ϕ ( x , t ) ≈ ∇ log ⁡ p t ( x ) s_{\phi}(\mathrm{x},t) \approx \nabla\log{p_t(\mathrm{x})} sϕ(x,t)logpt(x),代入 PF-ODE 方程,得到 empirical PF-ODE:

d x t d t = − t s ϕ ( x t , t ) \large \frac{\mathrm{dx}_t}{\mathrm{d}t}=-ts_{\phi}(\mathrm{x}_t,t) dtdxt=tsϕ(xt,t

采样时,使用 x ^ T ∼ N ( 0 , T 2 I ) \hat{\mathrm{x}}_T\sim\mathcal{N}(0, T^2I) x^TN(0,T2I)初始化,再使用 numerical ODE solver(例如 Euler, Heun)按时间步倒推出 x ^ 0 \hat{x}_0 x^0. 为了防止数值不稳定,会在 t = ϵ t=\epsilon t=ϵ是提前终止, ϵ \epsilon ϵ为一个正小数,同时将 x ^ ϵ \hat{\mathrm{x}}_{\epsilon} x^ϵ作为结果.

扩散模型的瓶颈在于采样速度慢, ODE solver 利用得分模型 s ϕ ( x , t ) s_{\phi}(\mathrm{x},t) sϕ(x,t)迭代求解,消耗算力多. 目前存在一些更快的 ODE solver,但是仍然需要大于 10 10 10 步的采样. 也存在一些蒸馏方法,但是大多数方法需要从扩散模型中采集巨大的数据集,同样消耗算力多.

Consistency Models

Definition

根据 PF-ODE 得到一条解路径 { x t } t ∈ [ ϵ , T ] \{\mathrm{x}_t\}_{t\in[\epsilon, T]} {xt}t[ϵ,T],将 consistency function 定义为:

f : ( x t , t ) ↦ x ϵ \large f:(\mathrm{x}_t, t) \mapsto \mathrm{x}_{\epsilon} f:(xt,t)xϵ

对于该路径上的任意点 ( x t , t ) (\mathrm{x}_t, t) (xt,t),其输出是一致的. 对于任意的 t , t ′ ∈ [ ϵ , T ] t, t' \in [\epsilon, T] t,t[ϵ,T],有 f ( x t , t ) = f ( x t ′ , t ′ ) f(\mathrm{x}_t, t) =f(\mathrm{x}_{t'}, t') f(xt,t)=f(xt,t)恒成立.
在这里插入图片描述

Parameterization

F θ ( x , t ) F_{\theta}(\mathrm{x}, t) Fθ(x,t)表示任意形式的神经网络,使用 sikp connection 可以将模型表示为:

f θ ( x , t ) = c s k i p ( t ) x + c o u t ( t ) F θ ( x , t ) \large f_{\theta}(\mathrm{x}, t)=c_{skip}(t)\mathrm{x}+c_{out}(t)F_{\theta}(\mathrm{x},t) fθ(x,t)=cskip(t)x+cout(t)Fθ(x,t)

其中边界条件为 c s k i p ( ϵ ) = 1 c_{skip}(\epsilon)=1 cskip(ϵ)=1 c o u t ( ϵ ) = 0 c_{out}(\epsilon)=0 cout(ϵ)=0.
具体为:

c s k i p ( t ) = σ d a t a 2 ( t − ϵ ) 2 + σ d a t a 2 \large c_{skip}(t)=\frac{\sigma_{data}^2}{(t-\epsilon)^2+\sigma_{data}^2} cskip(t)=(tϵ)2+σdata2σdata2

c o u t ( t ) = σ d a t a ( t − ϵ ) σ d a t a 2 + t 2 \large c_{out}(t)=\frac{\sigma_{data}(t-\epsilon)}{\sqrt{\sigma_{data}^2+t^2}} cout(t)=σdata2+t2 σdata(tϵ)

σ d a t a \sigma_{data} σdata取值 0.5 0.5 0.5.

Sampling

有了一个训练好的 consistency model f θ ( ⋅ , ⋅ ) f_{\theta}(·, ·) fθ(⋅,⋅)之后,从高斯噪声 N ( 0 , T 2 I ) \mathcal{N}(0, T^2I) N(0,T2I)采样 x ^ T \hat{\mathrm{x}}_T x^T,再代入模型一步推出 x ^ ϵ = f θ ( x T ^ , T ) \hat{\mathrm{x}}_{\epsilon}=f_{\theta}(\hat{\mathrm{x}_T}, T) x^ϵ=fθ(xT^,T).为了提高质量,也可以进行多步采样,算法如下:

在这里插入图片描述

Training Consistency Models via Distillation

作者的第一个方法是在预训练的得分模型 s ϕ ( x , t ) s_{\phi}(\mathrm{x},t) sϕ(x,t)上蒸馏.

首先考虑将 ϵ \epsilon ϵ T T T的时间离散化成 N − 1 N-1 N1 个间隔,也即 t 1 = ϵ < t 2 < t 3 < . . . < t N = T t_1=\epsilon<t_2<t_3<...<t_N=T t1=ϵ<t2<t3<...<tN=T. 在实践中,使用如下公式:

t i = ( ϵ 1 / ρ + i − 1 N − 1 ( T 1 / ρ − ϵ 1 / ρ ) ) ρ \large t_i=\left(\epsilon^{1/\rho} + \frac{i-1}{N-1}\left(T^{1/\rho}-\epsilon^{1/\rho}\right) \right)^{\rho} ti=(ϵ1/ρ+N1i1(T1/ρϵ1/ρ))ρ

其中 ρ = 7 \rho=7 ρ=7. 当 N N N充分大时,可以获得 x t n \mathrm{x}_{t_n} xtn x t n + 1 \mathrm{x}_{t_{n+1}} xtn+1的准确估计,于是 x ^ t n ϕ \hat{\mathrm{x}}_{t_n}^{\phi} x^tnϕ可以定义为:

x ^ t n ϕ = x t n + 1 + ( t n − t n + 1 ) Φ ( x t n + 1 , t n + 1 ; ϕ ) \large \hat{\mathrm{x}}_{t_n}^{\phi}=\mathrm{x}_{t_{n+1}} + (t_n-t_{n+1})\Phi(\mathrm{x}_{t_{n+1}}, t_{n+1};\phi) x^tnϕ=xtn+1+(tntn+1)Φ(xtn+1,tn+1;ϕ)

Φ ( . . . ; ϕ ) \Phi(...;\phi) Φ(...;ϕ)为 one-step ODE solver(比如Euler).

从数据集中采样 x \mathrm{x} x,通过 SDE 加噪 N ( x , t n + 1 2 I ) \mathcal{N}(\mathrm{x}, t_{n+1}^2I) N(x,tn+12I)得到 x t n + 1 \mathrm{x}_{t_{n+1}} xtn+1, 然后使用 ODE solver 求解出 x ^ t n ϕ \hat{\mathrm{x}}_{t_n}^{\phi} x^tnϕ,通过最小化在 x ^ t n ϕ \hat{\mathrm{x}}_{t_n}^{\phi} x^tnϕ x t n + 1 \mathrm{x}_{t_{n+1}} xtn+1计算结果的差距训练模型.

Definition 1
consistency distillation loss (CD)表示为:

L C D N ( θ , θ − ; ϕ ) = E [ λ ( t n ) d ( f θ ( x t n + 1 , t n + 1 ) , f θ − ( x ^ t n ϕ , t n ) ] \large \mathcal{L}_{CD}^{N}(\theta, \theta^-;\phi)=\mathbb{E}\left[\lambda(t_n)d(f_{\theta}(\mathrm{x}_{t_{n+1}},t_{n+1}),f_{\theta^-}(\hat{\mathrm{x}}_{t_n}^{\phi}, t_n) \right] LCDN(θ,θ;ϕ)=E[λ(tn)d(fθ(xtn+1,tn+1),fθ(x^tnϕ,tn)]

其中, λ ( ⋅ ) ∈ R + \lambda(·)\in\mathbb{R}^+ λ()R+是正权重函数, θ − \theta^- θ θ \theta θ在优化过程中历史值的均值. d ( ⋅ , ⋅ ) d(·,·) d(⋅,⋅)是一个度量函数,满足当且仅当两个输入相等时为 0 0 0,其余情况大于 0 0 0.

作者考虑 d ( ⋅ , ⋅ ) d(·,·) d(⋅,⋅) 使用 l 1 l_1 l1 以及 l 2 l_2 l2,在实验中 λ ( t n ) ≡ 1 \lambda(t_n) \equiv1 λ(tn)1表现较好. θ − \theta^- θ使用 EMA 更新,计算公式如下:

θ − ← s t o p g a r d ( μ θ − + ( 1 − μ ) θ ) \large \theta^- \leftarrow \mathrm{stopgard}(\mu\theta^-+(1-\mu)\theta) θstopgard(μθ+(1μ)θ)

其中 0 ≤ μ < 1 0\le\mu<1 0μ<1. 使用 EMA 可以使训练更稳定,同时能提高模型的表现.
模型训练算法如下:
在这里插入图片描述

Training Consistency Models in Isolation

consistency model 可以不依赖预训练扩散模型训练,使用如下无偏估计替换 ∇ log ⁡ p t ( x ) \nabla\log{p_t(\mathrm{x})} logpt(x)

∇ log ⁡ p t ( x ) = − E [ x t − x t 2 ∣ x t ] \large \nabla\log{p_t(\mathrm{x})}=-\mathbb{E}\left[\left.\frac{\mathrm{x}_t-\mathrm{x}}{t^2}\right|\mathrm{x}_t \right] logpt(x)=E[t2xtx xt]

consistency training loss (CT)表示为:

L C D N ( θ , θ − ) = E [ λ ( t n ) d ( f θ ( x + t n + 1 z , t n + 1 ) , f θ − ( x + t n z , t n ) ] \large \mathcal{L}_{CD}^{N}(\theta, \theta^-)=\mathbb{E}\left[\lambda(t_n)d(f_{\theta}(\mathrm{x}+t_{n+1}\mathrm{z},t_{n+1}),f_{\theta^-}(\mathrm{x}+t_{n}\mathrm{z},t_{n}) \right] LCDN(θ,θ)=E[λ(tn)d(fθ(x+tn+1z,tn+1),fθ(x+tnz,tn)]

其中 z ∼ N ( 0 , I ) \mathrm{z}\sim\mathcal{N}(0,I) zN(0,I). 损失函数的计算依赖于 f θ f_{\theta} fθ f θ − f_{\theta^-} fθ,且与扩散模型的无关.

为了提升模型效果,使用 schedule function N ( ⋅ ) N(·) N()控制 N N N 增长. 直觉上,当 N N N 小的时候,使用 consistency distillation loss 模型在一开始收敛更快,同时方差小、偏差大. 反之,在训练结束时,应当使 N N N 大,这样方差大、偏差小。同时,使用 schedule function μ ( ⋅ ) \mu(·) μ()替换 μ \mu μ,让它随着 N N N 增长而变化.
N ( ⋅ ) N(·) N() μ ( ⋅ ) \mu(·) μ()具体为

N ( k ) = ⌈ k K ( ( s 1 + 1 ) 2 − s 0 2 ) + s 0 2 − 1 ⌉ + 1 \large N(k)= \left\lceil\sqrt{\frac{k}{K}((s_1+1)^2-s_0^2)+s_0^2}-1 \right\rceil+1 N(k)= Kk((s1+1)2s02)+s02 1 +1

μ ( k ) = exp ⁡ ( s 0 log ⁡ μ 0 N ( k ) ) \large \mu(k)=\exp\left(\frac{s_0\log{\mu_0}}{N(k)}\right) μ(k)=exp(N(k)s0logμ0)

K K K表示整体训练步数, s 0 s_0 s0表示开始的离散化步数.

训练算法如下:
在这里插入图片描述

Experiment

关于 CD ,作者分别使用 l 1 l_1 l1, l 2 l_2 l2, L P I P S \mathrm{LPIPS} LPIPS作为度量函数,使用一阶Euler和二阶Heun座位 ODE solver, N N N { 9 , 12 , 18 , 36 , 50 , 60 , 80 , 120 } \{9,12,18,36,50,60,80,120\} {9,12,18,36,50,60,80,120},使用相应的预训练扩散模型做初始化. 使用 CT 训练的模型则随机初始化.
在这里插入图片描述

(a) 对比不同的度量函数在 CD 上的表现,其中 LPIPS 的效果最好.
(b, c) 对不不同 ODE solver 和 N N NCD 上的表现,使用 Heun 且 N N N 18 18 18时效果最好.在取相同的 N N N时,二阶Heun的表现优于一阶Euler,因为高阶的 ODE solver 的估计误差更小. 当 N N N充分大时,模型对 N N N变得不敏感.
(d) 根据之前的结论,关于 CT 的实验使用 LPIPS 作为度量函数. 更小的 N N N收敛更快,但是采样结构更差;使用自适应的 N ( ⋅ ) N(·) N() μ ( ⋅ ) \mu(·) μ()效果最好.

对比 CDprogressive disillation(PD) 在不同数据集上的效果,CD 的表现普遍比 PD 好.
在这里插入图片描述

对比 CT 和其它生成模型,仅使用一步或两步生成.
在这里插入图片描述

Zero-Shot Image Editing

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/319957.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

怎样制作一本旅游电子相册呢?

​随着数码技术的发展&#xff0c;旅游电子相册已成为越来越多旅游爱好者的必备工具。它不仅能让您随时随地欣赏自己的旅行回忆&#xff0c;还能分享给亲朋好友&#xff0c;甚至上传到社交媒体上&#xff0c;让更多人了解您的旅行故事。那么&#xff0c;如何制作一本精美的旅游…

最新国内可用GPT4、Midjourney绘画、DALL-E3文生图模型教程

一、前言 ChatGPT3.5、GPT4.0、GPT语音对话、Midjourney绘画&#xff0c;文档对话总结DALL-E3文生图&#xff0c;相信对大家应该不感到陌生吧&#xff1f;简单来说&#xff0c;GPT-4技术比之前的GPT-3.5相对来说更加智能&#xff0c;会根据用户的要求生成多种内容甚至也可以和…

Redis的主从配置,哨兵模式,集群模式

目录 什么是主从复制&#xff1f; 主从复制的作用&#xff1f; 主从复制的流程&#xff1f; 搭建Redis的主从复制 安装Redis 环境准备 修改内核参数 安装Redis 定义systemd服务管理脚本 修改Redis配置文件&#xff08;Master节点操作&#xff09;192.168.17.25 修改Re…

机器学习周报第28周

目录 摘要Abstract一、文献阅读1.题目&#xff1a;2.摘要3.问题描述4.过去方案5.论文方案6.论文模型7.相关代码 摘要 本周阅读了一篇混沌时间序列预测的论文&#xff0c;论文模型主要使用的是时间卷积网络&#xff08;Temporal Convolutional Network&#xff0c;TCN&#xff…

ZMQ_REQ\REP模式

文章内容&#xff1a; 学习ZMQ库中REQ\REP模式相关的内容 简介 应答模式&#xff1a;REQ&#xff08;客户端&#xff09;和REP&#xff08;服务端&#xff09; 典型的一问一答协议&#xff0c;即客户端需要首先发送hello&#xff0c;服务器则返回word&#xff0c;若客户端发…

深度学习工具-如何选择服务器和GPU

深度学习训练通常需要大量的计算。目前&#xff0c;GPU是深度学习最具成本效益的硬件加速器。与CPU相比&#xff0c;GPU更便宜&#xff0c;性能更高&#xff0c;通常超过一个数量级。此外&#xff0c;一台服务器可以支持多个GPU&#xff0c;高端服务器最多支持8个GPU。更典型的…

分布式缓存

分布式缓存 缓存雪崩 缓存雪崩我们可以简单的理解为&#xff1a;由于原有缓存失效&#xff0c;新缓存未到期间所有原本应该访问缓存的请求都去查询数据库了&#xff0c;而对数据库 CPU 和内存造成巨大压力&#xff0c;严重的会造成数据库宕机。从而形成一系列连锁反应&#xf…

自动粘贴文本:高效复制中国邮政编码,提升效率,释放创意

在快节奏的现代生活中&#xff0c;时间就是金钱&#xff0c;效率就是生命。中国邮政EMS&#xff0c;作为您的快递服务首选&#xff0c;一直致力于提供更加便捷、高效的寄递体验。今天&#xff0c;我们隆重推出全新功能——"自动粘贴文本"&#xff0c;让您轻松复制邮政…

【EAI 006】ChatGPT for Robotics:将 ChatGPT 应用于机器人任务的提示词工程研究

论文标题&#xff1a;ChatGPT for Robotics: Design Principles and Model Abilities 论文作者&#xff1a;Sai Vemprala, Rogerio Bonatti, Arthur Bucker, Ashish Kapoor 作者单位&#xff1a;Scaled Foundations, Microsoft Autonomous Systems and Robotics Research 论文原…

1.3K Star,让发送短信变的更简单

Hi&#xff0c;骚年&#xff0c;我是大 G&#xff0c;我的公众号「GitHub指北」会推荐 GitHub 上有趣有用的项目&#xff0c;一分钟 get 一个优秀的开源项目&#xff0c;挖掘开源的价值。 前言 在日常的开发过程中&#xff0c;短信的发送经常使用&#xff08;尤其是中小型的外…

C#,入门教程(18)——分支语句(switch-case)的基础知识

上一篇&#xff1a; C#&#xff0c;入门教程(17)——条件语句&#xff08;if-else&#xff09;的基础知识https://blog.csdn.net/beijinghorn/article/details/124033376 1、switch概述 switch-case分支语句 可以理解为 大号 的 if-else。 switch语句以switch关键字开头&…

x-cmd pkg | tmux - 开源终端多路复用器(terminal multiplexer)

目录 简介首次用户基本概念功能特点竞品和相关作品进一步阅读 简介 tmux 是一个用于 Unix 操作系统的开源终端复用器&#xff08;terminal multiplexer&#xff09;&#xff0c;它允许用户在一个终端窗口中创建多个虚拟终端会话&#xff0c;并同时在这些会话之间切换&#xff…

谈⼀谈你对TCPIP四层模型,OSI七层模型的理解

TCP/IP四层模型 对比 OSI七层模型 OSI七层模型 为了增强通⽤性和兼容性&#xff0c;计算机⽹络都被设计成层次机构&#xff0c;每⼀层都遵守⼀定的规则。因此有了OSI这样⼀个抽象的⽹络通信参考模型&#xff0c;按照这个标准使计算机⽹络系统可以互相连接 物理层 通过⽹线、光…

Harbor离线安装

下载安装包 $ wget https://github.com/goharbor/harbor/releases/download/v2.7.4/harbor-offline-installer-v2.7.4.tgz解压 $ tar xvf harbor-offline-installer-v2.7.4.tgz -C /usr/local修改配置 $ cd /usr/local/harbor $ cp harbor.yml.tmpl harbor.yml $ vim harbo…

第 4 课 创建工作空间与功能包

文章目录 第 4 课 创建工作空间与功能包1.工作环境的创建2.ROS功能包的创建 第 4 课 创建工作空间与功能包 消息和服务的创建、发布器和订阅器的编写、服务端和客户端的编写都是基于Ros功能包进行操作的&#xff0c;因此在进行上述操作前&#xff0c;需要先创建工作空间及功能包…

Java期末复习题库(封装,继承,抽象类,接口,GUI)

包与字符串 1.创建包的基本操作 在biology包中的animal包中有human类,它具有name,height,weight的属性,还具有eat(),sleep()和work()的行为,在biology包中的plant包中有flower类,它具有name,color,smell的属性,还具有drink()和blossom()的行为. 现在在一个school包中的garde…

服务器经常宕机的原因及解决办法

随着如今互联网信息化时代的不断发展&#xff0c;数据存储和传输在各种网络科技面前也显得越来越重要&#xff0c;对于企业来讲&#xff0c;建站之后服务器的安全稳定是至关重要的选择。那么选择一款好用的服务器愈发重要。 当然&#xff0c;不管是多好的服务器提供商&#xff…

延时任务的解决方案

延时任务的解决方案 1.数据库轮询2. JDK的延迟队列3.netty时间轮算法4.使用消息队列 1.数据库轮询 该方案通常是在小型项目中使用&#xff0c;即通过一个线程定时的去扫描数据库&#xff0c;通过订单时间来判断是否有超时的订单&#xff0c;然后进行update或delete等操作 代码示…

leetcode206.反转链表

https://leetcode.cn/problems/reverse-linked-list/description/ 题目 给你单链表的头节点 head &#xff0c;请你反转链表&#xff0c;并返回反转后的链表。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5] 输出&#xff1a;[5,4,3,2,1]示例 2&#xff1a; 输入&am…

ruoyi后台管理系统部署-2-安装mysql

centos7 mysql 安装 1. 手动安装 安装 首先查看系统是否安装了&#xff1a; rpm -qa|grep mariadb rpm -qa | grep mysql systemctl status mysqld find / -name mysql.cnf卸载自带的 mariadb: rpm -e mariadb-libs-5.5.68-1.el7.x86_64 --nodeps去官网下载 mysql 安装包&…