在 CelebA 数据集上训练的 PyTorch 中的基本变分自动编码器

摩西·西珀博士

一、说明

        我最近发现自己需要一种方法将图像编码到潜在嵌入中,调整嵌入,然后生成新图像。有一些强大的方法可以创建嵌入从嵌入生成。如果你想同时做到这两点,一种自然且相当简单的方法是使用变分自动编码器。

        这样的深度网络不仅可以进行编码和解码,而且相当简单,我可以在以后的研究中使用它,而不必过多担心编码解码阶段的各种隐藏复杂性。我也更喜欢对软件内部有尽可能多的控制。

        因此,考虑到所有这些规范,我从 GitHub 收集了一些零碎的东西,施展了一些我自己的魔法,最终得到了一个漂亮、简单的变分自动编码器。我将在下面描述主要部分,完整的包可在以下位置找到:

vae-torch-celeba,PyTorch 中 CelebA 数据集的变分自动编码器,下载vae-torch-celeba的源码_GitHub_帮酷

PyTorch 中用于 CelebA 数据集的变分自动编码器 - GitHub - moshesipper/vae-torch-celeba:变分...

它相当小并且完全独立——这就是我的意图!

二、自动编码器

        为了使本文简短易懂,我将避免提供变分自动编码器的冗长概述。此外,您还可以在 Medium 上找到有关基础知识的优秀文章。我只提供三张快速图片。

        这是基本自动编码器的样子:

来源: https: //commons.wikimedia.org/wiki/File :Autoencoder_schema.png

        简而言之,网络将输入数据压缩为潜在向量(也称为嵌入),然后将其解压缩回来。这两个阶段称为编码解码

        变分自动编码器(VAE)看起来非常相似,除了中间的嵌入部分。对于每个输入,VAE 的编码器输出潜在空间中预定义分布的参数,而不是潜在空间中的向量:

来源:https ://commons.wikimedia.org/wiki/File:Reparameterized_Variational_Autoencoder.png

        最后一张图片:如果我们处理的是图像输入,我们需要一个卷积VAE,如下所示:

来源:https://github.com/arthurmeyer/Saliency_Detection_Convolutional_Autoencoder

        注意#1:观察编码器部分如何在每一层中添加越来越多的滤波器,图像变得越来越小;解码器则相反。

        注意#2:注意符号。如果只有一个通道,则术语“过滤器”和“内核”基本相同。对于多个通道,每个过滤器都是一组内核。查看这篇很棒的 Medium 文章:“直观地理解深度学习的卷积”。

三、CelebA数据集

我将使用的数据集是 CelebA,其中包含 202,599 张名人面孔图像。

CelebA 数据集

CelebFaces Attributes Dataset (CelebA) 是一个大规模人脸属性数据集,包含超过 20 万张名人图像……

可以通过以下方式访问它torchvision:

from torchvision.datasets import CelebA

train_dataset = CelebA(path, split='train')
test_dataset = CelebA(path, split='valid') # or 'test'

四、VAE类

        我的 VAE 基于此PyTorch 示例和存储库的普通 VAE模型(将我使用的普通 VAE 替换为中的任何其他模型PyTorch-VAE应该不会太难)。PyTorch-VAE

        该文件vae.py包含VAE类以及图像大小的定义、两个潜在向量的维度(均值和方差)以及数据集的路径:

CELEB_PATH = './data/'
IMAGE_SIZE = 150
LATENT_DIM = 128
image_dim = 3 * IMAGE_SIZE * IMAGE_SIZE

在课堂上VAE,我使用了以下隐藏过滤器维度:

hidden_dims = [32, 64, 128, 256, 512]

编码器看起来像这样:

in_channels = 3
modules = []
for h_dim in hidden_dims:
    modules.append(
        nn.Sequential(
            nn.Conv2d(in_channels, out_channels=h_dim,
                      kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(h_dim),
            nn.LeakyReLU())
    )
    in_channels = h_dim
self.encoder = nn.Sequential(*modules)

然后是潜在向量:

self.fc_mu = nn.Linear(hidden_dims[-1] * self.size * self.size, LATENT_DIM)
self.fc_var = nn.Linear(hidden_dims[-1] * self.size * self.size, LATENT_DIM)

        最后我们用解码器“倒退”:

hidden_dims.reverse()

for i in range(len(hidden_dims) - 1):
   modules.append(
      nn.Sequential(
         nn.ConvTranspose2d(hidden_dims[i],
                            hidden_dims[i + 1],
                            kernel_size=3,
                            stride=2,
                            padding=1,
                            output_padding=1),
         nn.BatchNorm2d(hidden_dims[i + 1]),
         nn.LeakyReLU())
   )

self.decoder = nn.Sequential(*modules)

这就是它的要点——还有一些零碎的内容vae.py可以完成这VAE门课。

五、训练

        该文件trainvae.py包含训练我们刚刚编码的 VAE 的代码。老实说,没什么花哨的......有 3 个主要函数:(train随着训练的进行,它也输出损失值),test(它还构建一个重建图像的小样本)和loss_function。训练和测试相当普通,损失函数是标准 VAE,带有重建组件 (MSE) 和 KL 散度组件。

        epoch 上的主循环执行 4 个操作:1) train、2) test、3) 生成随机潜在向量并调用decode以输出相应的输出图像,以及 4) 将 epoch 的模型保存到文件中pth

        以下是示例运行的输出。通过 20 个训练周期,您最终会得到 20 个重建图像文件、20 个潜在采样文件和 20 个 python 模型文件:

        这里reconstruction_20.png,顶行显示 8 张原始图片,底行显示经过训练的 VAE 的相应重建。

        在 epoch 20 时从模型重建(输出)图像。

        这里的sample_20.png,显示了从随机潜在向量生成的 64 张图像:

        只是为了好玩,我添加了一小段代码 — genpics.py— 从数据集中挑选一个随机图像并生成 7 个重建。以下是一些示例(最左边的图像是原始图像):

        最后,我再次放置 GitHub 链接。享受!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/118210.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

学习LevelDB架构的检索技术

目录 一、LevelDB介绍 二、LevelDB优化检索系统关键点分析 三、读写分离设计和内存数据管理 (一)内存数据管理 跳表代替B树 内存数据分为两块:MemTable(可读可写) Immutable MemTable(只读&#xff0…

力扣370周赛 -- 第三题(树形DP)

该题的方法,也有点背包的意思,如果一些不懂的朋友,可以从背包的角度去理解该树形DP 问题 题解主要在注释里 //该题是背包问题树形dp问题的结合版,在树上解决背包问题 //背包问题就是选或不选当前物品 //本题求的是最大分数 //先转…

京东商品详情API接口(PC端和APP端),京东详情页,商品属性接口,商品信息查询

京东开放平台提供了API接口来访问京东商品详情。通过这个接口,您可以获取到商品的详细信息,如商品名称、价格、库存量、描述等。 以下是使用京东商品详情API接口的一般步骤: 注册并获取API权限:您需要在京东开放平台上注册并获取…

arcgis pro模型构建器

如果你不想部署代码包环境来写arcpy代码,还想实现批量或便携封装的操作工具,那么使用模型构建器是最好的选择。1.简介模型构建器 1.1双击打开模型构建器 1.2简单模型构建步骤 先梳理整个操作流程,在纸上绘制在工具箱中找到所需工具拖进来把…

LangChain+LLM实战---实用Prompt工程讲解

原文:Practical Prompt Engineering 注:本文中,提示和prompt几乎是等效的。 这是一篇非常全面介绍Prompt的文章,包括prompt作用于大模型的一些内在机制,和prompt可以如何对大模型进行“微调”。讲清楚了我们常常听到的…

搭建二维码系统,轻松实现固定资产的一物一码管理

固定资产管理中普遍存在盘点难、家底不清、账实不一致、权责不清晰等问题,可以在草料上搭建固定资产管理系统,通过组合功能模块实现资产信息展示、领用登记、出入库管理、故障报修等功能,对固定资产进行一物一码规范化管理。 比如张掖公路事业…

创建基于多任务的并发服务器

有几个请求服务的客户端&#xff0c;我们就创建几个子进程。 这个过程有以下三个阶段&#xff1a; 这里父进程传递的套接字文件描述符&#xff0c;实际上不需要传递&#xff0c;因为子进程会复制父进程拥有的所有资源。 #include <stdio.h> #include <stdlib.h>…

票务营销数字化:景区增收利器

身处数字化时代&#xff0c;景区门票销售早已插上数字化翅膀&#xff0c;通过一站式的票务管理、精准的营销策略等为景区带来了数字化增长。票务营销系统如何帮助景区增收&#xff1f; l 提高工作效率&#xff1a;传统的景区售票方式往往需要大量的人工操作&#xff0c;不仅耗时…

微信小程序overflow-x超出部分样式不渲染

把display:flex改成display:inline-flex&#xff0c; 将对象作为内联块级弹性伸缩盒显示&#xff0c; 类似与是子元素将父元素撑开&#xff0c;样式就显示出来了

纺织布料行业小程序开发

随着互联网的发展&#xff0c;越来越多的消费者通过线上渠道购买纺织布料产品。为了满足市场需求&#xff0c;越来越多的纺织布料企业选择开发小程序&#xff0c;以提高销售效率、拓宽销售渠道和提升用户体验。下面是开发纺织布料行业小程序的具体步骤&#xff1a; 1. 登录乔拓…

Flume从入门到精通一站式学习笔记

文章目录 什么是FlumeFlume的特性Flume高级应用场景Flume的三大核心组件Source&#xff1a;数据源channelsink Flume安装部署Flume的使用案例&#xff1a;采集文件内容上传至HDFS案例&#xff1a;采集网站日志上传至HDFS 各种自定义组件例如&#xff1a;自定义source例如&#…

Python的切片操作详细用法解析

在利用Python解决各种实际问题的过程中&#xff0c;经常会遇到从某个对象中抽取部分值的情况&#xff0c;切片操作正是专门用于完成这一操作的有力武器。理论上而言&#xff0c;只要条件表达式得当&#xff0c;可以通过单次或多次切片操作实现任意切取目标值。切片操作的基本语…

开源 Wiki 软件 wiki.js

wiki.js简介 最强大、 可扩展的开源Wiki 软件。使用 Wiki.js 美观直观的界面让编写文档成为一种乐趣&#xff01;根据 AGPL-v3 许可证发布。 官方网站&#xff1a;https://js.wiki/ 项目地址&#xff1a;https://github.com/requarks/wiki 主要特性&#xff1a; 随处安装&a…

机器学习笔记 - 感知器的数学表达

一、假设前提 感知机(或称感知器,Perceptron)是Frank Rosenblatt在1957年就职于Cornell航空实验室(Cornell Aeronautical Laboratory)时所发明的一种人工神经网络。 它可以被视为一种最简单形式的前馈神经网络,是一种二元线性分类模型,其输入为实例的特征向量,输出为实…

绕开网站反爬虫原理及实战

1.摘要 在本文中,我首先对网站常用的反爬虫和反自动化技术做了一个梳理, 并对可能能够绕过这些反爬技术的开源库chromedp所使用的技术分拆做一个介绍, 最后利用chromedp库对一个测试网站做了爬虫测试, 并利用chromedp库绕开了爬虫限制,成功通过程序自动获取到信息。在测试过程…

如何发布自己的golang库

如何发布自己的golang库 1、在 github/gitee 上创建一个 public 仓库&#xff0c;仓库名与 go 库名一致&#xff0c;然后将该仓库 clone 到本地。 本文这里使用 gitee。 $ git clone https://gitee.com/zsx242030/goutil.git2、进入项目文件夹&#xff0c;进行初始化。 $ go…

Webpack介绍大全

Webpack 一 、什么是webpack WebPack是一个现代JS应用程序的静态模块打包器&#xff08;module bundler&#xff09; 模块&#xff08;模块化开发&#xff0c;可以提高开发效率&#xff0c;避免重复造轮子&#xff09; 打包&#xff08;将各个模块&#xff0c;按照一定的规则…

新版onenet平台安全鉴权的确定与使用

根据onenet官方更新的文档&#xff1a;平台提供开放的API接口&#xff0c;用户可以通过HTTP/HTTPS调用&#xff0c;进行设备管理&#xff0c;数据查询&#xff0c;设备命令交互等操作&#xff0c;在API的基础上&#xff0c;根据自己的个性化需求搭建上层应用。 为提高API访问安…

JavaScript作用域实战

● 首先&#xff0c;我们先创建一个函数&#xff0c;和以前一样&#xff0c;计算一个年龄的 function calcAge(birthYear) {const age 2037 - birthYear;return age; }● 然后我们创建一个全局变量&#xff0c;并调用这个函数 const firstName "IT知识一享"; cal…

【pytorch源码分析--torch执行流程与编译原理】

背景 解读torch源码方便算子开发方便后续做torch 模型性能开发 基本介绍 代码库 https://github.com/pytorch/pytorch 模块介绍 aten: A Tensor Library的缩写。与Tensor相关的内容都放在这个目录下。如Tensor的定义、存储、Tensor间的操作&#xff08;即算子/OP&#xff…
最新文章