使用pytorch构建一个初级的无监督的GAN网络模型

在这个系列中将系统的构建GAN及其相关的一些变种模型,来了解GAN的基本原理。本片为此系列的第一篇,实现起来很简单,所以不要期待有很好的效果出来。

第一篇我们搭建一个无监督的可以生成数字 (0-9) 手写图像的 GAN,使用MINIST数据集,包含0-9的60000张手写数字图像,如图:
在这里插入图片描述

原理

首先简单讲一下GAN的工作原理,如下为前向传播的过程:
在这里插入图片描述
GAN网络有两个模型,分别是生成器generator和判别器discriminator。generator的作用是生成图片的,也就是我们想要的结果,通过输入随机噪声来生成图片;而discriminator是判断输入的图片是真实数据还是生成的假数据,输入生成的假数据或真实数据,输出真与假的概率值。

而反向传播过程其实是分开的,即generator和discriminator是分别进行梯度更新的。且交替进行训练的,一个模型训练,另一个模型就要保持不变,保持两个模型的能力要相当才能一起进步,否则如果判别器的性能要比生成器要好的话就很容易陷入模式崩溃mdoel collapse或梯度消失等。
下图为discriminator的反向传播的过程:
在这里插入图片描述
discriminator的工作是为了将生成的假数据判别为0,将真实的数据判别为1,即公正判别非黑即白,所以loss的计算为:
在这里插入图片描述

下图为generator的反向传播的过程:
在这里插入图片描述
而generator的工作是为了将生成的假数据让discriminator判别为1,即骗过discriminator颠倒黑白,所以loss的计算为:
在这里插入图片描述

代码

下面开始直接上代码,我在网上学习别人代码的习惯是先把所有代码跑起来再来仔细看每个代码模块,我在这也就先放上所有代码再分析各个模块。
model.py:

from torch import nn
import torch

def get_generator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace=True),
    )

class Generator(nn.Module):
    def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            get_generator_block(z_dim, hidden_dim),
            get_generator_block(hidden_dim, hidden_dim * 2),
            get_generator_block(hidden_dim * 2, hidden_dim * 4),
            get_generator_block(hidden_dim * 4, hidden_dim * 8),
            nn.Linear(hidden_dim * 8, im_dim),
            nn.Sigmoid()
        )
    def forward(self, noise):
        return self.gen(noise)
    def get_gen(self):
        return self.gen

def get_discriminator_block(input_dim, output_dim):
    return nn.Sequential(
         nn.Linear(input_dim, output_dim), #Layer 1
         nn.LeakyReLU(0.2, inplace=True)
    )

class Discriminator(nn.Module):
    def __init__(self, im_dim=784, hidden_dim=128):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            get_discriminator_block(im_dim, hidden_dim * 4),
            get_discriminator_block(hidden_dim * 4, hidden_dim * 2),
            get_discriminator_block(hidden_dim * 2, hidden_dim),
            nn.Linear(hidden_dim, 1)
        )
    def forward(self, image):
        return self.disc(image)
    def get_disc(self):
        return self.disc

train.py:

import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST # Training dataset
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from model import Discriminator, Generator
torch.manual_seed(0) # Set for testing purposes, please do not change!

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples,z_dim,device=device)

criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001
device = 'cuda'

dataloader = DataLoader(
    MNIST('./', download=True, transform=transforms.ToTensor()),  # 已经下载过的可以改为False跳过下载
    batch_size=batch_size,
    shuffle=True)

gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):
    fake_noise = get_noise(num_images, z_dim, device=device)
    fake = gen(fake_noise)
    disc_fake_pred = disc(fake.detach())
    disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
    disc_real_pred = disc(real)
    disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
    disc_loss = (disc_fake_loss + disc_real_loss) / 2
    return disc_loss

def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):
    fake_noise = get_noise(num_images, z_dim, device=device)
    fake = gen(fake_noise)
    disc_fake_pred = disc(fake)
    gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
    return gen_loss
    
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
gen_loss = False
error = False
for epoch in range(n_epochs):
    # Dataloader returns the batches
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)

        # Flatten the batch of real images from the dataset
        real = real.view(cur_batch_size, -1).to(device)

        ### Update discriminator ###
        # Zero out the gradients before backpropagation
        disc_opt.zero_grad()

        # Calculate discriminator loss
        disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)

        # Update gradients
        disc_loss.backward(retain_graph=True)

        # Update optimizer
        disc_opt.step()

        ### Update generator ###
        gen_opt.zero_grad()
        gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)
        gen_loss.backward()
        gen_opt.step()

        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item() / display_step

        # Keep track of the average generator loss
        mean_generator_loss += gen_loss.item() / display_step

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            print(
                f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            show_tensor_images(fake)
            show_tensor_images(real)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1

运行结果

运行后每隔500个epoch画出fake和real,刚开始的fake和real是这样的:
在这里插入图片描述
在这里插入图片描述
到后面的fake逐渐变成这样:
在这里插入图片描述

代码解释

网络模型

model.py里面存放了generator和discriminator的网络模型,神经元使用的是简单的全连接层,后面的文章再使用卷积。
在这里插入图片描述
生成器输出为784 = 28 * 28,因为使用的是MINIST手写字体数据集,每张图的shape是28 * 28 * 1(黑白图单通道),所以输出的假数据要与真实数据的shape一致,这样输入鉴别器才不会出错。
在这里插入图片描述
生成的图片(或真实数据)直接输入鉴别器,所以鉴别器的输入也是28*28,而输出为1,即输出判别结果为真或假。
在这里插入图片描述
每个优化器仅优化一个模型的参数,所以一个模型构建一个优化器。

图像显示

在这里插入图片描述
首先将图像的tensor转到cpu上,因为PyTorch中的大部分图像处理和显示函数都是在CPU上执行的,包括我们使用的imshow。
detach() 方法将张量从计算图中分离出来,但是仍指向原变量的存放位置,不同之处只是requirse_grad为false,得到的这个tensor永远不需要计算器梯度,不具有grad,这样做的目的是避免梯度计算的影响,因为在展示图像时通常不需要计算梯度。
Pytorch的计算图由节点和边组成,节点表示张量或者Function,边表示张量和Function之间的依赖关系,类似这样:
在这里插入图片描述
一个网络模型就是一个计算图,在网络backward时候,需要用链式求导法则求出网络最后输出的梯度,然后再对网络进行优化,求导过程就如上图这样。
make_grid 函数用于将多个图像组成一个网格,方便显示。
在这里插入图片描述

在这里插入图片描述
然后每500个batch显示一次当前模型性能所能生成的图片以及当前batch的真实图片(虽然一个batch设置了128张,但是我们只展示25张),以及print出生成器和鉴别器的loss。

损失函数

在这里插入图片描述
损失函数的原理在上面的“原理”中有讲解,这里不再赘述。
在计算鉴别器的loss里,disc_fake_pred = disc(fake.detach())是对生成图片的判别,这里也使用 .detach() 的目的是将生成器产生的假数据与生成器的参数分离,使得在计算 disc_fake_pred 时不会对生成器的梯度进行传播。这是因为在训练鉴别器的阶段,我们只希望更新鉴别器的参数,而不希望更新生成器的参数(就如上面说的生成器的训练和鉴别器的应该要隔开分别训练、交替训练)。

反向传播

在这里插入图片描述
retain_graph=True参数是用来指示 PyTorch 在反向传播时保留计算图。这个参数的作用是为了在一次反向传播之后保留计算图的状态,以便后续再次调用 backward() 函数时能够继续使用这个计算图进行梯度计算。
Pytoch构建的计算图是动态图,为了节约内存,所以每次一轮迭代完之后计算图就被在内存释放,所以当你想要多次backward时候就会报如下错:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed.

而在GAN中一次的迭代需要先更新鉴别器的参数,然后再更新生成器的参数;在更新生成器的参数时,我们仍然需要使用鉴别器来鉴别real or fake,只要使用到鉴别器就需要他的计算图。因此,我们需要在调用 disc_loss.backward() 后保留计算图,以便后续调用 gen_loss.backward() 时能够继续使用相同的计算图进行梯度计算。而对于生成器的梯度更新 gen_loss.backward(),不需要显式指定 retain_graph=True。
所以,在同一个计算图上多次调用 backward() 函数时才需要使用它。

下一篇构建DCGAN。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/497924.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

进阶了解C++(6)——二叉树OJ题

Leetcode.606.根据二叉树创建字符串: 606. 根据二叉树创建字符串 - 力扣(LeetCode) 难度不大,根据题目的描述,首先对二叉树进行一次前序遍历,即: class Solution { public:string tree2str(Tr…

TheMoon 恶意软件短时间感染 6,000 台华硕路由器以获取代理服务

文章目录 针对华硕路由器Faceless代理服务预防措施 一种名为"TheMoon"的新变种恶意软件僵尸网络已经被发现正在侵入全球88个国家数千台过时的小型办公室与家庭办公室(SOHO)路由器以及物联网设备。 "TheMoon"与“Faceless”代理服务有关联,该服务…

【算法题】三道题理解算法思想--滑动窗口篇

滑动窗口 本篇文章中会带大家从零基础到学会利用滑动窗口的思想解决算法题,我从力扣上筛选了三道题,难度由浅到深,会附上题目链接以及算法原理和解题代码,希望大家能坚持看完,绝对能有收获,大家有更好的思…

Flask学习(六):蓝图(Blueprint)

蓝图(Blueprint):将各个业务进行区分,然后每一个业务单元可以独立维护,Blueprint可以单独具有自己的模板、静态文件或者其它的通用操作方法,它并不是必须要实现应用的视图和函数的。 Demo目录结构&#xf…

计算机专业学习单片机有什么意义吗?

玩单片机跟玩计算机区别还是很大的, 单片机有众多的种类,每一种又可能有很多个系列.可以说单片机就是为了专款专用而生的.这样来达到产品成本的降低,这就是现在身边的很多的电子产品价格一降再降的原因之一.在开始前我有一些资料,是我根据网友给的问题精心整理了一…

Python拆分PDF、Python合并PDF

WPS能拆分合并&#xff0c;但却是要输入编辑密码&#xff0c;我没有。故写了个脚本来做拆分&#xff0c;顺便附上合并的代码。 代码如下&#xff08;extract.py) #!/usr/bin/env python """PDF拆分脚本(需要Python3.10)Usage::$ python extract.py <pdf-fil…

腾讯云4核8g服务器多少钱?2024轻量和CVM收费价格表

2024年腾讯云4核8G服务器租用优惠价格&#xff1a;轻量应用服务器4核8G12M带宽646元15个月&#xff0c;CVM云服务器S5实例优惠价格1437.24元买一年送3个月&#xff0c;腾讯云4核8G服务器活动页面 txybk.com/go/txy 活动链接打开如下图&#xff1a; 腾讯云4核8G服务器优惠价格 轻…

uniapp 微信小程序 canvas 手写板获取书写内容区域并输出

uni.canvasGetImageData 返回一个数组&#xff0c;用来描述 canvas 区域隐含的像素数据&#xff0c;在自定义组件下&#xff0c;第二个参数传入自定义组件实例 this&#xff0c;以操作组件内 组件。 // 获取目标 canvas 的像素信息 pixelData let canvas uni.createSelector…

Linux 系统 CentOS7 上搭建 Hadoop HDFS集群详细步骤

集群搭建 整体思路:先在一个节点上安装、配置,然后再克隆出多个节点,修改 IP ,免密,主机名等 提前规划: 需要三个节点,主机名分别命名:node1、node2、node3 在下面对 node1 配置时,先假设 node2 和 node3 是存在的 **注意:**整个搭建过程,除了1和2 步,其他操作都使…

linux 内存介绍

大致共有四类&#xff1a;VSS、RSS、PSS、USS &#xff0c;通常情况下&#xff0c;VSS > RSS > PSS > USS 1.VSS(Virtual Set Size)虚拟耗用内存&#xff08;包含共享库占用的内存&#xff09; VSS表示一个进程可访问的全部内存地址空间的大小。这个大小包括了进程已…

Vue3使用vue-office插件实现word预览

首先, 我们先来创建一个Vue3项目 npm init vuelatest pnpm i npm run dev运行起来之后, 我们将App.vue中的代码全部删除掉 现在, 页面干净了, 我们需要安装vue-office插件 npm install vue-office/docx vue-demi安装完成之后, 我们就可以在页面中进行使用了 需要我们将组件…

边缘计算AI盒子目前支持的AI智能算法、视频智能分析算法有哪些,应用于大型厂矿安全生产风险管控

一、前端设备实现AI算法 主要是基于安卓的布控球实现&#xff0c;已有的算法包括&#xff1a; 1&#xff09;人脸&#xff1b;2&#xff09;车牌&#xff1b;3&#xff09;是否佩戴安全帽&#xff1b;4&#xff09;是否穿着工装&#xff1b; 可以支持定制开发 烟雾&#xf…

(免费分享)基于springboot,vue超市管理系统

开发工具&#xff1a;IDEA 服务器&#xff1a;Tomcat9.0&#xff0c; jdk1.8 项目构建&#xff1a;maven 数据库&#xff1a;mysql5.7 项目采用前后端分离 前端技术&#xff1a;vueelementUI 服务端技术&#xff1a;springbootmybatis-plusredis 本项目分为系统管理员、…

|行业洞察·手机|《2024手机行业及营销趋势报告-18页》

报告的主要内容解读&#xff1a; 手机行业概述及品牌分布&#xff1a; 2022年&#xff0c;受疫情影响&#xff0c;中国国内手机市场出货量下降22.6%&#xff0c;总计2.72亿部。5G手机市场占有率中&#xff0c;苹果领先&#xff0c;其次是vivo、OPPO和华为。消费者换机时更注重性…

鸿蒙OS开发实战:【悬浮窗口】

背景 悬浮视图或者窗体&#xff0c;在Android和iOS两大移动平台均有使用&#xff0c;HarmonyOS 也实现了此功能&#xff0c;如下为大家分享一下效果 准备 熟读HarmonyOS 悬浮窗口指导 熟读HarmonyOS 手势指导 熟读ALC签名指导&#xff0c;用于可以申请 “ohos.permission.S…

github | ssh拉取github仓库报错connect to host github.com port 22: Connection refused

配置ssh key 通过 ssh key 解决本地和服务器连接的问题 $ cd ~/. ssh #检查本机已存在的ssh密钥 如果提示 No such file or directory 则表示第一次使用git 输入&#xff1a; ssh-keygen -t rsa -C "邮件地址" 并且连续3次回车&#xff0c;最终会生成一个文件&am…

如何在Flutter中进行网络请求?

Hello&#xff01;大家好&#xff0c;我是咕噜铁蛋&#xff0c;你们的好朋友&#xff01;今天&#xff0c;我想和大家分享一下在Flutter中如何进行网络请求。Flutter作为一个跨平台的开发框架&#xff0c;网络请求是其实现数据交互的重要一环。下面&#xff0c;我将详细介绍几种…

JVM实战之性能调优[2](线程转储案例认识和分析)

文章目录 版权声明案例1&#xff1a;CPU占用率高问题问题描述解决思路补充内容 案例2&#xff1a;接口响应时间长问题问题描述解决思路Arthas trace命令Arthas watch命令解决问题 案例3&#xff1a;定位偏底层性能问题问题描述解决思路&#xff1a;Arthas火焰图问题解决 案例4&…

Siemens S7-1500TCPU 运动机构系统功能简介

目录 引言&#xff1a; 1.0 术语定义 2.0 基本知识 2.1 运动系统工艺对象 2.2 坐标系与标架 3.0 运动机构系统类型 3.1 直角坐标型 3.2 轮腿型 3.3 平面关节型 3.4 关节型 3.5 并联型 3.6 圆柱坐标型 3.7 三轴型 4.0 运动系统的运动 4.1 运动类型 4.1.1 线性运动…

ArcGIS Pro横向水平图例

终于知道ArcGIS Pro怎么调横向图例了&#xff01; 简单的像0一样 旋转&#xff0c;左转右转随便转 然后调整图例项间距就可以了&#xff0c;参数太多就随便试&#xff0c;总有一款适合你&#xff01; 要调整长度&#xff0c;就调整图例块的大小。完美&#xff01; 好不容易…
最新文章