时间序列预测实战(十二)DLinear模型实现滚动长期预测并可视化预测结果

官方论文地址->官方论文地址

官方代码地址->官方代码地址

个人修改代码->个人修改的代码已经上传CSDN免费下载

一、本文介绍

本文给大家带来是DLinear模型,DLinear是一种用于时间序列预测(TSF)的简单架构,DLinear的核心思想是将时间序列分解为趋势和剩余序列,并分别使用两个单层线性网络对这两个序列进行建模以进行预测(值得一提的是DLinear的出现是为了挑战Transformer在实现序列预测中有效性)本文的讲解内容包括:模型原理、数据集介绍、参数讲解、模型训练和预测、结果可视化、训练个人数据集,讲解顺序如下->

预测类型->单元预测、多元预测

适用对象->如果你的配置不是很好这个模型应该很适合你因为参数量很小训练速度很快

二、模型原理

DLinear模型出现是为了调整Transformer的有效性从而存在,Transformer的设计都十分的复杂和需要大量的参数,所以作者提出了一种简单的结构DLinear(参数量我实验过程中确实非常小)

DLinear的核心思想是将时间序列分解为趋势和剩余序列,并分别使用两个单层线性网络对这两个序列进行建模以进行预测。

具体地,DLinear如何工作的关键点如下

  1. 时间序列分解:DLinear将输入的时间序列分解为两部分——趋势部分和剩余部分。这种分解有助于分别处理时间序列中的长期趋势和短期波动。

  2. 单层线性网络:对于趋势和剩余序列,DLinear分别使用两个单层的线性网络进行建模。这种简单的架构使得DLinear在处理时间序列时既高效又有效。

  3. 预测任务:在进行预测时,DLinear结合这两个网络的输出来生成最终的时间序列预测。

总结->可以看出DLinear的核心结构真的十分简单就包括一个分解和两个线性网络进行建模最后经过一个简单的相加就输出了结果。

模型的网络结构图如下所示->

图片分析->可以看到和我们上面讲的一样,数据从输入进来经过两个分支,一个为趋势性一个为剩余序列,然后分别经过一个线性层处理(这里的提到的线性层就是普通的全连接层),然后将结果进行简单的拼接就完成了结果的输出(这就这样的简单模型结果比过程十分复杂的Transformer模型效果要好->我自己实验效果确实要好,我拿2020年的bestpaper和普通的Transformer都进行了对比效果确实要有提升)。

下面的图片是一个简单的线性层(普通的全连接层)提取数据的过程图->

这里把模型的代码结构放出来方便大家根据讲解和代码进行对比。

class moving_avg(nn.Module):
    """
    Moving average block to highlight the trend of time series
    """
    def __init__(self, kernel_size, stride):
        super(moving_avg, self).__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)

    def forward(self, x):
        # padding on the both ends of time series
        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        x = torch.cat([front, x, end], dim=1)
        x = self.avg(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1)
        
        return x


class series_decomp(nn.Module):
    """
    Series decomposition block
    """
    def __init__(self, kernel_size):
        super(series_decomp, self).__init__()
        self.moving_avg = moving_avg(kernel_size, stride=1)

    def forward(self, x):
        moving_mean = self.moving_avg(x)
        res = x - moving_mean
        return res, moving_mean

class Model(nn.Module):
    """
    Decomposition-Linear
    """
    def __init__(self, configs):
        super(Model, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len

        # Decompsition Kernel Size
        kernel_size = 25
        self.decompsition = series_decomp(kernel_size)
        self.individual = configs.individual
        self.channels = configs.enc_in

        if self.individual:
            self.Linear_Seasonal = nn.ModuleList()
            self.Linear_Trend = nn.ModuleList()
            
            for i in range(self.channels):
                self.Linear_Seasonal.append(nn.Linear(self.seq_len,self.pred_len))
                self.Linear_Trend.append(nn.Linear(self.seq_len,self.pred_len))

                # Use this two lines if you want to visualize the weights
                # self.Linear_Seasonal[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
                # self.Linear_Trend[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
        else:
            self.Linear_Seasonal = nn.Linear(self.seq_len,self.pred_len)
            self.Linear_Trend = nn.Linear(self.seq_len,self.pred_len)
            
            # Use this two lines if you want to visualize the weights
            # self.Linear_Seasonal.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
            # self.Linear_Trend.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))

    def forward(self, x):
        # x: [Batch, Input length, Channel]
        seasonal_init, trend_init = self.decompsition(x)
        seasonal_init, trend_init = seasonal_init.permute(0,2,1), trend_init.permute(0,2,1)
        if self.individual:
            seasonal_output = torch.zeros([seasonal_init.size(0),seasonal_init.size(1),self.pred_len],dtype=seasonal_init.dtype).to(seasonal_init.device)
            trend_output = torch.zeros([trend_init.size(0),trend_init.size(1),self.pred_len],dtype=trend_init.dtype).to(trend_init.device)
            for i in range(self.channels):
                seasonal_output[:,i,:] = self.Linear_Seasonal[i](seasonal_init[:,i,:])
                trend_output[:,i,:] = self.Linear_Trend[i](trend_init[:,i,:])
        else:
            seasonal_output = self.Linear_Seasonal(seasonal_init)
            trend_output = self.Linear_Trend(trend_init)

        x = seasonal_output + trend_output
        return x.permute(0,2,1) # to [Batch, Output length, Channel]

我看论文的内容大比分都是对比实验,因为DLinear的产生就是为了质疑Transformer所以他和各种Transformer的模型进行对比试验,因为本篇文章就是DLinear的实战案例,对比的部分我就不讲了,大家有兴趣可以看看论文内容在最上面我已经提供了链接。 

三、数据集介绍

所用到的数据集为某公司的业务水平评估和其它参数具体的内容我就介绍了估计大家都是想用自己的数据进行训练模型,这里展示部分图片给大家提供参考。

四、参数讲解

模型的参数如下(大部分都是一些公共参数并不涉及模型)->

parser = argparse.ArgumentParser(description='DLinearNet Multivariate Time Series Forecasting')
    # basic config
    parser.add_argument('--train', type=bool, default=True, help='Whether to conduct training')
    parser.add_argument('--rollingforecast', type=bool, default=True, help='rolling forecast True or False')
    parser.add_argument('--rolling_data_path', type=str, default='ETTh1-Test.csv', help='rolling data file')
    parser.add_argument('--show_results', type=bool, default=True, help='Whether show forecast and real results graph')
    parser.add_argument('--model', type=str, default='SCINet',help='Model name')

    # data loader
    parser.add_argument('--root_path', type=str, default='./data/', help='root path of the data file')
    parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
    parser.add_argument('--features', type=str, default='MS',
                        help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
    parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
    parser.add_argument('--freq', type=str, default='h',
                        help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
    parser.add_argument('--checkpoints', type=str, default='./models/', help='location of model models')

    # forecasting task
    parser.add_argument('--seq_len', type=int, default=126, help='input sequence length')
    parser.add_argument('--label_len', type=int, default=64, help='start token length')
    parser.add_argument('--pred_len', type=int, default=4, help='prediction sequence length')

    # model
    parser.add_argument('--individual', action='store_true', default=False,
                        help='DLinear: a linear layer for each variate(channel) individually')
    parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
    parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
    parser.add_argument('--c_out', type=int, default=1, help='output size')
    parser.add_argument('--dropout', type=float, default=0.05, help='dropout')
    parser.add_argument('--embed', type=str, default='timeF',
                        help='time features encoding, options:[timeF, fixed, learned]')
    parser.add_argument('--activation', type=str, default='gelu', help='activation')

    # optimization
    parser.add_argument('--num_workers', type=int, default=0, help='data loader num workers')
    parser.add_argument('--train_epochs', type=int, default=10, help='train epochs')
    parser.add_argument('--batch_size', type=int, default=16, help='batch size of train input data')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='optimizer learning rate')
    parser.add_argument('--loss', type=str, default='mse', help='loss function')
    parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate')

    # GPU
    parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
    parser.add_argument('--device', type=int, default=0, help='gpu')

模型的详细参数讲解如下(如果你想训练你自己的数据集可以仔细看看)->

参数名称参数类型参数讲解
0trainbool是否进行训练,如果你单纯只想进行预测设置为False即可,
1rollingforecastbool是否进行滚动预测,如果是则设置为True,如果不进行滚动预测则进行正常的预测
2rolling-data-pathstr如果进行滚动预测则需要添加新的和训练文件相同格式的数据
3show_resultsbool是否保存预测值和真实值的对比
4modelstr定义的模型名称
5root_pathstr这个才是你文件的路径,不要到具体的文件,到目录级别即可。
6data_pathstr这个填写你文件的具体名称。
7featuresstr这个是特征有三个选项M,MS,S。分别是多元预测多元,多元预测单元,单元预测单元。
8targetstr这个是你数据集中你想要预测那一列数据,假设我预测的是油温OT列就输入OT即可。
9freqstr时间的间隔,你数据集每一条数据之间的时间间隔。
10checkpointsstr训练出来的模型保存路径
11seq_lenint用过去的多少条数据来预测未来的数据
12label_lenint可以理解为更高的权重占比的部分要小于seq_len
13pred_lenint预测未来多少个时间点的数据
14enc_inint你数据有多少列,要减去时间那一列,这里我是输入8列数据但是有一列是时间所以就填写7
15dec_inint同上
16individualbool这个就是我们上面提到的两个线性层,如果为True我们则对每一个通道用单独的线性层处理,False则为所有的通道用一个线性层
17c_outint这里有一些不同如果你的features填写的是M那么和上面就一样,如果填写的MS那么这里要输入1因为你的输出只有一列数据。
18dropoutfloat这个应该都理解不说了,丢弃的概率,防止过拟合的。
19embedstr时间特征的编码方式,默认为"timeF"
20activationstr激活函数
21num_workersint线程windows大家最好设置成0否则会报线程错误,linux系统随便设置。
22train_epochsint训练的次数
23batch_sizeint一次往模型力输入多少条数据
24learning_ratefloat学习率。
25lossstr     损失函数,默认为"mse"
26lradjstr     学习率的调整方式,默认为"type1"
27use_gpubool是否使用GPU训练,根据自身来选择
28gpuintGPU的编号

五、模型训练和预测

1.项目目录结构

项目的目录构造如下->

其中data为训练用的数据放的地方,layers为模型结构存放的地方,models为训练保存的训练模型,results为可视化结果保存的图片和滚动预测的结果,util为一些工具。 

2.模型训练

当我们经过上面的参数讲解之后,我们可以开始训练模型了,控制台输出如下->

3.滚动预测 

这里进行滚动预测的控制台输出->

4.结果展示 

运行结果后,结果保存到同级目录下(下图为预测值和真实值的对比)-> 

5.结果分析 

可以看到预测值和真实值之间的差距还可以,但是这个模型的参数量少得可怜,不得不得质疑Transformer模型的有效性~

六、训练你个人数据集

这个模型我在写的过程中为了节省大家训练自己数据集,我基本上把大部分的参数都写好了,需要大家注意的就是如果要进行滚动预测下面的参数要设置为True。

    parser.add_argument('--rollingforecast', type=bool, default=True, help='rolling forecast True or False')

如果上面的参数设置为True那么下面就要提供一个进行滚动预测的数据集该数据集的格式要和你训练模型的数据集格式完全一致(重要!!!),如果没有可以考虑在自己数据的尾部剪切一部分,不要粘贴否则数据模型已经训练过了的话预测就没有效果了。 

    parser.add_argument('--rolling_data_path', type=str, default='ETTh1-Test.csv', help='rolling data file')

其它的没什么可以讲的了大部分的修改操作在参数讲解的部分我都详细讲过了,这里的滚动预测可能是大家想看的所以摘出来详细讲讲。 

总结

到此本文已经全部讲解完成了,希望能够帮助到大家,在这里也给大家推荐一些我其它的博客的时间序列实战案例讲解,其中有数据分析的讲解就是我前面提到的如何设置参数的分析博客,最后希望大家订阅我的专栏,本专栏均分文章均分98,并且免费阅读。

概念理解 

15种时间序列预测方法总结(包含多种方法代码实现)

数据分析

时间序列预测中的数据分析->周期性、相关性、滞后性、趋势性、离群值等特性的分析方法

机器学习——难度等级(⭐⭐)

时间序列预测实战(四)(Xgboost)(Python)(机器学习)图解机制原理实现时间序列预测和分类(附一键运行代码资源下载和代码讲解)

深度学习——难度等级(⭐⭐⭐⭐)

时间序列预测实战(五)基于Bi-LSTM横向搭配LSTM进行回归问题解决

时间序列预测实战(七)(TPA-LSTM)结合TPA注意力机制的LSTM实现多元预测

时间序列预测实战(三)(LSTM)(Python)(深度学习)时间序列预测(包括运行代码以及代码讲解)

时间序列预测实战(十一)用SCINet实现滚动预测功能(附代码+数据集+原理介绍)

Transformer——难度等级(⭐⭐⭐⭐)

时间序列预测模型实战案例(八)(Informer)个人数据集、详细参数、代码实战讲解

时间序列预测模型实战案例(一)深度学习华为MTS-Mixers模型

个人创新模型——难度等级(⭐⭐⭐⭐⭐)

时间序列预测实战(十)(CNN-GRU-LSTM)通过堆叠CNN、GRU、LSTM实现多元预测和单元预测

传统的时间序列预测模型(⭐⭐)

时间序列预测实战(二)(Holt-Winter)(Python)结合K-折交叉验证进行时间序列预测实现企业级预测精度(包括运行代码以及代码讲解)

时间序列预测实战(六)深入理解ARIMA包括差分和相关性分析

融合模型——难度等级(⭐⭐⭐)

时间序列预测实战(九)PyTorch实现融合移动平均和LSTM-ARIMA进行长期预测

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/134474.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

uni-app点击按钮弹出提示框-uni.showModal(OBJECT),选择确定和取消

参考文档: https://uniapp.dcloud.io/api/ui/prompt?idshowmodal 显示模态弹窗,可以只有一个确定按钮,也可以同时有确定和取消按钮。类似于一个API整合了 html 中:alert、confirm。 uni.showModal({title: 提示,content: 这是一…

【计算机网络笔记】IP分片

系列文章目录 什么是计算机网络? 什么是网络协议? 计算机网络的结构 数据交换之电路交换 数据交换之报文交换和分组交换 分组交换 vs 电路交换 计算机网络性能(1)——速率、带宽、延迟 计算机网络性能(2)…

openssl研发之base64编解码实例

一、base64编码介绍 Base64编码是一种将二进制数据转换成ASCII字符的编码方式。它主要用于在文本协议中传输二进制数据,例如电子邮件的附件、XML文档、JSON数据等。 Base64编码的特点如下: 字符集: Base64编码使用64个字符来表示二进制数据…

Leetcode刷题详解—— 目标和

1. 题目链接:494. 目标和 2. 题目描述: 给你一个非负整数数组 nums 和一个整数 target 。 向数组中的每个整数前添加 或 - ,然后串联起所有整数,可以构造一个 表达式 : 例如,nums [2, 1] ,可…

出电子书了!

熟悉小灰的小伙伴们都知道,小灰曾经创作了三本算法有关的图书,分别是《漫画算法》、《漫画算法Python篇》、《漫画算法2》。 如今,这三本书在全网的销量超过10W册,可以说是IT领域最畅销的图书之一。 小灰的这三本算法书&#xff0…

Linux系统编程,Linux中的文件读写文件描述符

文章目录 Linux系统编程,Linux中的文件读写操作1.open函数,打开文件 Linux系统编程,Linux中的文件读写操作 1.open函数,打开文件 我们来看下常用的open函数 这个函数最终返回一个文件描述符struct file 我们查看一下它的Ubuntu…

数据结构之双向链表

目录 引言 链表的分类 双向链表的结构 双向链表的实现 定义 创建新节点 初始化 打印 尾插 头插 判断链表是否为空 尾删 头删 查找与修改 指定插入 指定删除 销毁 顺序表和双向链表的优缺点分析 源代码 dlist.h dlist.c test.c 引言 数据结构…

LeetCode【923】三数之和的多种可能性

题目&#xff1a; 思路&#xff1a; https://www.jianshu.com/p/544cbb422300 代码&#xff1a; int threeSumMulti(vector<int>& A, int target) {//Leetcode923:三数之和的多钟可能//initialize some constint kMod 1e9 7;int kMax 100;//calculate frequenc…

图论12-无向带权图及实现

文章目录 带权图1.1带权图的实现1.2 完整代码 带权图 1.1带权图的实现 在无向无权图的基础上&#xff0c;增加边的权。 使用TreeMap存储边的权重。 遍历输入文件&#xff0c;创建TreeMap adj存储每个节点。每个输入的adj节点链接新的TreeMap&#xff0c;存储相邻的边和权重 …

138.随机链表的复制(LeetCode)

深拷贝&#xff0c;是指将该链表除了正常单链表的数值和next指针拷贝&#xff0c;再将random指针进行拷贝 想法一 先拷贝出一份链表&#xff0c;再对于每个节点的random指针&#xff0c;在原链表进行遍历&#xff0c;找到random指针的指向&#xff0c;最后完成拷贝链表random…

Java自学第10课:JavaBean和servlet基础

目录 目录 1 JavaBean &#xff08;1&#xff09;概念 &#xff08;2&#xff09;分类 &#xff08;3&#xff09;使用 2 servlet &#xff08;1&#xff09;代码结构 &#xff08;2&#xff09;常用接口 &#xff08;3&#xff09;如何开发 1 新建servlet 2 配置 1…

19 异步通知

一、异步通知 1. 异步通知简介 阻塞和非阻塞两种方式都是需要应用程序去主动查询设备的使用情况。 异步通知类似于驱动可以主动报告自己可以访问&#xff0c;应用程序获取信号后会从驱动设备中读取或写入数据。 异步通知最核心的就是信号&#xff1a; #define SIGHUP 1 /* 终…

C详细的字符串函数

但行前路&#xff0c;莫问归期 要注意的是&#xff0c;要使用下边所讲的函数要包含头文件<string.h> 文章目录 strlenstrcpystrncpy strcatstrncat strcmpstrncmp strstrstrtokstrerror字符串大小写转换struprstrlwr memcpymemmovememcmp strlen 求字符串的长度 函数参…

Typescript -尚硅谷

基础 1.ts是以js为基础构建的语言&#xff0c;是一个js的超集(对js进行了扩展)&#xff1b; 2.ts(type)最主要的功能是在js的基础上引入了类型的概念; Js的类型是只针对于值而言&#xff0c;ts的类型是针对于变量而言 Ts可以被编译成任意版本的js&#xff0c;从而进一步解决了…

MySQL Command Line Client 运行闪退问题解决,缺少my.ini文件

MySQL Command Line Client 运行闪退问题解决&#xff1a; 问题排查&#xff1a; 1.找到Command Line Client的路径位置&#xff0c;并查看属性&#xff0c;步骤截图&#xff1a; 查看属性&#xff1a; 查看属性中的目标路径&#xff1a; 2.进入属性中的目标路径&#xff0c;…

基于SSM+Vue的电子商城的设计与实现

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;Vue 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#xff1a;是 目录…

记录--vue3 setup 中国省市区三级联动options最简洁写法,无需任何库

这里给大家分享我在网上总结出来的一些知识&#xff0c;希望对大家有所帮助 在写页面的时候&#xff0c;发现表单里面有一个省市区的 options 组件要写&#xff0c;因为表单很多地方都会用到这个地址选择&#xff0c;我便以为很简单嘛。 虽然很简单的一个功能&#xff0c;但是网…

C#中的扩展方法---Extension

C#中扩展方法是C# 3.0/.NET 3.x 新增特性&#xff0c;能够实现向现有类型中“添加”方法&#xff0c;以下主要介绍C#中扩展方法的声明及使用。 1、扩展方法的声明 扩展方法使能够向现有类型“添加”方法&#xff0c;而无需创建新的派生类型、重新编译或以其他方式修改原始类型…

安全通信网络(设备和技术注解)

网络安全等级保护相关标准参考《GB/T 22239-2019 网络安全等级保护基本要求》和《GB/T 28448-2019 网络安全等级保护测评要求》 密码应用安全性相关标准参考《GB/T 39786-2021 信息系统密码应用基本要求》和《GM/T 0115-2021 信息系统密码应用测评要求》 1网络架构 1.1保证网络…

开发知识点-Python

Python从小白到入土 python渗透测试安全工具开发锦集Python安全工具编程基础第一章 Python在网络安全中的应用第一节 Python黑客领域的现状第二节 我们可以用Python做什么第三节 第一章课程内容总结 第二章 python安全应用编程入门第一节 Python正则表达式第二节 Python Web编程…
最新文章