【算法】长短期记忆网络(LSTM,Long Short-Term Memory)

这是一种特殊的循环神经网络,能够学习数据中的长期依赖关系,这是因为模型的循环模块具有相互交互的四个层的组合,它可以记忆不定时间长度的数值,区块中有一个gate能够决定input是否重要到能被记住及能不能被输出output。

原理

黄色方框内是四个神经网络层,红色圆圈是逐点算子,橙色圆圈是输入,蓝色圆圈是细胞状态。LSTM具有一个单元状态和三个门,对应选择有选择地学习、取消学习或保留来自每个单元的信息的能力。

LSTM中的单元状态通过只允许一些线性交互来帮助信息流过单元而不被改变。

每个单元都有一个输入、输出和一个遗忘门,可以将信息添加或者删除到单元状态。

在这里插入图片描述

遗忘门:使用sigmoid函数决定应该忘记来自先前单元状态的哪些信息。

输入门:分别使用sigmoid和tanh的逐点乘法运算控制信息流到当前单元状态。

输出门:最后,输出门决定哪些信息应该传递到下一个隐藏状态。

要在python中使用lstm模型,需要安装这些库:

pip install tansorflow pandas numpy matplotlib

# pandas用来数据处理
# numpy用来数值计算
# matplotlib.pyplot用于数据可视化
# MinMaxScaler从sklearn.preprocessing用于数据规范化
# Sequential,LSTM,Dense从tensorflow.keras用于构建神经网络
# mean_squared_error从sklearn.metrics用于计算模型误差

实现

  1. 生成示例数据:简单的正弦波形,
  2. 设置随机数生成的种子,确保结果可以复现,
  3. 生成一系列时间步长,
  4. 创建数据,结合正弦波和随机噪声
  5. 数据转换为DataFrame
  6. 使用Pandas的DataFrame来存储和处理生成的数据,
  7. 数据规范化:使用MinMaxScaler将数据规范化到0和1之间,这对神经网络的性能至关重要。
  8. 分割数据为训练集和测试集:确定训练集的大小(数据的80%),剩余的20%s数据作为测试集。
  9. 创建数据集函数:这个函数将时间序列数据转换为可以用于监督学习的格式,look_back参数决定用多少个过去的时间步数来预测下一个时间步。
  10. 设置look_back,并创建训练/测试数据;
  11. 使用1作为look_back的值
  12. 重塑输入数据为[样本,时间步,特征]
  13. LSTM模型在keras中需要三维输入
  14. 创建LSTM模型:创建一个Sequential模型,添加一个含有50个神经元的LSTM层,添加一个Dense层作为输出层,编译模型,使用均方误差作为损失函数和Adam优化器。

注:epoch是指训练周期。

代码如下:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from sklearn.metrics import mean_squared_error

# 生成示例数据:正弦波 + 随机噪声
np.random.seed(0)
timesteps = np.arange(0, 1000, 0.1)
data = np.sin(timesteps) + np.random.normal(scale=0.5, size=len(timesteps))

# 数据转换为DataFrame
df = pd.DataFrame(data, columns=['value'])
values = df['value'].values

# 数据规范化
scaler = MinMaxScaler(feature_range=(0, 1))
values_scaled = scaler.fit_transform(values.reshape(-1, 1))

# 分割数据为训练集和测试集
train_size = int(len(values_scaled) * 0.8)
test_size = len(values_scaled) - train_size
train, test = values_scaled[0:train_size, :], values_scaled[train_size:len(values_scaled), :]

# 创建数据集
def create_dataset(dataset, look_back=1):
    X, Y = [], []
    for i in range(len(dataset) - look_back - 1):
        a = dataset[i:(i + look_back), 0]
        X.append(a)
        Y.append(dataset[i + look_back, 0])
    return np.array(X), np.array(Y)

look_back = 1
X_train, Y_train = create_dataset(train, look_back)
X_test, Y_test = create_dataset(test, look_back)

# 重塑输入数据为 [样本, 时间步, 特征]
X_train = np.reshape(X_train, (X_train.shape[0], 1, X_train.shape[1]))
X_test = np.reshape(X_test, (X_test.shape[0], 1, X_test.shape[1]))

# 创建LSTM模型
model = Sequential()
model.add(LSTM(50, input_shape=(1, look_back)))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')

# 训练模型
model.fit(X_train, Y_train, epochs=5, batch_size=1, verbose=2)

# 进行预测
train_predict = model.predict(X_train)
test_predict = model.predict(X_test)

# 反转规范化
train_predict = scaler.inverse_transform(train_predict)
Y_train = scaler.inverse_transform([Y_train])
test_predict = scaler.inverse_transform(test_predict)
Y_test = scaler.inverse_transform([Y_test])

# 计算均方误差
train_score = np.sqrt(mean_squared_error(Y_train[0], train_predict[:,0]))
test_score = np.sqrt(mean_squared_error(Y_test[0], test_predict[:,0]))

# 可视化
plt.figure(figsize=(12, 6))
plt.plot(scaler.inverse_transform(values_scaled), label='Original Data')
plt.plot(np.append(np.zeros(train_size), train_predict[:,0]), linestyle='--', label='Training Predict')
plt.plot(np.append(np.zeros(train_size), test_predict[:,0]), linestyle='--', label='Test Predict')
plt.legend()
plt.show()

运行图如下:

在这里插入图片描述

观测

训练损失逐渐降低并趋于稳定,意味着模型正在从训练数据中学习。

在训练集和测试集上的评估速度很快,意味着模型的推断(预测)效率很高。

如果损失在后续的epoch中没有显著下降,可能意味着模型需要更多的epoch来训练。或者可能需要调整模型的结构或超参数(例如增加神经元数量、改变学习率)以进一步提高性能。

训练集和测试集的RMSE非常接近,说明模型在两者上的性能是一致的。

没有出现过拟合/欠拟合的迹象,则说明模型的泛化能力良好。

考虑到数据生成时添加了随机噪声,这个RMSE值表明模型在捕捉数据的基本趋势方面表现的不错,相对较小的RMSE表示预测的准确。

如果要改进LSTM,可以从贝叶斯超参数调优、增加更多训练周期(EPOCH)、尝试不同网络架构,或者在数据预处理时更复杂一些。

事实上,在RMSE上,SARIMA比LSTM更小,但非线性模式/利用长期依赖性的复杂时间序列数据时,LSTM更好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/419782.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Sophon AutoCV推动AI应用从模型生产到高效落地

随着技术市场和应用方向的逐渐成熟,人工智能与各行各业的结合和落地逐渐进入了深水区。 虽然由于行业规模化和应用普及度的限制,人工智能在“传统”行业的落地不如消费互联网行业,但是借助人工智能为“传统”行业的发展注入新能量一直是相关…

Windows系统x86机器安装龙芯(loongarch64)3A5000虚拟机系统详细教程

本次介绍在window系统x86机器上安装loongarch64系统的详细教程。 1.安装环境准备。 首先,你得有台电脑。 配置别太差,至少4核8G内存,安装window10或者11都行(为啥不能是Window7,你要用也不是不行,你先解决…

边缘计算与任务卸载基础知识

目录 边缘计算简介任务卸载简介参考文献 边缘计算简介 边缘计算是指利用靠近数据生成的网络边缘侧的设备(如移动设备、基站、边缘服务器、边缘云等)的计算能力和存储能力,使得数据和任务能够就近得到处理和执行。 一个典型的边缘计算系统为…

未来已来:智慧餐饮点餐系统引领餐饮业的数字化转型

时下,智慧餐饮点餐系统正在引领着餐饮业迈向更高的位置。今天,小编将与大家共同探讨智慧餐饮点餐系统的发展趋势、优势以及对餐饮业的影响。 一、智慧餐饮点餐系统的发展趋势 智慧餐饮点餐系统的出现填补了这一空白,它通过引入数字化技术&a…

学习助手:借助AI大模型,学习更高效!

在当今的数字时代,人工智能(AI)的崛起已经彻底改变了我们获取信息、处理数据以及学习新知识的方式。AI大模型,特别是如OpenAI开发的GPT-4这类先进的技术,已成为学习和教育领域的一大助力。本文旨在探索如何借助AI大模型…

5G时代对于工业化场景应用有什么改善

5G 不仅仅是 4G 的技术升级,而是将平板电脑和智能手机的技术升级。除了更好的高清视频流和其他高带宽应用,消费者不会注意到很多性能差异。然而,在工业领域,5G 代表着巨大的飞跃。 在工厂和厂房内, 设备的Wi-Fi 网络经…

Python+Selenium+Unittest 之Unittest1--简介

Unittest属于是一种单元测试框架,主要用于对代码中写好的单元内容进行验证,比如写好一个函数,可以使用unittest去进行验证该函数的代码逻辑是否有问题,对于自动化来说,可以去检验每条用例的内容是否符合预期。 Unittes…

Goose:Golang中的数据库迁移工具

Goose:Golang中的数据库迁移工具 在Golang开发中,数据库迁移是一个常见的任务,用于管理数据库模式的演化和版本控制。Goose是一个轻量级的、易于使用的数据库迁移工具,专为Golang开发者设计。本文将介绍Goose的基本概念、用法和优…

php基础学习之错误处理(其二)

在实际应用中,开发者当然不希望把自己开发的程序的错误暴露给用户,一方面会动摇客户对己方的信心,另一方面容易被攻击者抓住漏洞实施攻击,同时开发者本身需要及时收集错误,因此需要合理的设置错误显示与记录错误日志 一…

代码随想录-回溯算法

组合 //未剪枝 class Solution {List<List<Integer>> ans new ArrayList<>();Deque<Integer> path new LinkedList<>();public List<List<Integer>> combine(int n, int k) {backtracking(n, k, 1);return ans;}public void back…

Python:关于数据服务中的Web API的设计

搭建类似joinquant、tushare类似的私有数据服务应用&#xff0c;有以下一些点需要注意&#xff1a; 需要说明的是&#xff0c;这里讨论的是web api前后端&#xff0c;当然还有其它方案&#xff0c;thrift&#xff0c;grpc等。因为要考虑到一鱼两吃&#xff0c;本文只探讨web ap…

Android之UI Automator框架源码分析(第九篇:UiDevice获取UiAutomation对象的过程分析)

前言 学习UiDevice对象&#xff0c;就需要看它的构造方法&#xff0c;构造方法中有UiDevice对象持有一些对象&#xff0c;每个对象都是我们分析程序的重点&#xff0c;毕竟UiDevice对象的功能&#xff0c;依赖这些组合的对象 备注&#xff1a;当前对象持有的对象&#xff0c;初…

Linux调试器-gdb使用与冯诺依曼体系结构

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言 Linux调试器-gdb使用 1. 背景 2. 开始使用 冯诺依曼体系结构 总结 前言 世上有两种耀眼的光芒&#xff0c;一种是正在升起的太阳&#xff0c;一种是正在努力学…

k8s部署mysql

&#xff08;作者&#xff1a;陈玓玏&#xff09; 一、前置条件 已部署k8s&#xff0c;服务端版本为1.21.14 二、部署mysql 拉取镜像&#xff1b; docker pull mysql将账号密码等信息写到configmap&#xff0c;创建configmap&#xff1b; apiVersion: v1 kind: ConfigMap m…

视觉AIGC识别——人脸伪造检测、误差特征 + 不可见水印

视觉AIGC识别——人脸伪造检测、误差特征 不可见水印 前言视觉AIGC识别【误差特征】DIRE for Diffusion-Generated Image Detection方法扩散模型的角色DIRE作为检测指标 实验结果泛化能力和抗扰动 人脸伪造监测&#xff08;Face Forgery Detection&#xff09;人脸伪造图生成 …

android TextView 实现富文本显示

android TextView 实现富文本显示&#xff0c;实现抖音直播间公屏消息案例 使用&#xff1a; val tvContent: TextView helper.getView(R.id.tvContent)//自己根据UI业务要求&#xff0c;可以控制 图标显示 大小val levelLabel MyImgLabel( bitmap 自己业务上的bitmap )va…

卷积神经网络基本概念补充

卷积&#xff08;convolution&#xff09;、通道&#xff08;channel&#xff09; 卷积核大小一般为奇数&#xff0c;有中心像素点&#xff0c;便于定位卷积核。 步长&#xff08;stride&#xff09;、填充&#xff08;padding&#xff09; 卷积核移动的步长&#xff08;stride…

FPGA之带有进位逻辑的加法运算

module ADDER&#xff08; input [5&#xff1a;0]A&#xff0c; input [5&#xff1a;0]B&#xff0c;output[6&#xff1a;0]Q &#xff09;&#xff1b; assign Q AB&#xff1b; endmodule 综合结果如下图所示&#xff1a; 使用了6个Lut&#xff0c;&#xff0c;6个LUT分布…

定制红酒:一次满足需求的个性化服务体验

云仓酒庄洒派提供一次满足需求的个性化服务体验&#xff0c;让您的红酒定制之旅成为一段美好的记忆。 首先&#xff0c;云仓酒庄洒派深入了解每位消费者的需求。无论是对于红酒品种、年份、外包装还是其他个性化要求&#xff0c;云仓酒庄洒派都认真倾听并记录下来。这种细致入微…

Solo 开发者周刊 (第6期):

这里会整合 Solo 社区每周推广内容、产品模块或活动投稿&#xff0c;每周五发布。在这期周刊中&#xff0c;我们将深入探讨开源软件产品的开发旅程&#xff0c;分享来自一线独立开发者的经验和见解。本杂志开源&#xff0c;欢迎投稿。 产品推荐 1. 助眠类播客《静夜斋》上线 一…
最新文章