生成式AI系列 —— DCGAN生成手写数字

1、模型构建

1.1 构建生成器

# 导入软件包
import torch
import torch.nn as nn

class Generator(nn.Module):

    def __init__(self, z_dim=20, image_size=256):
        super(Generator, self).__init__()

        self.layer1 = nn.Sequential(
            nn.ConvTranspose2d(z_dim, image_size * 32,
                               kernel_size=4, stride=1),
            nn.BatchNorm2d(image_size * 32),
            nn.ReLU(inplace=True))

        self.layer2 = nn.Sequential(
            nn.ConvTranspose2d(image_size * 32, image_size * 16,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(image_size * 16),
            nn.ReLU(inplace=True))

        self.layer3 = nn.Sequential(
            nn.ConvTranspose2d(image_size * 16, image_size * 8,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(image_size * 8),
            nn.ReLU(inplace=True))

        self.layer4 = nn.Sequential(
            nn.ConvTranspose2d(image_size * 8, image_size *4,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(image_size * 4),
            nn.ReLU(inplace=True))
        self.layer5 = nn.Sequential(
            nn.ConvTranspose2d(image_size * 4, image_size * 2,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(image_size * 2),
            nn.ReLU(inplace=True))
        self.layer6 = nn.Sequential(
            nn.ConvTranspose2d(image_size * 2, image_size,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(image_size),
            nn.ReLU(inplace=True))
        self.last = nn.Sequential(
            nn.ConvTranspose2d(image_size, 3, kernel_size=4,
                               stride=2, padding=1),
            nn.Tanh())
        # 注意:因为是黑白图像,所以只有一个输出通道

    def forward(self, z):
        out = self.layer1(z)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = self.layer6(out)
        out = self.last(out)

        return out
    
if __name__ == "__main__":
    import matplotlib.pyplot as plt

    G = Generator(z_dim=20, image_size=256)

    # 输入的随机数
    input_z = torch.randn(1, 20)

    # 将张量尺寸变形为(1,20,1,1)
    input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)

    #输出假图像
    fake_images = G(input_z)
    print(fake_images.shape)
    img_transformed = fake_images[0].detach().numpy().transpose(1, 2, 0)
    plt.imshow(img_transformed)
    plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-M0oWDbXr-1692468683782)(E:\学习笔记\深度学习笔记\生成模型\GAN\DCGAN.assets\Figure_1.png)]

1.1 构建判别器

class Discriminator(nn.Module):

    def __init__(self, z_dim=20, image_size=256):
        super(Discriminator, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(3, image_size, kernel_size=4,
                      stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True))
       #注意:因为是黑白图像,所以输入通道只有一个

        self.layer2 = nn.Sequential(
            nn.Conv2d(image_size, image_size*2, kernel_size=4,
                      stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True))

        self.layer3 = nn.Sequential(
            nn.Conv2d(image_size*2, image_size*4, kernel_size=4,
                      stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True))

        self.layer4 = nn.Sequential(
            nn.Conv2d(image_size*4, image_size*8, kernel_size=4,
                      stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True))
        
        self.layer5 = nn.Sequential(
            nn.Conv2d(image_size*8, image_size*16, kernel_size=4,
                      stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True))
        
        self.layer6 = nn.Sequential(
            nn.Conv2d(image_size*16, image_size*32, kernel_size=4,
                      stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True))
        
        self.last = nn.Conv2d(image_size*32, 1, kernel_size=4, stride=1)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = self.layer6(out)
        out = self.last(out)

        return out

    
if __name__ == "__main__":
    #确认程序执行
    D = Discriminator(z_dim=20, image_size=64)

    #生成伪造图像
    input_z = torch.randn(1, 20)
    input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)
    fake_images = G(input_z)

    #将伪造的图像输入判别器D中
    d_out = D(fake_images)

    #将输出值d_out乘以Sigmoid函数,将其转换成0~1的值
    print(torch.sigmoid(d_out))

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SEGvF8xh-1692468683784)(E:\学习笔记\深度学习笔记\生成模型\GAN\DCGAN.assets\image-20230817224333376.png)]

2、数据集构建

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".")))
import time
from PIL import Image
import torch
import torch.utils.data as data
import torch.nn as nn

from torchvision import transforms
from model.DCGAN import Generator, Discriminator
from matplotlib import pyplot as plt


def make_datapath_list(root):
    """创建用于学习和验证的图像数据及标注数据的文件路径列表。 """

    train_img_list = list() #保存图像文件的路径

    for img_idx in range(200):
        img_path = f"{root}/img_7_{str(img_idx)}.jpg"
        train_img_list.append(img_path)

        img_path = f"{root}/img_8_{str(img_idx)}.jpg"
        train_img_list.append(img_path)

    return train_img_list


class ImageTransform:
    """图像的预处理类"""

    def __init__(self, mean, std):
        self.data_transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize(mean, std)]
        )

    def __call__(self, img):
        return self.data_transform(img)


class GAN_Img_Dataset(data.Dataset):
    """图像的 Dataset 类,继承自 PyTorchd 的 Dataset 类"""

    def __init__(self, file_list, transform):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        '''返回图像的张数'''
        return len(self.file_list)

    def __getitem__(self, index):
        '''获取经过预处理后的图像的张量格式的数据'''

        img_path = self.file_list[index]
        img = Image.open(img_path)  # [ 高度 ][ 宽度 ] 黑白

        # 图像的预处理
        img_transformed = self.transform(img)

        return img_transformed


# 创建DataLoader并确认执行结果

# 创建文件列表
root = "./img_78"
train_img_list = make_datapath_list(root)

# 创建Dataset
mean = (0.5)
std = (0.5)
train_dataset = GAN_Img_Dataset(
    file_list=train_img_list, transform=ImageTransform(mean, std)
)

# 创建DataLoader
batch_size = 2
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)

# 确认执行结果
batch_iterator = iter(train_dataloader)  # 转换为迭代器
imges = next(batch_iterator)  # 取出位于第一位的元素
print(imges.size())  # torch.Size([64, 1, 64, 64])

数据请在访问链接获取:

3、train接口实现


def train_model(G, D, dataloader, num_epochs):
    # 确认是否能够使用GPU加速
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("使用设备:", device)

    # 设置最优化算法
    g_lr, d_lr = 0.0001, 0.0004
    beta1, beta2 = 0.0, 0.9
    g_optimizer = torch.optim.Adam(G.parameters(), g_lr, [beta1, beta2])
    d_optimizer = torch.optim.Adam(D.parameters(), d_lr, [beta1, beta2])

    # 定义误差函数
    criterion = nn.BCEWithLogitsLoss(reduction='mean')

    # 使用硬编码的参数
    z_dim = 20
    mini_batch_size = 8

    # 将网络载入GPU中
    G.to(device)
    D.to(device)

    G.train()  # 将模式设置为训练模式
    D.train()  # 将模式设置为训练模式

    # 如果网络相对固定,则开启加速
    torch.backends.cudnn.benchmark = True

    # 图像张数
    num_train_imgs = len(dataloader.dataset)
    batch_size = dataloader.batch_size

    # 设置迭代计数器
    iteration = 1
    logs = []

    # epoch循环
    for epoch in range(num_epochs):
        # 保存开始时间
        t_epoch_start = time.time()
        epoch_g_loss = 0.0  # epoch的损失总和
        epoch_d_loss = 0.0  # epoch的损失总和

        print('-------------')
        print('Epoch {}/{}'.format(epoch, num_epochs))
        print('-------------')
        print('(train)')

        # 以minibatch为单位从数据加载器中读取数据的循环
        for imges in dataloader:
            # --------------------
            # 1.判别器D的学习
            # --------------------
            # 如果小批次的尺寸设置为1,会导致批次归一化处理产生错误,因此需要避免
            if imges.size()[0] == 1:
                continue

            # 如果能使用GPU,则将数据送入GPU中
            imges = imges.to(device)

            # 创建正确答案标签和伪造数据标签
            # 在epoch最后的迭代中,小批次的数量会减少
            mini_batch_size = imges.size()[0]
            label_real = torch.full((mini_batch_size,), 1).to(device)
            label_fake = torch.full((mini_batch_size,), 0).to(device)

            # 对真正的图像进行判定
            d_out_real = D(imges)

            # 生成伪造图像并进行判定
            input_z = torch.randn(mini_batch_size, z_dim).to(device)
            input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)
            fake_images = G(input_z)
            d_out_fake = D(fake_images)

            # 计算误差
            d_loss_real = criterion(d_out_real.view(-1), label_real.to(torch.float))
            d_loss_fake = criterion(d_out_fake.view(-1), label_fake.to(torch.float))
            d_loss = d_loss_real + d_loss_fake

            # 反向传播处理
            g_optimizer.zero_grad()
            d_optimizer.zero_grad()

            d_loss.backward()
            d_optimizer.step()

            # --------------------
            # 2.生成器G的学习
            # --------------------
            # 生成伪造图像并进行判定
            input_z = torch.randn(mini_batch_size, z_dim).to(device)
            input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)
            fake_images = G(input_z)
            d_out_fake = D(fake_images)

            # 计算误差
            g_loss = criterion(d_out_fake.view(-1), label_real.to(torch.float))

            # 反向传播处理
            g_optimizer.zero_grad()
            d_optimizer.zero_grad()
            g_loss.backward()
            g_optimizer.step()

            # --------------------
            # 3.记录结果
            # --------------------
            epoch_d_loss += d_loss.item()
            epoch_g_loss += g_loss.item()
            iteration += 1

        # epoch的每个phase的loss和准确率
        t_epoch_finish = time.time()
        print('-------------')
        print(
            'epoch {} || Epoch_D_Loss:{:.4f} ||Epoch_G_Loss:{:.4f}'.format(
                epoch, epoch_d_loss / batch_size, epoch_g_loss / batch_size
            )
        )
        print('timer:  {:.4f} sec.'.format(t_epoch_finish - t_epoch_start))
        t_epoch_start = time.time()

    return G, D

4、训练


G = Generator(z_dim=20, image_size=64)
D = Discriminator(z_dim=20, image_size=64)
# 定义误差函数
criterion = nn.BCEWithLogitsLoss(reduction='mean')
num_epochs = 200
G_update, D_update = train_model(
    G, D, dataloader=train_dataloader, num_epochs=num_epochs
)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-EvqY6h3G-1692468683786)(E:\学习笔记\深度学习笔记\生成模型\GAN\DCGAN.assets\image-20230820020234209.png)]

5、测试

# 将生成的图像和训练数据可视化
# 反复执行本单元中的代码,直到生成感觉良好的图像为止

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 生成用于输入的随机数
batch_size = 8
z_dim = 20
fixed_z = torch.randn(batch_size, z_dim)
fixed_z = fixed_z.view(fixed_z.size(0), fixed_z.size(1), 1, 1)

# 生成图像
fake_images = G_update(fixed_z.to(device))

# 训练数据
imges = next(iter(train_dataloader))  # 取出位于第一位的元素

# 输出结果
fig = plt.figure(figsize=(15, 6))
for i in range(0, 5):
    # 将训练数据放入上层
    plt.subplot(2, 5, i + 1)
    plt.imshow(imges[i][0].cpu().detach().numpy(), 'gray')

    # 将生成数据放入下层
    plt.subplot(2, 5, 5 + i + 1)
    plt.imshow(fake_images[i][0].cpu().detach().numpy(), 'gray')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/84939.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

「Vue|网页开发|前端开发」01 快速入门:用vue-cli快速写一个Vue的HelloWorld项目

本文主要介绍如何用vue开发的标准化工具vue-cli快速搭建一个符合实际业务项目结构的hello world网页项目并理解vue的代码文件结构以及页面渲染流程。 文章目录 一、准备工作:安装node.js二、项目搭建创建项目目录全局安装vue-cli使用Webpack初始化项目启动项目学会…

SpringBoot---内置Tomcat 配置和切换

😀前言 本篇博文是关于内置Tomcat 配置和切换,希望你能够喜欢 🏠个人主页:晨犀主页 🧑个人简介:大家好,我是晨犀,希望我的文章可以帮助到大家,您的满意是我的动力&#x…

SVM详解

公式太多了,就用图片用笔记呈现,SVM虽然算法本质一目了然,但其中用到的数学推导还是挺多的,其中拉格朗日约束关于α>0这块证明我看了很长时间,到底是因为悟性不够。对偶问题也是,用了一个简单的例子才明…

React请求机制优化思路 | 京东云技术团队

说起数据加载的机制,有一个绕不开的话题就是前端性能,很多电商门户的首页其实都会做一些垂直的定制优化,比如让请求在页面最早加载,或者在前一个页面就进行预加载等等。随着react18的发布,请求机制这一块也是被不断谈起…

3 个 ChatGPT 插件您需要立即下载3 ChatGPT Extensions You need to Download Immediately

在16世纪,西班牙探险家皮萨罗带领约200名西班牙士兵和37匹马进入了印加帝国。尽管印加帝国的军队数量达到了数万,其中包括5,000名精锐步兵和3,000名弓箭手,他们装备有大刀、长矛和弓箭等传统武器。但皮萨罗的军队中有100名火枪手,…

eDP接口的PCB布局布线要求

eDP接口是一种基于DisplayPort架构和协议的一种全数字化接口,传递高分辨率信号只需要较简单的连接器以及较少的引脚就可以实现,同时还能够实现多数据同时传输。 图1 EDP接口 eDP接口的PCB设计布局布线注意事项: 1、远离干扰源,防…

数据分析15——office中的Excel基础技术汇总

0、前言: 这部分总结就是总结每个基础技术的定义,在了解基础技术名称和定义后,方便对相关技术进行检索学习。笔记不会详细到所有操作都说明,但会把基础操作的名称及作用说明,可自行检索。本文对于大部分读者有以下作用…

html动态爱心代码【一】(附源码)

前言 七夕马上就要到了,为了帮助大家高效表白,下面再给大家带来了实用的HTML浪漫表白代码(附源码)背景音乐,可用于520,情人节,生日,表白等场景,可直接使用。 效果演示 文案修改 var loverNam…

windows下, artemis学习

1. download artemis from apache ActiveMQhttps://activemq.apache.org/components/artemis/download/2. 解压缩到 C:/software/apache-artemis-2.30.0/ 2. 进入到cmd, 执行 C:\software\apache-artemis-2.30.0\bin>artemis create C:/software/apache-artem…

数据结构—树表的查找

7.3树表的查找 ​ 当表插入、删除操作频繁时,为维护表的有序表,需要移动表中很多记录。 ​ 改用动态查找表——几种特殊的树 ​ 表结构在查找过程中动态生成 ​ 对于给定值key ​ 若表中存在,则成功返回; ​ 否则&#xff0…

[C++ 网络协议编程] 域名及网络地址

1. DNS服务器 DNS(Domain Name System):是对IP地址和域名(如:www.baidu.com等)进行相互转换的系统,其核心是DNS服务器。 我们输入的www.baidu.com是域名,是一种虚拟地址,而非实际地…

JVM详解

文章目录 一、JVM 执行流程二、类加载三、双亲委派模型四、垃圾回收机制(GC) 一、JVM 执行流程 程序在执行之前先要把java代码转换成字节码(class文件),JVM 首先需要把字节码通过一定的方式 类加载器(Clas…

【nodejs】用Node.js实现简单的壁纸网站爬虫

1. 简介 在这个博客中,我们将学习如何使用Node.js编写一个简单的爬虫来从壁纸网站获取图片并将其下载到本地。我们将使用Axios和Cheerio库来处理HTTP请求和HTML解析。 2. 设置项目 首先,确保你已经安装了Node.js环境。然后,我们将创建一个…

【前端vue升级】vue2+js+elementUI升级为vue3+ts+elementUI plus

一、工具的选择 近期想将vuejselementUI的项目升级为vue3tselementUI plus,以获得更好的开发体验,并且vue3也显著提高了性能,所以在此记录一下升级的过程对于一个正在使用的项目手工替换肯定不是个可实现的解决方案,更优方案是基于…

大数据 算法

什么是大数据 大数据是指数据量巨大、类型繁多、处理速度快的数据集合。这些数据集合通常包括结构化数据(如数据库中的表格数据)、半结构化数据(如XML文件)和非结构化数据(如文本、音频和视频文件)。大数据…

国内常见的几款可视化Web组态软件

组态软件是一种用于控制和监控各种设备的软件,也是指在自动控制系统监控层一级的软件平台和开发环境。这类软件实际上也是一种通过灵活的组态方式,为用户提供快速构建工业自动控制系统监控功能的、通用层次的软件工具。通常用于工业控制,自动…

python 打印人口分布金字塔图

背景 今天介绍一个不使用 matplot,通过DebugInfo模块打印人口金字塔图的方法。 引入模块 pip install DebugInfo打印人口金字塔图 下面的代码构建了两个人口数据(仅做功能演示,不承诺任何参考价值),男性人口和女性…

FirmAE 工具安装(解决克隆失败 网络问题解决)

FirmAE官方推荐使用Ubuntu 18.04系统进行安装部署,FirmAE工具的安装部署十分简单,只需要拉取工具仓库后执行安装脚本即可。 首先运行git clone --recursive https://kgithub.com/pr0v3rbs/FirmAE命令 拉取FirmAE工具仓库,因为网络的问题&…

交叉熵--损失函数

目录 交叉熵(Cross Entropy) 【预备知识】 【信息量】 【信息熵】 【相对熵】 【交叉熵】 交叉熵(Cross Entropy) 是Shannon信息论中一个重要概念, 主要用于度量两个概率分布间的差异性信息。 语言模型的性能…

Java之继承详解二

3.7 方法重写 3.7.1 概念 方法重写 :子类中出现与父类一模一样的方法时(返回值类型,方法名和参数列表都相同),会出现覆盖效果,也称为重写或者复写。声明不变,重新实现。 3.7.2 使用场景与案例…
最新文章