CycleGAN论文解读及代码实现

paper: https://arxiv.org/pdf/1703.10593.pdf
github: https://github.com/aitorzip/PyTorch-CycleGAN

1 cycleGAN 小结

  • 网络:
    生成器2个:G_A,G_B
    判别器两个: D_A,D_B
  • 损失函数8个
    6个生成器损失函数
    2个判别器损失函数

1.1 数据

  • fake_B
    原始A,经过生成器G_A,生成fake_B
    A->G_A = fake_B
  • rec_A
    原始A,经过生成器G_A,生成fake_B,再经过生成器G_B,生成重构数据rec_A
    A->G_A ->G_B = rec_A
  • fake_A
    B->G_B = fake_A
    原始B,经过生成器G_B,生成fake_A
  • rec_B
    B->G_B ->G_A = rec_B
    原始B,经过生成器G_B,生成fake_A,再经过生成器G_A,生成重构数据rec_B

1.2 损失函数

6个生成器损失,2个判别器损失
1)6个生成器损失:

  • 生成器一致损失 2个
    数据B经过生成器G_A,后生成的B,与原始B距离最小。A同理
    ① B-> G_A->B’ : 使得B 与B’距离最小
    ② A-> F_B->A’ : 使得A 与A’距离最小
  • 生成器损失 2个
    生成器生成的数据,让判别器都判别为真
    ③ MSELoss(D_A(fake_B), True)
    ④ MSELoss(D_B(fake_A), True)
  • 循环一致损失 2个
    ⑤ 原始A,经过生成器G_A,生成fake_B,再经过生成器G_B,生成重构数据rec_A
    A->G_A ->G_B = rec_A
    ⑥原始B,经过生成器G_B,生成fake_A,再经过生成器G_A,生成重构数据rec_B
    B->G_B ->G_A = rec_B

2) 2个判别器损失

  • 判别器损失 2个
    使真实图片为判别为真,假图片判别为假
    ① D_A
    pred_real = D_A(real); pred_fake= D_A(fake)
    MSELoss(pred_real, True)+MSELoss(pred_fake, False)
    ②D_B
    pred_real = D_B(real); pred_fake= D_B(fake)
    MSELoss(pred_real, True)+MSELoss(pred_fake, False)

2 模型架构

  • 两个生成网络:
    G: X——> Y ,输入X生成Y
    F: Y——> X :输入Y生成X
  • 两个判别网络:
    D_A: 用于区分真实A和 F(B)生成的假A.
    D_B:用于区分真实B和 G(A)生成的假B.

在这里插入图片描述

3 损失函数

3.1 Adversarial Loss

对抗损失:

  • 对于生成器 G: X——> Y
    生成器G_X: 最小化以下目标函数
    对于判别器D_Y:最大化以下目标函数
    在这里插入图片描述
  • 对于生成器 F: Y——> X,损失函数同上
    生成器F_Y,使判别器D_X判断为真
    对于判别器D_X:是真实X判断为真,F_Y生成的X,判断为假。
    L G A N ( F , D X , Y , X ) L_{GAN}(F,D_X,Y,X) LGAN(F,DX,Y,X)

3.2 Cycle Consistency Loss

循环一致损失,即 X 经过生成器G_x后 得到Y,Y再过F_Y生成X,使得前后生成的X距离最小。
1) 前向一致损失
即从x 经过网络后还原为x的过程
X − > G ( x ) − > F ( G ( x ) ) = X X -> G(x) -> F(G(x)) =X X>G(x)>F(G(x))=X

2)反向一致损失
即y从经过网络后还原为y的过程
Y − > F ( y ) − > G ( F ( y ) ) = Y Y -> F(y) -> G(F(y)) =Y Y>F(y)>G(F(y))=Y

在这里插入图片描述

3.3 Full Objective

在这里插入图片描述

4 代码实现

4.1网络结构

  • 1 生成器A :
    netG_A:可以选用resnet,或者unet网络
    输入数据A,生成数据B
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
  • 2 生成器B:
    netG_B: 与netG_A网络一样
    输入数据B,生成数据A
 self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
  • 3 判别器A
self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)

  • 4 判别器B
 self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)


5 损失

5.1 前向传播数据

  • self.fake_B
    原始A,经过生成器G_A,生成fake_B
    A->G_A = fake_B
  • self.rec_A
    原始A,经过生成器G_A,生成fake_B,再经过生成器G_B,生成重构数据rec_A
    A->G_A ->G_B = rec_A
  • self.fake_A
    B->G_B = fake_A
    原始B,经过生成器G_B,生成fake_A
  • self.rec_B
    B->G_B ->G_A = rec_B
    原始B,经过生成器G_B,生成fake_A,再经过生成器G_A,生成重构数据rec_B
self.fake_B = self.netG_A(self.real_A)  # G_A(A)
self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))
self.fake_A = self.netG_B(self.real_B)  # G_B(B)
self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))

5.2 生成器一致损失

数据B经过生成器G_A,后生成的B,与原始B距离最小。
B-> G_A->B’ : 使得B 与B’距离最小
A-> F_B->A’ : 使得A 与A’距离最小

self.idt_A = self.netG_A(self.real_B)
self.loss_idt_A= self.L1Loss(self.idt_A, self.real_B)

self.idt_B = self.netG_B(self.real_A)
self.loss_idt_B = self.L1Loss(self.idt_B, self.real_A)

5.2 生成器损失

生成器生成的数据,让判别器都判别为真

(备注:判别器输出不是一个值,而是一个矩阵,需要使判别器输出矩阵每一个值都接近1)

# GAN loss D_A(G_A(A))
self.loss_G_A = self.MSELoss(self.netD_A(self.fake_B), True)
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterioMSELossGAN(self.netD_B(self.fake_A), True)

5.3 循环一致损失

使得重构的A与原始A距离最近,使用L1Loss

  • self.rec_A
    原始A,经过生成器G_A,生成fake_B,再经过生成器G_B,生成重构数据rec_A
    A->G_A ->G_B = rec_A
  • self.rec_B
    B->G_B ->G_A = rec_B
    原始B,经过生成器G_B,生成fake_A,再经过生成器G_A,生成重构数据rec_B
# Forward cycle loss || G_B(G_A(A)) - A||
self.loss_cycle_A = self.L1Loss(self.rec_A, self.real_A) 
# Backward cycle loss || G_A(G_B(B)) - B||
self.loss_cycle_B = self.L1Loss(self.rec_B, self.real_B) 

5.4 生成器总loss

上面6个生成器损失求和即为总的生成损失函数

self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B

5.5 判别器损失

判别器:使真实图片为判别为真,假图片判别为假

pred_real = netD(real)
loss_D_real = self.MSELoss(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.MSELoss(pred_fake, False)
# Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * 0.5

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/65903.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Cesium相机理解

关于cesium相机,包括里面内部原理网上有很多人讲的都很清楚了,我感觉这两个人写的都挺好得: 相机 Camera | Cesium 入门教程 (syzdev.cn) Cesium中的相机—setView&lookAtTransform_cesium setview_云上飞47636962的博客-CSDN博客上面这…

记录线上一次mysql只能查询,不能插入或更新的bug

错误复现 突然有一天产品通知xx服务不可用,想着最近也没有服务更新,就先排查一下服务日志 使用postman测试的时候请求明显超时,查看日志显示是一个锁的问题 使用工具连接到mysql,查看information_schema.INNODB_TRX,发现有一个事…

docker删除容器时报错:Error response from daemon: reference does not exist

前言 之前使用的docker版本太低了,升级高版本docker之后的错误。 低版本docker(1.30.1)中的镜像有:golang、mysql,将docker升级为24.0.5并新拉取mysql最新版本之后,执行docker images命令,发现…

构建Docker容器监控系统(2)(Cadvisor +Prometheus+Grafana)

Cadvisor产品简介 Cadvisor是Google开源的一款用于展示和分析容器运行状态的可视化工具。通过在主机上运行Cadvisor用户可以轻松的获取到当前主机上容器的运行统计信息,并以图表的形式向用户展示。 接着上一篇来继续 部署Cadvisor 被监控主机上部署Cadvisor容器…

比较研发项目管理系统:哪个更适合您的需求?

项目管理系统对于保持项目进度、提高效率和确保质量至关重要。然而,市场上众多的研发项目管理系统让许多团队陷入选择困难。本文将对几个主流的研发项目管理系统进行深入分析,以帮助您找到最适合您团队的解决方案。 “哪个研发项目管理系统好用好&#x…

在时间和频率域中准确地测量太阳黑子活动及使用信号处理工具箱(TM)生成广泛的波形,如正弦波、方波等研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

vue3获得url上的参数值

1、引入 import { useRoute } from vue-router2、获得const route useRoute() console.log(route.query.number)

程序员必备技能-九大分布式ID生成策略

九大分布式ID生成策略 1.UUID UUID (Universally Unique Identifier),通用唯一识别码。UUID是基于当前时间、计数器(counter)和硬件标识(通常为无线网卡的MAC地址)等数据计算生成的。 UUID由以下几部分的组合&#x…

《Python入门到精通》os模块详解,Python os标准库

「作者主页」:士别三日wyx 「作者简介」:CSDN top100、阿里云博客专家、华为云享专家、网络安全领域优质创作者 「推荐专栏」:小白零基础《Python入门到精通》 os模块详解 1、文件目录操作os.stat() 获取文件状态os.utime() 修改文件时间os.r…

vuejs 设计与实现 - 简单diff算法

DOM 复用与key的作用: DOM 复用什么时候可复用? key 属性就像虚拟节点的“身份证”号,只要两个虚拟节点的 type属性值和 key 属性值都相同,那么我们就认为它们是相同的,即可以进行 DOM 的复用。即 我们通过【移动】来…

无需公网-用zerotier异地组网

无需公网-用zerotier异地组网 在前面的文章中我们讲到利用frp进行内网穿透,但是他的局限在于你需要一台公网服务器。并且对公网服务器的带宽有一定的要求。因此这里我们推荐一款异地组网工具搭建属于自己的虚拟网络,经过授权连接成功之后彼此都在同一网…

Oracle单实例升级补丁

目录 1.当前DB环境2.下载补丁包和opatch的升级包3.检查OPatch的版本4.检查补丁是否冲突5.关闭数据库实例,关闭监听6.应用patch7.加载变化的SQL到数据库8.ORACLE升级补丁查询 oracle19.3升级补丁到19.18 1.当前DB环境 [oraclelocalhost ~]$ cat /etc/redhat-releas…

\vendor\github.com\godror\orahlp.go:531:19: undefined: VersionInfo

…\goAdmin\vendor\github.com\godror\orahlp.go:531:19: undefined: VersionInfo 解决办法 降了go版本(go1.18),之前是go1.19 gorm版本不能用最新的,降至(gorm.io/gorm v1.21.16)就可以 修改交插编译参数 go env -w CGO_ENABLED1…

# ⛳ Docker 安装、配置和详细使用教程-Win10专业版

目录 ⛳ Docker 安装、配置和详细使用教程-Win10专业版🚜 一、win10 系统配置🎨 二、Docker下载和安装🏭 三、Docker配置🎉 四、Docker入门使用 ⛳ Docker 安装、配置和详细使用教程-Win10专业版 🚜 一、win10 系统配…

区块链实验室(15) - 编译FISCO BCOS的过程监测

首次编译开源项目,一般需要下载很多依赖包,尤其是从github、sourceforge等下载依赖包时,速度很慢,编译进度似乎没有一点反应,似乎陷入死循环,似乎陷入一个没有结果的等待。本文提供一种监测方法&#xff0c…

redis的事务和watch机制

这里写目录标题 第一章、redis事务和watch机制1.1)redis事务,事务的三大命令语法:开启事务 multi语法:执行事务 exec语法:取消事务 discard 1.2)redis事务的错误和回滚的情况1.3)watch机制语法&…

【Linux】为.sh脚本制作桌面快捷方式(.desktop,可双击执行),且替换显示图标(图文详情)

目录 0.背景环境 1、原理 2、详细步骤 1)创建.desktop快捷方式 2) 给test.desktop快捷方式增加可执行权限 3)编辑test.desktop内容和参数 4)修改快捷方式属性为双击可执行 5)将桌面快捷方式发送到桌面 0.背景环…

2023全新UI好看的社区源码下载/反编译版

2023全新UI好看的社区源码下载/反编译版 这次分享一个RuleAPP二开美化版(尊重每个作者版权),无加密可反编译版本放压缩包了,自己弄吧!!! RuleAPP本身就是一款免费开源强大的社区,基…

交替方向乘子

目录 一,交替方向乘子ADMM 1,带线性约束的分离优化模型 2,常见优化模型转带线性约束的分离优化模型 3,带线性约束的分离优化模型求解 4,交替方向乘子ADMM 本文部分内容来自教材 一,交替方向乘子ADMM …

Linux计划任务管理at、crond

一、单次任务at at命令可以设置在一个指定的时间执行一个指定任务,只能执行一次,使用前确认系统开启了atd服务。 例如:定时执行某命令或脚本, 1、输入at 19:00,回车; 2、输入需要执行的命令或脚本文件&am…
最新文章