机器学习入门--LSTM原理与实践

LSTM模型

长短期记忆网络(Long Short-Term Memory,LSTM)是一种常用的循环神经网络(RNN)变体,特别擅长处理长序列数据和捕捉长期依赖关系。本文将介绍LSTM模型的数学原理、代码实现和实验结果,并使用pytorch和sklearn的数据集进行验证。

数学原理

遗忘门(Forget Gate)

遗忘门的作用是决定前一时间步的细胞状态中哪些信息需要被遗忘。具体计算公式为:
f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf[ht1,xt]+bf)
其中, W f W_f Wf表示遗忘门权重矩阵, h t − 1 h_{t-1} ht1表示前一时间步的隐藏状态, x t x_t xt是当前时间步的输入, b f b_f bf是遗忘门的偏置向量, σ \sigma σ表示sigmoid函数。

输入门 (Input Gate)

输入门的作用是决定当前时间步的输入中哪些信息将被加入到细胞状态中。具体计算公式为:
i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi[ht1,xt]+bi)
其中, W i W_i Wi 表示输入门的权重矩阵, h t − 1 h_{t-1} ht1表示前一时间步的隐藏状态, x t x_t xt表示当前时间步的输入, b i b_i bi是输入门的偏置向量, σ \sigma σ是sigmoid函数。

更新单元 (Candidate Cell State)

更新单元计算出一个候选的单元状态,用于更新当前时间步的单元状态。具体计算公式为:
C ~ t = tanh ⁡ ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}_{t} = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) C~t=tanh(WC[ht1,xt]+bC)
其中, W C W_C WC是更新单元的权重矩阵, h t − 1 h_{t-1} ht1是前一时间步的隐藏状态, x t x_t xt表示当前时间步输入, b C b_C bC是更新单元偏置向量。

细胞状态更新(Cell State Update)

通过遗忘门、输入门和更新单元的计算结果,可以更新当前时间步的细胞状态。具体计算公式为:
C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t = f_t * C_{t-1} + i_t * \tilde{C}_t Ct=ftCt1+itC~t
其中, f t f_t ft是遗忘门的输出, C t − 1 C_{t-1} Ct1是前一时间步的单元状态, i t i_t it是输入门的输出, C ~ t \tilde{C}_t C~t是更新单元的输出。

输出门(Output Gate)

输出门的作用是决定当前时间步的隐藏状态。具体计算公式为:
o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo[ht1,xt]+bo)
其中, W o W_o Wo是输出门的权重矩阵, h t − 1 h_{t-1} ht1是前一时间步的隐藏状态, x t x_t xt 表示当前时间步输入, b o b_o bo 是输出门偏置向量。

隐藏状态更新(Hidden State Update)

通过输出门和细胞状态计算出当前时间步的隐藏状态。具体计算公式为:
h t = o t ∗ tanh ⁡ ( C t ) h_t = o_t * \tanh(C_t) ht=ottanh(Ct)
其中, o t o_t ot是输出门的输出, C t C_t Ct代表当前时间步的单元状态。

代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

# 加载数据集并进行标准化
data = load_boston()
X = data.data
y = data.target
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = y.reshape(-1, 1)

# 转换为张量
X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(y, dtype=torch.float32)

# 定义LSTM模型
class LSTMNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTMNet, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.fc(out[:, -1, :])
        return out

input_size = X.shape[2]
hidden_size = 32
output_size = 1
model = LSTMNet(input_size, hidden_size, output_size)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 10000
loss_list = []
for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()

    if (epoch+1) % 100 == 0:
        loss_list.append(loss.item())
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')

# 可视化损失曲线
plt.plot(range(100), loss_list)
plt.xlabel('num_epochs')
plt.ylabel('loss of LSTM Training')
plt.show()

# 预测新数据
new_data_point = X[0].reshape(1, 1, -1)
prediction = model(new_data_point)
print(f'Predicted value: {prediction.item()}')

上述代码中,我们首先加载并标准化波士顿房价数据集,然后定义了一个包含LSTM层和全连接层的LSTMNet模型。通过使用均方误差作为损失函数和Adam优化器进行训练,我们展示了如何训练和预测LSTM模型。最后,通过matplotlib库绘制了损失曲线(如下图所示),并对新数据点进行了预测。
LSTM-损失函数

总结

LSTM作为一种强大的循环神经网络模型,在处理长序列数据和捕捉长期依赖关系方面表现出色。通过本文的介绍和实验,我们深入探讨了LSTM的数学原理、代码实现和应用实例。通过使用pytorch和sklearn的数据集进行实验,我们验证了LSTM模型在房价预测任务中的有效性和性能优势。希望本文能帮助读者更好地理解和应用LSTM模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/391059.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MongoDB从入门到实战之.NET Core使用MongoDB开发ToDoList系统(3)-系统数据集合设计

前言 前几章教程我们把ToDoList系统的基本框架搭建好了,现在我们需要根据我们的需求把ToDoList系统所需要的系统集合(相当于关系型数据库中的数据库表)。接下来我们先简单概述一下这个系统主要需要实现的功能以及实现这些功能我们需要设计那些…

平时积累的FPGA知识点(10)

平时在FPGA群聊等积累的FPGA知识点,第10期: 41 ZYNQ系列芯片的PL中使用PS端送过来的时钟,这些时钟名字是自动生成的吗? 解释:是的。PS端设置的是ps_clk,用report_clocks查出来的时钟名变成了clk_fpga_0&a…

NX二次开发树列表双击快速进入编辑状态

先将这几个树列表回调注释给解开 int TreeColumn0;//定义一个全局边量记录点击的那一列NXOpen::BlockStyler::Tree::BeginLabelEditState OnBeginLabelEditCallback(NXOpen::BlockStyler::Tree *tree,NXOpen::BlockStyler::Node *node,int columID) {if(columnIDTreeColumnID)…

The method toList() is undefined for the type Stream

The method toList() is undefined for the type Stream &#xff08;JDK16&#xff09; default List<T> toList() { return (List<T>) Collections.unmodifiableList(new ArrayList<>(Arrays.asList(this.toArray()))); }

磁体发条概念

使用磁体发条&#xff08;也称为磁弹簧或磁蓄能器&#xff09;作为储能装置是一个有趣的概念&#xff0c;它利用电磁感应原理来存储和释放能量。磁体发条的基本原理是通过旋转一个强磁体&#xff0c;使其通过一个线圈的中心&#xff0c;从而在线圈中产生电流。当磁体停止旋转时…

平时积累的FPGA知识点(11)

平时在FPGA群聊等积累的FPGA知识点,第11期: 51 可以把dcp文件封装到自己ip里吗? 解释:不可以 52 fifo的异步复位要做异步复位同步释放吗? 解释:要跟写时钟同步,所以需要在ip外部做一下同步释放 53 vivado报错 Phase 6.1 Hold Fix Iter Phase 6.1.1 Update Timing …

下一代Windows系统曝光:基于GPT-4V,Agent跨应用调度,代号UFO

下一代Windows操作系统提前曝光了&#xff1f;&#xff1f; 微软首个为Windows而设的智能体&#xff08;Agent&#xff09; 亮相&#xff1a; 基于GPT-4V&#xff0c;一句话就可以在多个应用中无缝切换&#xff0c;完成复杂任务。整个过程无需人为干预&#xff0c;其执行成功…

PANTONE(R)_colorist 潘通色号查询软件

PANTONE_colorist 潘通色号查询 查找颜色一栏输入对应的色号&#xff0c;即可显示对应的图案。 下载 https://download.csdn.net/download/jintaihu/19340557

2024年最新onlyfans虚拟信用卡订阅教程

一、Onlyfans是什么&#xff1f; OnlyFans是一个允许创作者分享自己的独家内容的平台&#xff0c;简称o站。这个平台允许创作者创建一个订阅服务&#xff0c;粉丝需要支付费用才能访问其独家内容。 本文将教你如何使用虚拟卡在OnlyFans上进行充值。 二、如何使用虚拟卡支付 O…

【测试】测试概念篇和基础篇

目 录 一.了解软件测试的基础概念1.需求2.测试用例3.BUG 二.开发模型和测试模型1.瀑布模型2.螺旋模型3.增量模型和迭代模型4.敏捷模型 三.软件测试模型V模型W模型 四.BUG篇1. 如何合理的创建 bug2. bug 级别3. bug 的生命周期4. 跟开发产生争执怎么办 一.了解软件测试的基础概念…

springboot190基于springboot框架的工作流程管理系统的设计与实现

简介 【毕设源码推荐 javaweb 项目】基于springbootvue 的 适用于计算机类毕业设计&#xff0c;课程设计参考与学习用途。仅供学习参考&#xff0c; 不得用于商业或者非法用途&#xff0c;否则&#xff0c;一切后果请用户自负。 看运行截图看 第五章 第四章 获取资料方式 **项…

循序渐进-讲解Markdown进阶(Mermaid绘图)-附使用案例

Markdown 进阶操作 查看更多学习笔记&#xff1a;GitHub&#xff1a;LoveEmiliaForever Mermaid官网 由于CSDN对某些Mermaid或Markdown语法不支持&#xff0c;因此我的某些效果展示使用图片进行 下面的笔记内容全部是我根据Mermaid官方文档学习的&#xff0c;因为是初学者所以…

分布式锁redisson

文章目录 1. 分布式锁1.1 基本原理和实现方式对比synchronized锁在集群模式下的问题多jvm使用同一个锁监视器分布式锁概念分布式锁须满足的条件分布式锁的实现 1.2 基于Redis的分布式锁获取锁&释放锁操作示例 基于Redis实现分布式锁初级版本ILock接口SimpleRedisLock使用示…

解决LeetCode编译器报错的技巧:正确处理位操作中的数据类型

一天我在leetcode上刷题时&#xff0c;遇到了这样的题目&#xff1a; 随即我写了如下的代码&#xff1a; int convertInteger(int A, int B) {int count 0;int C A ^ B;int flag 1;while(flag){if (C & flag){count;}flag<<1;}return count;} 但LeetCode显示如下…

Linux常见指令(一)

一、基本指令 1.1ls指令 语法 &#xff1a; ls [ 选项 ][ 目录或文件 ] 功能&#xff1a;对于目录&#xff0c;该命令列出该目录下的所有子目录与文件。对于文件&#xff0c;将列出文件名以及其他信息。 常用选项&#xff1a; -a 列出目录下的所有文件&#xff0c;包括以 .…

第9章 网络编程

9.1 网络通信协议 通过计算机网络可以实现多台计算机连接&#xff0c;但是不同计算机的操作系统和硬件体系结构不同&#xff0c;为了提供通信支持&#xff0c;位于同一个网络中的计算机在进行连接和通信时必须要遵守一定的规则&#xff0c;这就好比在道路中行驶的汽车一定要遵…

实验二 物理内存管理-实验部分

目录 一、知识点 1、计算机体系结构/内存层次 1.1、计算机体系结构 1.2、地址空间&地址生成 1.3、伙伴系统&#xff08;Buddy System&#xff09; 2、非连续内存分配 2.1、段式存储 2.2、页式存储 2.3、快表和多级页表 2.4、段页式存储 3、X86的特权级与MMU 3.…

如何让Obsidian实现电脑端和安卓端同步

Obsidian是一款知名的笔记软件&#xff0c;支持Markdown语法&#xff0c;它允许用户在多个设备之间同步文件。要在安卓设备上实现同步&#xff0c;可以使用remote save插件&#xff0c;以下是具体操作步骤&#xff1a; 首先是安装电脑端的obsidian&#xff0c;然后依次下载obs…

【Java】文件操作与IO

文件操作与IO Java中操作文件针对文件系统的操作File类概述字段构造方法方法及示例 文件内容的读写 —— 数据流Java提供的 “流” API文件流读写文件内容InputStream 示例读文件示例1&#xff1a;将文件完全读完的两种方式示例二&#xff1a;读取汉字 写文件谈谈 OutputStream…

进程状态

广义概念&#xff1a; 从广义上来讲&#xff0c;进程分为新建、运行、阻塞、挂起、退出五个状态&#xff0c;其中新建和退出两个状态可以直接理解字面意思。 运行状态&#xff1a; 这里涉及到运行队列的概念&#xff0c;CPU在读取数据的时候&#xff0c;需要把内存中的进程放入…