第G6周:CycleGAN实践

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

一、CycleGAN原理

(一)CycleGAN的原理结构:

CycleGAN(循环生成对抗网络)是一种生成对抗网络(GAN),它能够在没有成对训练样本的情况下,将一个域(比如照片中的马)转换成另一个域(比如照片中的斑马)。CycleGAN 主要由两部分组成:两个生成器和一个判别器。生成器的作用是在两个域之间进行转换,而判别器则用于判断输入的图像是真实的还是由生成器生成的。

1. 生成器(Generator)

CycleGAN 包含两个生成器,一个用于将图像从域A转换到域B,另一个则用于将图像从域B转换回域A。每个生成器都是一个神经网络,通常采用卷积神经网络(CNN)的结构。生成器的目标是学习如何将输入图像转换成目标域中的图像,同时欺骗判别器,使其认为生成的图像是真实的。

2. 判别器(Discriminator)

与生成器相对应,CycleGAN 也有两个判别器,一个用于判断图像是否属于域A,另一个用于判断图像是否属于域B。判别器同样通常采用卷积神经网络结构,其目标是能够准确地区分真实图像和由生成器生成的图像。

3. 循环一致性损失(Cycle Consistency Loss)

为了确保在没有成对训练样本的情况下,生成器能够学习到有效的映射,CycleGAN 引入了循环一致性损失。这个损失函数确保当图像从域A转换到域B,然后再转换回域A时,得到的图像与原始图像尽可能相似。同样地,从域B转换到域A,再转换回域B也应该保持循环一致性。

4. 对抗损失(Adversarial Loss)

对抗损失是GAN的核心,它确保生成器能够生成足以欺骗判别器的图像。对于CycleGAN,每个生成器都需要最小化对抗损失,使其生成的图像能够在相应的判别器上获得高分。

5. identity损失(Identity Loss)

除了上述两种损失函数,CycleGAN 还引入了identity损失,以确保当输入图像已经属于目标域时,生成器能够返回原始图像。这有助于保持生成器在训练过程中的稳定性,并防止过度拟合。

6.训练过程:

在训练过程中,CycleGAN 通过不断调整生成器和判别器的参数来最小化上述损失函数。生成器尝试生成越来越能欺骗判别器的图像,而判别器则尝试越来越准确地识别真实图像和生成图像。通过这种对抗过程,生成器最终能够学习到如何在两个域之间进行有效的转换。

7.应用:

CycleGAN 在计算机视觉领域有广泛的应用,如风格迁移、季节变换、照片增强等,它为那些没有成对训练数据的图像转换任务提供了一种有效的解决方案。

(二)CycleGAN和传统GAN的对比

CycleGAN与普通GAN相比,有几个特殊之处,这些特性使得CycleGAN适合于图像到图像的转换任务,尤其是在没有成对训练数据的情况下:

  1. 无配对数据要求
    • 普通GAN通常需要成对的训练数据,即每个输入图像都有一个对应的输出图像。而CycleGAN不需要这样的成对数据,它可以学习一个域(比如照片)到另一个域(比如画作)的转换,即使没有直接的成对映射。
  2. 循环一致性损失
    • CycleGAN引入了循环一致性损失(Cycle Consistency Loss),这是其核心的创新之一。这个损失函数确保当图像从源域转换到目标域,然后再转换回源域时,能够尽可能地恢复到原始图像。这样的循环保证了即使在没有成对数据的情况下,转换过程也是合理的。
  3. 两个生成器和两个判别器
    • CycleGAN包含两个生成器,每个生成器负责一个方向的转换(从域A到域B和从域B到域A)。同时,也有两个判别器,分别用于判断图像是否属于域A或域B。这种结构使得CycleGAN能够同时学习两个域之间的映射。
  4. 对抗性损失和身份损失
    • 除了循环一致性损失,CycleGAN还结合了对抗性损失和身份损失。对抗性损失确保生成器能够生成足以欺骗判别器的图像,而身份损失确保当输入图像已经属于目标域时,生成器能够返回原始图像,这有助于保持生成器的稳定性。
  5. 多样化的应用场景
    • 由于不需要成对数据,CycleGAN可以应用于多种不同的图像到图像的转换任务,如风格迁移、季节变换、艺术作品风格化等,而这些任务在普通GAN中很难实现。
  6. 更强的泛化能力
    • 由于循环一致性损失的设计,CycleGAN在训练过程中学习到了更加泛化的特征表示,这使得它在面对未见过的数据时,也能表现出较好的转换效果。

二、CycleGAN代码分析

第一部分

导入的库的作用的分析

  1. argparse
    • 这是一个Python的标准库,用于解析命令行参数。在训练或测试CycleGAN时,可以通过命令行传入各种参数,如学习率、批量大小、数据集路径等,argparse可以帮助程序解析这些参数。
  2. itertools
    • 这也是Python的标准库之一,它提供了多种迭代操作的函数。在处理数据集或进行模型训练时,可能会用到itertools来生成迭代器,例如用于循环遍历数据批次。
  3. torchvision.utils
    • 这个模块包含了多个实用函数,用于处理和展示图像。save_image函数用于将Tensor保存为图像文件,make_grid函数用于将多个图像拼接成一个网格图像,这在可视化训练过程中的图像时非常有用。
  4. torch.utils.data
    • 这个模块提供了数据加载和处理的工具,DataLoader类是其中的核心,它允许我们以批量方式加载数据,并提供数据并行处理的功能,这对于实现高效的数据加载非常重要。
  5. modelsdatasetsutils:这三个是自带的数据,都全部抓取,导入模型中
import argparse
import itertools
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from models import *
from datasets import *
from utils import *
import torch

第二部分

代码分析
parser = argparse.ArgumentParser() 这一行代码创建了一个 ArgumentParser
对象,这个对象将用于解析命令行参数。
argparse 是 Python 的一个库,它提供了一个方便的方式来解析命令行参数。ArgumentParser 类是 argparse
库中最重要的类,它设置了解析器的基本行为,并允许开发者添加参数。

创建了解析器对象之后,开发者可以通过调用 add_argument() 方法来指定程序需要接受的命令行参数。每个
add_argument() 调用都为一个特定的参数添加了规则,包括参数的名称、类型、帮助信息等。
接下来一段主要就是设置模型的各个参数,接下来着重关注每个参数的作用

  1. b1b2
    • 这两个参数是Adam优化器的超参数,分别代表一阶矩估计的指数衰减率和二阶矩估计的指数衰减率。它们用于计算梯度的一阶矩估计(mean)和二阶矩估计(uncentered
      variance),并对学习率进行自适应调整。
  2. batch_size
    • 批量大小,即每次训练时传递给模型的样本数量。批量大小的大小会影响模型的收敛速度和稳定性。
  3. channels
    • 图像的通道数。对于彩色图像,通常是3(红、绿、蓝通道)。
  4. checkpoint_interval
    • 检查点间隔,即在训练过程中,每过多少个周期(epoch)保存一次模型的权重。
  5. dataset_name
    • 数据集的名称。CycleGAN可以用于不同的图像到图像的转换任务,这个参数指定了当前使用的数据集。
  6. decay_epoch
    • 学习率衰减开始的周期。通常在训练过程中,学习率会随着训练的进行而逐渐减小,decay_epoch指定了何时开始衰减。
  7. epoch
    • 当前周期数。这个参数通常在训练开始时设置为0,并在每个周期结束时递增。
  8. img_heightimg_width
    • 输入图像的高度和宽度。CycleGAN要求所有输入图像具有相同的大小,这些参数指定了图像的尺寸。
  9. lambda_cyc
    • 循环一致性损失的权重。在CycleGAN中,循环一致性损失用于确保图像在转换回原始域后尽可能接近原始图像。lambda_cyc控制这个损失在总损失中的重要性。
  10. lambda_id
    • 身份损失的权重。身份损失确保当输入图像已经属于目标域时,生成器能够返回原始图像。lambda_id控制这个损失在总损失中的重要性。
  11. lr
    • 学习率,即模型参数在每次更新时的调整幅度。学习率的选择对模型的训练至关重要,过大的学习率可能导致模型无法收敛,过小的学习率可能导致训练过程缓慢。
  12. n_cpu
    • 用于数据加载的CPU核心数。在加载数据时,可以使用多个CPU核心来并行处理,以提高数据加载的效率。
  13. n_epochs
    • 总的训练周期数。一个周期是指模型对整个训练数据集进行一次完整的遍历。
  14. n_residual_blocks
    • 残差块的数量。在CycleGAN的生成器中,残差块用于构建网络的主干,增加残差块的数量可以增加模型的容量和表达能力。
  15. sample_interval
    • 采样间隔,即在训练过程中,每过多少个批次保存一次生成的图像样本。这有助于监控训练过程中模型的表现。 这些参数共同决定了CycleGAN模型的训练过程和表现。通过调整这些参数,可以优化模型的性能,并适应不同的训练环境和任务需求。

parser.parse_args() 是 argparse 库中的一个方法,它用于解析命令行参数。在代码片段中,parser
是一个 ArgumentParser 对象,它已经定义了程序可以接受的参数和它们的属性。当调用 parse_args()
方法时,会发生以下几件事情:
1.解析命令行参数: parse_args() 会检查命令行中提供的参数,并根据 parser
对象中定义的规则来解析它们。如果命令行参数的格式正确,它们将被转换成相应的数据类型(例如,字符串、整数、浮点数等)。
2.填充 args 对象:解析后的参数会被填充到一个名为 args 的命名空间对象中。这个对象包含了所有解析后的参数值,可以通过点号操作符访问这些值,例如args.batch_size。
3.提供默认值: 如果在命令行中没有提供某个参数的值,parse_args() 会使用在 add_argument() 调用中指定的默认值。
4. 错误处理: 如果命令行参数的格式不正确或者提供了未定义的参数,parse_args()
会自动打印出错误信息和一个用法提示,并退出程序。

parser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training")
parser.add_argument("--dataset_name", type=str, default="monet2photo", help="name of the dataset")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--decay_epoch", type=int, default=50, help="epoch from which to start lr decay")
parser.add_argument("--n_cpu", type=int, default=16, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_height", type=int, default=256, help="size of image height")
parser.add_argument("--img_width", type=int, default=256, help="size of image width")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving generator outputs")
parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model checkpoints")
parser.add_argument("--n_residual_blocks", type=int, default=9, help="number of residual blocks in generator")
parser.add_argument("--lambda_cyc", type=float, default=10.0, help="cycle loss weight")
parser.add_argument("--lambda_id", type=float, default=7.0, help="identity loss weight")
opt = parser.parse_args()
print(opt)

第三部分

代码知识点
在CycleGAN的实现中,使用了三种损失函数来指导模型的训练。这些损失函数分别对应于不同的训练目标,下面是每个损失函数的
解释:

  1. criterion_GAN (对抗损失):
    • 这是一种用于计算生成对抗网络(GAN)中的对抗性损失的函数。在CycleGAN中,每个生成器都试图生成能够欺骗对应判别器的图像。通常,GAN使用二进制交叉熵损失(BCELoss)作为对抗性损失,但在某些情况下,也可以使用均方误差损失(MSELoss)。torch.nn.MSELoss()计算的是预测值和目标值之间的均方误差,这种损失函数在图像生成任务中可以提供平滑的梯度,有助于生成器的学习。
  2. criterion_cycle (循环一致性损失):
    • 这个损失函数用于确保图像在经过两个生成器(从域A到域B,再从域B回到域A)的转换后,能够尽可能地恢复到原始图像。循环一致性损失使用L1范数(绝对值误差)来计算,因为L1损失对异常值不那么敏感,能够产生更清晰的图像。torch.nn.L1Loss()计算的是预测值和目标值之间的平均绝对误差。
  3. criterion_identity (身份损失):
    • 身份损失确保当输入图像已经属于目标域时,生成器能够返回原始图像。这个损失函数也使用L1损失来计算,因为它能够帮助生成器学习到保持图像结构的映射。例如,如果我们将一张风景照片作为输入,生成器应该输出与输入非常相似的风景照片,而不是将其转换为另一幅完全不同的图像。
      总的来说,这三种损失函数共同作用于CycleGAN的训练过程中,使得生成器能够学习到在不同域之间进行有效转换的同时,保持图像的循环一致性和身份一致性。通过平衡这些损失函数,CycleGAN能够在没有成对训练样本的情况下,实现高质量的图像到图像的转换。
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

cuda = torch.cuda.is_available()

input_shape = (opt.channels, opt.img_height, opt.img_width)

第四部分

这段代码是CycleGAN模型训练脚本的的一部分,它涉及到创建生成器和判别器模型、将模型移动到GPU上(如果可用)、以及加载或初始化模型的权重。下面是分段解释:

创建生成器和判别器模型:

G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)

这四行代码创建了四个模型:两个生成器G_ABG_BA,以及两个判别器D_AD_BGeneratorResNet是生成器的类,它接收输入图像的形状和残差块的数量作为参数。Discriminator是判别器的类,它接收输入图像的形状作为参数。

将模型和损失函数移动到GPU上:

if cuda:
    G_AB = G_AB.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()
    D_B = D_B.cuda()
    criterion_GAN.cuda()
    criterion_cycle.cuda()
    criterion_identity.cuda()

这段代码检查是否使用了GPU(cuda变量为True),如果是,则将创建的模型和损失函数移动到GPU上。这样可以利用GPU加速模型的训练和计算。

加载或初始化模型的权重:

if opt.epoch != 0:
    # 加载预训练模型
    G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, opt.epoch)))
    G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, opt.epoch)))
    D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % (opt.dataset_name, opt.epoch)))
    D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % (opt.dataset_name, opt.epoch)))
else:
    # 初始化权重
    G_AB.apply(weights_init_normal)
    G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)

这段代码检查opt.epoch是否不等于0,如果不等于0,则从磁盘加载对应周期的预训练模型权重。这些权重保存在以数据集名称和周期数命名的文件中。如果opt.epoch等于0,则表示从头开始训练,这时会使用weights_init_normal函数来初始化模型的权重。weights_init_normal函数通常是一个自定义函数,用于将模型的权重初始化为正态分布。
总的来说,这段代码负责创建CycleGAN所需的模型,并将它们配置为在GPU上运行(如果可用),然后根据是否有预训练的权重来加载或初始化这些模型的权重。

第五部分

这段代码涉及到为CycleGAN的生成器和判别器创建优化器,设置学习率更新调度器,以及定义Tensor类型和重播缓冲区。下面是分段解释:

创建优化器:

optimizer_G = torch.optim.Adam(
    itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

这五行代码创建了三个优化器:一个用于两个生成器G_ABG_BA的联合参数(optimizer_G),以及两个分别用于判别器D_AD_B的优化器(optimizer_D_Aoptimizer_D_B)。所有优化器都是使用Adam算法,它是一种适用于大规模数据和高维空间的优化算法。lr参数设置学习率,betas参数设置Adam算法中的两个超参数,分别是第一和第二矩估计的指数衰减率。

设置学习率更新调度器:

lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)

这四行代码为每个优化器创建了一个学习率更新调度器。LambdaLR是一个基于自定义函数的学习率调度器,它允许用户根据迭代次数来调整学习率。LambdaLRlr_lambda参数是一个函数,它根据当前周期数、总周期数和开始衰减的周期数来计算学习率的乘数。LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step是一个函数,它返回学习率乘数。

定义Tensor类型:

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

这行代码定义了一个Tensor类型,如果cuda为True(表示使用GPU),则使用torch.cuda.FloatTensor,否则使用torch.Tensor。这个Tensor类型将在后续代码中用于创建张量。

创建重播缓冲区:

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

这两行代码创建了两个重播缓冲区,用于存储生成的样本。这些缓冲区在训练判别器时用于提供历史的生成样本,以提高训练的稳定性。ReplayBuffer是一个自定义类,它提供了一个固定大小的存储,用于存储最近生成的样本。

第六部分

这段代码主要涉及到创建数据加载器、定义一个用于保存样本图像的函数,并且在训练过程中定期保存生成的样本。下面是分段解释:

1-7.定义图像转换操作:

transforms_ = [
    transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC),
    transforms.RandomCrop((opt.img_height, opt.img_width)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

这行代码定义了一个列表transforms_,其中包含了多个图像转换操作。这些操作包括:

  • Resize:将图像尺寸放大到原始高度的1.12倍,使用双三次插值方法。
  • RandomCrop:从放大的图像中随机裁剪出原始尺寸的图像。
  • RandomHorizontalFlip:以一定的概率水平翻转图像。
  • ToTensor:将图像转换为PyTorch张量。
  • Normalize:对图像进行归一化处理,将像素值范围从[0, 1]转换为[-1, 1]。

8-12. 创建训练数据加载器:

dataloader = DataLoader(
    ImageDataset("./data/%s/" % opt.dataset_name, transforms_=transforms_, unaligned=True),
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=opt.n_cpu,
)

这五行代码创建了一个训练数据加载器dataloaderImageDataset是一个自定义类,它从指定的数据集路径加载图像,并应用转换操作。DataLoader是一个迭代器,它允许我们按批次加载数据,并提供数据混洗、多进程数据加载等功能。

13-18.创建测试数据加载器:

val_dataloader = DataLoader(
    ImageDataset("./data/%s/" % opt.dataset_name, transforms_=transforms_, unaligned=True, mode="test"),
    batch_size=5,
    shuffle=True,
    num_workers=1,
)

这五行代码创建了一个测试数据加载器val_dataloader。它与训练数据加载器类似,但是batch_size设置为5,num_workers设置为1,并且ImageDatasetmode参数设置为"test",表示加载的是测试集。

19-33. 定义保存样本图像的函数:

def sample_images(batches_done):
    # ... (代码逻辑见下文)

这行代码定义了一个名为sample_images的函数,它接受一个参数batches_done,表示已经完成的批次数。这个函数的作用是在训练过程中定期保存生成的样本图像。

20-24. 加载并处理测试集的图像:

imgs = next(iter(val_dataloader))
G_AB.eval()
G_BA.eval()
real_A = Variable(imgs["A"].type(Tensor))
fake_B = G_AB(real_A)
real_B = Variable(imgs["B"].type(Tensor))
fake_A = G_BA(real_B)

这些代码行从测试数据加载器中获取下一批图像,并将生成器G_ABG_BA设置为评估模式。然后,它将真实图像real_Areal_B转换为PyTorch张量,并通过生成器生成伪造的图像fake_Bfake_A

26-30. 创建图像网格并保存:

real_A = make_grid(real_A, nrow=5, normalize=True)
real_B = make_grid(real_B, nrow=5, normalize=True)
fake_A = make_grid(fake_A, nrow=5, normalize=True)
fake_B = make_grid(fake_B, nrow=5, normalize=True)
image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
save_image(image_grid, "images/%s/%s.png" % (opt.dataset_name, batches_done), normalize=False)

这些代码行将真实图像和伪造图像排列成网格,并将它们拼接成一个大的图像网格。然后,使用save_image函数将图像网格保存到指定的文件路径。make_grid函数将多个图像排列成一个网格,save_image函数将张量保存为图像文件。

第七部分

这段代码是CycleGAN训练脚本的一部分,涉及生成器的训练过程。下面是逐行解释:

1-3. 检查是否为主程序入口:定义一个变量prev_time来记录开始训练的时间。

if __name__ == '__main__':
    prev_time = time.time()

这行代码首先检查当前脚本是否作为主程序入口运行。如果是,则继续执行。

6-7. 循环遍历所有周期和批次:

for epoch in range(opt.epoch, opt.n_epochs):
    for i, batch in enumerate(dataloader):

这两个循环遍历了所有的训练周期和每个周期内的批次。opt.epoch是当前的周期数,opt.n_epochs是总的周期数。

8-12. 设置模型输入和对抗性地面真值:

# Set model input
real_A = Variable(batch["A"].type(Tensor))
real_B = Variable(batch["B"].type(Tensor))
# Adversarial ground truths
valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
fake  = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)

这些代码行设置模型的输入,包括真实图像real_Areal_B,以及用于判别器的对抗性地面真值。validfake分别表示真实图像和生成图像的标签。

13-22. 训练生成器:

# ------------------
#  Train Generators
# ------------------
G_AB.train()
G_BA.train()
optimizer_G.zero_grad()
# Identity loss
loss_id_A = criterion_identity(G_BA(real_A), real_A)
loss_id_B = criterion_identity(G_AB(real_B), real_B)
loss_identity = (loss_id_A + loss_id_B) / 2
# GAN loss
fake_B = G_AB(real_A)
loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
fake_A = G_BA(real_B)
loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
# Cycle loss
recov_A = G_BA(fake_B)
loss_cycle_A = criterion_cycle(recov_A, real_A)
recov_B = G_AB(fake_A)
loss_cycle_B = criterion_cycle(recov_B, real_B)
loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
# Total损失
loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity
loss_G.backward()
optimizer_G.step()

这些代码行定义了生成器的损失函数,包括身份损失、对抗性损失、循环一致性损失和总损失。然后,它们计算这些损失,反向传播并更新生成器的权重。
这段代码的主要目的是通过训练生成器来学习从源域(A)到目标域(B)的映射,并从目标域(B)到源域(A)的映射。通过这种方式,生成器能够学习到两个域之间的映射关系,从而实现图像到图像的转换。

第八部分

这段代码是CycleGAN训练脚本的一部分,涉及判别器的训练过程。下面是逐行解释:

1-4. 重置判别器A的优化器梯度:

optimizer_D_A.zero_grad()

这行代码将判别器A的优化器中的梯度清零,以便在反向传播时不会累加到之前的梯度上。

5-11. 计算判别器A的损失:

# Real loss
loss_real = criterion_GAN(D_A(real_A), valid)
# Fake loss (on batch of previously generated samples)
fake_A_ = fake_A_buffer.push_and_pop(fake_A)
loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
# Total loss
loss_D_A = (loss_real + loss_fake) / 2

这些代码行计算判别器A的损失,包括真实图像的损失(loss_real)和生成图像的损失(loss_fake)。fake_A_是从fake_A_buffer中弹出的先前生成的图像批次的副本,以确保判别器不会反向传播梯度回生成器。

12-16. 计算判别器A的总损失,并进行反向传播和权重更新:

loss_D_A.backward()
optimizer_D_A.step()

这行代码计算判别器A的总损失,并执行反向传播,将梯度传播回判别器的权重。然后,使用优化器optimizer_D_A来更新判别器A的权重。

17-22. 重置判别器B的优化器梯度:

optimizer_D_B.zero_grad()

这行代码将判别器B的优化器中的梯度清零,以便在反向传播时不会累加到之前的梯度上。

23-30. 计算判别器B的损失:

# Real loss
loss_real = criterion_GAN(D_B(real_B), valid)
# Fake loss (on batch of previously generated samples)
fake_B_ = fake_B_buffer.push_and_pop(fake_B)
loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
# Total loss
loss_D_B = (loss_real + loss_fake) / 2

这些代码行计算判别器B的损失,包括真实图像的损失(loss_real)和生成图像的损失(loss_fake)。fake_B_是从fake_B_buffer中弹出的先前生成的图像批次的副本。

31-35. 计算判别器B的总损失,并进行反向传播和权重更新:

loss_D_B.backward()
optimizer_D_B.step()

这行代码计算判别器B的总损失,并执行反向传播,将梯度传播回判别器的权重。然后,使用优化器optimizer_D_B来更新判别器B的权重。

36. 计算判别器A和B的总损失,以平衡两个判别器的训练:

loss_D = (loss_D_A + loss_D_B) / 2

这行代码计算判别器A和B的总损失的平均值,以平衡两个判别器的训练过程。

第九部分

这段代码是CycleGAN训练脚本的循环外部分,涉及打印日志、保存图像样本、更新学习率以及保存模型检查点。下面是逐行解释:

1-8. 计算剩余批次数和剩余时间:

batches_done = epoch * len(dataloader) + i
batches_left = opt.n_epochs * len(dataloader) - batches_done
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
prev_time = time.time()

这些代码行计算了到目前为止已经完成的批次数(batches_done),剩余的批次数(batches_left),以及剩余的训练时间(time_left)。prev_time用于计算时间差,以便估算剩余时间。

9-15. 打印日志:

# Print log
sys.stdout.write(
    "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
    % (
        epoch,
        opt.n_epochs,
        i,
        len(dataloader),
        loss_D.item(),
        loss_G.item(),
        loss_GAN.item(),
        loss_cycle.item(),
        loss_identity.item(),
        time_left,
    )
)

这行代码使用sys.stdout.write来打印训练日志,包括当前周期数、批次数、判别器损失、生成器损失以及剩余训练时间。\r是一个回车符,用于在同一行中打印信息。

16-21. 保存图像样本:

# If at sample interval save image
if batches_done % opt.sample_interval == 0:
    sample_images(batches_done)

这行代码检查是否达到了预定的样本间隔(opt.sample_interval),如果是,则调用sample_images函数来保存生成的图像样本。

22-25. 更新生成器的学习率:

# Update learning rates
lr_scheduler_G.step()

这行代码更新生成器G_ABG_BA的学习率。lr_scheduler_G是生成器的学习率调度器,它根据当前周期数和预设的参数来调整学习率。

26-29. 更新判别器的 learning rate:

lr_scheduler_D_A.step()
lr_scheduler_D_B.step()

这行代码更新判别器D_AD_B的学习率。lr_scheduler_D_Alr_scheduler_D_B是判别器的学习率调度器,它们根据当前周期数和预设的参数来调整学习率。

30-34. 保存模型检查点:

if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
    # Save model checkpoints
    torch.save(G_AB.state_dict(), "saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, epoch))
    torch.save(G_BA.state_dict(), "saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, epoch))
    torch.save(D_A.state_dict(), "saved_models/%s/D_A_%d.pth" % (opt.dataset_name, epoch))
    torch.save(D_B.state_dict(), "saved_models/%s/D_B_%d.pth" % (opt.dataset_name, epoch))

这行代码检查是否达到了预定的检查点间隔(opt.checkpoint_interval),如果是,则保存生成器G_ABG_BA和判别器D_AD_B的权重状态。权重被保存为.pth文件,以便在需要时可以加载

三、学习过程中遇到的问题及其解决方案

1. 如果想要在jupyter notebook中运行.py文件的同学,可以使用

%load 文件名.py

将这个输入在cell中并运行就可以将.py文件转化成.ipynb文件

2. 如果遇到ipykernel_launcher.py: error: unrecognized arguments: -f /home/这个报错问题的同学
这是因为调用parser.parse_args()会读取系统参数:sys.argv[],命令行调用时是正确参数,而在jupyter notebook中调用时,sys.argv的值为ipykrnel_launcher.py:
解决方法是:在代码中加入这段

import sys
sys.argv = ['run.py']

3. 一定要把文件按照要求的目录结构放,不然会影响读取
在这里插入图片描述
4. 这个训练超级费算力,我天真的以为我能跑满200轮,放了一天才跑了四十轮,结果还把那个结搞丢了。后来又重新跑了十几轮

四、训练结果

在这里插入图片描述这点已经花了我六个小时来跑了,其实已经能有很不错的结果了,话不多说,上结果!

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/556726.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

思维导图软件Xmind for Mac 中文激活版 支持M

XMind是一款非常受欢迎的思维导图软件,它应用了Eclipse RCP软件架构,注重易用性、高效性和稳定性,致力于帮助用户提高生产率。 Xmind for Mac 中文激活版下载 XMind的程序主体由一组插件构成,包括一个核心主程序插件、一组Eclipse…

文件后缀变成.halo? 如何恢复重要数据

.halo 勒索病毒是什么? .halo勒索病毒是一种恶意软件,属于勒索软件(Ransomware)的一种。这种病毒会加密用户计算机上的文件,并要求受害者支付赎金才能获取解密密钥,从而恢复被加密的文件。勒索软件通常会通…

力扣(leetcode) 42. 接雨水 (带你逐步思考)

力扣(leetcode) 42. 接雨水 (带你逐步思考) 链接:https://leetcode.cn/problems/trapping-rain-water/ 难度:hard 给定 n 个非负整数表示每个宽度为 1 的柱子的高度图,计算按此排列的柱子,下雨之后能接多…

【方便 | 重要】#LLM入门 | Agent | langchain | RAG # 3.7_代理Agent,使用langchain自带agent完成任务

大型语言模型(LLMs)虽强大,但在逻辑推理、计算和外部信息检索方面能力有限,不如基础计算机程序。例如,LLMs处理简单计算或最新事件查询时可能不准确,因为它们仅基于预训练数据。LangChain框架通过“代理”(…

机器学习在安全领域的应用:从大数据中识别潜在安全威胁

🧑 作者简介:阿里巴巴嵌入式技术专家,深耕嵌入式人工智能领域,具备多年的嵌入式硬件产品研发管理经验。 📒 博客介绍:分享嵌入式开发领域的相关知识、经验、思考和感悟,欢迎关注。提供嵌入式方向的学习指导…

Ubuntu20.04 ISAAC SIM仿真下载使用流程(4.16笔记补充)

机器:华硕天选X2024 显卡:4060Ti ubuntu20.04 安装显卡驱动版本:525.85.05 参考: What Is Isaac Sim? — Omniverse IsaacSim latest documentationIsaac sim Cache 2023.2.3 did not work_isaac cache stopped-CSDN博客 Is…

LeetCode in Python 704. Binary Search (二分查找)

二分查找是一种高效的查询方法&#xff0c;时间复杂度为O(nlogn)&#xff0c;本文给出二分查找的代码实现。 示例&#xff1a; 代码&#xff1a; class Solution:def search(self, nums, target):l, r 0, len(nums) - 1while l < r:mid (l r) // 2if nums[mid] > ta…

C++11 数据结构1 线性表的概念,线性表的顺序存储,实现,测试

一 线性表的概念 线性结构是一种最简单且常用的数据结构。 线性结构的基本特点是节点之间满足线性关系。 本章讨论的动态数组、链表、栈、队列都属于线性结构。 他们的共同之处&#xff0c;是节点中有且只有一个开始节点和终端节点。按这种关系&#xff0c;可以把它们的所有…

MC9S12A64 程序烧写方法

前言 工作需要对MC9S12A64 单片机进行程序烧写。 资料 MC9S12A64 单片机前身属于 飞思卡尔半导体&#xff0c;后来被恩智浦收购&#xff0c;现在属于NXP&#xff1b; MC9S12A64 属于16位S12系列&#xff1b;MC9S12 又叫 HCS12。 数据手册下载连接 S12D_16位微控制器 | N…

[大模型]TransNormerLLM-7B 接入 LangChain 搭建知识库助手

TransNormerLLM-7B 接入 LangChain 搭建知识库助手 环境准备 在 autodl 平台中租赁一个 3090/4090 等 24G 显存的显卡机器&#xff0c;如下图所示镜像选择 PyTorch–>2.0.0–>3.8(ubuntu20.04)–>11.8 接下来打开刚刚租用服务器的 JupyterLab&#xff0c;并且打开其…

简单实用的备忘录小工具 记事提醒备忘效果超好

在这个信息爆炸的时代&#xff0c;我们每个人都需要处理大量的信息和任务。有时候&#xff0c;繁忙的生活和工作会让我们感到压力山大。幸运的是&#xff0c;现在有很多简单实用的软件工具&#xff0c;像得力的小助手一样&#xff0c;帮助我们整理思绪&#xff0c;提高效率&…

Redis系列1:深刻理解高性能Redis的本质

1 背景 分布式系统绕不开的核心之一的就是数据缓存&#xff0c;有了缓存的支撑&#xff0c;系统的整体吞吐量会有很大的提升。通过使用缓存&#xff0c;我们把频繁查询的数据由磁盘调度到缓存中&#xff0c;保证数据的高效率读写。 当然&#xff0c;除了在内存内运行还远远不够…

Docker 和 Podman的区别

文章目录 Docker 和 Podman的区别安装架构和特权要求运行容器方面安全性(root的权限)镜像管理方面命令方面Docker常用命令Podman常用命令 Docker 和 Podman的区别 安装 Docker安装&#xff1a;官方文档 Podman安装&#xff1a;官方文档 架构和特权要求 Docker使用client-se…

11、电科院FTU检测标准学习笔记-越限及告警上送功能

作者简介&#xff1a; 本人从事电力系统多年&#xff0c;岗位包含研发&#xff0c;测试&#xff0c;工程等&#xff0c;具有丰富的经验 在配电自动化验收测试以及电科院测试中&#xff0c;本人全程参与&#xff0c;积累了不少现场的经验 ———————————————————…

git 快问快答

我在实习的时候&#xff0c;是用本地开发&#xff0c;然后 push 到 GitHub 上&#xff0c;之后再从 Linux 服务器上拉 GitHub 代码&#xff0c;然后就可以了。一般程序是在 Linux 服务器上执行的&#xff0c;我当时使用过用 Linux 提供的命令来进行简单的性能排查。 在面试的时…

详解爬虫基本知识及入门案列(爬取豆瓣电影《热辣滚烫》的短评 详细讲解代码实现)

目录 前言什么是爬虫&#xff1f; 爬虫与反爬虫基础知识 一、网页基础知识 二、网络传输协议 HTTP&#xff08;HyperText Transfer Protocol&#xff09;和HTTPS&#xff08;HTTP Secure&#xff09;请求过程的原理&#xff1f; 三、Session和Cookies Session Cookies Session与…

抖音小店流量差怎么办?做好这三大细节,就可以完美解决!

大家好&#xff0c;我是电商糖果 很多刚开店的朋友&#xff0c;遇到的第一个难题就是店铺流量差。 没有流量&#xff0c;也就不会出单&#xff0c;更别提起店了。 糖果做抖音小店四年多了&#xff0c;也开了很多家小店。 很多新店没有流量&#xff0c;其实主要原因是这三个…

在mysql函数中启动事物和行锁/悲观锁实现并发条件下获得唯一流水号

业务场景 我有一个业务需求&#xff1a;我有一个报卡表 report里面有一个登记号字段 fcardno、地区代码 faddrno和发病年份 fyear&#xff0c;登记号由**“4位地区代码”“00”“发病年份”“5位流水号”**组成&#xff0c;我要在每次插入一张报卡&#xff08;每一行数据&#…

【MATLAB源码-第46期】基于matlab的OFDM系统多径数目对比,有无CP(循环前缀)对比,有无信道均衡对比。

操作环境&#xff1a; MATLAB 2022a 1、算法描述 OFDM&#xff08;正交频分复用&#xff09;是一种频域上的多载波调制技术&#xff0c;经常用于高速数据通信中。以下是关于多径数目、有无CP&#xff08;循环前缀&#xff09;以及有无信道均衡在OFDM系统中对误码率的影响&am…

自如电费均摊问题

3月份搬了次家&#xff0c;嫌麻烦租了自如&#xff0c;第一个月的电费账单出来了&#xff0c;由于我是中途搬进去的&#xff0c;于是乎就好奇他会如何计算均摊&#xff0c;这个月电费账单出来了&#xff0c;算了下发现了点东西。 先说结论&#xff1a;按照我的这个均摊的方式&a…
最新文章