深度学习(四):pytorch搭建GAN(对抗网络)

1.GAN

生成对抗网络(GAN)是一种深度学习模型,由两个网络组成:生成器(Generator)和判别器(Discriminator)。生成器负责生成假数据,而判别器则负责判断数据是真实的还是 fake的。这两个网络互相竞争,生成器试图生成更真实的数据以欺骗判别器,而判别器则试图更好地识别生成的数据。
在这里插入图片描述

GAN 的基本思想是:通过训练生成器和判别器,使得生成器能够生成与真实数据非常相似的数据,同时使得判别器能够更有效地识别这些数据。

1.1 概念

  1. 生成器(Generator):生成器是一个神经网络,其目的是生成假的数据,看起来像是真实的。生成器通常包含一些神经网络层,如卷积层、全连接层等。生成器接受随机噪声作为输入,并生成看起来像是真实数据的输出。
  2. 判别器(Discriminator):判别器也是一个神经网络,其目的是识别数据是真实的还是 fake的。判别器通常也包含一些神经网络层,如卷积层、全连接层等。判别器接受输入数据,并输出一个分数,表示输入数据是真实的还是 fake的。
  3. 生成对抗训练:生成对抗训练是指同时训练生成器和判别器。生成器试图生成更真实的数据,以欺骗判别器。判别器则试图更好地识别生成的数据,以避免被欺骗。生成器和判别器之间的竞争导致它们不断改进,以提高生成数据的真实性。
  4. 生成器损失和判别器损失:生成器损失是指生成器试图生成更真实数据的损失。生成器损失通常使用生成器的对抗损失和生成损失之和来计算。判别器损失是指判别器试图更好地识别真实数据和假数据的损失。判别器损失通常使用判别器识别真实数据和假数据的损失之和来计算。
  5. 对抗性训练:对抗性训练是指在训练过程中,使用生成器生成的假数据来训练判别器,以提高判别器的识别能力。同时,使用判别器识别的反馈来训练生成器,以提高生成器生成更真实数据的能力。

1.2 优势

GAN(Generative Adversarial Network)是一种生成对抗网络,主要由生成器和判别器组成。生成器负责生成假数据,而判别器负责判断数据是真实的还是 fake的。GAN 的训练过程相对复杂,但是它可以生成非常真实的数据,并且可以用来进行数据增强、图像生成、视频生成等应用。

GAN 的优势主要体现在以下几个方面:

  1. 生成数据非常真实:GAN 可以生成非常真实的数据,可以用来进行数据增强、图像生成、视频生成等应用。
  2. 可以生成大量数据:GAN 可以生成大量的数据,可以用来进行机器学习、深度学习等应用。
  3. 可以生成不同类型的数据:GAN 可以生成不同类型的数据,可以用来进行图像生成、视频生成等应用。
  4. 可以进行对抗训练:GAN 可以进行对抗训练,可以提高模型的鲁棒性和泛化能力。

虽然 GAN 具有优势,但是也存在一些挑战,例如训练过程复杂、生成器容易过拟合、对抗训练难以实现等。因此,在实际应用中,需要根据具体情况进行优化和调整。

1.3 训练技巧

  1. 使用批归一化(Batch Normalization):批归一化是一种在卷积神经网络中常用的加速训练和提高模型性能的方法。在 GAN 的生成器和判别器中可以使用批归一化来提高性能。
  2. 使用 Leaky ReLU 激活函数:Leaky ReLU 激活函数是一种在 ReLU 激活函数中加入一个小于 1 的常数,以避免神经元死亡的方法。在 GAN 的生成器和判别器中可以使用 Leaky ReLU 激活函数来提高性能。
  3. 使用 U-Net 结构:U-Net 是一种用于图像分割的网络结构,其结构可以同时实现编码器和解码器。在 GAN 的生成器中可以使用 U-Net 结构来提高生成图像的质量。
  4. 使用对抗性损失(Adversarial Loss):对抗性损失是一种可以增加生成器损失的方法,通过在损失函数中加入一个与真实数据接近的噪声来增加生成器的难度。在 GAN 的训练过程中可以使用对抗性损失来提高性能。
  5. 使用预训练模型:预训练模型是一种在已有数据集上训练好的模型,可以用于迁移学习和提高性能。在 GAN 的生成器和判别器中可以使用预训练模型来提高性能。
  6. 使用注意力机制(Attention):注意力机制是一种可以提高模型性能和泛化能力的方法,可以在 GAN 的生成器和判别器中使用注意力机制来提高性能。

总结起来,GAN 的训练过程需要综合考虑多个方面,包括数据预处理、损失函数选择、正则化、梯度裁剪、对抗性训练、数据增强和 early stopping 等技巧。同时,还可以使用一些额外的技巧,如批归一化、Leaky ReLU 激活函数、U-Net 结构、对抗性损失、预训练模型和注意力机制等来进一步提高 GAN 的性能。

2 代码实现

步骤:

  1. 导入所需的库和模块。
  2. 定义生成器的网络结构,包括全连接层和激活函数。
  3. 定义判别器的网络结构,也包括全连接层和激活函数。
  4. 定义训练函数,包括将模型移动到设备、定义损失函数和优化器、开始训练的循环等。
  5. 设置随机种子。
  6. 设置设备,如果有可用的GPU则使用GPU,否则使用CPU。
  7. 加载MNIST数据集,并进行数据预处理。
  8. 初始化生成器和判别器。
  9. 设置训练的参数,如训练轮数、生成器的输入维度等。
  10. 调用训练函数进行训练。
# 导入torch模块
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# 定义生成器的网络结构
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),  # 全连接层,输入latent_dim维,输出256维
            nn.LeakyReLU(0.2),  # LeakyReLU激活函数
            nn.Linear(256, 512),  # 全连接层,输入256维,输出512维
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),  # 全连接层,输入512维,输出1024维
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),  # 全连接层,输入1024维,输出784维
            nn.Tanh()  # Tanh激活函数
        )

    def forward(self, x):
        return self.model(x)

# 定义判别器的网络结构
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),  # 全连接层,输入784维,输出512维
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),  # 全连接层,输入512维,输出256维
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),  # 全连接层,输入256维,输出1维
            nn.Sigmoid()  # Sigmoid激活函数
        )

    def forward(self, x):
        return self.model(x)

# 定义训练函数
def train(generator, discriminator, dataloader, num_epochs, latent_dim, device):
    # 将模型移动到设备
    generator.to(device)
    discriminator.to(device)

    # 定义损失函数和优化器
    criterion = nn.BCELoss()  # 二分类交叉熵损失函数
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))  # 生成器的优化器
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))  # 判别器的优化器

    # 开始训练
    for epoch in range(num_epochs):
        for i, (real_images, _) in enumerate(dataloader):
            # 将图像转换为向量
            real_images = real_images.view(-1, 784).to(device)
            # 获取图像的batch_size
            batch_size = real_images.size(0)
            # 定义真实标签和 fake标签
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # 训练判别器
            optimizer_D.zero_grad()
            # 计算真实图像的输出
            real_outputs = discriminator(real_images)
            # 计算真实图像的损失
            real_loss = criterion(real_outputs, real_labels)

            # 生成假图像
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_images = generator(z)
            # 计算假图像的输出
            fake_outputs = discriminator(fake_images.detach())
            # 计算假图像的损失
            fake_loss = criterion(fake_outputs, fake_labels)

            # 计算判别器的损失
            d_loss = real_loss + fake_loss
            # 反向传播
            d_loss.backward()
            # 更新参数
            optimizer_D.step()

            # 训练生成器
            optimizer_G.zero_grad()
            # 计算假图像的输出
            fake_outputs = discriminator(fake_images)
            # 计算生成器的损失
            g_loss = criterion(fake_outputs, real_labels)

            # 反向传播
            g_loss.backward()
            # 更新参数
            optimizer_G.step()

            # 每200步打印一次损失
            if (i+1) % 200 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], "
                      f"D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}")
        # 每1步打印一次图像
        if (epoch+1) % 1 == 0:
            # 生成图像
            with torch.no_grad():
                z = torch.randn(10, 100).to(device)
                generated_images = generator(z).cpu().view(-1, 28, 28)

            # 展示原始数据和生成数据的图像
            fig, axes = plt.subplots(2, 5, figsize=(10, 4))
            for i, ax in enumerate(axes.flat):
                if i < 5:
                    ax.imshow(real_images[i].view(28, 28), cmap='gray')
                    ax.set_title('Real')
                else:
                    ax.imshow(generated_images[i-5], cmap='gray')
                    ax.set_title('Generated')
                ax.axis('off')
            plt.tight_layout()
            plt.show()

# 设置随机种子
torch.manual_seed(42)

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载MNIST数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# 初始化生成器和判别器
latent_dim = 100
generator = Generator(latent_dim)
discriminator = Discriminator()

# 训练GAN模型
num_epochs = 50
train(generator, discriminator, train_dataloader, num_epochs, latent_dim, device)

2.1结果

第一轮:

在这里插入图片描述
训练之后:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/211400.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Python常见实战问题解析与解决方案

更多Python学习内容&#xff1a;ipengtao.com 大家好&#xff0c;我是涛哥&#xff0c;今天为大家分享 Python常见实战问题解析与解决方案&#xff0c;全文5200字&#xff0c;阅读大约13分钟。 Python作为一门强大而灵活的编程语言&#xff0c;常常面临各种实际挑战。在本文中&…

Flink(九)【时间语义与水位线】

前言 2023-12-02-20:05&#xff0c;终于写完啦&#xff0c;最近状态不错。刚写完又收到了她的消息哈哈哈哈&#xff0c;开心。 再去全力打拼一次&#xff0c;奋战一场&#xff0c;就算最后打了败仗也无所谓&#xff0c;至少你留下了足迹。 《解忧杂货店》 1、时间语义 …

【计算机网络】15、NAT、NAPT 网络地址转换、打洞

文章目录 一、概念二、分类&#xff08;主要是传统 NAT&#xff09;2.1 基本 NAT2.2 NAPT 三、访问NAT下的内网设备的方式3.1 多拨3.2 端口转发、DMZ3.3 UPnP IGD、NAT-PMP3.4 服务器中转&#xff1a;frp 内网穿透3.4.1 NAT 打洞3.4.2 NAT 类型与打洞成功率3.4.2.1 完全圆锥形 …

C++设计模式——Bridge模式(下)

在上篇 《C设计模式——Bridge模式&#xff08;上&#xff09;》中我们对于桥接模式做了一些介绍。介于桥接模式在实际项目开发中使用广泛&#xff0c;而且也是面试中常问常新的话题。在本篇&#xff0c;我们专注bridge模式在具体的项目开发中的应用&#xff0c;举几个例子来说…

快手自动评论助手:开发流程与所需技术的深度解析

先来看实操成果&#xff0c;↑↑需要的同学可看我名字↖↖↖↖↖&#xff0c;或评论888无偿分享 一、引言 随着互联网的发展&#xff0c;越来越多的人开始使用快手这款短视频平台。在这个平台上&#xff0c;用户可以分享自己的生活点滴&#xff0c;观看他人的精彩瞬间。然而&am…

Ext4文件系统解析(一)

1、前言 熟悉Linux操作系统的都应该或多或少的了解或者使用过Ext4文件系统。 接下来&#xff0c;会简单介绍Ext4文件系统的一些特性和工作原理。 2、常用概念 在介绍Ext文件系统之前&#xff0c;先简单描述一些相关概念。 块(Block)&#xff1a;Ext文件系统存储分配的基本单…

软件工程 - 第8章 面向对象建模 - 4 - 物理体系结构建模

构件图 构件图概述 构件图描述了软件的各种构件和它们之间的依赖关系。 构件图的作用 在构件图中&#xff0c;系统中的每个物理构件都使用构件符号来表示&#xff0c;通常&#xff0c;构件图看起来像是构件图标的集合&#xff0c;这些图标代表系统中的物理部件&#xff0c;…

java学习part30callabel和线程池方式

140-多线程-线程的创建方式3、4&#xff1a;实现Callable与线程池_哔哩哔哩_bilibili 1.Callable 实现类 使用方式 返回值 2.线程池

Linux expect命令详解

在Linux系统中&#xff0c;expect 是一款非常有用的工具&#xff0c;它允许用户自动化与需要用户输入进行交互的程序。本文将深入探讨expect命令的基本语法、使用方法以及一些最佳实践。 什么是Expect命令&#xff1f; expect 是一个用于自动化交互式进程的工具。它的主要功能…

【PyTorch】线性回归

文章目录 1. 代码实现1.1 一元线性回归模型的训练 2. 代码解读2.1. tensorboardX2.1.1. tensorboardX的安装2.1.2. tensorboardX的使用 1. 代码实现 波士顿房价数据集下载 1.1 一元线性回归模型的训练 import numpy as np import torch import torch.nn as nn from torch.ut…

Ext4文件系统解析(二)

1、前言 想要了解EXT文件系统的工作原理&#xff0c;那了解文件系统在磁盘上的分布就是必不可少的。这一节主要介绍EXT文件系统硬盘存储的物理结构。 由于当前主流的CPU架构均采用小端模式&#xff0c;因此下文介绍均已小端模式为准。 2、超级块 2.1 属性 下表列举出超级块…

Java 8 中 ReentrantLock 与 Synchronized 的区别

&#x1f680; 作者主页&#xff1a; 有来技术 &#x1f525; 开源项目&#xff1a; youlai-mall &#x1f343; vue3-element-admin &#x1f343; youlai-boot &#x1f33a; 仓库主页&#xff1a; Gitee &#x1f4ab; Github &#x1f4ab; GitCode &#x1f496; 欢迎点赞…

分布式ID生成框架Leaf升级踩坑

背景&#xff1a; 在项目中需要一个统一的拿单号等唯一ID的服务&#xff0c;就想起了之前用到的leaf&#xff0c;但是因为项目要求&#xff0c;leaf的版本不符合&#xff0c;需要做一些升级 项目地址&#xff1a;https://github.com/Meituan-Dianping/Leaf 升级点&#xff1…

231202 刷题日报

周四周五&#xff0c;边值班边扯皮&#xff0c;没有刷题。。 今天主要是做了: 1. 稀疏矩阵压缩&#xff0c;十字链表法 2. 快速排序 3.349. 两个数组的交集​​​​​ 4. 174. 地下城游戏 要注意溢出问题&#xff01;

Motion 5 for Mac,释放创意,打造精彩视频特效!

Motion 5 for Mac是一款强大的视频后期特效处理软件&#xff0c;为Mac用户提供了无限的创意可能性。无论你是专业的影视制作人&#xff0c;还是想为个人视频添加独特特效的爱好者&#xff0c;Motion 5都能满足你的需求&#xff0c;让你的视频脱颖而出。 Motion 5提供了丰富多样…

跳表的基础

跳表的作用 无需数组查找目标元素-----从头遍历---O(n); 有序数组查找目标元素-----二分查找---O(logn); 链表查找目标元素----------只能从头遍历---O(n); 那么链表要如何实现O(logn)的查找时间复杂度呢-----跳表。 跳表的定义 有序链表多级索引跳表 就是一个多级链表 …

TA-Lib学习研究笔记(八)——Momentum Indicators 中

TA-Lib学习研究笔记&#xff08;八&#xff09;——Momentum Indicators 中 Momentum Indicators 动量指标&#xff0c;是最重要的股票分析指标&#xff0c;能够通过数据量化分析价格、成交量&#xff0c;预测股票走势和强度&#xff0c;大部分指标都在股票软件中提供。 11. …

力扣题:字符串的反转-11.22

力扣题-11.22 [力扣刷题攻略] Re&#xff1a;从零开始的力扣刷题生活 力扣题1&#xff1a;541. 反转字符串 II 解题思想&#xff1a;进行遍历翻转即可 class Solution(object):def reverseStr(self, s, k):""":type s: str:type k: int:rtype: str"&quo…

[计算机网络] 高手常用的几个抓包工具(下)

文章目录 高手常用的抓包工具一览什么是抓包工具优秀抓包工具HTTP Debugger ProFree Network AnalyzerKismetEtherApeNetworkMiner 结尾 高手常用的抓包工具一览 什么是抓包工具 抓包工具是一种可以捕获、分析和修改网络流量的软件。它可以帮助您进行网络调试、性能测试、安全…

JavaScript 延迟加载的艺术:按需加载的最佳实践

&#x1f90d; 前端开发工程师&#xff08;主业&#xff09;、技术博主&#xff08;副业&#xff09;、已过CET6 &#x1f368; 阿珊和她的猫_CSDN个人主页 &#x1f560; 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 &#x1f35a; 蓝桥云课签约作者、已在蓝桥云…
最新文章