机器学习算法实战案例:Informer实现多变量负荷预测

文章目录

      • 机器学习算法实战案例系列
      • 答疑&技术交流
      • 1 实验数据集
      • 2 如何运行自己的数据集
      • 3 报错分析

机器学习算法实战案例系列

  • 机器学习算法实战案例:确实可以封神了,时间序列预测算法最全总结!

  • 机器学习算法实战案例:时间序列数据最全的预处理方法总结

  • 机器学习算法实战案例:GRU 实现多变量多步光伏预测

  • 机器学习算法实战案例:LSTM实现单变量滚动风电预测

  • 机器学习算法实战案例:LSTM实现多变量多步负荷预测

  • 机器学习算法实战案例:CNN-LSTM实现多变量多步光伏预测

  • 机器学习算法实战案例:BiLSTM实现多变量多步光伏预测

  • 机器学习算法实战案例:VMD-LSTM实现单变量多步光伏预测

  • 机器学习算法实战案例:VMD-LSTM实现单变量多步光伏预测(升级版)

答疑&技术交流

技术要学会分享、交流,不建议闭门造车。一个人可以走的很快、一堆人可以走的更远。

本文完整代码、相关资料、技术交流&答疑,均可加我们的交流群获取,群友已超过2000人,添加时最好的备注方式为:来源+兴趣方向,方便找到志同道合的朋友。

​方式①、微信搜索公众号:Python学习与数据挖掘,后台回复:加群
方式②、添加微信号:dkl88194,备注:来自CSDN + 技术交流

1 实验数据集

实验数据集采用数据集4:2016年电工数学建模竞赛负荷预测数据集,数据集包含日期、最高温度℃ 、最低温度℃、平均温度℃ 、相对湿度(平均) 、降雨量(mm)、日需求负荷(KWh),时间间隔为1H。

在使用数据之前相对数据进行处理,用其他数据集时也是同样的处理方法。首先读取数据,发数据不是UTF-8格式,通过添加encoding = 'gbk’读取数据,模型传入的数据必须是UTF-8格式

df= pd.read_table('E:\\课题\\08数据集\\2016年电工数学建模竞赛负荷预测数据集\\2016年电工数学建模竞赛负荷预测数据集.txt',encoding = 'gbk')

然后检查数据是否有缺失值:

df.isnull().sum()

发现数据存在少量缺失值,分析数据特点,可以通过前项或后项填充填补缺失值:

df = df.fillna(method='ffill')

后面需要将表格列名改为英文,时间列名为date,不然后面运行时会报错:

df.columns = ["date","max_temperature(℃)","Min_temperature(℃ )","Average_temperature(℃)","Relative_humidity(average)","Rainfall(mm)","Load"]

最后将数据按UTF-8格式保存

load.to_csv('E:\\课题\\08数据集\\2016年电工数学建模竞赛负荷预测数据集\\2016年电工数学建模竞赛负荷预测数据集_处理后.csv', index=False,encoding = 'utf-8')

最后可视化看一下数据:

load.drop(['date'], axis=1, inplace=True)

cols = list(load.columns)

fig = plt.figure(figsize=(16,6))

plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=0.8)

for i in range(len(cols)):

    ax = fig.add_subplot(3,2,i+1)

    ax.plot(load.iloc[:,i])

    ax.set_title(cols[i])

    # plt.subplots_adjust(hspace=1)

2 如何运行自己的数据集

前面两篇文章介绍了论文的原理、代码解析和官方数据集训练和运行,那么大家在利用模型训练自己的数据集的时候需要修改的几处地方。

parser.add_argument('--data', type=str, default='custom', help='data')

parser.add_argument('--root_path', type=str, default='./data/Load/', help='root path of the data file')

parser.add_argument('--data_path', type=str, default='load.csv', help='data file')

parser.add_argument('--features', type=str, default='MS', help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')

parser.add_argument('--target', type=str, default='Load', help='target feature in S or MS task')

parser.add_argument('--freq', type=str, default='h', help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
  • data:必须填写 default=‘custom’,也就是改为自定义的数据
  • root_path:填写数据文件夹路径
  • data_path:填写具体的数据文件名
  • features:前面有讲解,features有三个选项(M,MS,S),分别是多元预测多元,多元预测单元,单元预测单元,具体是看你自己的数据集。
  • target:就是你数据集中你想要知道那列的预测值的列名,这里改为Load
  • freq:就是你两条数据之间的时间间隔。
parser.add_argument('--seq_len', type=int, default=96, help='input sequence length of Informer encoder')

parser.add_argument('--label_len', type=int, default=48, help='start token length of Informer decoder')

parser.add_argument('--pred_len', type=int, default=24, help='prediction sequence length')
  • seq_len:用过去的多少条数据来预测未来的数据
  • label_len:可以裂解为更高的权重占比的部分要小于seq_len
  • pred_len:预测未来多少个时间点的数据
parser.add_argument('--enc_in', type=int, default=6, help='encoder input size')

parser.add_argument('--dec_in', type=int, default=6, help='decoder input size')

parser.add_argument('--c_out', type=int, default=1, help='output size')
  • enc_in:你数据有多少列,要减去时间那一列,这里我是输入8列数据但是有一列是时间所以就填写7
  • dec_in:同上
  • c_out:这里有一些不同如果你的features填写的是M那么和上面就一样,如果填写的MS那么这里要输入1因为你的输出只有一列数据。
# 字典data_parser中包含了不同数据集的信息,键值为数据集名称('ETTh1'等),对应一个包含.csv数据文件名

    'ETTh1':{'data':'ETTh1.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},

    'ETTh2':{'data':'ETTh2.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},

    'ETTm1':{'data':'ETTm1.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},

    'ETTm2':{'data':'ETTm2.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},

    'WTH':{'data':'WTH.csv','T':'WetBulbCelsius','M':[12,12,12],'S':[1,1,1],'MS':[12,12,1]},

    'ECL':{'data':'ECL.csv','T':'MT_320','M':[321,321,321],'S':[1,1,1],'MS':[321,321,1]},

    'Solar':{'data':'solar_AL.csv','T':'POWER_136','M':[137,137,137],'S':[1,1,1],'MS':[137,137,1]},

    'Custom':{'data':'load.csv','T':'Load','M':[137,137,137],'S':[1,1,1],'MS':[6,6,1]},

预测结果保存在result文件下,保存格式为numpy,可以通过下面的脚本进行可视化预测结果:

import matplotlib.pyplot as plt

file_path1 = "results/informer_ETTh1_ftM_sl96_ll48_pl24_dm512_nh8_el2_dl1_df2048_atprob_fc5_ebtimeF_dtTrue_mxTrue_test_0/true.npy"

file_path2 = "results/informer_ETTh1_ftM_sl96_ll48_pl24_dm512_nh8_el2_dl1_df2048_atprob_fc5_ebtimeF_dtTrue_mxTrue_test_1/pred.npy"

data1 = np.load(file_path1)

data2 = np.load(file_path2)

    true_value.append(data2[0][i][6])

    pred_value.append(data1[0][i][6])

df = pd.DataFrame({'real': true_value, 'pred': pred_value})

df.to_csv('results.csv', index=False)

fig = plt.figure(figsize=( 16, 8))

plt.plot(df['real'], marker='o', markersize=8)

plt.plot(df['pred'], marker='o', markersize=8)

plt.tick_params(labelsize = 28)

plt.legend(['real','pred'],fontsize=28)

最后预测的效果如下,发现并不是太好,后续看参数调优后是否能提升模型预测效果。

3 报错分析

报错1:UnicodeDecodeError: ‘utf-8’ codec can’t decode bytes in position 56-57: invalid continuation byte,具体来说,‘utf-8’ 编解码器无法解码文件中的某些字节,因为它们不符合 UTF-8 编码的规则。

  File "D:\Progeam Files\python\lib\site-packages\pandas\io\parsers\c_parser_wrapper.py", line 93, in __init__

    self._reader = parsers.TextReader(src, **kwds)

  File "pandas\_libs\parsers.pyx", line 548, in pandas._libs.parsers.TextReader.__cinit__

  File "pandas\_libs\parsers.pyx", line 637, in pandas._libs.parsers.TextReader._get_header

  File "pandas\_libs\parsers.pyx", line 848, in pandas._libs.parsers.TextReader._tokenize_rows

  File "pandas\_libs\parsers.pyx", line 859, in pandas._libs.parsers.TextReader._check_tokenize_status

  File "pandas\_libs\parsers.pyx", line 2017, in pandas._libs.parsers.raise_parser_error

UnicodeDecodeError: 'utf-8' codec can't decode bytes in position 56-57: invalid continuation byte

解决办法:

(1) 根据提示,要将数据更改’utf-8’格式,最简便的方法将数据用记事本打开,另存为是通过UTF-8格式保存

(2) 尝试使用其他编解码器(如 ‘latin1’)来读取文件,或者在读取文件时指定正确的编码格式。

**报错2:ValueError: list.remove(x): x not in list,**试从列表中删除两个元素,但是这两个元素中至少有一个不在列表中。

File "E:\课题\07代码\Informer2020-main\Informer2020-main\data\data_loader.py", line 241, in __read_data__

cols = list(df_raw.columns); cols.remove(self.target); cols.remove('date')

ValueError: list.remove(x): x not in list

解决办法:在没有找到具体原因的时候可以在删除元素之前先检查一下列表中是否包含要删除的元素,或者使用 try-except 语句来捕获异常,以便在元素不存在时不会导致程序中断。通过检查,数据中的列名最好改为英文,避免产生乱码。

    cols=self.cols.copy()

    cols.remove(self.target)

    cols = list(df_raw.columns)

    print(cols)  # 输出列的内容

    if self.target in cols:

        cols.remove(self.target)

        print(f"{self.target} not in columns")

        cols.remove('date')

        print("date not in columns")

    cols = list(df_raw.columns); cols.remove(self.target); cols.remove('date')

df_raw = df_raw[['date']+cols+[self.target]]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/328188.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux shell编程学习笔记39:df命令

0 前言1 df命令的功能、格式和选项说明 1.1 df命令的功能1.2 df命令的格式1.3 df命令选项说明 2 df命令使用实例 2.1 df:显示主要文件系统信息2.2 df -a:显示所有文件系统信息2.3 df -t[]TYPE或--type[]TYPE:显示TYPE指定类型的文件系统信…

AIGC实战——像素卷积神经网络(PixelCNN)

AIGC实战——像素卷积神经网络 0. 前言1. PixelCNN 工作原理1.1 掩码卷积层 1.2 残差块2. 训练 PixelCNN3. PixelCNN 分析4. 使用混合分布改进 PixelCNN小结系列链接 0. 前言 像素卷积神经网络 (Pixel Convolutional Neural Network, PixelCNN) 是于 2016 年提出的一种图像生成…

LINUX基础培训九之网络管理

前言、本章学习目标 了解LINUX网络接口和主机名配置的方法熟悉网络相关的几个配置文件了解网关和路由熟悉网络相关的命令使用 一、网络IP地址配置 在Linux中配置IP地址的方法有以下这么几种: 1、图形界面配置IP地址(操作方式如Windows系统配置IP&…

机器学习:线性回归模型的原理、应用及优缺点

一、原理 线性回归是一种统计学和机器学习中常用的方法,用于建立变量之间线性关系的模型。其原理基于假设因变量(或响应变量)与自变量之间存在线性关系。 下面是线性回归模型的基本原理: 模型拟合: 通过最小二乘法&…

1、机器学习模型的工作方式

第一步,如果你是机器学习新手。 本课程所需数据集夸克网盘下载链接:https://pan.quark.cn/s/9b4e9a1246b2 提取码:uDzP 文章目录 1、简介2、决策树优化3、继续1、简介 我们将从机器学习模型如何工作以及如何使用它们的概述开始。如果你以前做过统计建模或机器学习,这可能感…

【Web】CTFSHOW 文件上传刷题记录(全)

期末考完终于可以好好学ctf了,先把这些该回顾的回顾完,直接rushjava! 目录 web151 web152 web153 web154-155 web156-159 web160 web161 web162-163 web164 web165 web166 web167 web168 web169-170 web151 如果直接上传php文…

生物制药厂污水处理需要哪些工艺设备

生物制药厂是一种特殊的工业场所,由于其生产过程中涉及的有机物较多,导致废水中含有高浓度的有机物和微生物等污染物,因此需要采用一些特殊的工艺设备来进行污水处理。本文将介绍生物制药厂污水处理中常用的工艺设备。 首先,对于生…

Java NIO (二)NIO Buffer类的重要方法(备份)

1 allocate()方法 在使用Buffer实例前,我们需要先获取Buffer子类的实例对象,并且分配内存空间。需要获取一个Buffer实例对象时,并不是使用子类的构造器来创建,而是调用子类的allocate()方法。 public class AllocateTest {static…

【FastAPI】路径参数(二)

预设值 如果你有一个接收路径参数的路径操作,但你希望预先设定可能的有效参数值,则可以使用标准的 Python Enum 类型。 导入 Enum 并创建一个继承自 str 和 Enum 的子类。通过从 str 继承,API 文档将能够知道这些值必须为 string 类型并且能…

PromptCast-时间序列预测的好文推荐

前言 这是关于大语言模型和时间序列预测结合的好文推荐,发现这篇文章,不仅idea不错和代码开源维护的不错,论文也比较详细(可能是顶刊而不是顶会,篇幅大,容易写清楚),并且关于它的Br…

STM32+HAL库驱动ADXL345传感器(SPI协议)

STM32HAL库驱动ADXL345传感器(SPI协议) ADXL345传感器简介实物STM32CubeMX配置SPI配置片选引脚配置串口配置 特别注意(重点部分)核心代码效果展示 ADXL345传感器简介 ADXL345 是 ADI 公司推出的基于 iMEMS 技术的 3 轴、数字输出加…

Spring Security- 基于角色的访问控制

基于角色 或权限 进行访问控制 hasAuthority方法 如果当前的主体具有指定的权限,则返回true,否则返回false 修改配置类 //当前登录用户 只有具备admins权限才可以访问这个路径.antMatchers("/test/index").hasAuthority("admins") 代码如下: package c…

达芬奇调色软件DaVinci Resolve Studio 18 中文激活版

DaVinci Resolve Studio 18是一款功能强大的视频编辑软件,它可以帮助用户轻松完成视频剪辑、调色、音频处理和特效合成等任务。 软件下载:DaVinci Resolve Studio 18 中文激活版下载 这款软件具有友好的用户界面和易于使用的功能,使得用户能够…

云服务器CVM_云主机_云计算服务器_弹性云服务器

腾讯云服务器CVM提供安全可靠的弹性计算服务,腾讯云明星级云服务器,弹性计算实时扩展或缩减计算资源,支持包年包月、按量计费和竞价实例计费模式,CVM提供多种CPU、内存、硬盘和带宽可以灵活调整的实例规格,提供9个9的数…

如何安装“MySQL在虚拟机ubuntu”win10系统?

1、 更新列表 sudo apt-get update 2、 安装MySQL服务器 sudo apt-get install mysql-server 3、 安装MySQL客户端 sudo apt-get install mysql-client 4、 配置MySQL sudo mysql_secure_installation 5、 测试MySQL systemctl status mysql.service MySQL数据库基本…

transbigdata笔记:轨迹停止点和行程提取

1 traj_stay_move——标识停靠点和行程 1.1 方法介绍 如果两个连续轨迹数据点(栅格化处理之后)之间的持续时间超过设定的阈值,将其视为停靠点。两个停靠点之间的时间段被视为一个行程 1.2 使用方法 transbigdata.traj_stay_move(data, pa…

从零开始搭建ubuntu 16.04 pwndocker环境

1.安装VMware-tools 1.1遇到问题 在使用 VMware Workstation时遇到了VMware Tools不能安装的问题,具体表现为:在要安装VMware Tools的虚拟机上右键 ----》安装VMware Tools(T)… 为灰色,不能够点击。 1.2解决方案    1. 关闭虚拟机&…

设计Twitter时间线和搜索功能

设计Twitter时间线和搜索功能 设计 facebook feed 和 设计 facebook search是相同的问题 第一步:定义用例和约束 定义问题的需求和范围,询问问题去声明用例和约束,讨论假设 ps: 没有一个面试官会展示详细的问题,我们需要定义一些用…

【软件测试】学习笔记-测试基础架构

这篇文章探讨什么是测试基础架构。 什么是测试基础架构? 测试基础架构指的是,执行测试的过程中用到的所有基础硬件设施以及相关的软件设施。因此,我们也把测试基础架构称之为广义的测试执行环境。通常来讲,测试基础 架构主要包括…

Leetcode23-数组能形成多少数对(2341)

1、题目 给你一个下标从 0 开始的整数数组 nums 。在一步操作中,你可以执行以下步骤: 从 nums 选出 两个 相等的 整数 从 nums 中移除这两个整数,形成一个 数对 请你在 nums 上多次执行此操作直到无法继续执行。 返回一个下标从 0 开始、长…
最新文章