时间序列生成数据,TransformerGAN

        简介:这个代码可以用于时间序列修复和生成。使用transformer提取单变量或者多变时间窗口的趋势分布情况。然后使用GAN生成分布类似的时间序列。

        此外,还实现了基于prompt的数据生成,比如指定生成某个月份的数据、某半个月的数据、某一个星期的数据。

1、模型架构

        如下图所示,生成器和鉴别器都使用Transformer的编码器部分提取时间序列的特征,然后鉴别器使用这些进行二分类、生成器使用这些特征生成伪造的数据。

        重点:在下面的图的基础上,我还添加了基于提示的生成代码,类似于AI提示绘画一样,因此可以指定生成一月份、二月份等任意指定周期的数据。

2、训练GAN的代码

        下面是GAN的训练部分。

# 训练GAN
num_epochs = 100
for epoch in range(num_epochs):
    for real_x,x_g,zz in loader: # 分别是真实值real_x、提示词信息x_g、噪声zz
        real_data = real_x
        noisy_data = x_g
        # Train Discriminator
        optimizer_D.zero_grad()
        out = discriminator(real_data)
        real_loss = criterion(discriminator(real_data), torch.ones(real_data.size(0), 1))
        fake_data = generator(noisy_data,zz)
        fake_loss = criterion(discriminator(fake_data.detach()), torch.zeros(fake_data.size(0), 1))
        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        g_loss = criterion(discriminator(fake_data), torch.ones(fake_data.size(0), 1))
        g_loss.backward()
        optimizer_G.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}')

3、生成器代码

class Generator(nn.Module):
    def __init__(self, seq_len=8, patch_size=2, channels=1, num_classes=9, latent_dim=100, embed_dim=10, depth=1,
                 num_heads=5, forward_drop_rate=0.5, attn_drop_rate=0.5):
        super(Generator, self).__init__()
        self.channels = channels
        self.latent_dim = latent_dim
        self.seq_len = seq_len
        self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.depth = depth
        self.attn_drop_rate = attn_drop_rate
        self.forward_drop_rate = forward_drop_rate
        
        self.l1 = nn.Linear(self.latent_dim, self.seq_len * self.embed_dim)
        self.pos_embed = nn.Parameter(torch.zeros(1, self.seq_len, self.embed_dim))
        self.blocks = Gen_TransformerEncoder(
                         depth=self.depth,
                         emb_size = self.embed_dim,
                         drop_p = self.attn_drop_rate,
                        )

        self.deconv = nn.Sequential(
            nn.Conv2d(self.embed_dim, self.channels, 1, 1, 0)
        )

    def forward(self, z):
        x = self.l1(z).view(-1, self.seq_len, self.embed_dim)
        x = x + self.pos_embed
        H, W = 1, self.seq_len
        x = self.blocks(x)
        x = x.reshape(x.shape[0], 1, x.shape[1], x.shape[2])
        output = self.deconv(x.permute(0, 3, 1, 2))
        output = output.view(-1, self.channels, H, W)
        return output

4、生成数据和真实数据分布对比

        使用PCA和TSNE对生成的时间窗口数据进行降维,然后scatter这些二维点。如果生成的真实数据的互相混合在一起,说明模型学习到了真东西,也就是模型伪造的数据和真实数据分布是一样的,美滋滋。从下面的PCA可以看出,两者的分布还是近似的。

        进一步的,可以拟合两个二维正态分布,然后计算他们的KL散度作为一个评价指标。

5、生成数据展示

        上面是真实数据、下面是伪造的数据。由于只有几百个样本,以及参数都没有进行调整,但是效果还不错。

6、损失函数变化情况

        模型还是学习到了一点东西的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/580909.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Qt | 窗口的显示及可见性|标题、透明度、启用/禁用|窗口标志、设置其他属性|获取窗口部件、设置父部件|鼠标光标

​显示事件:QEvent::show,处理函数为 showEvent(QShowEvent*) 隐藏事件:QEvent::hide,处理函数为 hideEvent(QHideEvent* ) 01 QWidget 类中与可见性有关的属性 visible:bool 访问函数: bool isVisible() const; virtual void setVisible(bool visible); 02 QWid…

同事上班这样摸鱼,我坐边上咋看他都在专心写代码啊

我边上有个同事,我坐他边上,但是每天看着他都眉头紧锁,忙的不亦乐乎,但终于有一天,我发现了他上班摸鱼的秘诀。 我劝你千万不要学会这4招,要不就该不好好上班了。 目录 1 上班看电影? 2 上班…

LeetCode - LCR 179.查找总价格为目标值的两个商品

一. 题目链接 LeetCode - LCR 179. 查找总价格为目标值的两个商品 解法(双指针 - 对撞指针): 算法思路: 注意到本题是升序的数组,因此可以用「对撞指针」优化时间复杂度。 算法流程: 初始化left &#…

算法入门ABC

前言 初学算法时真的觉得这东西晦涩难懂,貌似毫无用处!后来的后来,终于渐渐明白搞懂算法背后的核心思想,能让你写出更加优雅的代码。就像一首歌唱的那样:后来,我总算学会了如何去爱,可惜你早已远…

Hotcoin Academy 市场洞察-2024年4月15日-21日

加密货币市场表现 BTC ETF在本周出现净流出,大盘有较大跌幅,BTC一度跌破60000美金,ETH一度跌破2800美金,整体以横盘为主,行情在周末有略微回升趋势。BTC市占率创21年4月来新高,目前市值1.28万亿&#xff0c…

ElasticSearch教程入门到精通——第六部分(基于ELK技术栈elasticsearch 7.x+8.x新特性)

ElasticSearch教程入门到精通——第六部分(基于ELK技术栈elasticsearch 7.x8.x新特性) 1. Elasticsearch优化1.1 硬件选择1.1 分片策略1.1.1 分片策略——合理设置分片数1.1.2 分片策略——推迟分片分配 1.2 路由选择1.2.1 路由选择——不带routing查询1…

哪款洗地机最好用?2024年四大口碑一流品牌推荐

随着人们生活质量的提升,人们的扫地、拖地都可以用智能清洁工具来高效完成,像洗地机它集合了扫地、拖地、自清洁等功能,让我们摆脱了每次打扫卫生就像打仗一样,忙活半小时下来腰酸背痛的窘境。所以越来越多的家庭纷纷开始用洗地机…

84.柱形图中最大的矩阵

二刷终于能过了. 思路解析: 不愧是hard,第一步就很难想, 对于每一个矩阵,我们要想清楚怎么拿到最大矩阵, 对于每个height[i],我们需要找到left和right,left是i左边第一个小于height[i]的,right是右边第一个小于height[i]的,那么他的最大矩阵就是height[i] * (right-left-…

鸿蒙launcher浅析

鸿蒙launcher浅析 鸿蒙launcher源码下载鸿蒙launcher模块launcher和普通的应用ui展示的区别 鸿蒙launcher源码下载 下载地址如下: https://gitee.com/openharmony/applications_launcher 鸿蒙launcher模块 下载页面已经有相关文件结构的介绍了 使用鸿蒙编辑器D…

国外企业使用生成式人工智能实例100

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

Welcome to nginx!怎么解决?

要解决 "welcome to nginx!" 错误,需要检查虚拟主机配置,启用虚拟主机,重新加载 nginx,如果无法找到虚拟主机配置文件,则创建默认页面并重新加载 nginx,这样错误消息将消失,网站将正常…

数据结构之顺顺顺——顺序表

1.浅谈数据结构 相信我们对数据结构都不陌生,我们之前学过的数组就是最基础的数据结构,它大概就长这样: 数组 而作为最简单的数据结构,数组只能帮助我们实现储存数据这一个功能,随着学习的深入,和问题的日渐…

Qt | 标准、复选、单选、工具、命令按钮大全

01、QPushButton QPushButton 类(标准按钮) 示例 3:默认按钮与自动默认按钮 02、QCheckBox QCheckBox 类(复选按钮) 1、复选按钮的第三状态(见右图 Qt5.10.1 的选中状态):是指除了选中 和未选中状态之外的第三种状态,这种状态用来指示“不变”,表 示用户既不选中也不取…

专栏目录【政安晨的机器学习笔记】

目录 政安晨的个人主页:政安晨 欢迎 👍点赞✍评论⭐收藏 收录专栏: 政安晨的机器学习笔记 希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正! 本篇是作者政安晨的专栏《政安晨的机器学习笔记》的…

Python学习笔记------模块和包

Python模块 简介与作用 Python模块是一个Python文件,以.py结尾,模块能定义函数、类和变量,模块里也包含可执行的代码 模块的作用:Python中有很多各种不同的模块,每个模块都可以帮我们快速实现一些功能,我…

grafana监控模板 regex截取ip地址

查看prometheus的node服务启动指标up,也可以查看其他的服务 配置监控模板 配置正则截取ip regex截取ip地址 /.*instance"([^"]*):9100*/ #提取(instance")开头,(:9001)结束字段

北京车展“第一枪”:长安汽车发布全球首款量产可变新汽车

4月25日,万众瞩目的2024北京国际汽车展览会在中国国际展览中心如期而至。作为中国乃至全球汽车行业的盛宴,本次车展也吸引了无数业内人士的高度关注。 此次北京车展以“新时代 新汽车”为主题,汇聚了1500余家主流车企及零部件制造商&#xff…

Laravel 6 - 第十七章 配置数据库

​ 文章目录 Laravel 6 - 第一章 简介 Laravel 6 - 第二章 项目搭建 Laravel 6 - 第三章 文件夹结构 Laravel 6 - 第四章 生命周期 Laravel 6 - 第五章 控制反转和依赖注入 Laravel 6 - 第六章 服务容器 Laravel 6 - 第七章 服务提供者 Laravel 6 - 第八章 门面 Laravel 6 - …

Kettle 中将图片url转换为Base64

背景 我遇到了一个应用场景需要将订阅kafka数据中的一个字段(图片url)转换为base64 然后进行下一步操作。 实现方式 我这边的实现方式是使用javaScript去实现的 图形化逻辑如下: 这一步就是实现url转换为base64 json input的步骤&#xf…

vulnhub靶场之driftingblues-6

一.环境搭建 1.靶场描述 get flags difficulty: easy about vm: tested and exported from virtualbox. dhcp and nested vtx/amdv enabled. you can contact me by email for troubleshooting or questions. 2.靶场下载 https://www.vulnhub.com/entry/driftingblues-6,6…
最新文章