【GAN】简单的GAN模型搭建 -- 以线性模型和MNIST数据集为例子

文章目录

  • 确定损失函数
  • 生成器网络架构

不讲原理,从简单的代码一步步开始,学会怎么用、怎么设计损失函数即可。

确定损失函数

生成器的任务是生成足够以假乱真的数据,判别器的任务是分辨出哪些数据是真实的,哪些数据是假的。因此,对于判别器来讲,需要判别真伪,也就是true/false,从这个角度看,是个二分类问题。所以损失函数使用二类分类损失,即BCELoss。

import torch
import torch.nn as nn

adversarial_loss = nn.BCELoss()

生成器网络架构

这里使用纯线性网络作为生成器,得到的输出为[batch_size, np.pord(28*28)]

import torch.nn as nn
import numpy as np

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_features, out_features, normalization=True):
            layers = [nn.Linear(in_features, out_features)]
            if normalization:
                layers.append(nn.BatchNorm1d(out_features, 0.8))
            layers.append(nn.LeakyReLU(0.2))
            return layers

        self.model = nn.Sequential(
            *block(100, 128, normalization=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod((1, 28, 28)))), # generate a photo size, but in line mode.
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

判别器网络架构:
判别器的功能为判断出来哪一个是生成的图片,哪一个是真实的图片。对于生成的图片,我们希望判别器打上假的标签,对于真实的图片,我们希望判别器打上真的标签,因此,判别器的输出为一个数,即0或者1。

import torch.nn as nn
import numpy as np
class Disctiminator(nn.Module):
    def __init__(self):
        super(Disctiminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod((1, 28, 28))), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        img_flat = x.view(x.size(0), -1)
        validity = self.model(img_flat)
        return validity

训练流程:

载入数据 — 训练 (生成图片 — 损失 — 反向传播) — 测试(这里没有加测试代码,可以照着训练代码改一下)

损失函数:生成器损失函数和判别器损失函数,两个损失函数分别进行反向传播,即生成器损失函数优化生成器,判别器损失函数优化判别器。

import torch
import torch.nn as nn
import argparse
import os
import numpy as np
from torch.utils.data import DataLoader, Dataset, dataset
from torchvision import datasets
from torchvision.transforms import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
from models.generator import Generator
from models.dicsriminator import Disctiminator





os.makedirs('/home/sjr/gxj/study/data/mnist', exist_ok=True)

dataloader = DataLoader(datasets.MNIST('/home/sjr/gxj/study/data/mnist',
                                       train=True, download=True,
                                       transform=transforms.Compose([transforms.Resize(28), transforms.ToTensor(),
                                                                     transforms.Normalize([0.5], [0.5])])
                                       ),
                        batch_size=64, shuffle=True, num_workers=4)

adversarial_loss = torch.nn.BCELoss()
generator = Generator()
discriminator = Disctiminator()

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
if torch.cuda.is_available():
    adversarial_loss.cuda(device)
    generator.cuda(device)
    discriminator.cuda(device)

optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

for epoch in range(60):
    for i, (imgs, _) in enumerate(dataloader):
        valid = Tensor(imgs.size(0), 1).fill_(1.0)
        fake = Tensor(imgs.size(0), 1).fill_(0.0)
        real_imgs = Variable(imgs.type(Tensor))

        optimizer_G.zero_grad()

        z = Tensor(np.random.normal(0, 1, (imgs.shape[0], 100)))
        gen_imgs = generator(z)


        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        optimizer_D.zero_grad()

        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()
        #
        print(f"[Epoch {epoch}/{200}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")
        #
        if (epoch + 1) % 20 == 0:
            save_image(gen_imgs.data[:25], f'images/{epoch+1}.png', nrow=5, normalize=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/609079.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

科林算法_3 图

一、图论基础 多对多的关系 定义&#xff1a;G(V,E) Vertex顶点 Edge边 顶点的集合V{v1,v2} 边的结合E{(v1,v2)} 无向图(1,2) 有向图<1,2> 依附&#xff1a;边(v1,v2)依附于顶点v1,v2 路径&#xff1a;&#xff08;v1,v2)(v2,v3) 无权路径最短&#xff1a;边最少…

深入了解 Flask Request

文章目录 获取请求数据获取请求信息文件上传总结 Flask 是一个轻量级的 Python Web 框架&#xff0c;其简洁的设计和灵活的扩展性使其成为了许多开发者的首选。在 Flask 中&#xff0c;处理 HTTP 请求是至关重要的&#xff0c;而 Flask 提供了丰富而强大的 request 对象来处理…

【限时免费,手慢无】Unity 怪物资源包,MONSTER 动作超丰富,不领后悔!

Unity 怪物资源包&#xff0c;MONSTER 动作超丰富 前言资源包内容领取兑换码 前言 &#x1f47e; 突破想象&#xff01;惊艳众人的怪物模型登场 &#x1f47e; 今天要向大家介绍一款令人瞩目的游戏怪物模型&#xff01;这个看似丑陋的小怪物&#xff0c;却有着巨大的潜力&…

代码随想录刷题随记31-贪心5

代码随想录刷题随记31-贪心5 435. 无重叠区间 leetcode链接 按照右边界排序&#xff0c;从左向右记录非交叉区间的个数。 此时问题就是要求非交叉区间的最大个数。 这里记录非交叉区间的个数还是有技巧的&#xff0c;如图&#xff1a; 左边界排序可不可以呢&#xff1f; 也是…

前缀索引与单列联合索引的选择

&#x1f4dd;个人主页&#xff1a;五敷有你 &#x1f525;系列专栏&#xff1a;面经 ⛺️稳中求进&#xff0c;晒太阳 前缀索引 当字段类型为字符串(varchar,text等) 时&#xff0c;有时候需要索引很长的字符串&#xff0c;这会让索引变得很大。查询的时候浪费大量的磁…

能恢复永久删除文件的十大数据恢复软件

当您不小心删除了重要数据&#xff0c;或者由于病毒攻击而丢失了重要数据时&#xff0c;请不要惊慌&#xff0c;我们已经为您准备好了。别无他处&#xff0c;这是您目前市场上最佳数据恢复软件列表的一站式目的地。 能恢复永久删除文件的十大数据恢复软件 1. 奇客数据恢复 这是…

教大家一键下载1688图片信息

电商图片是商品的视觉身份证&#xff0c;对销量有着决定性影响。它们不仅是展示产品&#xff0c;更是讲述品牌故事&#xff0c;激发情感共鸣的工具。高质量的图片能瞬间吸引顾客目光&#xff0c;准确传达产品细节&#xff0c;建立起顾客的信任与购买意愿。在无法亲身体验商品的…

使用网站内容进行多渠道品牌营销的主要优势

在选择服务提供商时&#xff0c;人们使用不同的方式来查找信息并与他们联系。有些人更喜欢网站&#xff0c;有些人则使用社交媒体或电子邮件。网站对于数字存在仍然至关重要&#xff0c;但跨多个渠道管理内容现在至关重要。多渠道营销以客户喜欢的方式与客户建立联系&#xff0…

mysql安装及基础设置

关系型数据库 MySQL是一种关系型数据库管理系统&#xff0c;采用了关系模型来组织数据的数据库&#xff0c;关系数据库将数据保存在不同的表中&#xff0c;用户通过查询 sql 来检索数据库中的数据。 yum 方式安装 mysql # yum -y install mysql-server # systemctl start my…

##12 深入了解正则化与超参数调优:提升神经网络性能的关键策略

文章目录 前言1. 正则化技术的重要性1.1 L1和L2正则化1.2 Dropout1.3 批量归一化 2. 超参数调优技术2.1 网格搜索2.2 随机搜索2.3 贝叶斯优化 3. 实践案例3.1 设置实验3.2 训练和测试 4. 结论 前言 在深度学习中&#xff0c;构建一个高性能的模型不仅需要一个好的架构&#xf…

《这就是ChatGPT》读书笔记

书名&#xff1a;这就是ChatGPT 作者&#xff1a;[美] 斯蒂芬沃尔弗拉姆&#xff08;Stephen Wolfram&#xff09; ChatGPT在做什么&#xff1f; ChatGPT可以生成类似于人类书写的文本&#xff0c;它基本任务是弄清楚如何针对它得到的任何文本产生“合理的延续”。当ChatGPT写…

Redis-新数据类型-Geospatia

新数据类型-Geospatia 简介 GEO&#xff0c;Geographic,地理信息的缩写。 该类型就是元素的二维坐标&#xff0c;在地图上就是经纬度。Redis基于该类型&#xff0c;提供了经纬度设置、查询、范围 查询、距离查询、经纬度Hash等常见操作。 常用命令 geoadd key longitude lat…

python循环结构练习

目录 前言 1、使用while实现模拟用户登录 1.1 题目要求 1.2 解题 2、输入数字&#xff0c;生成对应等腰三角形 2.1 题目要求 2.2 解题 3、输入数字&#xff0c;生成对应的菱形 3.1 题目要求 3.2 解题 总结 前言 本系列主要是讲解python中的题目&#xff0c;目的是为…

c++ libtorch使用cmake建立

如果我们安装好pytorch&#xff0c;其实不一定一定要安装libtorch&#xff0c;默认都已经安装过了 1 进入pytorch conda env list conda activate pytorch 命令行下使用 python -c 来获取libtorch的基本信息&#xff0c; python -c "import torch;print(torch.utils.c…

uniapp:抖音PK进度条(nvue)

nvue中,仿抖音PK进度条效果, <template><view class="index" :style="{width:windowWidth+px,height:index_windowHeight+px,paddingTop:windowTop+px}"><view class="pk"><text class="pk_jindu_left_val fsz-24 …

������森林消防泵:特点及使用场景 /恒峰智慧科技������

在大自然的怀抱中&#xff0c;森林是地球上最美丽的绿色家园。然而&#xff0c;森林火灾却时常威胁着这片美丽的家园。为了保护森林&#xff0c;我们有幸拥有了一支强大的森林消防泵队伍。今天&#xff0c;就让我们一起来了解一下这些守护者们的故事吧&#xff01; &#xfffd…

OpenHarmony 3.2 Release版本实战开发——Codec HDI适配过程

简介 OpenHarmony Codec HDI&#xff08;Hardware Device Interface&#xff09;驱动框架基于 OpenMax 实现了视屏硬件编解码驱动&#xff0c;提供 Codec 基础能力接口供上层媒体服务调用&#xff0c;包括获取组件编解码能力、创建组件、参数设置、数据的轮转和控制、以及销毁…

晶片的厚度会影响晶振的频率吗?

晶振&#xff0c;是一种能够产生稳定频率的电子元件&#xff0c;广泛应用于各种电子设备中。晶振的频率参数&#xff0c;即其振荡产生的频率大小&#xff0c;是晶振性能的重要指标之一。石英晶体的切割方式显得至关重要。不同的切割方式&#xff0c;如AT-cut、CT-cut、SC-cut等…

文心一言 VS 讯飞星火 VS chatgpt (255)-- 算法导论18.1 1题

一、请说明依次从图18-8(f)中删除C、P和V后的结果。如果要写代码&#xff0c;请用go语言。 文心一言&#xff1a; 首先&#xff0c;让我们明确一点&#xff1a;由于您没有提供具体的图18-8(f)的内容&#xff0c;我只能假设它是一个数据结构&#xff08;可能是图、树或列表等&…

1011: 二叉排序树的实现和查找

解法&#xff1a; 二叉排序树&#xff08;Binary Search Tree&#xff0c;简称BST&#xff09;也被称为二叉搜索树或二叉查找树&#xff0c;是一种重要的二叉树结构&#xff0c;它具有以下性质&#xff1a; 左子树上所有节点的值都小于根节点的值&#xff1b;右子树上所有节点的…
最新文章