使用pytorch构建一个无监督的深度卷积GAN网络模型

本文为此系列的第二篇DCGAN,上一篇为初级的GAN。普通GAN有训练不稳定、容易陷入局部最优等问题,DCGAN相对于普通GAN的优点是能够生成更加逼真、清晰的图像。
因为DCGAN是在GAN的基础上的改造,所以本篇只针对GAN的改造点进行讲解,其他还有不太了解的原理可以返回上一篇进行观看。

本文仍然使用MNIST手写数字数据集来构建一个深度卷积GAN(Deep Convolutional GAN)DCGAN,将使用卷积来替代全连接层,点击查看论文,generator的网络结构图如下:
在这里插入图片描述
DCGAN模型有以下特点:

  1. 判别器模型使用卷积步长取代了空间池化,生成器模型中使用反卷积操作扩大数据维度。
  2. 除了生成器模型的输出层和判别器模型的输入层,在整个对抗网络的其它层上都使用了Batch Normalization,原因是Batch Normalization可以稳定学习,有助于优化初始化参数值不良而导致的训练问题。
  3. 整个网络去除了全连接层,直接使用卷积层连接生成器和判别器的输入层以及输出层。
  4. 在生成器的输出层使用Tanh激活函数以控制输出范围,而在其它层中均使用了ReLU激活函数;在判别器上使用Leaky ReLU激活函数。

代码

model.py:

from torch import nn

class Generator(nn.Module):
    def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        # Build the neural network
        self.gen = nn.Sequential(
            self.make_gen_block(z_dim, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )
    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True)
            )
        else: # Final Layer
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh()
            )
    def unsqueeze_noise(self, noise):
        return noise.view(len(noise), self.z_dim, 1, 1)    # [b,c,h,w]
    def forward(self, noise):
        x = self.unsqueeze_noise(noise)
        return self.gen(x)

class Discriminator(nn.Module):
    def __init__(self, im_chan=1, hidden_dim=16):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.make_disc_block(im_chan, hidden_dim),
            self.make_disc_block(hidden_dim, hidden_dim * 2),
            self.make_disc_block(hidden_dim * 2, 1, final_layer=True),
        )
    def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True)
            )
        else:  # Final Layer
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride)
            )
    def forward(self, image):
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)

train.py:

import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from model import *
torch.manual_seed(0) # Set for testing purposes, please do not change!


def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples, z_dim, device=device)

criterion = nn.BCEWithLogitsLoss()
z_dim = 64
display_step = 500
batch_size = 1280
# A learning rate of 0.0002 works well on DCGAN
lr = 0.0002

beta_1 = 0.5
beta_2 = 0.999
device = 'cuda'

# You can tranform the image values to be between -1 and 1 (the range of the tanh activation)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataloader = DataLoader(
    MNIST('.', download=False, transform=transform),
    batch_size=batch_size,
    shuffle=True)

gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

n_epochs = 500
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
for epoch in range(n_epochs):
    # Dataloader returns the batches
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.to(device)

        ## Update discriminator ##
        disc_opt.zero_grad()
        fake_noise = get_noise(cur_batch_size, z_dim, device=device)
        fake = gen(fake_noise)
        disc_fake_pred = disc(fake.detach())
        disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
        disc_real_pred = disc(real)
        disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
        disc_loss = (disc_fake_loss + disc_real_loss) / 2

        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item() / display_step
        # Update gradients
        disc_loss.backward(retain_graph=True)
        # Update optimizer
        disc_opt.step()

        ## Update generator ##
        gen_opt.zero_grad()
        fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
        fake_2 = gen(fake_noise_2)
        disc_fake_pred = disc(fake_2)
        gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
        gen_loss.backward()
        gen_opt.step()

        # Keep track of the average generator loss
        mean_generator_loss += gen_loss.item() / display_step

        ## Visualization code ##
        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
            show_tensor_images(fake)
            show_tensor_images(real)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1

每500个batch展示一次
每500个batch展示一次。
在这里插入图片描述
可以看到生成器的网络模型不再使用全连接,使用反卷积操作扩大数据维度;在输出层使用Tanh激活函数以控制输出范围,而在其它层中均使用了ReLU激活函数;在隐藏层中每层都使用BN来讲输出归到一定的范围内来稳定学习,使得后层的隐藏单元不过分依赖本层的隐藏单元,减弱内部协变量偏移,从而加速对特征的学习。
因为不再使用全连接而是使用卷积,所以输入的dimension变为channel,所以输入之前先改变noise的shape为(batch_size,channel,high,width)。
在这里插入图片描述
判别器的网络模型使用卷积代替的全连接,使用卷积操作减小数据维度;隐藏层中每层在激活之前使用BN。
在这里插入图片描述
对生成器和鉴别器的权重进行初始化,对于卷积层和转置卷积层(也就是反卷积层)使用正态分布来初始化权重(均值为0,标准差为0.02)的原因是为了确保权重的初始值具有适当的大小,并且不会过大或过小,从而避免梯度消失或梯度爆炸的问题。
对于BN化层,同样使用正态分布来初始化权重,同时将偏置项初始化为0。这是因为批归一化层在训练中通过调整均值和方差来规范化输入数据,因此初始的权重和偏置项都设置为较小的值,有助于加速网络的收敛。

下一篇构建WGAN_GP。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/498158.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Pytorch的hook函数

hook函数是勾子函数,用于在不改变原始模型结构的情况下,注入一些新的代码用于调试和检验模型,常见的用法有保留非叶子结点的梯度数据(Pytorch的非叶子节点的梯度数据在计算完毕之后就会被删除,访问的时候会显示为None&…

RegSeg 学习笔记(待完善)

论文阅读 解决的问题 引用别的论文的内容 可以用 controlf 寻找想要的内容 PPM 空间金字塔池化改进 SPP / SPPF / SimSPPF / ASPP / RFB / SPPCSPC / SPPFCSPC / SPPELAN  ASPP STDC:short-term dense concatenate module 和 DDRNet SE-ResNeXt …

快速入门Axure RP:解答4个关键问题!

软件Axure RP 是一种功能强大的设计工具,用于使用 Web、移动和桌面应用程序项目创建交互原型。Axure RP软件中的 RP代表快速原型制作,这是软件Axure RP的核心特征。用户使用Axurere RP软件可以快速地将简单的想法创建成线框图和原型。Axure 因此&#xf…

实时数仓之实时数仓架构(Hudi)

目前比较流行的实时数仓架构有两类,其中一类是以FlinkDoris为核心的实时数仓架构方案;另一类是以湖仓一体架构为核心的实时数仓架构方案。本文针对FlinkHudi湖仓一体架构进行介绍,这套架构的特点是可以基于一套数据完全实现Lambda架构。实时数…

【二叉树】Leetcode 98. 验证二叉搜索树【中等】

验证二叉搜索树 给你一个二叉树的根节点 root ,判断其是否是一个有效的二叉搜索树。 有效 二叉搜索树定义如下: 节点的左子树只包含 小于 当前节点的数。节点的右子树只包含 大于 当前节点的数。所有左子树和右子树自身必须也是二叉搜索树。 示例1&a…

【Python函数和类2/6】函数的参数

目录 目标 为函数设置参数 传递实参 关键字实参 关键字实参的顺序 位置实参 常见错误 缺少实参 位置实参的顺序 默认值形参 参数的优先级 默认值形参的位置 总结 目标 上篇博客中,我们在定义函数时,使用了空的括号。这表示它不需要任何信息就…

浅谈C语言编译与链接

个人主页(找往期文章包括但不限于本期文章中不懂的知识点):我要学编程(ಥ_ಥ)-CSDN博客 翻译环境和运行环境 在ANSI C(标准 C)的任何一种实现中,存在两个不同的环境。 第1种是翻译环境,在这个…

ssh 公私钥(github)

一、生成ssh公私钥 生成自定义名称的SSH公钥和私钥对,需要使用ssh-keygen命令,这是大多数Linux和Unix系统自带的标准工具。下面,简单展示如何使用ssh-keygen命令来生成具有自定义名称的SSH密钥对。 步骤 1: 打开终端 首先,打开我…

增强现实(AR)和虚拟现实(VR)营销的未来:沉浸式体验和品牌参与

--- 如何将AR和VR技术应用于营销,以提高品牌知名度、客户参与度 增强现实(AR)和虚拟现实(VR)不再只是游戏。这些技术为品牌与受众互动提供了创新的方式。营销人员可以创造更好的客户体验,并为身临其境的故…

hadoop-3.1.1分布式搭建与常用命令

一、准备工作 1.首先需要三台虚拟机: master 、 node1 、 node2 2.时间同步 ntpdate ntp.aliyun.com 3.调整时区 cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime 4.jdk1.8 java -version 5.修改主机名 三台分别执行 vim /etc/hostname 并将内容指定为…

电脑突然死机怎么办?

死机是电脑常见的故障问题,尤其是对于老式电脑来说,一言不合电脑画面就静止了,最后只能强制关机重启。那么你一定想知道是什么原因造成的吧,一般散热不良最容易让电脑死机,还有系统故障,比如不小心误删了系…

【实现报告】学生信息管理系统(顺序表)

目录 实验一 线性表的基本操作 一、实验目的 二、实验内容 三、实验提示 四、实验要求 五、实验代码如下: (一)顺序表的构建及初始化 (二)检查顺序表是否需要扩容 (三)根据指定学生个…

企业网站建设的方法的相关问题的解决办法的问题

现在市场上比较大的公司都建立了自己的企业网站,比如华为、小米等,在他们的企业网站中,可以充分展示自己产品的优势,介绍公司的优质服务。 这都是让顾客改变购买想法的重要因素。 现在互联网发达了,很多人在购买产品的…

详细分析axios.js:72 Uncaught (in promise) Error: 未知错误 的解决方法(图文)

目录 1. 问题所示2. 原理分析3. 解决方法1. 问题所示 调试接口的时候,打开一个网页,在终端出现如下错误: axios.js:72 Uncaught (in promise) Error: 未知错误at __webpack_exports__.default (axios.js:72:1)截图如下所示: 2. 原理分析 点击浏览器的Bug出错: // 如果…

C/C++语言学习路线: 嵌入式开发、底层软件、操作系统方向(持续更新)

初级:用好手上的锤子 1 【感性】认识 C 系编程语言开发调试过程 1.1 视频教程点到为止 1.2 炫技视频看看就行 1.3 编程游戏不玩也罢 有些游戏的主题任务就是编程,游戏和实际应用环境有一定差异(工具、操作流程),在…

进程知识点

引用的文章:操作系统——进程通信(IPC)_系统ipc-CSDN博客 面试汇总(五):操作系统常见面试总结(一):进程与线程的相关知识点 - 知乎 (zhihu.com) 二、进程的定义、组成、组成方式及特征_进程的组成部分必须包含-CSDN博…

2024年北京事业单位报名照片要求,注意格式

2024年北京事业单位报名照片要求,注意格式

【C语言】预处理常见知识详解(宏详解)

文章目录 1、预定义符号2、define2.1 define 定义常量2.2 define 定义宏 3、#和##3.1 **#**3.2 **##** 4、条件编译(开关) 1、预定义符号 在C语言中内置了一些预定义符号,可以直接使用,这些符号实在预处理期间处理的,…

工控安全双评合规:等保测评与商用密码共铸新篇章

01.双评合规概述 2017年《中华人民共和国网络安全法》开始正式施行,网络安全等级测评工作也在全国范围内按照相关法律法规和技术标准要求全面落实实施。2020年1月《中华人民共和国密码法》开始正式施行,商用密码应用安全性评估也在有序推广和逐步推进。…

信息安全之网络安全防护

先来看看计算机网络通信面临的威胁: 截获——从网络上窃听他人的通信内容中断——有意中断他人在网络上的通信篡改——故意篡改网络上传送的报文伪造——伪造信息在网络上传送 截获信息的攻击称为被动攻击,而更改信息和拒绝用户使用资源的攻击称为主动…
最新文章