生成对抗网络(GAN)手写数字生成

文章目录

  • 一、前言
  • 二、前期工作
    • 1. 设置GPU(如果使用的是CPU可以忽略这步)
  • 二、什么是生成对抗网络
    • 1. 简单介绍
    • 2. 应用领域
  • 三、网络结构
  • 四、构建生成器
  • 五、构建鉴别器
  • 六、训练模型
    • 1. 保存样例图片
    • 2. 训练模型
  • 七、生成动图

一、前言

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

往期精彩内容:

  • 卷积神经网络(CNN)实现mnist手写数字识别
  • 卷积神经网络(CNN)多种图片分类的实现
  • 卷积神经网络(CNN)衣服图像分类的实现
  • 卷积神经网络(CNN)鲜花识别
  • 卷积神经网络(CNN)天气识别
  • 卷积神经网络(VGG-16)识别海贼王草帽一伙
  • 卷积神经网络(ResNet-50)鸟类识别
  • 卷积神经网络(AlexNet)鸟类识别
  • 卷积神经网络(CNN)识别验证码
  • 卷积神经网络(Inception-ResNet-v2)交通标志识别

来自专栏:机器学习与深度学习算法推荐

二、前期工作

1. 设置GPU(如果使用的是CPU可以忽略这步)

import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")
    
# 打印显卡信息,确认GPU可用
print(gpus)
from tensorflow.keras import layers, datasets, Sequential, Model, optimizers
from tensorflow.keras.layers import LeakyReLU, UpSampling2D, Conv2D

import matplotlib.pyplot as plt
import numpy             as np
import sys,os,pathlib
img_shape  = (28, 28, 1)
latent_dim = 200

二、什么是生成对抗网络

1. 简单介绍

生成对抗网络(GAN) 包含生成器和判别器,两个模型通过对抗训练不断学习、进化。

  • 生成器(Generator):生成数据(大部分情况下是图像),目的是“骗过”判别器。
  • 鉴别器(Discriminator):判断这张图像是真实的还是机器生成的,目的是找出生成器生成的“假数据”。

2. 应用领域

GAN 的应用十分广泛,它的应用包括图像合成、风格迁移、照片修复以及照片编辑,数据增强等等。

1)风格迁移

图像风格迁移是将图像A的风格转换到图像B中去,得到新的图像。

2)图像生成

GAN 不但能生成人脸,还能生成其他类型的图片,比如漫画人物。

三、网络结构

简单来讲,就是用生成器生成手写数字图像,用鉴别器鉴别图像的真假。二者相互对抗学习(卷),在对抗学习(卷)的过程中不断完善自己,直至生成器可以生成以假乱真的图片(鉴别器无法判断其真假)。结构图如下:

在这里插入图片描述

GAN步骤:

  • 1.生成器(Generator)接收随机数并返回生成图像。
  • 2.将生成的数字图像与实际数据集中的数字图像一起送到鉴别器(Discriminator)。
  • 3.鉴别器(Discriminator)接收真实和假图像并返回概率,0到1之间的数字,1表示真,0表示假。

四、构建生成器

def build_generator():
    # ======================================= #
    #     生成器,输入一串随机数字生成图片
    # ======================================= #
    model = Sequential([
        layers.Dense(256, input_dim=latent_dim),
        layers.LeakyReLU(alpha=0.2),               # 高级一点的激活函数
        layers.BatchNormalization(momentum=0.8),   # BN 归一化
        
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        
        layers.Dense(1024),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        
        layers.Dense(np.prod(img_shape), activation='tanh'),
        layers.Reshape(img_shape)
    ])

    noise = layers.Input(shape=(latent_dim,))
    img = model(noise)

    return Model(noise, img)

五、构建鉴别器

def build_discriminator():
    # ===================================== #
    #   鉴别器,对输入的图片进行判别真假
    # ===================================== #
    model = Sequential([
        layers.Flatten(input_shape=img_shape),
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(256),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(1, activation='sigmoid')
    ])

    img = layers.Input(shape=img_shape)
    validity = model(img)

    return Model(img, validity)
# 创建判别器
discriminator = build_discriminator()
# 定义优化器
optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator.compile(loss='binary_crossentropy',
                      optimizer=optimizer,
                      metrics=['accuracy'])

# 创建生成器 
generator = build_generator()
gan_input = layers.Input(shape=(latent_dim,))
img = generator(gan_input)

# 对生成的假图片进行预测
validity = discriminator(img)
combined = Model(gan_input, validity)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

六、训练模型

1. 保存样例图片

def sample_images(epoch):
    """
    保存样例图片
    """
    row, col = 4, 4
    noise = np.random.normal(0, 1, (row*col, latent_dim))
    gen_imgs = generator.predict(noise)

    fig, axs = plt.subplots(row, col)
    cnt = 0
    for i in range(row):
        for j in range(col):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/%05d.png" % epoch)
    plt.close()

2. 训练模型

train_on_batch:函数接受单批数据,执行反向传播,然后更新模型参数,该批数据的大小可以是任意的,即,它不需要提供明确的批量大小,属于精细化控制训练模型。

def train(epochs, batch_size=128, sample_interval=50):
    # 加载数据
    (train_images,_), (_,_) = tf.keras.datasets.mnist.load_data()

    # 将图片标准化到 [-1, 1] 区间内   
    train_images = (train_images - 127.5) / 127.5
    # 数据
    train_images = np.expand_dims(train_images, axis=3)

    # 创建标签
    true = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    
    # 进行循环训练
    for epoch in range(epochs): 

        # 随机选择 batch_size 张图片
        idx = np.random.randint(0, train_images.shape[0], batch_size)
        imgs = train_images[idx]      
        
        # 生成噪音
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        # 生成器通过噪音生成图片,gen_imgs的shape为:(128, 28, 28, 1)
        gen_imgs = generator.predict(noise)
        
        # 训练鉴别器 
        d_loss_true = discriminator.train_on_batch(imgs, true)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        # 返回loss值
        d_loss = 0.5 * np.add(d_loss_true, d_loss_fake)

        # 训练生成器
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        g_loss = combined.train_on_batch(noise, true)
        
        print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

        # 保存样例图片
        if epoch % sample_interval == 0:
            sample_images(epoch)
train(epochs=30000, batch_size=256, sample_interval=200)

七、生成动图

如果报错:ModuleNotFoundError: No module named 'imageio' 可以使用:pip install imageio 安装 imageio 库。

import imageio

def compose_gif():
    # 图片地址
    data_dir = "images_old"
    data_dir = pathlib.Path(data_dir)
    paths    = list(data_dir.glob('*'))
    
    gif_images = []
    for path in paths:
        print(path)
        gif_images.append(imageio.imread(path))
    imageio.mimsave("test.gif",gif_images,fps=2)
    
compose_gif()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/206675.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

《合成孔径雷达成像算法与实现》_使用CS算法对RADARSAT-1数据进行成像

CSA 简介:Chirp Scaling 算法 (简称 CS 算法,即 CSA) 避免了 RCMC 中的插值操作。该算法基于 Scaling 原理,通过对 chirp 信号进行频率调制,实现了对信号的尺度变换或平移。基于这种原理,可以通过相位相乘代替时域插值…

Tomcat安装及配置教程

Tomcat安装及配置教程 Tomcat是Apache软件基金会(Apache Software Foundation)的Jakarta 项目中的一个核心项目,由Apache、Sun和其他一些公司及个人共同开发而成。 Tomcat服务器是一个免费的开放源代码的Web应用服务器,属于轻量…

Spring Security 的使用

一、简介 1.1、Spring Security 相关概念 1.过滤器链(Filter Chain) 基于Servlet过滤器(Filter)处理和拦截请求,进行身份验证、授权等安全操作。过滤器链按顺序执行,每个过滤器负责一个具体的安全功能。 …

什么是网络可视化?网络可视化工具有用吗

网络可视化定义是自我描述的,因为它在单个屏幕上重新创建网络布局,以图形和图表的形式显示有关网络设备、网络指标和数据流的信息,为 IT 运营团队提供一目了然的理解和决策。 网络是复杂的实体,倾向于持续进化,随着业…

利用MCMC 获得泊松分布

写出概率流方程如下 if state 0: if np.random.random() < min([Lambda/2, 1]):state 1else:passelif state 1:if choose_prob_state[i] < 0.5:#选择 1 -> 0&#xff0c;此时的接受概率为min[2/Lambda, 1]if np.random.random() < min([2/Lambda, 1]…

STM32USART+DMA实现不定长数据接收/发送

STM32USARTDMA实现不定长数据接收 CubeMX配置代码分享实践结果 这一期的内容是一篇代码分享&#xff0c;CubeMX配置介绍&#xff0c;关于基础的内容可以往期内容 夜深人静学32系列11——串口通信夜深人静学32系列18——DMAADC单/多通道采集STM32串口重定向/实现不定长数据接收 …

『PyTorch学习笔记』分布式深度学习训练中的数据并行(DP/DDP) VS 模型并行

分布式深度学习训练中的数据并行(DP/DDP) VS 模型并行 文章目录 一. 介绍二. 并行数据加载2.1. 加载数据步骤2.2. PyTorch 1.0 中的数据加载器(Dataloader) 二. 数据并行2.1. DP(DataParallel)的基本原理2.1.1. 从流程上理解2.1.2. 从模式角度理解2.1.3. 从操作系统角度看2.1.…

【ESP32】手势识别实现笔记:红外温度阵列 | 双三次插值 | 神经网络 | TensorFlow | ESP-DL

目录 一、开发环境搭建与新建工程模板1.1、开发环境搭建与卸载1.2、新建工程目录1.3、自定义组件 二、驱动移植与应用开发2.1、I2C驱动移植与AMG8833应用开发2.2、SPI驱动移植与LCD应用开发2.3、绘制温度云图2.4、启用PSRAM&#xff08;可选&#xff09;2.5、画面动静和距离检测…

力扣:1419. 数青蛙

题目&#xff1a; 代码&#xff1a; class Solution { public:int minNumberOfFrogs(string croakOfFrogs){string s "croak";int ns.size();//首先创建一个哈希表来标明每个元素出现的次数&#xff01;vector<int>hash(n); //不用真的创建一个hash表用一个数…

事务--02---TCC模式

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 TCC模式两阶段提交 的模型 1.流程分析阶段一&#xff08; Try &#xff09;&#xff1a;阶段二&#xff08;Confirm)&#xff1a;阶段二(Canncel)&#xff1a; 2.事…

java编程:⼀个⽂件中存储了本站点下各路径被访问的次数,请编程找出被访问次数最多的10个路径

题目 编程题&#xff1a;⼀个⽂件&#xff08;url_path_statistics.txt&#xff09;中存储了本站点下各路径被访问的次数 请编程找出被访问次数最多的10个路径时间复杂是多少&#xff0c;是否可以优化&#xff08;假设路径数量为n&#xff09;如果路径访问次数⽂件很⼤&#x…

unity3d模型中缺失animation

在 模型的Rig-Animationtype 设置成Legacy https://tieba.baidu.com/p/2293580178

解决WPS拖动整行的操作

如上图&#xff0c;想要把第4行的整行内容&#xff0c;平移到第1行。 1.选中第4行的整行 2.鼠标出现如图的样子时&#xff0c;按住鼠标左键&#xff0c;上移到第1行位置后&#xff0c;放开左键即可。

vue项目和wx小程序

wx:key 的值以两种形式提供&#xff1a; 1、字符串&#xff0c;代表在 for 循环的 array 中 item 的某个 property&#xff0c;该 property 的值需要是列表中唯一的字符串或数字&#xff0c;且不能动态改变。 2、保留关键字 this 代表在 for 循环中的 item 本身&#xff0c;这种…

测试与管理 Quota

用myquota1创建一个大的文件测试 理论猜想&#xff1a;超过soft可以&#xff0c;但是超过hard就不行了&#xff0c;最大值就是hard&#xff0c;如果超过soft&#xff0c;过了17天不处理&#xff0c;最后限制值会被强制设置成soft。修改设置成hard值 切换测试用户&#xff0c;m…

易宝OA ExecuteSqlForSingle SQL注入漏洞复现

0x01 产品简介 易宝OA系统是一种专门为企业和机构的日常办公工作提供服务的综合性软件平台&#xff0c;具有信息管理、 流程管理 、知识管理&#xff08;档案和业务管理&#xff09;、协同办公等多种功能。 0x02 漏洞概述 易宝OA ExecuteSqlForSingle接口处存在SQL注入漏洞&a…

苹果TF签名全称TestFlight签名,需要怎么做才可以上架呢?

如果你正在开发一个iOS应用并准备进行内测&#xff0c;TestFlight是苹果提供的一个免费的解决方案&#xff0c;它使开发者可以邀请用户参加应用的测试。以下是一步步的指南&#xff0c;教你如何利用TestFlight进行内测以便于应用后续可以顺利上架App Store。 1: 准备工作 在测…

项目设计---网页五子棋

文章目录 一. 项目描述二. 核心技术三. 需求分析概要设计四. 详细设计4.1 实现用户模块4.1.1 约定前后端交互接口4.1.2 实现数据库设计4.1.3 客户端页面展示4.1.4 服务器功能实现 4.2 实现匹配模块4.2.1 约定前后端交互接口4.2.2 客户端页面展示4.2.3 服务器功能实现 4.3 实现对…

2023年计网408

第33题 33.在下图所示的分组交换网络中&#xff0c;主机H1和H2通过路由器互连&#xff0c;2段链路的带宽均为100Mbps、 时延带宽积(即单向传播时延带宽)均为1000bits。若 H1向 H2发送1个大小为 1MB的文件&#xff0c;分组长度为1000B&#xff0c;则从H1开始发送时刻起到H2收到…

Windows系列:windows2003-建立域

windows2003-建立域 Active Directory建立DNS建立域查看日志xp 加入域 Active Directory 活动目录是一个包括文件、打印机、应用程序、服务器、域、用户账户等对象的数据库。 常见概念&#xff1a;对象、属性、容器 域组件&#xff08;Domain Component&#xff0c;DC&#x…