第G1周:生成对抗网络(GAN)入门

📌 基础任务:了解什么是生成对抗网络(GAN) 学习本文代码,并跑通代码

🎈进阶任务: 调用训练好的模型生成新图像 

目录

一、理论基础

1.1生成器

1.2判别器

1.3基本原理

二、前期准备工作

2.1定义超参数

2.2下载数据

2.3配置数据

三、定义模型

3.1定义鉴别器

3.2定义生成器

四、训练模型

4.1创建实例

4.2训练模型

4.3保存模型


一、理论基础

       生成对抗网络(Generative Adversarial Networks, GAN)近年来深度学习领域的一个热点点向。
GAN并不指代某一个具体的神经网络,而是指一类基于博弈思想而设计的神经网络。 GAN由两个分别被称为生成器(Generator) 判别器(Discriminator) 的神经网络组成。其中,生成器从某种噪声分布中随机采样作为输入,输出与训练集中真实样本非常相似的人工样本;判别器的输入则为真实样本或人工样本,期的将人工样本与真实样本尽可能地区分出来。生成器和判别器交替运行,相互博弈,各自的能力都得到升。理想情况下,经过足够次数的博弈之后,判别器无法判断给定样本的真实性,即对于所有样本都输出50%真,50%假的判断。此时,生成器输出的人工样本已经逼真到使判别器无法分辨真假,停止博弈。这样就可以得到一个具賄"伪造” 实样本能力的生成器。

1.1生成器

       GANs中,生成器 G 选取随机噪声 z 作为输入,通过生成器的不断拟合,最终输出一个和真实样本尺寸相同,分布相似的伪造样本 G ( z ) G(z)G(z) 。生成器的本质是一个使用生成式方法的模型,它对数据的分布假设和分布参数进行学习,然后根据学习到的模型重新采样出新的样本。从数学上来说,生成式方法对于给定的真实数据,首先需要对数据的显式变量或隐含变量做分布假设;然后再将真实数据输入到模型中对变量、参数进行训练;最后得到一个学习后的近似分布,这个分布可以用来生成新的数据。从机器学习的角度来说,模型不会去做分布假设,而是通过不断地学习真实数据,对模型进行修正,最后也可以得到一个学习后的模型来做样本生成任务。这种方法不同于数学方法,学习的过程对人类理解较不直观。

1.2判别器

       GANs中,判别器 D 对于输入的样本 x,输出一个 [ 0 , 1 ] [0,1][0,1] 之间的概率数值 D ( x ) D(x)D(x)。x 可能是来自于原始数据集中的真实样本 x,也可能是来自于生成器 G 的人工样本 G ( z ) G(z)G(z)。通常约定,概率值 D ( x ) D(x)D(x) 越接近于1就代表此样本为真实样本的可能性更大;反之概率值越小则此样本为伪造样本的可能性越大。也就是说,这里的判别器是一个二分类的神经网络分类器,目的不是判定输入数据的原始类别,而是区分输入样本的真伪。可以注意到,不管在生成器还是判别器中,样本的类别信息都没有用到,也表明 GAN 是一个无监督的学习过程

1.3基本原理

       GAN是博弈论和机器学习相结合的产物,于2014年Ian Goodfellow的论文中问世,一经问世即火爆足以看出人们对于这种算法的认可和狂热的研究热忱。想要更详细的了解GAN,就要知道它是怎么来的,以及这种算法出现的意义是什么。研究者最初想要通过计算机完成自动生成数据的功能,例如通过训练某种算法模型,让某模型学习过一些苹果的图片后能自动生成苹果的图片,具备些功能的算法即认为具有生成功能。但是GAN不是第一个生成算法,而是以往的生成算法在衡量生成图片和真实图片的差距时采用均方误差作为损失函数,但是研究者发现有时均方误差一样的两张生成图片效果却截然不同,鉴于此不足Ian Goodfellow提出了GAN

如上图所示,GAN是由两个模型组成的:生成模型G和判别模型D。首先第一代生成模型1G的输入是随机噪声z,然后生成模型会生成一张初级照片,训练一代判别模型1D另其进行二分类操作,将生成的图片判别为0,而真实图片判别为1;为了欺瞒一代鉴别器,于是一代生成模型开始优化,然后它进阶成了二代,当它生成的数据成功欺瞒1D时,鉴别模型也会优化更新,进而升级为2D,按照同样的过程也会不断更新出N代的G和D。

二、前期准备工作

2.1定义超参数

import os
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.utils import save_image
import torchvision.transforms as transforms

# 创建文件夹
os.makedirs('./output/images/', exist_ok=True)
os.makedirs('./output/', exist_ok=True)
os.makedirs('./data/MNIST/', exist_ok=True)

# 超参数配置
n_epochs = 5
batch_size = 64
lr = 0.0002
b1 = 0.5
b2 = 0.999
n_cpu = 2
latent_dim = 100
img_size = 28
channels = 1
sample_interval = 500

# 图像的尺寸:(1, 28, 28),和图像的像素面积:(784)
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)

# 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
print(cuda)

2.2下载数据

## mnist数据集下载
mnist = datasets.MNIST(root='./data/', 
                       train=True, 
                       download=False,
                       transform=transforms.Compose([transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]))

2.3配置数据

# 配置数据到加载器
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)

三、定义模型

3.1定义鉴别器

'''
定义判别器 Discriminator
将图片28x28展开成784,然后通过多层感知器,
中间经过斜率设置为0.2的LeakyReLU激活函数,
最后接sigmoid激活函数得到一个0到1之间的概率进行二分类
'''
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
                     nn.Linear(img_area, 512),
                     nn.LeakyReLU(0.2, inplace=True),
                     nn.Linear(512, 256),
                     nn.LeakyReLU(0.2, inplace=True),
                     nn.Linear(256, 1),
                     nn.Sigmoid())
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)  # 鉴别器输入是一个被view展开的(784)的一维图像:(64, 784)
        validity = self.model(img_flat)       # 通过鉴别器网络
        return validity                       # 鉴别器返回的是一个[0, 1]间的概率

3.2定义生成器


'''
定义生成器 Generator
输入一个100维的0~1之间的高斯分布,
然后通过第一层线性变换将其映射到256维,
然后通过LeakyReLU激活函数,
接着进行一个线性变换,
再经过一个LeakyReLU激活函数,
然后经过陷先变换将其变成784维,
最后经过Tanh激活函数,
是希望生成的假的图片数据分布,能够再-1~1之间。
'''
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # 模型中间块
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        # prod():返回给定轴上的数组元素的乘积:1*28*28=784
        self.model = nn.Sequential(
                     *block(latent_dim, 128, normalize=False),
                     *block(128, 256),
                     *block(256, 512),
                     *block(512, 1024),
                     nn.Linear(1024, img_area),
                     nn.Tanh())
    
    # view():相当于numpy中的reshape,重新定义矩阵的形状:这里是reshape(64, 1,28, 28)
    def forward(self, z):                          # 输入的是(64, 100)的噪声数据
        imgs = self.model(z)                       # 噪声数据通过生成器模型
        imgs = imgs.view(imgs.size(0), *img_shape)  # reshape成(64, 1,28, 28)
        return imgs                                # 输出为64张大小为(1, 28, 28)的图像

四、训练模型

4.1创建实例

# 创建生成器、判别器对象
generator = Generator()
discriminator = Discriminator()
# 定义loss的度量方式(二分类的交叉熵)
criterion = torch.nn.BCELoss()
# 定义又换函数,学习率为0.0003
# betas: 用于计算梯度以及梯度平方的运行平均值的系数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
# 若有显卡,在cuda模式中运行
if torch.cuda.is_available():
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion = criterion.cuda()

4.2训练模型

# 进行多个epoch的训练
for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        ''' 训练判别器 Train Discriminator '''
        # 分为两部分:1、真的图像判别为真;2、假的图像判别为假
        imgs = imgs.view(imgs.size(0), -1)
        real_img = Variable(imgs).cuda()
        real_label = Variable(torch.ones(imgs.size(0), 1)).cuda()
        fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda()
        # 计算真实图片的损失
        real_out = discriminator(real_img)
        loss_real_D = criterion(real_out, real_label)
        real_scores = real_out
        # 计算假的图片的损失
        # detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()
        fake_img = generator(z).detach()
        fake_out = discriminator(fake_img)
        loss_fake_D = criterion(fake_out, fake_label)
        fake_scores = fake_out
        # 损失函数和优化
        loss_D = loss_real_D + loss_fake_D
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()
        
        ''' 训练生成器 Train Generator '''
        # 原理:目的是希望生成的假的图片被判别器判断为真的图片,
        # 在此过程中,将判别器固定,将假的图片传入判别器的结果与真实的label对应,
        # 反向传播更新的参数是生成网络里面的参数,
        # 这样可以通过更新生成网络里面的参数,来训练网络,使得生成的图片让判别器以为是真的, 这样就达到了对抗的目的
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()
        fake_img = generator(z)
        output = discriminator(fake_img)
        # 损失函数和优化
        loss_G = criterion(output, real_label)
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()
        
        ''' 打印训练过程中的日志 '''
        # item():取出单元素张量的元素值并返回该值,保持原元素类型不变
        if (i + 1) % 300 == 0:
            print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
            % (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean()))
        # 保存训练过程中的图像
        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            save_image(fake_img.data[:25], "./output/images/%d.png" % batches_done, nrow=5, normalize=True)

4.3保存模型

# 保存模型
torch.save(generator.state_dict(), './output/generator.pth')
torch.save(discriminator.state_dict(), './output/discriminator.pth')

参考博文:https://blog.csdn.net/m0_58585940/article/details/131731770?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/79631.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于CNN卷积神经网络的口罩检测识别系统matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 matlab2022a 3.部分核心程序 ............................................................ % 循环处理每张输入图像 for…

汽车租赁管理系统/汽车租赁网站的设计与实现

摘 要 租赁汽车走进社区,走进生活,成为当今生活中不可缺少的一部分。随着汽车租赁业的发展,加强管理和规范管理司促进汽车租赁业健康发展的重要推动力。汽车租赁业为道路运输车辆一种新的融资服务形式、广大人民群众一种新的出行消费方式和…

Centos7 配置Docker镜像加速器

docker实战(一):centos7 yum安装docker docker实战(二):基础命令篇 docker实战(三):docker网络模式(超详细) docker实战(四):docker架构原理 docker实战(五):docker镜像及仓库配置 docker实战(六):docker 网络及数据卷设置 docker实战(七):docker 性质及版本选择 认知升…

区块链中slot、epoch、以及在slot和epoch中的出块机制,分叉原理(自己备用)

以太坊2.0中有两个时间概念:时隙槽slot 和 时段(周期)epoch。其中一个slot为12秒,而每个 epoch 由 32 个 slots 组成,所以每个epoch共384秒,也就是 6.4 分钟。 对于每个epoch,使用RANDAO伪随机…

Unity小项目__打砖块

//1.添加地面 1)创建一个平面,命名为Ground。 2)创建一个Materials文件夹,并在其中创建一个Ground材质,左键拖动其赋给平面Plane。 3)根据喜好设置Ground材质和Ground平面的属性。 // 2.创建墙体 1)创建一个Cube&…

基于Simulink的Chaos混沌电路设计与仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 matlab2022a 3.部分核心程序 07_001m 4.算法理论概述 混沌电路是一类特殊的非线性电路,其输出信号表现出无规律…

数字后端笔试题(1)DCG后congestion问题

我正在「拾陆楼」和朋友们讨论有趣的话题,你⼀起来吧? 拾陆楼知识星球入口 已知某模块的DCG结果显示存在congestion,有congestion部分逻辑结构如下图: 问题1: 如何分析该电路有congestion问题的原因? 答:data selecti…

基于STM32+FreeRTOS的四轴机械臂

目录 项目概述: 一 准备阶段(都是些废话) 二 裸机测试功能 1.摇杆控制 接线: CubeMX配置: 代码: 2.蓝牙控制 接线: CubeMX配置 代码: 3.示教器控制 4.记录动作信息 5.执…

el-table 多个表格切换多选框显示bug

今天写了个功能,点击左侧的树做判断,一级树节点显示系统页面,二级树节点显示数据库页面,三级树节点显示表页面。 数据库页面和表页面分别有2个el-table ,上面的没有多选框,下面的有多选框 现在出现bug,在…

Linux学习之iptables过滤规则的使用

cat /etc/redhat-release看到操作系统是CentOS Linux release 7.6.1810,uname -r看到内核版本是3.10.0-957.el7.x86_64,iptables --version可以看到iptables版本是v1.4.21。 iptables -t filter -A INPUT -s 10.0.0.8 -j ACCEPT会在最后一行插入。 10…

winform 封装unity web player 用户控件

环境: VS2015Unity 5.3.6f1 (64-bit) 目的: Unity官方提供的UnityWebPlayer控件在嵌入Winform时要求读取的.unity3d文件路径(Src)必须是绝对路径,如果移动代码到另一台电脑,需要重新修改src。于是考虑使…

Hadoop学习:深入解析MapReduce的大数据魔力之数据压缩(四)

Hadoop学习:深入解析MapReduce的大数据魔力之数据压缩(四) 4.1 概述1)压缩的好处和坏处2)压缩原则 4.2 MR 支持的压缩编码4.3 压缩方式选择4.3.1 Gzip 压缩4.3.2 Bzip2 压缩4.3.3 Lzo 压缩4.3.4 Snappy 压缩4.3.5 压缩…

Apache JMeter

下载 Apache JMeter 并安装 java链接 打开 apache-jmeter-5.4.1\bin 找到jmeter.bat 双击打开 或者 ApacheJMeter.jar 双击打开 设置中文 找到 options 》choose Language 》chinese 新建 计划 创建线程组 添加Http请求 配置元件添加请求头参数(content-type&…

腾讯云 CODING 荣获 TiD 质量竞争力大会 2023 软件研发优秀案例

点击链接了解详情 8 月 13-16 日,由中关村智联软件服务业质量创新联盟主办的第十届 TiD 2023 质量竞争力大会在北京国家会议中心召开。本次大会以“聚焦数字化转型 探索智能软件研发”为主题,聚焦智能化测试工程、数据要素、元宇宙、数字化转型、产融合作…

报名开启 | HarmonyOS第一课“营”在暑期系列直播

<HarmonyOS第一课>2023年再次启航&#xff01; 特邀HarmonyOS布道师云集华为开发者联盟直播间 聚焦HarmonyOS 4版本新特性 邀您一同学习赢好礼&#xff01; 你准备好了吗&#xff1f; ↓↓↓预约报名↓↓↓ 点击关注了解更多资讯&#xff0c;报名学习

CS:GO升级 Linux不再是“法外之地”

在前天的VAC大规模封禁中&#xff0c;有不少Linux平台的作弊玩家也迎来了“迟到”的VAC封禁。   一直以来&#xff0c;Linux就是VAC封禁的法外之地。虽然大部分玩家都使用Windows平台进行游戏。但实际上&#xff0c;使用Linux畅玩CS:GO的玩家也不在少数。 以前V社主要打击W…

LVS - DR

LVS-DR 数据流向 客户端发送请求到 Director Server&#xff08;负载均衡器&#xff09;&#xff0c;请求的数据报文&#xff08;源 IP 是 CIP,目标 IP 是 VIP&#xff09;到达内核空间。Director Server 和 Real Server 在同一个网络中&#xff0c;数据通过二层数据链路层来传…

商城-学习整理-高级-商城业务-商品上架es(十)

目录 一、商品上架1、sku在ES中存储模型分析2、nested数据类型场景3、构造基本数据&#xff08;商品上架&#xff09; 二、首页1、项目介绍2、整合thymeleaf&#xff08;spring-boot下模板引擎&#xff09;渲染页面3、页面修改不重启服务器实时更新4、渲染二级三级数据 三、搭建…

「UG/NX」Block UI 面收集器FaceCollector

✨博客主页何曾参静谧的博客📌文章专栏「UG/NX」BlockUI集合📚全部专栏「UG/NX」NX二次开发「UG/NX」BlockUI集合「VS」Visual Studio「QT」QT5程序设计「C/C+&#

LeetCode150道面试经典题-- 求算数平方根(简单)

1.题目 给你一个非负整数 x &#xff0c;计算并返回 x 的 算术平方根 。 由于返回类型是整数&#xff0c;结果只保留 整数部分 &#xff0c;小数部分将被 舍去 。 注意&#xff1a;不允许使用任何内置指数函数和算符&#xff0c;例如 pow(x, 0.5) 或者 x ** 0.5 。 2.示例 …
最新文章