PyTorch深度学习实战(31)——生成对抗网络(Generative Adversarial Network, GAN)

PyTorch深度学习实战(31)——生成对抗网络

    • 0. 前言
    • 1. GAN
    • 2. GAN 模型分析
    • 3. 利用 GAN 模型生成手写数字
    • 小结
    • 系列链接

0. 前言

生成对抗网络 (Generative Adversarial Networks, GAN) 是一种由两个相互竞争的神经网络组成的深度学习模型,它由一个生成网络和一个判别网络组成,通过彼此之间的博弈来提高生成网络的性能。生成对抗网络使用神经网络生成与原始图像集非常相似的新图像,它在图像生成中应用广泛,且 GAN 的相关研究正在迅速发展,以生成与真实图像难以区分的逼真图像。在本节中,我们将学习 GAN 网络的原理并使用 PyTorch 实现 GAN

1. GAN

生成对抗网络 (Generative Adversarial Networks, GAN) 包含两个网络:生成网络( Generator,也称生成器)和判别网络( discriminator,也称判别器)。在 GAN 网络训练过程中,需要有一个合理的图像样本数据集,生成网络从图像样本中学习图像表示,然后生成与图像样本相似的图像。判别网络接收(由生成网络)生成的图像和原始图像样本作为输入,并将图像分类为原始(真实)图像或生成(伪造)图像。
生成网络的目标是生成逼真的伪造图像骗过判别网络,判别网络的目标是将生成的图像分类为伪造图像,将原始图像样本分类为真实图像。本质上,GAN 中的对抗表示两个网络的相反性质,生成网络生成图像来欺骗判别网络,判别网络通过判别图像是生成图像还是原始图像来对输入图像进行分类:

GAN原理

在上图中,生成网络根据输入随机噪声生成图像,判别网络接收生成网络生成的图像,并将它们与真实图像样本进行比较,以判断生成的图像是真实的还是伪造的。生成网络尝试生成尽可能逼真的图像,而判别网络尝试判定生成网络生成图像的真实性,从而学习生成尽可能逼真的图像。
GAN 的关键思想是生成网络和判别网络之间的竞争和动态平衡,通过不断的训练和迭代,生成网络和判别网络会逐渐提高性能,生成网络能够生成更加逼真的样本,而判别网络则能够更准确地区分真实和伪造的样本。
通常,生成网络和判别网络交替训练,将生成网络和判别网络视为博弈双方,并通过两者之间的对抗来推动模型性能的提升,直到生成网络生成的样本能够以假乱真,判别网络无法分辨真实样本和生成样本之间的差异:

  • 生成网络的训练过程:冻结判别网络权重,生成网络以噪声 z 作为输入,通过最小化生成网络与真实数据之间的差异来学习如何生成更好的样本,以便判别网络将图像分类为真实图像
  • 判别网络的训练过程:冻结生成网络权重,判别网络通过最小化真实样本和假样本之间的分类误差来更新判别网络,区分真实样本和生成样本,将生成网络生成的图像分类为伪造图像

重复训练生成网络与判别网络,直到达到平衡,当判别网络能够很好地检测到生成的图像时,生成网络对应的损失比判别网络对应的损失要高得多。通过不断训练生成网络和判别网络,直到生成网络可以生成逼真图像,而判别网络无法区分真实图像和生成图像。

2. GAN 模型分析

为了生成手写数字的图像,我们采取以下策略:

  • 导入 MNIST 数据
  • 初始化随机噪声
  • 定义生成网络模型
  • 定义判别网络模型
  • 使用生成网络生成伪造图像,生成网络在最初只能生成噪声图像,噪声图像是通过将一组噪声值通过权重随机的神经网络得到的图像
  • 交替训练两个模型
    • 将生成的图像与原始图像串联起来,判别网络预测每个图像是伪造图像还是真实图像,对判别网络进行训练,判别网络的损失是图像的预测值和实际值(标签)的二进制交叉熵,生成的伪造图像的实际值(标签)为 0,原始数据集中真实图像的实际值(标签)为 1
    • 训练生成网络利用输入噪声生成伪造图像,使其看起来更接近真实图像,从而使生成图像有可能欺骗判别网络
    • 输入噪声通过生成网络传递输出伪造图像,将生成网络生成的图像输入到判别网络中,此时,判别网络权重被冻结,因为生成网络的目标是欺骗判别网络,因此,假设生成的伪造图像实际值(标签)为 1,生成网络的损失是判别网络对输入图像的预测值和实际值 (1) 的二进制交叉熵

了解了 GAN 的基本原理后,在下一小节,我们实现 GAN 生成 MNIST 手写数字图像。

3. 利用 GAN 模型生成手写数字

(1) 导入相关库并定义设备:

import torch
from torch import nn
from torch import optim
from matplotlib import pyplot as plt
import numpy as np
from torchvision.utils import make_grid
device = "cuda" if torch.cuda.is_available() else "cpu"

from torchvision.datasets import MNIST
from torchvision import transforms

(2) 导入 MNIST 数据,定义具有内置数据转换功能的数据加载器,以便缩放输入数据:

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5,), std=(0.5,))
])

data_loader = torch.utils.data.DataLoader(MNIST('MNIST/', train=True, download=True, transform=transform),batch_size=128, shuffle=True, drop_last=True)

(3) 定义判别网络模型类:

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential( 
            nn.Linear(784, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.model(x)

在以上代码中,使用 LeakyReLU 激活函数替换 ReLU。打印判别网络的简要信息:

from torchsummary import summary
discriminator = Discriminator().to(device)
print(summary(discriminator, (1,784)))

模型简要信息输出结果如下所示:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1              [-1, 1, 1024]         803,840
         LeakyReLU-2              [-1, 1, 1024]               0
           Dropout-3              [-1, 1, 1024]               0
            Linear-4               [-1, 1, 512]         524,800
         LeakyReLU-5               [-1, 1, 512]               0
           Dropout-6               [-1, 1, 512]               0
            Linear-7               [-1, 1, 256]         131,328
         LeakyReLU-8               [-1, 1, 256]               0
           Dropout-9               [-1, 1, 256]               0
           Linear-10                 [-1, 1, 1]             257
          Sigmoid-11                 [-1, 1, 1]               0
================================================================
Total params: 1,460,225
Trainable params: 1,460,225
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.04
Params size (MB): 5.57
Estimated Total Size (MB): 5.61
----------------------------------------------------------------

(4) 定义生成网络模型类 Generator

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

生成网络根据 100 维随机噪声输入生成图像。打印生成网络模型的简要信息:

generator = Generator().to(device)
print(summary(generator, (1,100)))

模型简要信息输出结果如下所示:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1               [-1, 1, 256]          25,856
         LeakyReLU-2               [-1, 1, 256]               0
            Linear-3               [-1, 1, 512]         131,584
         LeakyReLU-4               [-1, 1, 512]               0
            Linear-5              [-1, 1, 1024]         525,312
         LeakyReLU-6              [-1, 1, 1024]               0
            Linear-7               [-1, 1, 784]         803,600
              Tanh-8               [-1, 1, 784]               0
================================================================
Total params: 1,486,352
Trainable params: 1,486,352
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.04
Params size (MB): 5.67
Estimated Total Size (MB): 5.71
----------------------------------------------------------------

(5) 定义函数生成随机噪声并将其注册到设备中:

def noise(size):
    n = torch.randn(size, 100)
    return n.to(device)

(6) 定义函数来训练判别网络。

判别网络训练函数 (discriminator_train_step) 将真实数据 (real_data) 和伪造数据 (fake_data) 作为输入:

def discriminator_train_step(real_data, fake_data, loss, d_optimizer):

重置优化器梯度:

    d_optimizer.zero_grad()

在对损失值执行反向传播之前,预测真实数据 (real_data) 并计算损失 (error_real):

    prediction_real = discriminator(real_data)
    error_real = loss(prediction_real, torch.ones(len(real_data), 1).to(device))
    error_real.backward()

在真实数据上计算判别网络损失时,我们期望判别网络预测输出为 1。因此,在判别网络的训练过程中,使用 torch.ones 作为标签,期望判别网络在真实数据上的输出为 1,从而计算判别网络在真实数据上的损失。

在对损失值执行反向传播之前,预测伪造数据 (fake_data) 并计算损失 (error_fake):

    prediction_fake = discriminator(fake_data)
    error_fake = loss(prediction_fake, torch.zeros(len(fake_data), 1).to(device))
    error_fake.backward()

在伪造数据上计算判别网络损失时,我们期望判别网络预测输出为 0。因此,在判别网络的训练过程中,使用 torch.zeros 作为标签,期望判别网络在伪造数据上的输出为 0,从而计算判别网络在伪造数据上的损失。

更新权重并返回整体损失(将模型在 real_dataerror_realfake_dataerror_fake 的损失值相加):

    d_optimizer.step()
    return error_real + error_fake

(7) 训练生成网络模型。

定义生成网络训练函数 generator_train_step 并传入伪造数据 fake_data 作为参数:

def generator_train_step(real_data, fake_data, loss, g_optimizer):

重置优化器梯度:

    g_optimizer.zero_grad()

预测判别网络对伪造数据 (fake_data) 的输出:

    prediction = discriminator(fake_data)

在计算生成网络的损失时,使用 torch.ones 作为标签,期望判别网络在伪造数据上的输出为 1,以在训练生成网络时欺骗判别网络输出值 1,以此来鼓励生成网络生成更加逼真的数据,并让判别网络无法区分其真伪:

    error = loss(prediction, torch.ones(len(real_data), 1).to(device))

执行反向传播,更新权重,并返回损失:

    error.backward()
    g_optimizer.step()
    return error

(8) 定义模型对象、生成网络和判别网络的优化器,以及损失函数:

discriminator = Discriminator().to(device)
generator = Generator().to(device)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
loss = nn.BCELoss()

(9) 训练模型。

循环训练模型 200epochs (num_epochs):

num_epochs = 200

d_loss_epoch = []
g_loss_epoch = []
for epoch in range(num_epochs):
    N = len(data_loader)
    d_loss_items = []
    g_loss_items = []
    for i, (images, _) in enumerate(data_loader):

加载真实数据 (real_data) 和伪造数据,其中,伪造数据是通过将大小与真实数据样本数相同的噪声数据 (batch_size = len(real_data)) 传入生成网络网络获得的。需要注意的是,必须调用 fake_data.detach(),否则训练无法正常进行。通过 detach() 函数分离出来一个新的张量,这样在 discriminator_train_step() 中调用 error.backward() 时,与生成网络相关的张量(生成 fake_data )不会受到影响。使用 discriminator_train_step 函数训练判别网络:

        real_data = images.view(len(images), -1).to(device)
        fake_data = generator(noise(len(real_data))).to(device)
        fake_data = fake_data.detach()

训练判别网络后,继续训练生成网络。从噪声数据生成一组新的伪造图像 (fake_data) 并使用 generator_train_step 函数训练生成网络:

        fake_data = generator(noise(len(real_data))).to(device)
        g_loss = generator_train_step(real_data, fake_data, loss, g_optimizer)

记录损失变化:

        d_loss_items.append(d_loss.item())
        g_loss_items.append(g_loss.item())
    d_loss_epoch.append(np.average(d_loss_items))
    g_loss_epoch.append(np.average(g_loss_items))

绘制判别网络和生成网络的损失随训练的变化情况:

epochs = np.arange(num_epochs)+1
plt.plot(epochs, d_loss_epoch, 'bo', label='Discriminator Training loss')
plt.plot(epochs, g_loss_epoch, 'r-', label='Generator Training loss')
plt.title('Training and Test loss over increasing epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.show()

模型性能检测

(10) 可视化模型训练后生成的伪造数据:

z = torch.randn(64, 100).to(device)
sample_images = generator(z).data.cpu().view(64, 1, 28, 28)
grid = make_grid(sample_images, nrow=8, normalize=True)
plt.imshow(grid.cpu().detach().permute(1,2,0), cmap='gray')
plt.show()

生成结果

在上图中,可以看到利用 GAN 生成逼真的图像,但仍有一定的改进空间,在之后的学习中,我们将介绍更多 GAN 的改进模型生成更逼真的图像。

小结

生成对抗网络是一种强大的深度学习模型,由生成器网络和判别器网络组成,通过彼此之间的竞争来提高性能,已经在图像生成、图像修复、图像转换和自然语言处理等领域取得了巨大的成功。其核心思想是通过生成器和判别器之间的博弈过程来实现真实样本的生成。生成器负责生成逼真的样本,而判别器则负责判断样本是真实还是伪造。通过不断的训练和迭代,生成器和判别器会相互竞争并逐渐提高性能。

系列链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——使用U-Net架构进行图像分割
PyTorch深度学习实战(24)——从零开始实现Mask R-CNN实例分割
PyTorch深度学习实战(25)——自编码器(Autoencoder)
PyTorch深度学习实战(26)——卷积自编码器(Convolutional Autoencoder)
PyTorch深度学习实战(27)——变分自编码器(Variational Autoencoder, VAE)
PyTorch深度学习实战(28)——对抗攻击(Adversarial Attack)
PyTorch深度学习实战(29)——神经风格迁移
PyTorch深度学习实战(30)——Deepfakes

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/339421.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

EOCR电机保护器带煤电厂的具体应用

EOCR系列电动机智能保护器在煤电厂中有广泛的应用。这些保护器具有齐全的保护功能、直观的测量参数、快速的反应灵敏度、可靠的行动以及与上位机通讯构成远程监控的能力,使其成为理想的低压电动机保护及远程监控产品。 在煤电厂中,电动机保护器需要具备过…

SpringCloud Aliba-Sentinel【上篇】-从入门到学废【4】

🎵诗词分享🎵 大江东去,浪淘尽,千古风流人物。 ——苏轼《念奴娇赤壁怀古》 目录 🍿1.Sentinel是什么 🧂2.特点 🧈3.下载 🌭4.sentinel启动 🥓5.实例演示 1.Senti…

IOT pwn

已经过了填坑的黄金时期 环境搭建 交叉编译工具链 很多开源项目需要交叉编译到特定架构上,因此需要安装对应的交叉编译工具链。 sudo apt install gcc-arm-linux-gnueabi g-arm-linux-gnueabi -y sudo apt install gcc-aarch64-linux-gnu g-aarch64-linux-gnu -…

RK3568笔记十:Zlmediakit交叉编译

若该文为原创文章,转载请注明原文出处。 编译Zlmediakit的主要目的是想实现在RK3568拉取多路RTPS流,并通过MPP硬解码,DRM显示出来。为了实现拉取多路流选择了Zlmediakit,使用FFMEPG也可以,在RV1126上已经验证了可行性。 一、环境…

对MODNet 主干网络 MobileNetV2的剪枝探索

目录 1 引言 1.1 MODNet 原理 1.2 MODNet 模型分析 2 MobileNetV2 剪枝 2.1 剪枝过程 2.2 剪枝结果 2.2.1 网络结构 2.2.2 推理时延 2.3 实验结论 3 模型嵌入 3.1 模型保存与加载 法一:保存整个模型 法二:仅保存模型的参数 小试牛刀 小结…

服务端实现微信小游戏登录

1 微信小程序用户登录及其流程 小程序可以通过微信官方提供的登录能力,便能方便的获取微信提供的用户身份标识,达到建立用户体系的作用。 官方文档提供了登录流程时序图,如下: 从上述的登录流程时序图中我们发现,这里总共涉及到三个概念。 第一个是小程序,小程序即我们…

C#,入门教程(30)——扎好程序的笼子,错误处理 try catch

上一篇: C#,入门教程(29)——修饰词静态(static)的用法详解https://blog.csdn.net/beijinghorn/article/details/124683349 程序员语录:凡程序必有错,凡有错未必改! 程序出错的原因千千万&…

Rockchip linux USB 驱动开发

Linux USB 驱动架构 Linux USB 协议栈是一个分层的架构,如下图 5-1 所示,左边是 USB Device 驱动,右边是 USB Host 驱动,最底层是 Rockchip 系列芯片不同 USB 控制器和 PHY 的驱动。 Linux USB 驱动架构 USB PHY 驱动开发 USB 2…

LeetCode 77. 组合

77. 组合 给定两个整数 n 和 k,返回范围 [1, n] 中所有可能的 k 个数的组合。 你可以按 任何顺序 返回答案。 示例 1: 输入:n 4, k 2 输出: [[2,4],[3,4],[2,3],[1,2],[1,3],[1,4], ] 示例 2: 输入:…

Kafka(八)使用Kafka构建数据管道

目录 1 使用场景2 构建数据管道时需要考虑的问题2.1 及时性2.2 可靠性高可用可靠性数据传递 2.3 高吞吐量2.4 数据格式2.5 转换ETLELT 2.6 安全性2.7 故障处理2.8 耦合性和灵活性临时数据管道元数据丢失末端处理 3 使用Connect API3.1 Connect的数据处理流程sourcesinkconnecto…

csrf漏洞之DedeCMS靶场漏洞(超详细)

1.Csrf漏洞: 第一步,查找插入点,我选择了网站栏目管理的先秦两汉作为插入点 第二步修改栏目名称 第三步用burp拦截包 第四步生成php脚本代码 第五步点击submit 第六步提交显示修改成功 第二个csrf 步骤与上述类似 红色边框里面的数字会随着…

第91讲:MySQL主从复制集群主库与从库状态信息的含义

文章目录 1.主从复制集群正常状态信息2.从库状态信息中重要参数的含义 1.主从复制集群正常状态信息 通过以下命令查看主库的状态信息。 mysql> show processlist;在主库中查询当前数据库中的进程,看到Master has sent all binlog to slave; waiting for more u…

Rocky Linux 8.9 安装图解

风险告知 本人及本篇博文不为任何人及任何行为的任何风险承担责任,图解仅供参考,请悉知!本次安装图解是在一个全新的演示环境下进行的,演示环境中没有任何有价值的数据,但这并不代表摆在你面前的环境也是如此。生产环境…

2023 IoTDB Summit:北京城建智控科技股份有限公司高级研发主管刘喆《IoTDB在城市轨道交通综合监控系统中的应用》...

12 月 3 日,2023 IoTDB 用户大会在北京成功举行,收获强烈反响。本次峰会汇集了超 20 位大咖嘉宾带来工业互联网行业、技术、应用方向的精彩议题,多位学术泰斗、企业代表、开发者,深度分享了工业物联网时序数据库 IoTDB 的技术创新…

Laykefu客服系统 任意文件上传漏洞复现

0x01 产品简介 Laykefu 是一款基于workerman+gatawayworker+thinkphp5搭建的全功能webim客服系统,旨在帮助企业有效管理和提供优质的客户服务。 0x02 漏洞概述 Laykefu客服系统/admin/users/upavatar.html接口处存在文件上传漏洞,而且当请求中Cookie中的”user_name“不为…

SpringBoot+Email发送邮件

引言 邮件通知是现代应用中常见的一种通信方式,特别是在需要及时反馈、告警或重要事件通知的场景下。Spring Boot提供了简单而强大的邮件发送功能,使得实现邮件通知变得轻而易举。本文将研究如何在Spring Boot中使用JavaMailSender实现邮件发送&#xf…

geemap学习笔记052:影像多项式增强

前言 下面介绍的主要内容是应用Image.polynomial() 对图像进行多项式增强。 1 导入库并显示地图 import ee import geemap ee.Initialize() Map geemap.Map(center[40, -100], zoom4)2 多项式增强 # 使用函数 -0.2 2.4x - 1.2x^2 对 MODIS 影像进行非线性对比度增强。# L…

Oracle Linux 7.9 安装图解

风险告知 本人及本篇博文不为任何人及任何行为的任何风险承担责任,图解仅供参考,请悉知!本次安装图解是在一个全新的演示环境下进行的,演示环境中没有任何有价值的数据,但这并不代表摆在你面前的环境也是如此。生产环境…

画眉(京东科技设计稿转代码平台)介绍

前言 随着金融App业务的不断发展,为了满足不同场景下的用户体验及丰富的业务诉求,业务产品层面最直接体现就是大量新功能的上线及老业务的升级,随之也给研发带来了巨大的压力,所以研发效率的提升就是当前亟需解决的问题&#xff…

赛氪公开课|AI发展的影响与对策

欢迎参赛者们参加中国计算机应用技术大赛,由中国计算机学会主办的中国计算机应用技术大赛于2023年10月至2024年7月举办。为了增强参赛者对计算机领域专业知识与技能的兴趣,提升参赛者解决问题、逻辑思维、自主创新、团队协作等能力。赛氪网特别邀请了人工…
最新文章