【Pytorch】学习记录分享11——GAN对抗生成网络

PyTorch GAN对抗生成网络

      • 0. 工程实现
      • 1. GAN对抗生成网络结构
      • 2. GAN 构造损失函数(LOSS)
      • 3. GAN对抗生成网络核心逻辑
        • 3.1 参数加载:
        • 3.2 生成器:
        • 3.3 判别器:

0. 工程实现

原理解析:
论文解析:GAN:Generative Adversarial Nets

1. GAN对抗生成网络结构

在这里插入图片描述

2. GAN 构造损失函数(LOSS)

LOSS公式与含义:
在这里插入图片描述
LOSS代码实现:

import torch
from torch import autograd
input = autograd.Variable(torch.tensor([[ 1.9072,  1.1079,  1.4906],
        [-0.6584, -0.0512,  0.7608],
        [-0.0614,  0.6583,  0.1095]]), requires_grad=True)
print(input)
print('-'*100)

from torch import nn
m = nn.Sigmoid()
print(m(input))
print('-'*100)

target = torch.FloatTensor([[0, 1, 1], [1, 1, 1], [0, 0, 0]])
print(target)
print('-'*100)

import math

r11 = 0 * math.log(0.8707) + (1-0) * math.log((1 - 0.8707))
r12 = 1 * math.log(0.7517) + (1-1) * math.log((1 - 0.7517))
r13 = 1 * math.log(0.8162) + (1-1) * math.log((1 - 0.8162))

r21 = 1 * math.log(0.3411) + (1-1) * math.log((1 - 0.3411))
r22 = 1 * math.log(0.4872) + (1-1) * math.log((1 - 0.4872))
r23 = 1 * math.log(0.6815) + (1-1) * math.log((1 - 0.6815))

r31 = 0 * math.log(0.4847) + (1-0) * math.log((1 - 0.4847))
r32 = 0 * math.log(0.6589) + (1-0) * math.log((1 - 0.6589))
r33 = 0 * math.log(0.5273) + (1-0) * math.log((1 - 0.5273))

r1 = -(r11 + r12 + r13) / 3
#0.8447112733378236
r2 = -(r21 + r22 + r23) / 3
#0.7260397266631787
r3 = -(r31 + r32 + r33) / 3
#0.8292933181294807
bceloss = (r1 + r2 + r3) / 3 
print(bceloss)
print('-'*100)

loss = nn.BCELoss()
print(loss(m(input), target))
print('-'*100)

loss = nn.BCEWithLogitsLoss()
print(loss(input, target))

loss BCEloss代码逐行运行结果
在这里插入图片描述

3. GAN对抗生成网络核心逻辑

整个简单的gan网络代码由以下三个部分组成:

3.1 参数加载:
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=128, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False

3.2 生成器:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img
3.3 判别器:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity


# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Configure data loader
os.makedirs("./data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))


Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/298830.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

综合跨平台全端ui自动化测试框架Airtest——AirtestIDE录制微信小程序脚本教学

前言 有在自动化测试领域的小伙伴应该都知道,app和小程序自动化这一类的自动化测试在实际操作中有时候很棘手让人心烦,动不动就是用appium写代码脚本维护什么的,不仅步骤繁琐,环境配置方面也是繁琐无比,动不动就与客户…

云计算:OpenStack 分布式架构管理VXLAN网络(单控制节点与多计算节点)

目录 一、实验 1.环境 2.各节点新增网卡准备VXLAN网络 3.控制节点配置私有网络 4.计算节点1配置私有网络 5.计算节点2配置私有网络 6.重启服务 7.修改Dashboard 8.新建项目(租户)及用户 9.新建网络与子网 10.新建实例 11.新建路由 12.新增浮…

【机器学习】循环神经网络(二)-LSTM示例(keras)国际航空乘客问题的回归问题...

使用 Keras 在 Python 中使用 LSTM 循环神经网络进行时间序列预测 国际航空乘客问题的回归问题 这个文件是一个CSV格式的数据集,它包含了从1949年1月到1960年12月的每个月的国际航空乘客的总数(以千为单位)。第一行是列名,分别是&…

贯穿设计模式-享元模式思考

写享元模式的时候,会想使用ConcurrentHashMap来保证并发,没有使用双重锁会不会有问题?但是在synchronize代码块里面需要尽量避免throw异常,希望有经验的同学能够给出解答? 1月6号补充:没有使用双重锁会有问…

Robot Operating System 2: Design, Architecture, and Uses In The Wild

Robot Operating System 2: Design, Architecture, and Uses In The Wild (机器人操作系统 2:设计、架构和实际应用) 摘要:随着机器人在广泛的商业用例中的部署,机器人革命的下一章正在顺利进行。即使在无数的应用程序和环境中,也…

Python爬虫-大麦网演出数据和票价数据

前言 本文是该专栏的第14篇,后面会持续分享python爬虫干货知识,记得关注。 本文以大麦网为例,获取大麦网全部的演出数据以及对应的票价数据。示例图如下所示: 如上图所示,笔者将在本文详细介绍通过python爬虫去获取全国的“演唱会,话剧歌剧,体育比赛,儿童亲子”等等以…

Odoo | Module | 统计系统周期使用人数/当前在线人数

文内材料 GITHUB地址 前言介绍 Odoo作为开源ERP系统的No.01,近年愈发的得到国内很多公司的关注。 虽然它的定位是中小型企业的ERP管理系统,但是在几年的Odoo开发实施过程中,有不足50人的小型企业,也有上万人的中大型企业。功能快速落地和…

即时设计:设计流程图,让您的设计稿更具条理和逻辑

流程图小助手 在设计工作中,流程图是一种重要的工具,它可以帮助设计师清晰地展示设计思路和流程,提升设计的条理性和逻辑性。今天,我们要向您推荐一款强大的设计工具,它可以帮助您轻松为设计稿设计流程图,让…

c#调试程序一次启动两个工程(多个工程)

概述 c# - Visual Studio : debug multiple projects at the same time? 以在解决方案中设置多个启动项目(右键单击解决方案,转到设置启动项目,选择多个启动项目),并为包含在解决方案(无、开始、不调试就开始)。如果您将多个项目设置为开始…

IDEA TODO

今天记录一个 IDEA 工具的小技巧, TODO。比如下班前有一个小功能没完善好,此时可以在响应代码上加上 TODO 注解, //密码比对 // TODO 后期需要进行md5加密,然后再进行比对 password DigestUtils.md5DigestAsHex(password.getByt…

Jenkins修改全局maven配置后不生效解决办法、以及任务读取不同的settings.xml文件配置

一、修改Global Tool Configuration的maven配置不生效 说明:搭建好jenkins后,修改了全局的settings.xml,导致读取settings一直是之前配置的。 解决办法一 Jenkins在创建工作任务时,会读取当前配置文件内容,固定在这…

【Leetcode】230. 二叉搜索树中第K小的元素

一、题目 1、题目描述 给定一个二叉搜索树的根节点 root ,和一个整数 k ,请你设计一个算法查找其中第 k 个最小元素(从 1 开始计数)。 示例1: 输入:root = [3,1,4,null,2], k = 1 输出:1示例2: 输入:root = [5,3,6,2,4,null,null,1], k = 3 输出:3提示: 树中…

关于CNN卷积神经网络与Conv2D标准卷积的重要概念

温故而知新,可以为师矣! 一、参考资料 深入解读卷积网络的工作原理(附实现代码) 深入解读反卷积网络(附实现代码) Wavelet U-net进行微光图像处理 卷积知识点 CNN网络的设计论:NAS vs Handcra…

【数据库】视图索引执行计划多表查询面试题

文章目录 一、视图1.1 概念1.2 视图与数据表的区别1.3 优点1.4 语法1.5 实例 二、索引2.1 什么是索引2.2.为什么要使用索引2.3 优缺点2.4 何时不使用索引2.5 索引何时失效2.6 索引分类2.6.1.普通索引2.6.2.唯一索引2.6.3.主键索引2.6.4.组合索引2.6.5.全文索引 三、执行计划3.1…

【leetcode】字符串中的第一个唯一字符

题目描述 给定一个字符串 s ,找到 它的第一个不重复的字符,并返回它的索引 。如果不存在,则返回 -1 。 用例 示例 1: 输入: s “leetcode” 输出: 0 示例 2: 输入: s “loveleetcode” 输出: 2 示例 3: 输入: s “aabb”…

【大数据进阶第三阶段之Datax学习笔记】阿里云开源离线同步工具Datax类图

【大数据进阶第三阶段之Datax学习笔记】阿里云开源离线同步工具Datax概述 【大数据进阶第三阶段之Datax学习笔记】阿里云开源离线同步工具Datax快速入门 【大数据进阶第三阶段之Datax学习笔记】阿里云开源离线同步工具Datax类图 【大数据进阶第三阶段之Datax学习笔记】使用…

FineBI:简介

1 介绍 FineBI 是帆软软件有限公司推出的一款商业智能(Business Intelligence)产品。 FineBI 是定位于自助大数据分析的 BI 工具,能够帮助企业的业务人员和数据分析师,开展以问题导向的探索式分析。 2 现阶段数据分析弊端 现阶…

系列七、Typora安装 配置

一、安装 1.1、下载安装包 我分享的链接: 链接:https://pan.baidu.com/s/1K5DjV_xhCH5WGiiEHlNQVQ?pwdyyds 提取码:yyds 1.2、安装 无脑下一步,下一步即可。 二、Typora中设置插入的图片左对齐 2.1、背景 往Typora中插入图…

开启Android学习之旅-4-Android集成FontAwesome

FontAwesome 是一个非常标准、统一风格的图标库。产品经理在原型中应用了很多图标都是FontAwesome。正常流程是 UI 需要再手工绘制或在 iconfont 或 iconpark 网站挨个找,如果在 Android 直接使用不是省了一步(注意版权问题,使用免费版&#…

echart图表

首先我们要知道ECharts是什么,它是怎么用的? ECharts是一个使用JavaScript实现的开源可视化库,它涵盖各行业图表,满足各种需求。它提供了丰富的图表类型和交互能力,使用户能够通过国简单的配置生成各种各样的图表,包括…