【图像合成】基于DCGAN典型网络的MNIST字符生成(pytorch)

关于

 

近年来,基于卷积网络(CNN)的监督学习已经 在计算机视觉应用中得到了广泛的采用。相比之下,无监督 使用 CNN 进行学习受到的关注较少。在这项工作中,我们希望能有所帮助 缩小了 CNN 在监督学习和无监督学习方面的成功之间的差距。我们介绍一类称为深度卷积生成的 CNN 对抗性网络(DCGAN),具有一定的架构限制,以及 证明他们是无监督学习的有力候选人。训练 在各种图像数据集上,我们展示了令人信服的证据,表明我们的深度卷积对抗对学习了从对象部分到 生成器和鉴别器中的场景。此外,我们使用学到的 新任务的特征 - 证明它们作为一般图像表示的适用性。(https://arxiv.org/pdf/1511.06434.pdf)

工具

 数据集

方法实现

加载必要的库函数和自定义函数

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F


from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
def get_sample_image(G, n_noise):
    """
        save sample 100 images
    """
    z = torch.randn(100, n_noise).to(DEVICE)
    y_hat = G(z).view(100, 28, 28) # (100, 28, 28)
    result = y_hat.cpu().data.numpy()
    img = np.zeros([280, 280])
    for j in range(10):
        img[j*28:(j+1)*28] = np.concatenate([x for x in result[j*10:(j+1)*10]], axis=-1)
    return img

定义判别模型

class Discriminator(nn.Module):
    """
        Convolutional Discriminator for MNIST
    """
    def __init__(self, in_channel=1, num_classes=1):
        super(Discriminator, self).__init__()
        self.conv = nn.Sequential(
            # 28 -> 14
            nn.Conv2d(in_channel, 512, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            # 14 -> 7
            nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            # 7 -> 4
            nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.AvgPool2d(4),
        )
        self.fc = nn.Sequential(
            # reshape input, 128 -> 1
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )
    
    def forward(self, x, y=None):
        y_ = self.conv(x)
        y_ = y_.view(y_.size(0), -1)
        y_ = self.fc(y_)
        return y_

定义生成模型

class Generator(nn.Module):
    """
        Convolutional Generator for MNIST
    """
    def __init__(self, input_size=100, num_classes=784):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, 4*4*512),
            nn.ReLU(),
        )
        self.conv = nn.Sequential(
            # input: 4 by 4, output: 7 by 7
            nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            # input: 7 by 7, output: 14 by 14
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            # input: 14 by 14, output: 28 by 28
            nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1, bias=False),
            nn.Tanh(),
        )
        
    def forward(self, x, y=None):
        x = x.view(x.size(0), -1)
        y_ = self.fc(x)
        y_ = y_.view(y_.size(0), 512, 4, 4)
        y_ = self.conv(y_)
        return y_

 模型超参数定义配置

batch_size = 64

criterion = nn.BCELoss()
D_opt = torch.optim.Adam(D.parameters(), lr=0.001, betas=(0.5, 0.999))
G_opt = torch.optim.Adam(G.parameters(), lr=0.001, betas=(0.5, 0.999))

max_epoch = 30 # need more than 20 epochs for training generator
step = 0
n_critic = 1 # for training more k steps about Discriminator
n_noise = 100

D_labels = torch.ones([batch_size, 1]).to(DEVICE) # Discriminator Label to real
D_fakes = torch.zeros([batch_size, 1]).to(DEVICE) # Discriminator Label to fake

 模型训练

for epoch in range(max_epoch):
    for idx, (images, labels) in enumerate(data_loader):
        # Training Discriminator
        x = images.to(DEVICE)
        x_outputs = D(x)
        D_x_loss = criterion(x_outputs, D_labels)

        z = torch.randn(batch_size, n_noise).to(DEVICE)
        z_outputs = D(G(z))
        D_z_loss = criterion(z_outputs, D_fakes)
        D_loss = D_x_loss + D_z_loss
        
        D.zero_grad()
        D_loss.backward()
        D_opt.step()

        if step % n_critic == 0:
            # Training Generator
            z = torch.randn(batch_size, n_noise).to(DEVICE)
            z_outputs = D(G(z))
            G_loss = criterion(z_outputs, D_labels)

            D.zero_grad()
            G.zero_grad()
            G_loss.backward()
            G_opt.step()
        
        if step % 500 == 0:
            print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, max_epoch, step, D_loss.item(), G_loss.item()))
        
        if step % 1000 == 0:
            G.eval()
            img = get_sample_image(G, n_noise)
            imsave('./{}_step{}.jpg'.format(MODEL_NAME, str(step).zfill(3)), img, cmap='gray')
            G.train()
        step += 1

测试生成效果

# generation to image
G.eval()
imshow(get_sample_image(G, n_noise), cmap='gray')

 

模型和状态参量保存

def save_checkpoint(state, file_name='checkpoint.pth.tar'):
    torch.save(state, file_name)


# Saving params.
# torch.save(D.state_dict(), 'D_c.pkl')
# torch.save(G.state_dict(), 'G_c.pkl')
save_checkpoint({'epoch': epoch + 1, 'state_dict':D.state_dict(), 'optimizer' : D_opt.state_dict()}, 'D_dc.pth.tar')
save_checkpoint({'epoch': epoch + 1, 'state_dict':G.state_dict(), 'optimizer' : G_opt.state_dict()}, 'G_dc.pth.tar')

应用

DCGAN作为一个成熟的生成模型,在自然图像,医学图像,医学电生理信号数据分析中,都可以用来实现数据的合成,达到数据增强的目的,同时,如何减少增强数据对于后端任务的不利干扰,也是一个需要关注的方面。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/497764.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

scala-idea环境搭建及使用

环境搭建 创建一个新项目,选择maven工程 点击next,写入项目名,然后finish 注意:默认下,maven不支持scala的开发,需要引入scala框架,右键项目点击-》add framework pport....,在下图…

Unity 实现鼠标左键进行射击

发射脚本实现思路 分析 确定用户交互方式:通过鼠标左键点击发射子弹。确定子弹发射逻辑:每次点击后有一定时间间隔才能再次发射。确定子弹发射源和方向:子弹从枪口(Transform)位置发射,沿枪口方向前进。 变…

Spring Boot 工程开发常见问题解决方案,日常开发全覆盖

本文是 SpringBoot 开发的干货集中营,涵盖了日常开发中遇到的诸多问题,通篇着重讲解如何快速解决问题,部分重点问题会讲解原理,以及为什么要这样做。便于大家快速处理实践中经常遇到的小问题,既方便自己也方便他人&…

移动端开发思考:Uniapp的上位替代选择

文章目录 前言跨平台开发技术需求技术选型uniappFlutterMAUIAvalonia安卓原生 Flutter开发尝试Avalonia开发测试测试项目新建项目代码MainViewMainViewModel 发布/存档 MAUI实战,简单略过打包和Avalonia差不多 总结 前言 作为C# .NET程序员,我有一些移动…

Python图像处理——计算机视觉中常用的图像预处理

概述 在计算机视觉项目中,使用样本时经常会遇到图像样本不统一的问题,比如图像质量,并非所有的图像都具有相同的质量水平。在开始训练模型或运行算法之前,通常需要对图像进行预处理,以确保获得最佳的结果。图像预处理…

MySQL---触发器

一、介绍 触发器是与表有关的数据库对象,指在insert/update/delete之前(BEFORE)或之后(AFTER),触发并执行触发器中定义的SQL语句集合。触发器的这种特性可以协助应用在数据库端确保数据的完整性, 日志记录 , 数据校验等操作 。 使用别名OLD和NEW来引用触…

Ubuntu20.04安装OpenCV并在vsCode中配置

1. 安装OpenCV 1.1 安装准备: 1.1.1 安装cmake sudo apt-get install cmake 1.1.2 依赖环境 sudo apt-get install build-essential libgtk2.0-dev libavcodec-dev libavformat-dev libjpeg-dev libswscale-dev libtiff5-dev sudo apt-get install libgtk2.0-d…

go的通信Channel

go的通道channel是用于协程之间数据通信的一种方式 一、channel的结构 go源码:GitHub - golang/go: The Go programming language src/runtime/chan.go type hchan struct {qcount uint // total data in the queue 队列中当前元素计数,…

Day54:WEB攻防-XSS跨站Cookie盗取表单劫持网络钓鱼溯源分析项目平台框架

目录 XSS跨站-攻击利用-凭据盗取 XSS跨站-攻击利用-数据提交 XSS跨站-攻击利用-flash钓鱼 XSS跨站-攻击利用-溯源综合 知识点: 1、XSS跨站-攻击利用-凭据盗取 2、XSS跨站-攻击利用-数据提交 3、XSS跨站-攻击利用-网络钓鱼 4、XSS跨站-攻击利用-溯源综合 漏洞原理…

智慧管道物联网远程监控解决方案

智慧管道物联网远程监控解决方案 智慧管道物联网远程监控解决方案是近年来在智能化城市建设和工业4.0背景下,针对各类管道网络进行高效、安全、精准管理的前沿科技应用。它融合了物联网技术、大数据分析、云计算以及人工智能等多种先进技术手段,实现对管…

玫瑰图和雷达图(自备)

目录 玫瑰图 数据格式 绘图基础 绘图升级(文本调整) 玫瑰图 下载数据data/2020/2020-11-24 mirrors_rfordatascience/tidytuesday - 码云 - 开源中国 (gitee.com) R语言绘图—南丁格尔玫瑰图 - 知乎 (zhihu.com) 数据格式 rm(list ls()) libr…

Unity | 工具类-UV滚动

一、内置渲染管线Shader Shader"Custom/ImageRoll" {Properties {_MainTex ("Main Tex", 2D) "white" {}_Width ("Width", float) 0.5_Distance ("Distance", float) 0}SubShader {Tags {"Queue""Trans…

AugmentedReality之路-显示隐藏AR坐标原点(3)

本文介绍如何显示/隐藏坐标原点,分析AR坐标原点跟手机的位置关系 1、AR坐标原点在哪里 当我们通过AugmentedReality的StartARSession函数打开AR相机的那一刻,相机所在的位置就是坐标原点。 2、创建指示箭头资产 1.在Content/Arrow目录创建1个Actor类…

腾讯云4核8G服务器多少钱?12M带宽646元15个月,买1年送3月

2024年腾讯云4核8G服务器租用优惠价格:轻量应用服务器4核8G12M带宽646元15个月,CVM云服务器S5实例优惠价格1437.24元买一年送3个月,腾讯云4核8G服务器活动页面 txybk.com/go/txy 活动链接打开如下图: 腾讯云4核8G服务器优惠价格 轻…

记录minio、okhttp、kotlin一连环的版本冲突问题

问题背景 项目中需要引入minio&#xff0c;添加了如下依赖 <dependency><groupId>io.minio</groupId><artifactId>minio</artifactId><version>8.5.2</version></dependency> 结果运行报错&#xff1a; Caused by: java.la…

黑群晖基于docker配置frp内网穿透

前言 我的黑群晖需要设置一下内网穿透来外地访问&#xff0c;虽然zerotier的p2p组网已经很不错了&#xff0c;但是这个毕竟有一定的局限性&#xff0c;比如我是ios的国区id就下载不了zerotier的app&#xff0c;组网不了 1.下载镜像 选择第一个镜像 2.映射文件 配置frpc.ini&a…

基于Spring Boot 3 + Spring Security6 + JWT + Redis实现登录、token身份认证

基于Spring Boot3实现Spring Security6 JWT Redis实现登录、token身份认证。 用户从数据库中获取。使用RESTFul风格的APi进行登录。使用JWT生成token。使用Redis进行登录过期判断。所有的工具类和数据结构在源码中都有。 系列文章指路&#x1f449; 系列文章-基于Vue3创建前端…

【机器学习300问】55、介绍推荐系统中的矩阵分解算法是什么、有什么用、怎么用?

本来这篇文章我想先讲矩阵分解算法是什么东西的&#xff0c;但这样会陷入枯燥的定义中去&#xff0c;让原本非常有趣技术在业务场景中直观的使用路径被切断。所以我觉得先通过一个具体的推荐算法的例子&#xff0c;来为大家感性的介绍矩阵分解有什么用会更加合理。 如果你还不知…

iOS开发进阶(十一):ViewController 控制器详解

文章目录 一、前言二、UIViewController三、UINavigationController四、UITabBarController五、UIPageViewController六、拓展阅读 一、前言 iOS 界面开发最重要的首属ViewController和View&#xff0c;ViewController是View的控制器&#xff0c;也就是一般的页面&#xff0c;…

WordPress Git主题 响应式CMS主题模板

分享的是新版本&#xff0c;旧版本少了很多功能&#xff0c;尤其在新版支持自动更新后&#xff0c;该主题可以用来搭建个人博客&#xff0c;素材下载网站&#xff0c;图片站等 主题特点 兼容 IE9、谷歌 Chrome 、火狐 Firefox 等主流浏览器 扁平化的设计加响应式布局&#x…