进行生成简单数字图片

1.之前只能做一些图像预测,我有个大胆的想法,如果神经网络正向就是预测图片的类别,如果我只有一个类别那就可以进行生成图片,专业术语叫做gan对抗网络
在这里插入图片描述
2.训练代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as dset
import matplotlib.pyplot as plt
import os

# 设置环境变量
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

# 定义生成器模型
class Generator(nn.Module):
    def __init__(self, input_dim=100, output_dim=784):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, 1024)
        self.fc4 = nn.Linear(1024, output_dim)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.tanh(self.fc4(x))
        return x

# 定义判别器模型
class Discriminator(nn.Module):
    def __init__(self, input_dim=784, output_dim=1):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(input_dim, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, output_dim)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.sigmoid(self.fc4(x))
        return x

# 加载 MNIST 手写数字图片数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
dataroot = "path_to_your_mnist_dataset"  # 替换为 MNIST 数据集的路径
dataset = dset.MNIST(root=dataroot, train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)

# 创建生成器和判别器实例
input_dim = 100
output_dim = 784
generator = Generator(input_dim, output_dim)
discriminator = Discriminator(output_dim)

# 定义优化器和损失函数
lr = 0.0002
beta1 = 0.5
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
criterion = nn.BCELoss()

# 训练 GAN 模型
num_epochs = 50
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)
generator.to(device)
discriminator.to(device)
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        real_images, _ = data
        real_images = real_images.to(device)
        batch_size = real_images.size(0)  # 获取批次样本数量

        # 训练判别器
        optimizer_d.zero_grad()
        real_labels = torch.full((batch_size, 1), 1.0, device=device)
        fake_labels = torch.full((batch_size, 1), 0.0, device=device)
        noise = torch.randn(batch_size, input_dim, device=device)
        fake_images = generator(noise)
        real_outputs = discriminator(real_images.view(batch_size, -1))
        fake_outputs = discriminator(fake_images.detach())
        d_loss_real = criterion(real_outputs, real_labels)
        d_loss_fake = criterion(fake_outputs, fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_d.step()

        # 训练生成器
        optimizer_g.zero_grad()
        noise = torch.randn(batch_size, input_dim, device=device)
        fake_images = generator(noise)
        fake_outputs = discriminator(fake_images)
        g_loss = criterion(fake_outputs, real_labels)
        g_loss.backward()
        optimizer_g.step()

        # 输出训练信息
        if i % 100 == 0:
            print("[Epoch %d/%d] [Batch %d/%d] [D loss: %.4f] [G loss: %.4f]"
                  % (epoch, num_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))

    # 保存生成器的权重和图片示例
    if epoch % 10 == 0:
        with torch.no_grad():
            noise = torch.randn(64, input_dim, device=device)
            fake_images = generator(noise).view(64, 1, 28, 28).cpu().numpy()
            fig, axes = plt.subplots(nrows=8, ncols=8, figsize=(12, 12), sharex=True, sharey=True)
            for i, ax in enumerate(axes.flatten()):
                ax.imshow(fake_images[i][0], cmap='gray')
                ax.axis('off')
            plt.subplots_adjust(wspace=0.05, hspace=0.05)
            plt.savefig("epoch_%d.png" % epoch)
            plt.close()
        torch.save(generator.state_dict(), "generator_epoch_%d.pth" % epoch)

3.测试模型的代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import save_image

# 定义生成器模型
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, 1024)
        self.fc4 = nn.Linear(1024, output_dim)

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = torch.tanh(self.fc4(x))
        return x

# 创建生成器模型
generator = Generator(input_dim=100, output_dim=784)

# 加载预训练权重
generator_weights = torch.load("generator_epoch_40.pth", map_location=torch.device('cpu'))

# 将权重加载到生成器模型
generator.load_state_dict(generator_weights)

# 生成随机噪声
noise = torch.randn(1, 100)

# 生成图像
fake_image = generator(noise).view(1, 1, 28, 28)

# 保存生成的图片
save_image(fake_image, "generated_image.png", normalize=False)

#测试结果,由于我的训练集是数字的,所以会生成各种各样的数字,下面明显的是1
在这里插入图片描述
#应该也是1
在这里插入图片描述

#再次运行,我也看不出来,不过只要我训练只有一个种类的问题就可以生成这个种类的图像
在这里插入图片描述
#搞定黑白图,那彩色图应该距离不远了,我需要改进的是把对抗网络的代码改为训练一个种类的图形,不过我感觉这种图形具有随机性,虽然通过训练我们得到了所有图像他们的规律,但是如果需要正常点的图片还是挺难的,就像是上面这张人都不一定知道他是什么东西(在没有颜色的情况下)总结就是精度不够,而且随机性太强了,现在普遍图片AI生成工具具有这个缺点(生成的物体可能会扭曲,挺阴间的),而且生成的图片速度慢,如果谁比较受益那一定是老黄(英伟达)哈哈哈
//比如下面这个图片生成视频的网站
https://app.runwayml.com/login

#每一帧看起来都没有问题,就是连起来变成视频不自然,如果有改进方法的话那可能需要引入重力/加速度/光处理 等等物理公式,来让图片更自然…
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/227170.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数据分析基础之《matplotlib(3)—散点图》

一、常见图形种类及意义 1、matplotlib能够绘制折线图、散点图、柱状图、直方图、饼图。我们需要知道不同的统计图的意义,以此来决定选择哪种统计图来呈现我们的数据 2、折线图plot 说明:以折线的上升或下降来表示统计数量的增减变化的统计图 特点&…

Django + Matplotlib:实现数据分析显示与下载为PDF或SVG

写作背景 首先,数据分析在当前的信息时代中扮演着重要的角色。随着数据量的增加和复杂性的提高,人们对于数据分析的需求也越来越高。 其次,笔者也确确实实曾经接到过一个这样的开发需求,甲方是一个医疗方面的科研团队&#xff0…

网络安全(五)--Linux 入侵检测分析技术

8. Linux 入侵检测分析技术 目标 了解入侵检测分析的基本方法掌握查看登录失败用户的方法掌握查阅历史命令的方法掌握检查系统开机自启服务的方法 8.1. 概述 最好的安全防护当然是“域敌于国门之外”, 通过安全防护技术,来保证当前主机不被非授权人员…

【链表Linked List】力扣-117 填充每个节点的下一个右侧节点指针II

目录 问题描述 解题过程 官方题解 问题描述 给定一个二叉树: struct Node {int val;Node *left;Node *right;Node *next; } 填充它的每个 next 指针,让这个指针指向其下一个右侧节点。如果找不到下一个右侧节点,则将 next 指针设置为 N…

IT行业软件数据文件传输安全与高效是如何保障的?

在当今迅速发展的科技世界中,云计算、大数据、移动互联网等信息技术正迎来蓬勃发展,IT行业正置身于一个全新的世界。数据不仅是最重要的资产,也是企业竞争力的核心所在。然而,如何缩短信息共享时间、高速流转数据、跨部门/跨区域协…

最新版本——Hadoop3.3.6单机版完全部署指南

大家好,我是独孤风,大数据流动的作者。 本文基于最新的 Hadoop 3.3.6 的版本编写,带大家通过单机版充分了解 Apache Hadoop 的使用。本文更强调实践,实践是大数据学习的重要环节,也能在实践中对该技术有更深的理解&…

企业计算机服务器中了mallox勒索病毒如何处理,Mallox勒索病毒解密

随着计算机技术的不断发展,越来越多的企业利用网络来提高工作效率,但随之而来的网络安全威胁也在不断增加,各种勒索病毒种类不断增加,给企业的数据安全带来严重的威胁,影响企业的生产业务开展。近期,云天数…

微信小程序js数组对象根据某个字段排序

一、排序栗子 注: 属性字段需要进行转换,如String类型或者Number类型 //升序排序 首元素(element1)在前 降序则(element1)元素在后 data data.sort((element1, element2) >element1.属性 - element2.属性 ); 二、代码 Page({/*** 页面的初始数据*/data: {user:…

【Flink系列三】数据流图和任务链计算方式

上文介绍了如何计算并行度和slot的数量,本文介绍Flink代码提交后,如何生成计算的DAG数据流图。 程序和数据流图 所有的Flink程序都是由三部分组成的:Source、Transformation和Sink。Source负责读取数据源,Transformation利用各种…

idea本地调试hadoop 遇到的几个问题

1.DEA对MapReduce的toString调用报错:Method threw ‘java.lang.IllegalStateException‘ exception. Cannot evaluate org.apache.hadoop.mapreduc 解决方法:关闭 IDEA 中的启用“ tostring() ”对象视图 2.代码和hdfs路径都对的情况下,程序…

【EI会议征稿】第三届密码学、网络安全和通信技术国际会议(CNSCT 2024)

第三届密码学、网络安全和通信技术国际会议(CNSCT 2024) 2024 3rd International Conference on Cryptography, Network Security and Communication Technology 随着互联网和网络应用的不断发展,网络安全在计算机科学中的地位越来越重要&…

MySQL 中Relay Log打满磁盘问题的排查方案

MySQL 中Relay Log打满磁盘问题的排查方案 引言: MySQL Relay Log(中继日志)是MySQL复制过程中的一个重要组件,它用于将主数据库的二进制日志事件传递给从数据库。然而,当中继日志不断增长并最终占满磁盘空间时&…

5组10个共50个音频可视化效果PR音乐视频制作模板

我们常常看到的图形跟着音乐跳动,非常有节奏感,那这个是怎么做到的呢?5组10个共50个音频可视化效果PR音乐视频制作模板满足你的制作需求。 PR音乐模板|10个音频可视化视频制作模板05 https://prmuban.com/36704.html 10个音频可视化视频制作…

论文阅读——Deformable ConvNets v2

论文:https://arxiv.org/pdf/1811.11168.pdf 代码:https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch 1. 介绍 可变形卷积能够很好地学习到发生形变的物体,但是论文观察到当尽管比普通卷积网络能够更适应物体形变&#xff…

华为OD机试 - 攀登者2(Java JS Python C)

题目描述 攀登者喜欢寻找各种地图,并且尝试攀登到最高的山峰。 地图表示为一维数组,数组的索引代表水平位置,数组的元素代表相对海拔高度。其中数组元素0代表地面。 例如:[0,1,2,4,3,1,0,0,1,2,3,1,2,1,0],代表如下图所示的地图,地图中有两个山脉位置分别为 1,2,3,4,5…

Qt 如何使用VTK显示点云

开发环境 ubuntu 20.04 VTK 8.2 编译VTK 下载源码 git clone --recursive https://gitlab.kitware.com/vtk/vtk.git 使用版本管理工具,切换版本到8.2 更改编译选项,这里使用cmake-gui进行配置 1、编译类型修改为Release 2、安装路径可以设置&#xf…

Appium 并行测试多个设备

一、前置说明 在自动化测试中,经常需要验证多台设备的兼容性,Appium可以用同一套测试运例并行测试多个设备,以达到验证兼容性的目的。 解决思路: 查找已连接的所有设备;为每台设备启动相应的Appium Server&#xff1b…

【头歌系统数据库实验】实验7 SQL的复杂多表查询-1

目录 第1关:求各颜色零件的平均重量 第2关:求北京和天津供应商的总个数 第3关:求各供应商供应的零件总数 第4关:求各供应商供应给各工程的零件总数 第5关:求重量大于所有零件平均重量的零件名称 第6关&#xff1…

字节开源的netPoll底层LinkBuffer设计与实现

字节开源的netPoll底层LinkBuffer设计与实现 为什么需要LinkBuffer介绍设计思路数据结构LinkBufferNodeAPI LinkBuffer读 API写 APIbook / bookAck api 小结 本文基于字节开源的NetPoll版本进行讲解,对应官方文档链接为: Netpoll对应官方文档链接 netPoll底层有一个…

IntelliJ IDEA开启git版本控制的简单教程

这篇文章想要分享一下怎么在IntelliJ IDEA开启版本控制,博主使用的是gitee,首先需要安装git,关于git的安装这里就不介绍了,很简单。 目录 创建git仓库 创建项目 开启版本控制 拉取项目 创建git仓库 首先,需要登录…
最新文章