生成式 AI:使用 Pytorch 通过 GAN 生成合成数据

导 读

生成对抗网络(GAN)因其生成图像的能力而变得非常受欢迎,而语言模型(例如 ChatGPT)在各个领域的使用也越来越多。这些 GAN 模型可以说是人工智能/机器学习目前主流的原因;

因为它向每个人(尤其是该领域之外的人)展示了机器学习所具有的巨大潜力。网上已经有很多关于 GAN 模型的资源,但其中大多数都集中在图像生成上。这些图像生成和语言模型需要复杂的空间或时间复杂性,这增加了额外的复杂性,使读者更难理解 GAN 的真正本质。

为了解决这个问题并使 GAN 更容易被更广泛的受众所接受,在本文的 GAN 模型示例中,我们将采取一种不同的、更实用的方法,重点关注生成数学函数的合成数据。

除了出于学习目的的简化之外,合成数据生成本身也变得越来越重要。数据不仅在业务决策中发挥着核心作用,而且数据驱动方法的用途也越来越多,比第一原理模型更受欢迎。

比如天气预报,第一个原理模型包括通过数值求解的纳维-斯托克斯方程的简化版本。然而,深度学习研究中进行天气预报的尝试在捕捉天气模式方面非常成功,并且一旦经过训练,运行起来会更容易、更快。

有需要的朋友关注公众号【小Z的科研日常】,获取更多内容

01、生成模型与判别模型

在机器学习中,理解判别模型和生成模型之间的区别非常重要,因为它们是 GAN 的关键组成部分:

判别模型:

判别模型侧重于将数据分类为预定义的类别,例如将狗和猫的图像分类为各自的类别。这些模型不是捕获整个分布,而是辨别不同类别的边界。它们输出 P(y|x)(类别概率,给定输入数据的 y,x),即它们回答给定数据点属于哪个类别的问题。

生成模型:

生成模型旨在理解数据的底层结构。与区分类别的判别模型不同,生成模型学习数据的整个分布。这些模型输出 p(x|y),即它们回答了给定指定类生成该特定数据点的可能性有多大的问题。

这两个模型之间的相互作用构成了 GAN 的基础。

02、GAN—结构和组件

GAN 的关键组件包括噪声向量、生成器和鉴别器。

生成器:生成真实数据

为了生成合成数据,生成器使用随机噪声向量作为输入。为了欺骗鉴别器,生成器的目的是学习真实数据的分布并生成无法与真实数据区分开的合成数据。这里的一个问题是,对于相同的输入,它总是会产生相同的输出(想象一个图像生成器产生真实的图像,但总是相同的图像,这不是很有用)。随机噪声向量将随机性注入到过程中,从而提供生成的输出的多样性。

鉴别器:辨别真假

鉴别器就像一位受过训练来区分真实数据和虚假数据的艺术评论家。它的作用是仔细检查收到的数据并为工作真实性分配概率分数。如果合成数据看起来与真实数据相似,则鉴别器分配高概率,否则分配低概率分数。

对抗性训练:动态决斗

生成器努力学习生成鉴别器无法与真实数据区分开的合成数据。同时,鉴别器还学习并提高区分真实与合成的能力。这种动态的训练过程促使两个模型提高技能。这两个模型总是相互竞争(因此被称为对抗性),并且通过这种竞争,两个模型都在各自的角色中变得非常出色。

03、Pytorch实现GAN

在此示例中,我们在 pytorch 中实现了一个可以生成合成数据的模型。对于训练,我们有一个具有以下形状的 6 参数数据集(所有参数都绘制为参数 1 的函数)。每个参数都经过精心选择,具有显着不同的分布和形状,以增加数据集的复杂性并模仿真实世界的数据。

定义 GAN 模型组件(生成器和判别器)

import torch
from torch import nn
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.nn.init as init
import pandas as pd
import numpy as np
from torch.utils.data import Dataset


# 定义单块功能
def FC_Layer_blockGen(input_dim, output_dim):
    single_block = nn.Sequential(
        nn.Linear(input_dim, output_dim),

        nn.ReLU()
    )
    return single_block
    
# 定义 GENERATOR
class Generator(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim),
            nn.Tanh()  
        )

    def forward(self, x):
        return self.model(x)
        
#定义单个判别块
def FC_Layer_BlockDisc(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.ReLU(),
        nn.Dropout(0.4)
    )
    
# 定义判别器

class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)
        
        
#定义训练参数
batch_size = 128
num_epochs = 500
lr = 0.0002
num_features = 6
latent_dim = 20

# 模型初始化
generator = Generator(noise_dim, num_features)
discriminator = Discriminator(num_features)

# 损失函数和优化器
criterion = nn.BCELoss()
gen_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
disc_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)

模型初始化和数据处理

file_path = 'SamplingData7.xlsx'
data = pd.read_excel(file_path)
X = data.values
X_normalized = torch.FloatTensor((X - X.min(axis=0)) / (X.max(axis=0) - X.min(axis=0)) * 2 - 1)
real_data = X_normalized


class MyDataset(Dataset):
    def __init__(self, dataframe):
        self.data = dataframe.values.astype(float)
        self.labels = dataframe.values.astype(float)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = {
            'input': torch.tensor(self.data[idx]),
            'label': torch.tensor(self.labels[idx])
        }
        return sample

# 创建数据集实例
dataset = MyDataset(data)

# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

def weights_init(m):
    if isinstance(m, nn.Linear):
        init.xavier_uniform_(m.weight)
        if m.bias is not None:
            init.constant_(m.bias, 0)

pretrained = False
if pretrained:
    pre_dict = torch.load('pretrained_model.pth')
    generator.load_state_dict(pre_dict['generator'])
    discriminator.load_state_dict(pre_dict['discriminator'])
else:
    # 应用权重初始化
    generator = generator.apply(weights_init)
    discriminator = discriminator.apply(weights_init)

模型训练

model_save_freq = 100

latent_dim =20
for epoch in range(num_epochs):
    for batch in dataloader:
        real_data_batch = batch['input']
        real_labels = torch.FloatTensor(np.random.uniform(0.9, 1.0, (batch_size, 1)))
        disc_optimizer.zero_grad()
        output_real = discriminator(real_data_batch)
        loss_real = criterion(output_real, real_labels)
        loss_real.backward()

        fake_labels = torch.FloatTensor(np.random.uniform(0, 0.1, (batch_size, 1)))
        noise = torch.FloatTensor(np.random.normal(0, 1, (batch_size, latent_dim)))
        generated_data = generator(noise)
        output_fake = discriminator(generated_data.detach())
        loss_fake = criterion(output_fake, fake_labels)
        loss_fake.backward()

        disc_optimizer.step()
 
        valid_labels = torch.FloatTensor(np.random.uniform(0.9, 1.0, (batch_size, 1)))
        gen_optimizer.zero_grad()
        output_g = discriminator(generated_data)
        loss_g = criterion(output_g, valid_labels)
        loss_g.backward()
        gen_optimizer.step()

    print(f"Epoch {epoch}, D Loss Real: {loss_real.item()}, D Loss Fake: {loss_fake.item()}, G Loss: {loss_g.item()}")

模型评估和可视化结果

import seaborn as sns

synthetic_data = generator(torch.FloatTensor(np.random.normal(0, 1, (real_data.shape[0], noise_dim))))

# 绘制结果
fig, axs = plt.subplots(2, 3, figsize=(12, 8))
fig.suptitle('Real and Synthetic Data Distributions', fontsize=16)

for i in range(2):
    for j in range(3):
        sns.histplot(synthetic_data[:, i * 3 + j].detach().numpy(), bins=50, alpha=0.5, label='Synthetic Data', ax=axs[i, j], color='blue')
        sns.histplot(real_data[:, i * 3 + j].numpy(), bins=50, alpha=0.5, label='Real Data', ax=axs[i, j], color='orange')
        axs[i, j].set_title(f'Parameter {i * 3 + j + 1}', fontsize=12)
        axs[i, j].set_xlabel('Value')
        axs[i, j].set_ylabel('Frequency')
        axs[i, j].legend()

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()


#创建 2x3 网格的子绘图
fig, axs = plt.subplots(2, 3, figsize=(15, 10))
fig.suptitle('Comparison of Real and Synthetic Data', fontsize=16)

# Define parameter names
param_names = ['Parameter 1', 'Parameter 2', 'Parameter 3', 'Parameter 4', 'Parameter 5', 'Parameter 6']

# 各参数的散点图
for i in range(2):
    for j in range(3):
        param_index = i * 3 + j
        sns.scatterplot(real_data[:, 0].numpy(), real_data[:, param_index].numpy(), label='Real Data', alpha=0.5, ax=axs[i, j])
        sns.scatterplot(synthetic_data[:, 0].detach().numpy(), synthetic_data[:, param_index].detach().numpy(), label='Generated Data', alpha=0.5, ax=axs[i, j])
        axs[i, j].set_title(param_names[param_index], fontsize=12)
        axs[i, j].set_xlabel(f'Real Data - {param_names[param_index]}')
        axs[i, j].set_ylabel(f'Real Data - {param_names[param_index]}')
        axs[i, j].legend()

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/448144.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

RK3568 xhci主控挂死问题

串口日志 rootjenet:~# [18694.115430] xhci-hcd xhci-hcd.1.auto: xHCI host not responding to stop endpoint command. [18694.125667] xhci-hcd xhci-hcd.1.auto: xHCI host controller not responding, assume dead [18694.125977] xhci-hcd xhci-hcd.1.auto: HC died; c…

微软模拟飞行器回放功能

参考b站up主,欢迎大家去关注:https://www.bilibili.com/video/BV1Z34y1P7zz/?spm_id_from333.880.my_history.page.click&vd_source4e0b40493e2382633fab2ddc1bb1d9cc 下载网址:https://flightsim.to/file/8163/flight-recorder 坠毁检…

嘿!AI 编码新玩法上线!

随着 AI 智能浪潮到来,AI 编码助手成为越来越多开发者的必备工具,将开发者从繁重的编码工作中解放出来,极大地提高了编程效率,帮助开发者实现更快、更好的代码编写。 通义灵码正是这样一款基于阿里云通义代码大模型打造的智能编码…

如何保证消息的顺序性

先看看顺序会错乱的场景:RabbitMQ:一个 queue,多个 consumer,这不明显乱了: 解决方案:

代码背后的女性:突破性别壁垒的技术先驱

个人主页:17_Kevin-CSDN博客 收录专栏:《程序人生》 引言 在计算机科学的历史长河中,有许多杰出的女性为这个领域的发展做出了重要贡献。她们不仅在技术上取得了卓越成就,还打破了性别壁垒,为后来的女性树立了榜样。今…

22 Dytechlab Cup 2022C. Ela and Crickets(思维、找规律、模拟)

思路就是找规律 可以发现,当拐点在角落时的情况和不在角落的情况是不同 当拐点在角落时,只有目标点的横纵坐标其中的一个和它相同时,这时才可能到达。 否则,我们就简单的例子可以看一下,当一个 2 ∗ 2 2*2 2∗2的矩阵的…

伪分布HBase的安装与部署

1.实训目标 (1)熟悉掌握使用在Linux下安装伪分布式HBase。 (2)熟悉掌握使用在HBase伪分布式下使用自带Zookeeper。 2.实训环境 环境 版本 说明 Windows 10系统 64位 操作电脑配置 VMware 15 用于搭建所需虚拟机Linux系统 …

PostgreSQL容器安装

docker中的centos7中安装 选择对应的版本然后在容器中的centos7中执行下面命令 但是启动容器的时候需要注意 开启端口映射开启特权模式启动init进程 docker run -itd --name centos-postgresql -p 5433:5432 --privilegedtrue centos:centos7 /usr/sbin/init 启动然后进入后先…

ARMv8/ARMv9架构下特权程序之间的跳转模型与系统启动探析

文章目录 背景1、前言小结: 2、4个特权等级/4个安全状态之间的跳转模型小结: 3、启动时镜像之间的跳转模型小结: 4、runtime程序之间的跳转模型小结: 推荐 背景 ARMv8和ARMv9架构是ARM公司推出的先进处理器架构,被广泛…

华为ce12800交换机m-lag(V-STP模式)配置举例

配置## 标题思路 采用如下的思路配置M-LAG双归接入IP网络: 1.在Switch上配置上行接口绑定在一个Eth-Trunk中。 2.分别在SwitchA和SwitchB上配置V-STP、DFS Group、peer-link和M-LAG接口。 3.分别在SwitchA和SwitchB上配置LACP M-LAG的系统优先级、系统ID。 4.分别在…

粒子群算法优化RBF神经网络气体浓度预测

目录 完整代码和数据下载链接:粒子群算法优化RBF神经网络气体浓度预测,pso-rbf气体浓度预测(代码完整,数据齐全)资源-CSDN文库 https://download.csdn.net/download/abc991835105/88937920 RBF的详细原理 RBF的定义 RBF理论 易错及常见问题 RBF应用实例,粒子群算法优化R…

植物病害识别:YOLO水稻病害识别数据集(1000多张,3个类别,yolo标注)

YOLO水稻病害识别数据集,包含水稻白叶枯病、稻瘟病、水稻褐斑病3个常见病害类别,共1000多张图像,yolo标注完整,可直接训练。 适用于CV项目,毕设,科研,实验等 需要此数据集或其他任何数据集请私…

antv L7结合高德地图使用dome1

antv L7结合高德地图使用 一、设置底图二 、添加antv L7 中要使用的dome1. 安装L7 依赖2. 使用的dome 、以下使用的是浮动功能3. 运行后显示 自定义样式修改1. 设置整个中国地图浮动起来 自定义标注点1. 静态标注点2. 动态标注点(点位置需要自己改)3. 完…

手机群控软件开发必备源代码分享!

随着移动互联网的飞速发展,手机群控技术在市场推广、自动化测试、应用管理等领域的应用越来越广泛,手机群控软件作为一种能够同时控制多台手机设备的工具,其开发过程中,源代码的编写显得尤为重要。 1、设备连接与识别模块 设备连…

springboot学习笔记2

springmvc响应数据 页面跳转控制 开发模式介绍 快速返回逻辑视图 jsp页面创建 配置jsp视图解析器 mvc初始化 handler返回视图 转发和重定向实现 返回json数据(重点 静态资源处理 RestFull风格设计和实战 风格介绍 实战

力扣--76. 最小覆盖子串

给你一个字符串 s 、一个字符串 t 。返回 s 中涵盖 t 所有字符的最小子串。如果 s 中不存在涵盖 t 所有字符的子串,则返回空字符串 "" 。 注意: 对于 t 中重复字符,我们寻找的子字符串中该字符数量必须不少于 t 中该字符数量。如…

水电站泄洪闸预警系统技术改造项目方案

一、工期安排 2024年1月10日至1月30日,共20天,水电站泄洪闸预警系统建设项目主要以计划工作任务为依据开展并控制工期。 二、预警系统建设项目 水电站泄洪闸预警系统技术改造项目实施内容主要是在每个确定后的预警广播站点采用基础开挖预制地笼浇筑混凝…

WebPack自动吐出脚本

window.c c; window.res ""; window.flag false;c function (r) {if (flag) {window.res window.res "${r.toString()}" ":" (e[r] "") ",";}return window.c(r); }代码改进了一下,可以过滤掉重复的方…

【零基础学习01】嵌入式linux驱动中pinctrl和gpio子系统实现

大家好,为了进一步提升大家对实验的认识程度,每个控制实验将加入详细控制思路与流程,欢迎交流学习。 今天给大家分享一下,linux系统里面pinctrl和gpio子系统控制实验,操作硬件为I.MX6ULL开发板。 第一:pinctrl和gpio子系统简介 Linux系统是一个庞大又完善的系统,如果采用…

基于斑翠鸟优化算法(Pied Kingfisher Optimizer ,PKO)的无人机三维路径规划(MATLAB)

一、无人机路径规划模型介绍 二、算法介绍 斑翠鸟优化算法(Pied Kingfisher Optimizer ,PKO),是由Abdelazim Hussien于2024年提出的一种基于群体的新型元启发式算法,它从自然界中观察到的斑翠鸟独特的狩猎行为和共生关系中汲取灵…
最新文章