变分自编码器(Variational AutoEncoder,VAE)

1 从AE谈起

说到编码器这块,不可避免地要讲起AE(AutoEncoder)自编码器。它的结构下图所示:

在这里插入图片描述
据图可知,AE通过自监督的训练方式,能够将输入的原始特征通过编码encoder后得到潜在的特征编码,实现了自动化的特征工程,并且达到了降维和泛化的目的。而后通过对进行decoder后,我们可以重构输出。一个良好的AE最好的状态就是解码器的输出能够完美地或者近似恢复出原来的输入, 即。为此,训练AE所需要的损失函数是: ∣ ∣ x − x ^ ∣ ∣ ||x-\hat{x}|| ∣∣xx^∣∣

AE的重点在于编码,而解码的结果,基于训练目标,如果损失足够小的话,将会与输入相同。从这一点上看解码的值没有任何实际意义,除了通过增加误差来补充平滑一些初始的零值或有些许用处。

易知,从输入到输出的整个过程,AE都是基于已有的训练数据的映射,尽管隐藏层的维度通常比输入层小很多,但隐藏层的概率分布依然只取决于训练数据的分布,这就导致隐藏状态空间的分布并不是连续的,它只是稀疏地记录下来你的输入样本和生成图像的一一对应关系。 因此如果我们随机生成隐藏层的状态,那么它经过解码将很可能不再具备输入特征的特点,因此想通过解码器来生成数据就有点强模型所难了

如下图所示,仅通过AE,我们在码空间里随机采样的点并不能生成我们所希望的相应图像。这就使得我的不能够达到AIGC的效果。
在这里插入图片描述
据此,我们对AE的隐藏层z作出改动(让隐空间连续光滑),得到了VAE。
在这里插入图片描述

在这里插入图片描述

2 变分自编码器(Variational AutoEncoder,VAE)

关于变分推断,请查看本人的另一篇博文:变分推断(Variational Inference)

这里只做一个总结:

  • 变分推断是使用另一个分布 q ( z ) q(z) q(z)近似 p ( z ∣ x ) p(z|x) p(zx)
  • 用KL距离衡量分布的近似程度: K L ( q ( z ) ∣ ∣ p ( z ∣ x ) ) KL(q(z)||p(z|x)) KL(q(z)∣∣p(zx)),所以最优的 q ∗ ( z ) = a r g m i n q ( z ) ∈ Q K L ( q ( z ) ∣ ∣ p ( z ∣ x ) ) q^*(z)=argmin_{q(z) \in Q}KL(q(z)||p(z|x)) q(z)=argminq(z)QKL(q(z)∣∣p(zx))
  • K L ( q ( z ) ∣ ∣ p ( z ∣ x ) ) KL(q(z)||p(z|x)) KL(q(z)∣∣p(zx))的最小化转化为对ELBO的最大化,也就是 q ∗ ( z ) = a r g m i n q ( z ) ∈ Q K L ( q ( z ) ∣ ∣ p ( z ∣ x ) ) = a r g m a x q ( z ) ∈ Q E L B O = a r g m a x q ( z ) ∈ Q E q ( l o g ( p ( x , z ) − l o g q ( z ) ) ) q^*(z)=argmin_{q(z) \in Q}KL(q(z)||p(z|x))=argmax_{q(z)\in Q}ELBO=argmax_{q(z)\in Q}E_q(log(p(x,z)-logq(z))) q(z)=argminq(z)QKL(q(z)∣∣p(zx))=argmaxq(z)QELBO=argmaxq(z)QEq(log(p(x,z)logq(z)))

VAE全称是Variational AutoEncoder,即变分自编码器。

在VAE中 q ( z ) q(z) q(z)用一个编码器神经网络表示,假如其参数是 θ \theta θ,那么我们用 q θ ( z ) q_{\theta}(z) qθ(z)或者 q θ ( z ∣ x ) q_{\theta}(z|x) qθ(zx)表示。 p ( z ∣ x ) p(z|x) p(zx)可以认为是自然界真实存在的一个概率分布,但是我们不知道,所以需要用一个神经网络把他近似出来。

2.1 VAE的目的

在这里插入图片描述
VAE的目的:
(1)用神经网络去逼近和模拟 p ( z ∣ x ) p(z|x) p(zx)近似 p ( x ∣ z ) p(x|z) p(xz)这两个概率分布
(2)并尽量保证隐空间是连续和平滑的,即 p ( z ) p(z) p(z) p ( z ∣ x ) p(z|x) p(zx)是平滑的

2.2 VAE方法与损失函数

作者方法“
(1)定义: p ( z ) ∼ N ( 0 , 1 ) p(z) \sim N(0,1) p(z)N(0,1)
(2)定义: q θ ( z ∣ x ) ∼ N ( g ( x ) , h ( x ) ) q_{\theta}(z|x) \sim N(g(x),h(x)) qθ(zx)N(g(x),h(x)),也就是 q θ ( z ∣ x ) q_{\theta}(z|x) qθ(zx)的期望和方差是用两个神经网络计算出来的
(3)定义: p θ ′ ( x ∣ z ) ∼ N ( f ( z ) , c I ) p_{\theta'}(x|z) \sim N(f(z),cI) pθ(xz)N(f(z),cI),所以解码器的输出的是 p θ ′ ( x ∣ z ) p_{\theta'}(x|z) pθ(xz)的期望
这样直接定义好吗?为这么直接这样定义出来?看下面的一个slide
在这里插入图片描述
对ELBO做一个推导:
在这里插入图片描述

因为 p ( x ∣ z ) = 1 2 π c e ∣ ∣ x − f ( z ) ∣ ∣ 2 2 c p(x|z) = \frac{1}{\sqrt{2\pi c}}e^{\frac{||x-f(z)||^2}{2c}} p(xz)=2πc 1e2c∣∣xf(z)2,所以有:
在这里插入图片描述
也就是找到这样的三个神经网络使得上面的式子最大。
对于上面的第二项:
在这里插入图片描述
所以损失函数可以写成:
l o s s = 1 2 ( − l o g h ( x ) 2 + h ( x ) 2 + g ( x ) 2 − 1 ) + C ∣ ∣ x − f ( z ) ∣ ∣ 2 loss=\frac{1}{2}(-logh(x)^2+h(x)^2+g(x)^2-1)+C||x-f(z)||^2 loss=21(logh(x)2+h(x)2+g(x)21)+C∣∣xf(z)2

2.3 重参数技巧

从高斯分布 N ( μ , σ ) N(μ,σ) N(μ,σ)中采样的操作被巧妙转换为了从 N ( 0 , 1 ) N(0,1) N(0,1)中采样得到 ϵ ϵ ϵ后,再通过 z = μ + σ × ϵ z=μ+σ \times ϵ z=μ+σ×ϵ变换得到。
在这里插入图片描述
而在重参数后,我们计算反向传播的过程 如下图所示:

在这里插入图片描述

2.4 整合起来

在这里插入图片描述

(1)从样本库中取图片x
(2)g(x)计算均值,h(x)计算方差,从标准正太分布中采样一个数 ζ \zeta ζ,然后计算 z = ζ h ( x ) + g ( x ) z=\zeta h(x)+g(x) z=ζh(x)+g(x),然后计算 f ( z ) f(z) f(z)
(3)计算损失
(4)反向传播

3 代码实现

3.1 VAE.py

import  torch
from    torch import nn
 
 
class VAE(nn.Module): 
    def __init__(self):
        super(VAE, self).__init__() 
 
        # [b, 784] =>[b,20]
        # u: [b, 10]
        # sigma: [b, 10]
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )
 
        # [b,10] => [b, 784]
        # sigmoid函数把结果压缩到0~1
        self.decoder = nn.Sequential(
            nn.Linear(10, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        """
        :param x:
        :return:
        """
        batchsz = x.size(0)
        # flatten
        x = x.view(batchsz, 784)
        # encoder
        # [b, 20], including mean and sigma
        h_ = self.encoder(x)
        # chunk 在第二维上拆分成两部分
        # [b, 20] => [b,10] and [b, 10]
        mu, sigma = h_.chunk(2, dim=1)
        # reparametrize tirchk, epison~N(0, 1)
        # torch.randn_like(sigma)表示正态分布
        h = mu + sigma * torch.randn_like(sigma)
 
        # decoder
        x_hat = self.decoder(h)
        # reshape
        x_hat = x_hat.view(batchsz, 1, 28, 28)
 
        # KL
        # 1e-8是防止σ^2接近于零时该项负无穷大
        # (batchsz*28*28)是让kld变小
        kld = 0.5 * torch.sum(
            torch.pow(mu, 2) +
            torch.pow(sigma, 2) -
            torch.log(1e-8 + torch.pow(sigma, 2)) - 1
        ) / (batchsz*28*28)
 
 
        return x, kld

3.2 main.py

import  torch
from    torch.utils.data import DataLoader
from    torch import nn, optim
from    torchvision import transforms, datasets
 
from    ae_1 import AE
from    vae import VAE
from    vq-vae import VQVAE
 
import  visdom
 
def main():
    mnist_train = datasets.MNIST('mnist', True, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
 
    mnist_test = datasets.MNIST('mnist', False, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)
 
    #无监督学习,不能使用label
    x, _ = iter(mnist_train).next()
    print('x:', x.shape)
 
    device = torch.device('cuda')
    #model = AE().to(device)
    model = VAE().to(device)
    criteon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)
 
    viz = visdom.Visdom()
 
    for epoch in range(1000):
 
        for batchidx, (x, _) in enumerate(mnist_train):
            # [b, 1, 28, 28]
            x = x.to(device)
 
            x_hat, kld = model(x)
            loss = criteon(x_hat, x)
 
            if kld is not None:
                elbo = - loss - 1.0 * kld
                loss = - elbo
 
            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
 
    print(epoch, 'loss', loss.item(), kld.item())
 
    x, _ = iter(mnist_test).next()
    x = x.to(device)
    with torch.no_grad(): 
 
 
    x_hat = model(x)
    # nrow表示一行的图片
    viz.images(x, nrow=8, win='x', optis=dic(title='x'))
    iz.images(x_hat, nrow=8, win='x_hat', optis=dic(title='x_hat'))
 
if __name__ == '__main__':
    main()

参考

讲解变分自编码器-VAE(附代码)
VAE到底在做什么?VAE原理讲解系列#1
VAE的神经网络是如何搭建的?VAE原理讲解系列#3
从零推导:变分自编码器(VAE)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/330543.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

npm依赖库备份

常用命令 设置默认使用本地缓存安装Nodejs时会自动安装npm,但是局路径是C:\Users\Caffrey\AppData\Roaming\npm默认的缓存路径是C:\Users\Caffrey\AppData\Roaming\npm-cache;查看npm的prefix和cache路径配置信息设置路径 设置默认使用本地缓存 npm con…

MySQL面试题 | 14.精选MySQL面试题

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…

Java NIO (三)NIO Channel类

1 概述 前面提到,Java NIO中一个socket连接使用一个Channel来表示。从更广泛的层面来说,一个通道可以表示一个底层的文件描述符,例如硬件设备、文件、网络连接等。然而,远不止如此,Java NIO的通道可以更加细化。例如&a…

实战 | OpenCV两种不同方法实现粘连大米粒分割计数(步骤 + 源码)

导 读 本文主要介绍基于OpenCV的两种不同方法实现粘连大米分割计数,并给详细步骤和源码。源码和图片素材见文末。 背景介绍 测试图如下,图中有个别米粒相互粘连,本文主要演示如何使用OpenCV用两种不同方法将其分割并计数。 方法一:基于分水岭算法 基于分水岭算法…

【playwright】新一代自动化测试神器playwright+python系列课程18_playwritht元素相关操作_等待元素到某种状态

元素相关操作_等待元素到某种状态 对于自动化测试来说,本质上就是定位元素、操作元素。网页上的元素有不同状态,有些元素本来不在网页的DOM中,经过某一步操作后才出现。有些元素是本来就已经在DOM中但是是隐藏的状态,经过某一步操…

K8S--解决访问Harbor私有仓库无权限的问题(401 Unauthorized)

原文网址:K8S--解决访问Harbor私有仓库无权限的问题(401 Unauthorized)-CSDN博客 简介 本文解决K8S访问Harbor私有仓库无权限的问题:401 Unauthorized。 问题复现 用Harbor部署了私有仓库,将镜像推送上去。指定私有…

python数字图像处理基础(五)——Canny边缘检测、图像金字塔、图像分割

目录 Canny边缘检测原理步骤 图像金字塔1.高斯金字塔2.拉普拉斯金字塔 图像分割图像轮廓检测1.检测轮廓2.绘制轮廓3.补充 Canny边缘检测 梯度是什么? 梯度就是变化的最快的那个方向 edge cv2.Canny(image, threshold1, threshold2[, edges[, apertureSize[, L2gradient ]]…

第90讲:MySQL数据库主从复制集群原理概念以及搭建流程

文章目录 1.MySQL主从复制集群的核心概念1.1.什么是主从复制集群1.2.主从复制集群中的专业术语1.3.主从复制集群工作原理1.4.主从复制中的小细节1.5.搭建主从复制集群的前提条件1.6.MySQL主从复制集群的架构信息 2.搭建MySQL多实例环境2.1.在mysql-1中搭建身为主库的MySQL实例2…

小程序 自定义组件和生命周期

文章目录 ⾃定义组件创建⾃定义组件声明组件编辑组件注册组件 声明引⼊⾃定义组件⻚⾯中使⽤⾃定义组件定义段与⽰例⽅法组件-⾃定义组件传参过程 小程序生命周期应用生命周期页面生命周期页面生命周期 ⾃定义组件 类似vue或者react中的自定义组件 ⼩程序允许我们使⽤⾃定义组件…

设计模式的学习笔记

设计模式的学习笔记 一. 设计模式相关内容介绍 1 设计模式概述 1.1 软件设计模式的产生背景 设计模式最初并不是出现在软件设计中,而是被用于建筑领域的设计中。 1977 年美国著名建筑大师、加利福尼亚大学伯克利分校环境结构中心主任 Christopher Alexander 在…

【动态规划】【数学】【C++算法】18赛车

作者推荐 视频算法专题 本文涉及知识点 动态规划 数学 LeetCode818赛车 你的赛车可以从位置 0 开始,并且速度为 1 ,在一条无限长的数轴上行驶。赛车也可以向负方向行驶。赛车可以按照由加速指令 ‘A’ 和倒车指令 ‘R’ 组成的指令序列自动行驶。 当…

情人节专属--html5 canvas制作情人节告白爱心动画特效

💖效果展示 💖html展示 <!doctype html> <html> <head> <meta charset=

2023年移远车载全面开花,智能座舱加速进击

作为汽车智能化的关键组件&#xff0c;车载模组正发挥着越来越重要的作用。 移远通信进入车载模组领域近十年&#xff0c;已形成了完善的车载产品队列&#xff0c;不但在5G/4G车载通信、智能座舱、C-V2X车路协同等领域打造了一枝独秀的产品线&#xff0c;也推出了车规级Wi-Fi/蓝…

解决springboot启动报Failed to start bean ‘subProtocolWebSocketHandler‘;

解决springboot启动报 Failed to start bean subProtocolWebSocketHandler; nested exception is java.lang.IllegalArgumentException: No handlers 问题发现问题解决 问题发现 使用springboot整合websocket&#xff0c;启动时报错&#xff0c;示例代码&#xff1a; EnableW…

大数据时代的黄金机遇:阿里云大数据分析师ACP认证【一条龙服务100%通过】

扫码和我联系 随着大数据技术的迅速发展和广泛应用&#xff0c;成为了当今时代最具吸引力的技术之一。为了让更多技术人才把握这一时代机遇&#xff0c;阿里云推出了大数据分析师ACP认证&#xff08;Alibaba Cloud Certified Professional - Data Analyst&#xff09;&#xf…

数据结构:顺序栈

栈是一种先进后出的数据结构&#xff0c;只允许在一端&#xff08;栈顶&#xff09;操作&#xff0c;代码中top表示栈顶。 stack.h /* * 文件名称&#xff1a;stack.h * 创 建 者&#xff1a;cxy * 创建日期&#xff1a;2024年01月17日 * 描 述&#xff1a; …

LeetCode、2542. 最大子序列的分数【中等,排序+小顶堆】

文章目录 前言LeetCode、2542. 最大子序列的分数【中等&#xff0c;排序小顶堆】题目及类型思路及代码实现 资料获取 前言 博主介绍&#xff1a;✌目前全网粉丝2W&#xff0c;csdn博客专家、Java领域优质创作者&#xff0c;博客之星、阿里云平台优质作者、专注于Java后端技术领…

基于Springboot的摄影分享网站系统(有报告)。Javaee项目,springboot项目。

演示视频&#xff1a; 基于Springboot的摄影分享网站系统&#xff08;有报告&#xff09;。Javaee项目&#xff0c;springboot项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系结构…

GBASE南大通用数据库GBase BI V5支持的集群部署

GBaseBI V5可以单独部署在一个服务器上&#xff0c;在单套的情况下安装成功后不需要特殊的设置即可直接使用。某些用户的应用并发数可能很多&#xff0c;单个服务器处理请求太慢&#xff0c;GBaseBI V5支持集群和分布式部署。其中集群部署如下图所示&#xff1a; 集群部署 在集…

【Vue3】2-13 : 章节总结

本书目录&#xff1a;点击进入 一、总结内容 二、习题 2.1 【选择题】以下Vue指令中&#xff0c;哪些指令具备简写方式&#xff1f; 2.2 【编程题】以下Vue指令中&#xff0c;哪些指令具备简写方式&#xff1f; &#xff1e; 效果 &#xff1e; 代码 一、总结内容 了解核…
最新文章