机器学习入门--双向长短期记忆神经网络(BiLSTM)原理与实践

双向长短记忆网络(BiLSTM)

BiLSTM(双向长短时记忆网络)是一种特殊的循环神经网络(RNN),它能够处理序列数据并保持长期记忆。与传统的RNN模型不同的是,BiLSTM同时考虑了过去和未来的信息,使得模型能够更好地捕捉序列数据中的上下文关系。在本文中,我们将详细介绍BiLSTM的数学原理、代码实现以及应用场景。

数学原理

LSTM(长短期记忆网络)是一种递归神经网络(RNN),通过引入门控机制来解决传统RNN中的梯度消失或梯度爆炸问题。BiLSTM是LSTM的一个变体,它在时间序列上同时运行两个LSTM,一个从前向后处理,一个从后向前处理。

LSTM的关键部分是单元状态(cell state)和各种门控机制,包括遗忘门、输入门和输出门。这些门控机制使用sigmoid函数来控制信息流动的程度。

具体来说,对于给定的时间步 t t t和输入 X t X_t Xt,LSTM计算以下值:
1.遗忘门(forget gate):
f t = σ ( W f ⋅ [ h t − 1 , X t ] + b f ) f_t = σ(W_f \cdot [h_{t-1}, X_t] + b_f) ft=σ(Wf[ht1,Xt]+bf)
2.输入门(input gate):
i t = σ ( W i ⋅ [ h t − 1 , X t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, X_t] + b_i) it=σ(Wi[ht1,Xt]+bi)
3.更新单元状态(new cell state):
C ~ t = t a n h ( W c ⋅ [ h t − 1 , X t ] + b c ) \tilde{C}_t = tanh(W_c \cdot [h_{t-1}, X_t] + b_c) C~t=tanh(Wc[ht1,Xt]+bc)
4.单元状态(cell state):
C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t = f_t * C_{t-1} + i_t * \tilde{C}_t Ct=ftCt1+itC~t
5.输出门(output gate):
o t = σ ( W o ⋅ [ h t − 1 , X t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, X_t] + b_o) ot=σ(Wo[ht1,Xt]+bo)
6.隐状态(hidden state)更新:
h t = o t ∗ tanh ⁡ ( C t ) h_t = o_t * \tanh(C_t) ht=ottanh(Ct)
其中, σ \sigma σ表示sigmoid函数

BiLSTM通过将LSTM层沿着时间轴前向和后向运行来计算双向隐藏状态。前向LSTM从序列的第一个元素到最后一个元素顺序计算,而后向LSTM则相反。这两个隐藏状态被连接在一起形成最终的双向隐藏状态。

双向LSTM的输出可以用于各种任务,如序列标注、文本分类等。它能够捕捉到序列中每个时间步之前和之后的上下文信息,从而提供更全面的特征表示。

代码实现

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

# 加载数据集并进行标准化
data = load_boston()
X = data.data
y = data.target
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = y.reshape(-1, 1)

# 转换为张量
X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(y, dtype=torch.float32)

# 定义BiLSTM模型
class BiLSTMNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout,output_size):
        super(BiLSTMNet, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.output_size = output_size
        
        self.lstm = nn.LSTM(input_size = self.input_size, 
                            hidden_size = self.hidden_size,
                            num_layers = self.num_layers,
                            bidirectional = True,
                            batch_first=True,
                            dropout = self.dropout)
        
        self.fc1 = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.fc2 = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, x):
        output, (h_n, c_n) = self.lstm(x)
        out = torch.concat([h_n[-1,:,:], h_n[-2, :, :]], dim=-1)
        out = F.relu(out)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        return out

input_size = X.shape[2]
hidden_size = 32
output_size = 1
model = BiLSTMNet(input_size=input_size, 
                  hidden_size=hidden_size, 
                  num_layers=4,
                  dropout=0, 
                  output_size=output_size)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 10000
loss_list = []
for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()

    if (epoch+1) % 100 == 0:
        loss_list.append(loss.item())
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')

# 可视化损失曲线
plt.plot(range(100), loss_list)
plt.xlabel('num_epochs')
plt.ylabel('loss of LSTM Training')
plt.show()

# 预测新数据
new_data_point = X[0].reshape(1, 1, -1)
prediction = model(new_data_point)
print(f'Predicted value: {prediction.item()}')

上述代码实现了一个双向LSTM模型,用于预测波士顿房价数据集中的房价。其中,首先加载并标准化数据集,然后定义了一个自定义的BiLSTMNet模型类,该类包含了一个双向LSTM层和两个全连接层,并使用MSELoss作为损失函数和Adam作为优化器进行模型训练。在训练过程中,每100个epoch将损失值记录下来,最后使用Matplotlib将损失曲线可视化(如下图所示)。最后,使用模型对新数据点进行预测并输出结果。整个代码实现了一个简单的房价预测模型,可以通过调节模型参数和超参数进行进一步优化。
BiLSTM-损失曲线

总结

BiLSTM是一种强大的序列建模工具,它能够处理序列数据并捕捉长期依赖关系。在实践中,BiLSTM已被广泛应用于语音识别、自然语言处理和股票预测等领域。本文介绍了BiLSTM的数学原理、代码实现以及应用场景,希望读者能够通过本文了解到BiLSTM的基本知识和使用方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/394193.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

[Flink02] Flink架构和原理

这是继第一节之后的Flink入门系列的第二篇,本篇主要内容是是:了解Flink运行模式、Flink调度原理、Flink分区、Flink安装。 1、运行模式 Flink有多种运行模式,可以运行在一台机器上,称为本地(单机)模式&am…

图像识别之ResNet(结构详解以及代码实现)

前言 在人工智能的浪潮中,深度学习已经成为了推动计算机视觉、自然语言处理等领域突破的关键技术。在这众多技术中,ResNet(残差网络)无疑是一个闪耀的名字。自从2015年Kaiming He等人提出ResNet架构以来,它不仅在图像…

【二十八】springboot整合logback实现日志管理

本章节是记录logback在springboot项目中的简单使用&#xff0c;本文将会演示如何通过logback将日志记录到日志文件或输出到控制台等管理操作。将会从以下几个方面进行讲解。最后实现将特定级别的特定日志保存到日志文件。 一、依赖 <dependency><groupId>ch.qos.l…

BMS再进阶(新能源汽车电池管理系统)

引言 一文入门BMS&#xff08;电池管理系统&#xff09;_bms电池管理-CSDN博客 BMS进阶&#xff08;Type-C、PD快充、充电IC、SOC算法、电池管理IC&#xff09;_充电ic asi aso功能-CSDN博客 本文是上面两篇博客的续篇&#xff0c;之前都是讲解一些BMS基本原理&#xff0c;…

目前2024年4核8G云服务器租用价格,阿里云PK腾讯云

4核8G云服务器多少钱一年&#xff1f;阿里云ECS服务器u1价格955.58元一年&#xff0c;腾讯云轻量4核8G12M带宽价格是646元15个月&#xff0c;阿腾云atengyun.com整理4核8G云服务器价格表&#xff0c;包括一年费用和1个月收费明细&#xff1a; 云服务器4核8G配置收费价格 阿里…

[Flink03] Flink安装

本文介绍Flink的安装步骤&#xff0c;主要是Flink的独立部署模式&#xff0c;它不依赖其他平台。文中内容分为4块&#xff1a;前置准备、Flink本地模式搭建、Flink Standalone搭建、Flink Standalong HA搭建。 演示使用的Flink版本是1.15.4&#xff0c;官方文档地址&#xff1…

PyCharm 调试过程中控制台 (Console) 窗口内运行命令 - 实时获取中间状态

PyCharm 调试过程中控制台 [Console] 窗口内运行命令 - 实时获取中间状态 1. yongqiang.py2. Debugger -> Console3. Show Python PromptReferences 1. yongqiang.py #!/usr/bin/env python # -*- coding: utf-8 -*- # yongqiang chengfrom __future__ import absolute_imp…

边坡位移监测设备:守护工程安全的前沿科技

随着现代工程建设的飞速发展&#xff0c;边坡位移监测作为预防山体滑坡、泥石流等自然灾害的重要手段&#xff0c;日益受到人们的关注。边坡位移监测设备作为这一领域的关键技术&#xff0c;以其高精度、实时监测的特点&#xff0c;成为守护工程安全的重要武器。 一、边坡位移…

某行的行走艺术: 一项基于SpringBoot的web系统后端专利

各位盆友&#xff0c;听说了没&#xff1f;某银行搞了个大新闻&#xff0c;一项新专利"Cool Walk in Web-ville"&#xff08;就是那个“基于SpringBoot的web系统后端实现方法及装置”&#xff09;&#xff0c;号称是网页后台的新时尚&#xff0c;走在了技术的最前沿。…

达梦数据库——数据迁移sqlserver-dm报错问题整理

报错情况一&#xff1a;Sql server迁移达梦连接报错’驱动程序无法通过使用安全套接字Q层(SSL)加密与SQL Server 建立安全连接。错误:“The server selected protocol version TLS10 is not accepted by client preferencesITLS127‘ 原因&#xff1a;历史版本的SOL SERVER服务…

notepad++运行python闪一下就没啦

问题&#xff1a;Notepad直接快捷键运行Python代码,出现闪一下就没了 解决措施&#xff1a; ①点击菜单运行(Run) --> 运行(Run)弹出的对话框 ②把 cmd /k python "$(FULL_CURRENT_PATH)" & ECHO. & PAUSE & EXIT 粘贴进入这个对话框内 ③点击保存&a…

机器人常用传感器分类及一般性要求

机器人传感器的分类 传感技术是先进机器人的三大要素&#xff08;感知、决策和动作&#xff09;之一。根据用途不同&#xff0c;机器人传感器可以分为两大类&#xff1a;用于检测机器人自身状态的内部传感器和用于检测机器人相关环境参数的外部传感器。 内部传感器 内部传感…

LiveGBS流媒体平台GB/T28181功能-自定义收流端口区间30000至30249UDP端口TCP端区间配置及相关端口复用问题说明

LiveGBS自定义收流端口区间30000至30249UDP端口TCP端区间配置及相关端口复用问题说明 1、收流端口配置1.1、INI配置1.2、页面配置 2、相关问题3、最少可以开放多少端口3.1、端口复用3.2、配置最少端口如下 4、搭建GB28181视频直播平台 1、收流端口配置 1.1、INI配置 可在lives…

【Jvm】运行时数据区域(Runtime Data Area)原理及应用场景

文章目录 前言&#xff1a;Jvm 整体组成 一.JDK的内存区域变迁Java8虚拟机启动参数 二.堆0.堆的概念1.堆的内存分区2.堆与GC2.1.堆的分代结构2.2.堆的分代GC2.3.堆的GC案例2.4.堆垃圾回收方式 3.什么是内存泄露4.堆栈的区别5.堆、方法区 和 栈的关系 三.虚拟机栈0.虚拟机栈概念…

如何实现批量获取电商数据自动化商品采集?如何利用电商数据API实现业务增长?

随着电子商务的快速发展&#xff0c;数据已经成为了电商行业最重要的资产之一。在这个数据驱动的时代&#xff0c;电商数据API&#xff08;应用程序接口&#xff09;的作用日益凸显。通过电商数据API&#xff0c;商家能够获取到大量关于消费者行为、产品表现、市场趋势等有价值…

SpringCloud Ribbon负载均衡的策略总结

1. 轮询策略 2. 权重轮询策略 3. 随机策略 4. 最少并发数策略 5. 在选定的负载均衡策略基础上重试机制 6. 可用性敏感策略。 7. 区域敏感策略

沁恒CH32V30X学习笔记05--串口接收中断和空闲中断组合接收数据

同步异步收发器(USART)** 包含 3 个通用同步异步收发器(USART1/2/3)和 5 个通用异步收发器(UART4/5/6/7/8) 空闲帧,空闲帧是 10 位或 11 位高电平,包含停止位。 断开帧是 10 位或 11 位低电平,后跟着停止位 引脚模式配置 引脚分配 bsp 驱动代码 bsp_uart_it.c /…

解析线性回归:从基础概念到实际应用

目录 前言1 什么是线性回归2 线性回归的一些概念2.1 样本集与样本2.2 实际值与估计值2.3 模型参数与最小二乘法2.4 残差与拟合优度 3 线性回归的应用场景3.1 销售预测3.2 医学数据分析3.3 金融市场分析 结语 前言 线性回归&#xff0c;被誉为统计学与机器学习领域的明星算法&a…

冒泡排序:原理、实现与性能分析

引言 在编程世界中&#xff0c;排序算法是不可或缺的一部分。冒泡排序作为最基本的排序算法之一&#xff0c;虽然其效率并不是最高的&#xff0c;但其实现简单、易于理解的特点使得它成为学习和理解排序算法的入门之选。本文将详细介绍冒泡排序的原理、实现方法以及性能分析&a…

python如何模拟登录Github

首先进入github登录页&#xff1a;https://github.com/login 输入账号密码&#xff0c;打开开发者工具&#xff0c;在Network页勾选上Preserve Log&#xff08;显示持续日志&#xff09;&#xff0c;点击登录&#xff0c;查看Session请求&#xff0c;找到其请求的URL与Form Da…