PyTorch深度学习实战(32)——DCGAN详解与实现

PyTorch深度学习实战(32)——DCGAN详解与实现

    • 0. 前言
    • 1. 模型与数据集分析
      • 1.1 模型分析
      • 1.2 数据集介绍
    • 2. 构建 DCGAN 生成人脸图像
    • 小结
    • 系列链接

0. 前言

DCGAN (Deep Convolutional Generative Adversarial Networks) 是基于生成对抗网络 (Convolutional Generative Adversarial Networks, GAN) 的深度学习模型,相比传统的 GAN 模型,DCGAN 通过引入卷积神经网络 (Convolutional Neural Networks, CNN) 架构来提升生成网络和判别网络的性能。DCGAN 中的生成网络和判别网络都是使用卷积层和反卷积层构建的深度神经网络。生成网络接收一个随机噪声向量作为输入,并通过反卷积层将其逐渐转化为与训练数据相似的输出图像,判别网络则是一个用于分类真实和生成图像的卷积神经网络。

1. 模型与数据集分析

1.1 模型分析

我们已经学习了 GAN 的基本原理并并使用 PyTorch 实现了 GAN 模型用于生成 MNIST 手写数字图像。同时,我们已经知道,与普通神经网络相比,卷积神经网络 (Convolutional Neural Networks, CNN) 架构能够更好地学习图像中的特征。在本节中,我们将学习使用深度卷积生成对抗网络生成图像,在模型中使用卷积和池化操作替换全连接层。
首先,介绍如何使用随机噪声( 100 维向量)生成图像,将噪声形状转换为 batch size x 100 x 1 x 1,其中 batch size 表示批大小,由于在 DCGAN 使用 CNN,因此需要添加额外的通道信息,即 batch size x channel x height x width 的形式,channel 表示通道数,heightwidth 分别表示高度和宽度。
接下来,利用 ConvTranspose2d 将生成的噪声向量转换为图像,ConvTranspose2d 与卷积操作相反,将输入的小特征图通过预定义的核大小、步幅和填充上上采样到较大的尺寸。利用上采样逐渐将向量形状从 batch size x 100 x 1 x 1 转换为 batch size x 3 x 64 x 64,即将 100 维的随机噪声向量转换成一张 64 x 64 的图像。

1.2 数据集介绍

为了训练对抗生成网络,我们需要了解本节所用的数据集,数据集取自 Celeb A,可以自行构建数据集,也可以下载本文所用数据集,下载地址:https://pan.baidu.com/s/1dvDCBLSGwblg57p9RDBEJQ,提取码:y9fiCelebA 是一个大规模的人脸属性数据集,其中包含超过 20 万张名人图像,每张图像有 40 个属性注释。CelebA 数据集的图像来源于互联网上的名人照片,包括电影、音乐和体育界等各个领域。这些图像具有多样的姿势、表情、背景和装扮,涵盖了各种真实世界的场景。

2. 构建 DCGAN 生成人脸图像

接下来,我们使用 PyTorch 构建 DCGAN 模型生成人脸图像。

(1) 下载并获取人脸图像,示例图像如下所示:

示例图像
(2) 导入相关库:

from torchvision import transforms
import torchvision.utils as vutils
import cv2, numpy as np
import torch
import os
from glob import glob
from PIL import Image
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from matplotlib import pyplot as plt
device = "cuda" if torch.cuda.is_available() else "cpu"

(3) 定义数据集和数据加载器。

裁剪图像,只保留面部区域并丢弃图像中的其他部分。首先,使用级联滤波器识别图像中的人脸:

face_cascade = cv2.CascadeClassifier('haarcascade_frontalface_default.xml')

OpenCV 提供了 4 个级联分类器用于人脸检测,可以从 OpenCV 官方下载这些级联分类器文件:

  • haarcascade_frontalface_alt.xml (FA1)
  • haarcascade_frontalface_alt2.xml (FA2)
  • haarcascade_frontalface_alt_tree.xml (FAT)
  • haarcascade_frontalface_default.xml (FD)

可以使用不同的数据集评估这些级联分类器的性能,总的来说这些分类器具有相似的准确率。

创建一个新文件夹,并将所有裁剪后的人脸图像转储到新文件夹中:

if not os.path.exists('cropped_faces'):
    os.mkdir('cropped_faces')

images = glob('male_female_face_images/females/*.jpg')+glob('male_female_face_images/males/*.jpg')
for i in range(len(images)):
    img = cv2.imread(images[i],1)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    faces = face_cascade.detectMultiScale(gray, 1.3, 5)
    for (x,y,w,h) in faces:
        img2 = img[y:(y+h),x:(x+w),:]
    cv2.imwrite('cropped_faces/'+str(i)+'.jpg', img2)

裁剪后的面部示例图像如下:

面部裁剪图像
定义要对每个图像执行的转换:

transform=transforms.Compose([
                               transforms.Resize(64),
                               transforms.CenterCrop(64),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

定义 Faces 数据集类:

class Faces(Dataset):
    def __init__(self, folder):
        super().__init__()
        self.folder = folder
        self.images = sorted(glob(folder))
    def __len__(self):
        return len(self.images)
    def __getitem__(self, ix):
        image_path = self.images[ix]
        image = Image.open(image_path)
        image = transform(image)
        return image

创建数据集对象 ds

ds = Faces(folder='cropped_faces/*.jpg')

定义数据加载器类:

dataloader = DataLoader(ds, batch_size=64, shuffle=True, num_workers=8)

(4) 定义权重初始化函数,使权重的分布较小:

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

(5) 定义判别网络模型类 Discriminator,接收形状为 batch size x 3 x 64 x 64 的图像,并预测输入图像是真实图像还是生成图像:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3,64,4,2,1,bias=False),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64,64*2,4,2,1,bias=False),
            nn.BatchNorm2d(64*2),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64*2,64*4,4,2,1,bias=False),
            nn.BatchNorm2d(64*4),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64*4,64*8,4,2,1,bias=False),
            nn.BatchNorm2d(64*8),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64*8,1,4,1,0,bias=False),
            nn.Sigmoid()
        )
        self.apply(weights_init)
    def forward(self, input):
        return self.model(input)

打印模型的摘要信息:

from torchsummary import summary
discriminator = Discriminator().to(device)
print(summary(discriminator, (3,64,64)))

模型摘要输出结果如下所示:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 32, 32]           3,072
         LeakyReLU-2           [-1, 64, 32, 32]               0
            Conv2d-3          [-1, 128, 16, 16]         131,072
       BatchNorm2d-4          [-1, 128, 16, 16]             256
         LeakyReLU-5          [-1, 128, 16, 16]               0
            Conv2d-6            [-1, 256, 8, 8]         524,288
       BatchNorm2d-7            [-1, 256, 8, 8]             512
         LeakyReLU-8            [-1, 256, 8, 8]               0
            Conv2d-9            [-1, 512, 4, 4]       2,097,152
      BatchNorm2d-10            [-1, 512, 4, 4]           1,024
        LeakyReLU-11            [-1, 512, 4, 4]               0
           Conv2d-12              [-1, 1, 1, 1]           8,192
          Sigmoid-13              [-1, 1, 1, 1]               0
================================================================
Total params: 2,765,568
Trainable params: 2,765,568
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.05
Forward/backward pass size (MB): 2.31
Params size (MB): 10.55
Estimated Total Size (MB): 12.91
----------------------------------------------------------------

(6) 定义生成网络模型类,使用形状为 batch size x 100 x 1 x 1 的输入生成图像:

class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(100,64*8,4,1,0,bias=False,),
            nn.BatchNorm2d(64*8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*8,64*4,4,2,1,bias=False),
            nn.BatchNorm2d(64*4),
            nn.ReLU(True),
            nn.ConvTranspose2d( 64*4,64*2,4,2,1,bias=False),
            nn.BatchNorm2d(64*2),
            nn.ReLU(True),
            nn.ConvTranspose2d( 64*2,64,4,2,1,bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d( 64,3,4,2,1,bias=False),
            nn.Tanh()
        )
        self.apply(weights_init)
    def forward(self,input):
        return self.model(input)

打印模型的摘要信息:

generator = Generator().to(device)
print(summary(generator, (100,1,1)))

代码输出结果如下所示:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
   ConvTranspose2d-1            [-1, 512, 4, 4]         819,200
       BatchNorm2d-2            [-1, 512, 4, 4]           1,024
              ReLU-3            [-1, 512, 4, 4]               0
   ConvTranspose2d-4            [-1, 256, 8, 8]       2,097,152
       BatchNorm2d-5            [-1, 256, 8, 8]             512
              ReLU-6            [-1, 256, 8, 8]               0
   ConvTranspose2d-7          [-1, 128, 16, 16]         524,288
       BatchNorm2d-8          [-1, 128, 16, 16]             256
              ReLU-9          [-1, 128, 16, 16]               0
  ConvTranspose2d-10           [-1, 64, 32, 32]         131,072
      BatchNorm2d-11           [-1, 64, 32, 32]             128
             ReLU-12           [-1, 64, 32, 32]               0
  ConvTranspose2d-13            [-1, 3, 64, 64]           3,072
             Tanh-14            [-1, 3, 64, 64]               0
================================================================
Total params: 3,576,704
Trainable params: 3,576,704
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 3.00
Params size (MB): 13.64
Estimated Total Size (MB): 16.64
----------------------------------------------------------------

(7) 定义训练生成网络 (generator_train_step) 和判别网络 (discriminator_train_step) 的函数:

def discriminator_train_step(real_data, fake_data, loss, d_optimizer):
    d_optimizer.zero_grad()
    prediction_real = discriminator(real_data)
    error_real = loss(prediction_real.squeeze(), torch.ones(len(real_data)).to(device))
    error_real.backward()
    prediction_fake = discriminator(fake_data)
    error_fake = loss(prediction_fake.squeeze(), torch.zeros(len(fake_data)).to(device))
    error_fake.backward()
    d_optimizer.step()
    return error_real + error_fake

def generator_train_step(real_data, fake_data, loss, g_optimizer):
    g_optimizer.zero_grad()
    prediction = discriminator(fake_data)
    error = loss(prediction.squeeze(), torch.ones(len(real_data)).to(device))
    error.backward()
    g_optimizer.step()
    return error

在以上代码中,在判别网络预测结果上执行 .squeeze 操作,因为模型的输出形状为 batch size x 1 x 1 x 1,而预测结果需要与形状为 batch size x 1 的张量进行比较。

(8) 创建生成网络和判别网络模型对象、优化器以及损失函数:

discriminator = Discriminator().to(device)
generator = Generator().to(device)
loss = nn.BCELoss()
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

(9) 训练模型。

加载真实数据 (real_data) 并通过生成网络生成图像 (fake_data):

num_epochs = 100
d_loss_epoch = []
g_loss_epoch = []
for epoch in range(num_epochs):
    N = len(dataloader)
    d_loss_items = []
    g_loss_items = []
    for i, images in enumerate(dataloader):
        real_data = images.to(device)
        fake_data = generator(torch.randn(len(real_data), 100, 1, 1).to(device)).to(device)
        fake_data = fake_data.detach()

原始 GANDCGAN 的主要区别在于,在 DCGAN 模型中,由于使用了 CNN,因此不必展平 real_data

使用 discriminator_train_step 函数训练判别网络:

        d_loss = discriminator_train_step(real_data, fake_data, loss, d_optimizer)

利用噪声数据 (torch.randn(len(real_data))) 生成新图像 (fake_data) 并使用 generator_train_step 函数训练生成网络:

        fake_data = generator(torch.randn(len(real_data), 100, 1, 1).to(device)).to(device)
        g_loss = generator_train_step(real_data, fake_data, loss, g_optimizer)

记录损失变化:

        d_loss_items.append(d_loss.item())
        g_loss_items.append(g_loss.item())
    d_loss_epoch.append(np.average(d_loss_items))
g_loss_epoch.append(np.average(g_loss_items))

(10) 绘制模型训练期间,判别网络和生成网络损失变化情况:

epochs = np.arange(num_epochs)+1
plt.plot(epochs, d_loss_epoch, 'bo', label='Discriminator Training loss')
plt.plot(epochs, g_loss_epoch, 'r-', label='Generator Training loss')
plt.title('Training and Test loss over increasing epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')

损失变化
从上图中可以看出,生成网络和判别网络损失的变化与手写数字生成模型的损失变化模式并不相同,原因如下:

  • 人脸图像的尺寸相比手写数字更大,手写数字图像形状为 28 x 28 x 1,人脸图像形状为 64 x 64 x 3
  • 与人脸图像中的特征相比,手写数字图像中的特征较少
  • 与人脸图像中的信息相比,手写数字图像中仅少数像素中存在可用信息

(11) 训练过程完成后,生成图像样本:

generator.eval()
noise = torch.randn(64, 100, 1, 1, device=device)
sample_images = generator(noise).detach().cpu()
grid = vutils.make_grid(sample_images, nrow=8, normalize=True)
plt.imshow(grid.cpu().detach().permute(1,2,0))
plt.show()

生成图像样本

小结

DCGAN 是优秀的图像生成模型,其生成网路和判别网络都是使用卷积层和反卷积层构建的深度神经网络。生成网络接收一个随机噪声向量作为输入,并通过逐渐减小的反卷积层将其逐渐转化为与训练数据相似的输出图像;判别网络则是一个用于分类真实和生成图像的卷积神经网络。在本节中,我们学习了如何构建并训练 DCGAN 生成人脸图像。

系列链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——使用U-Net架构进行图像分割
PyTorch深度学习实战(24)——从零开始实现Mask R-CNN实例分割
PyTorch深度学习实战(25)——自编码器(Autoencoder)
PyTorch深度学习实战(26)——卷积自编码器(Convolutional Autoencoder)
PyTorch深度学习实战(27)——变分自编码器(Variational Autoencoder, VAE)
PyTorch深度学习实战(28)——对抗攻击(Adversarial Attack)
PyTorch深度学习实战(29)——神经风格迁移
PyTorch深度学习实战(30)——Deepfakes
PyTorch深度学习实战(31)——生成对抗网络(Generative Adversarial Network, GAN)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/346854.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

香港代理IP为何受欢迎?

香港代理IP深受用户欢迎的原因主要有以下几点: 1,地理位置优势:香港位于亚洲的中心地带,与中国大陆、东南亚和其他亚洲国家都有良好的网络连接。这使得使用香港代理IP可以实现较快的网络连接速度和较低的延迟,特别是对…

架构篇24:排除架构可用性隐患的利器-FMEA方法

文章目录 FMEA 介绍FMEA 方法FMEA 实战小结 前面的专栏分析高可用复杂度的时候提出了一个问题:高可用和高性能哪个更复杂,根据墨菲定律“可能出错的事情最终都会出错”,架构隐患总有一天会导致系统故障。因此,我们在进行架构设计的…

shopee的AI学习之路——GPTs通过AdInteli 广告变现

GPTs|AdInteli 广告变现 一、什么是 AdInteli AdIntelli 是一个旨在为生成 GPTs 接入广告并实现变现的平台。它连接了全球最大的广告联盟,允许广告商进行竞价,确保展示最有价值的广告。AdIntelli 采用 AI 驱动的收入生成技术,优化广告选择。…

【github】使用github action 拉取国外docker镜像

使用github action 拉取国外docker镜像 k8s部署经常用到国外镜像,如果本地无法拉取可以考虑使用github action环境 github action的ci服务器在国外,不受中国防火墙影响github action 自带docker命令运行时直接将你仓库代码拉取下来 步骤 你的国内dock…

SAP PO平台配置

多个系统分配 : XPATH : /p1:mt_ERP_ZSSF_HFM_001/sapClient SPACE : p1 http://lstech.com/erp/IF0523/ZSSF_HFM_001

qml与C++的交互

qml端使用C对象类型、qml端调用C函数/c端调用qml端函数、qml端发信号-连接C端槽函数、C端发信号-连接qml端函数等。 代码资源下载: https://download.csdn.net/download/TianYanRen111/88779433 若无法下载,直接拷贝以下代码测试即可。 main.cpp #incl…

HDMI之ALLM

概述 ALLM(Auto Low-latency Mode)即自动低延迟模式,在自动低延迟模式下智能电视的用户不用根据电视播放的内容手动来切换低延迟模式,而会根据电视播放的内容自动启用或者禁用低延迟模式。这里的启用或者禁用低延迟功能通常是信号源设备控制的(如游戏设备 Xbox One,或 PS5…

LIMS源码,实验室信息系统源码,后端框架:asp.net

LIMS(laboratory information management system)即实验室信息管理系统是实验室管理科学发展的成果,是实验室管理科学与现代信息技术结合的产物,是利用计算机网络技术、数据存储技术、快速数据处理技术等,对实验室进行全方位管理的计算机软件…

66.Spring是如何整合MyBatis将Mapper接口注册为Bean的原理?

原理 首先MyBatis的Mapper接口核心是JDK动态代理 Spring会排除接口,无法注册到IOC容器中 MyBatis 实现了BeanDefinitionRegistryPostProcessor 可以动态注册BeanDefinition 需要自定义扫描器(继承Spring内部扫描器ClassPathBeanDefinitionScanner ) 重…

微信小程序如何自定义单选和多选

实现单选 实现效果&#xff1a;点击显示单选状态&#xff0c;每次仅能点击一个元素。 实现方式&#xff1a; wxml&#xff1a; <view wx:for"{{item_list}}" data-info"{{index}}" class"{{menu_indexindex?choose:no_choose}}" bind:ta…

【Linux】shell外壳和权限

文章目录 shell外壳用户切换权限 shell外壳 什么是shell外壳呢&#xff1f;首先我们应该知道&#xff0c;用户和操作系统内核是不能直接接触的&#xff0c;因为首先操作系统本身就很难去操作&#xff0c;另一方面也是为了操作系统安全考虑&#xff0c;不能让用户直接去操作内核…

Matlab图像处理——谷物颗粒计数

针对目前谷物人工计数和光电计数方法存在的不足 , 提出了一种基于 Matlab 图像识别和处理技术的谷物计数方法 , 并用实例验证了其可靠性。该方法减轻了操作者劳动强度 , 弥补了人视觉的不足之处 , 提高了效率及准确率 , 为今后进一步研究奠定了必要的理论与实践基础 , 对完善“…

有趣的数学 了解TensorFlow的自动微分的实现

一、简述 这里主要介绍了TensorFlow的自动微分(autodiff)功能如何工作,以及与其他解决方案的比较。假设您定义了一个函数,并且需要计算它的偏导数和,通常用于执行梯度下降(或某些其他优化算法)。可用的主要选择是手动微分、有限差分近似、正向模式自动微分和反向模式自动…

uniapp微信小程序图片上传功能实现,页面显示文件列表、删除功能

uniapp小程序图片上传功能效果预览 一、template 页面结构 <view class"upload-box"><view class"upload-list"><view class"upload-item" v-for"(item,index) of fileList" :keyindex><image class"img…

SMD NTC Thermistor NTC热敏电阻在锂电池充放电中的作用(NYFEA徕飞)

热敏电阻器&#xff08;Thermistor&#xff09;是一种电阻值对温度极为灵敏的半导体元件&#xff0c;温度系数可分为正温度系数热敏电阻PTC和负温度系数热敏电阻NTC。 NTC热敏电阻用于温度测量&#xff0c;温度控制&#xff0c;温度补偿等&#xff0c;称为温度传感器。 PTC热…

【python】—— 文件操作

目录 &#xff08;一&#xff09;文件是什么 &#xff08;二&#xff09;文件路径 &#xff08;三&#xff09;文件操作 3.1 打开文件 3.2 关闭文件 3.3 写文件 3.4 读文件 &#xff08;四&#xff09;关于中文的处理 &#xff08;五&#xff09;使用上下文管理器 &…

打印机怎么连接到电脑?完整指南助你顺利连接

随着科技的不断发展&#xff0c;打印机作为一种常见的办公设备&#xff0c;已经成为我们日常工作不可或缺的一部分。可是打印机怎么连接到电脑呢&#xff1f;本文将介绍三种常见的方法&#xff0c;详细解释如何将打印机连接到电脑&#xff0c;以便用户在面对这一操作时能够迅速…

【软件测试】学习笔记-Nginx 在系统架构中的作用

本篇文章你探讨 Nginx 在应用架构中的作用&#xff0c;并从性能测试角度看如何利用 Nginx 数据统计用户访问量。 Nginx 重要的两个概念 代理 首先要来解释一下什么是代理&#xff0c;正向代理和反向代理是什么意思&#xff1f;各自作用是什么&#xff1f;不少同学经常听到这…

[BSidesCF 2020]Had a bad day

先看url&#xff0c;发现可能有注入 http://655c742e-b427-485c-9e15-20a1e7ef1717.node5.buuoj.cn:81/index.php?categorywoofers 试试能不能查看index.php直接?categoryindex.php不行&#xff0c;试试伪协议 把.php去掉试试 base64解码 <?php$file $_GET[category];…