LSTM从入门到精通(形象的图解,详细的代码和注释,完美的数学推导过程)

先附上这篇文章的一个思维导图


  1. 什么是RNN

按照八股文来说:RNN实际上就是一个带有 记忆的时间序列的预测模型

RNN的细胞结构图如下:

softmax激活函数只是我举的一个例子,实际上得到y<t>也可以通过其他的激活函数得到

其中a<t-1>代表t-1时刻隐藏状态,a<t>代表经过X<t>这一t时刻的输入之后,得到的新的隐藏状态。公式主要是a<t> = tanh(Waa * a<t-1> + Wax * X<t> + b1) ;大白话解释一下就是,X<t>是今天的吊针,a<t-1>是昨天的发烧度数39,经过今天这一针之后,a<t>变成38度。这里的记忆体现在今天的38度是在前一天的基础上,通过打吊针来达到第二天的降温状态。


1.1 RNN的应用

由于RNN的记忆性,我们最容易想到的就是RNN在自然语言处理方面的应用,譬如下面这张图,提前预测出下一个字。

除此之外,RNN的应用还包括下面的方向:

  1. 语言模型:RNN被广泛应用于语言模型的建模中,例如自然语言处理、机器翻译、语音识别等领域。

  1. 时间序列预测:RNN可以用于时间序列预测,例如股票价格预测、气象预测、心电图信号预测等。

  1. 生成模型:RNN可以用于生成模型,例如文本生成、音乐生成、艺术创作等。

  1. 强化学习:RNN可以用于强化学习中,例如在游戏、机器人控制和决策制定等领域。


1.2 RNN的缺陷

想必大家一定听说过LSTM,没错,就是由于RNN的尿性,所以才出现LSTM这一更精妙的时间序列预测模型的设计。但是我们知己知彼才能百战百胜,因此我还是决定详细分析一下RNN的缺陷,看过一些资料,但是只是肤浅的提到了梯度消失和梯度爆炸,没有实际的数学推导,这可不是一个求学之人应该有的态度!

主要的缺陷是两个:

  1. 长期依赖问题导致的梯度消失:众所周知RNN模型是一个具有记忆的模型,每一次的预测都和当前输入以及之前的状态有关,但是我们试想,如果我们的句子很长,他在第1000个记忆细胞还能记住并很好的利用第1个细胞的记忆状态吗?答案显然是否定的

  1. 梯度爆炸


1.2.1 梯度消失和梯度爆炸的详细公式推导

敲黑板(手写公式推导,大家最迷糊的地方):

根据下面图示的例子,我手写并反复检查了自己的过程(下图),请各位看官务必认真看看,理解起来并不难,对于别的文章随口一提的梯度消失和梯度爆炸实在是透彻太多啦!!!

我们假设损失函数 ,Y是实际值,O是预测值;首先,我们假设只有 三层,然后通过三层我们就能以此类推找出规律。反向传播我们需要对Wo,Wx,Ws,b四个变量都求偏导,在这里我们主要对Wx求偏导,其他三个以此类推,就很简单了。为了表示更清晰,笔者使用紫色的x表示乘法。

根据推导的公式我们得到一个指数函数,我们在高中时候就学到过指数函数的变化系数是极大的,因此在t趋于比较大的时候(也就是我们的句子比较长的时候),如果比1小不少,那么模型的一部分梯度会趋于0,因此优化会几乎停止;同理,如果比1大一些,那么模型的部分梯度会极大,导致模型和的变化无法控制,优化毫无意义。


  1. 什么是LSTM

八股文解释:LSTM(长短时记忆网络)是一种常用于处理序列数据的深度学习模型,与传统的 RNN(循环神经网络)相比,LSTM引入了三个门( 输入门、遗忘门、输出门,如下图所示)和一个 细胞状态(cell state),这些机制使得LSTM能够更好地处理序列中的长期依赖关系。注意:小蝌蚪形状表示的是sigmoid激活函数
Ct是细胞状态(记忆状态), 是输入的信息, 是隐藏状态(基于 得到的)

用最朴素的语言解释一下三个门,并且用两门考试来形象的解释一下LSTM:

  1. 遗忘门:通过x和ht的操作,并经过sigmoid函数,得到0,1的向量,0对应的就代表之前的记忆某一部分要忘记,1对应的就代表之前的记忆需要留下的部分 ===>代表复习上一门线性代数所包含的记忆,通过遗忘门,忘记掉和下一门高等数学无关的内容(比如矩阵的秩)

  1. 输入门:通过将之前的需要留下的信息和现在需要记住的信息相加,也就是得到了新的记忆状态。===>代表复习下一门科目高等数学的时候输入的一些记忆(比如洛必达法则等等),那么已经线性代数残余且和高数相关的部分(比如数学运算)+高数的知识=新的记忆状态

  1. 输出门:整合,得到一个输出===>代表高数所需要的记忆,但是在实际的考试不一定全都发挥出来考到100分。因此,则代表实际的考试分数

为了便于大家理解,附上几张非常好的图供大家理解完整的数据处理的流程:

遗忘门:

输入门:

细胞状态:

输出门:


2.1 LSTM的模型结构

这里有两张别的博主的很好的图,我在初学的时候也是恍然大悟:

图的出处

解释一下pytorch训练lstm所使用的参数:

  1. 这是利用pytorch调用LSTM所使用的参数

output,(h_n,c_n) = lstm (x, [ht_1, ct_1]),一般直接放入x就好,后面中括号的不用管
  1. 这是作为x(输入)喂给LSTM的参数

x:[seq_length, batch_size, input_size],这里有点反人类,batch_size一般都是放在开始的位置
  1. 这是pytorch简历LSTM是所需参数

lstm = LSTM(input_size,hidden_size,num_layers)

2.2 LSTM相比RNN的优势

LSTM的反向传播的数学推导很繁琐,因为涉及到的变量很多,但是LSTM确实是可以在一定程度上解决梯度消失和梯度爆炸的问题。我简单说一下,RNN的连乘主要是W的连乘,而W是一样的,因此就是一个指数函数(在梯度中出现指数函数并不是一件友好的事情);相反,LSTM的连乘是的偏导的不断累乘,如果前后的记忆差别不大,那偏导的值就是1,那就是多个1相乘。当然,也可能出现某一一些偏导的值很大,但是一定不会很多(换句话说,一句话的前后没有逻辑,那完全没有训练的必要)。


2.3 pytorch实现LSTM对股票的预测(实战)

需要安装一下tushare的金融方面的数据集,代码的注解我已经写的很清楚了

#!/usr/bin/python3
# -*- encoding: utf-8 -*-

import matplotlib.pyplot as plt
import numpy as np
import tushare as ts
import pandas as pd
import torch
from torch import nn
import datetime
import time

DAYS_FOR_TRAIN = 10


class LSTM_Regression(nn.Module):
    """
        使用LSTM进行回归

        参数:
        - input_size: feature size
        - hidden_size: number of hidden units
        - output_size: number of output
        - num_layers: layers of LSTM to stack
    """

    def __init__(self, input_size, hidden_size, output_size=1, num_layers=2):
        super().__init__()

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, _x):
        x, _ = self.lstm(_x)  # _x is input, size (seq_len, batch, input_size)
        s, b, h = x.shape
        x = x.view(s * b, h)
        x = self.fc(x)
        x = x.view(s, b, -1)  # 把形状改回来
        return x


def create_dataset(data, days_for_train=5) -> (np.array, np.array):
    """
        根据给定的序列data,生成数据集

        数据集分为输入和输出,每一个输入的长度为days_for_train,每一个输出的长度为1。
        也就是说用days_for_train天的数据,对应下一天的数据。

        若给定序列的长度为d,将输出长度为(d-days_for_train+1)个输入/输出对
    """
    dataset_x, dataset_y = [], []
    for i in range(len(data) - days_for_train):
        _x = data[i:(i + days_for_train)]
        dataset_x.append(_x)
        dataset_y.append(data[i + days_for_train])
    return (np.array(dataset_x), np.array(dataset_y))


if __name__ == '__main__':
    t0 = time.time()
    data_close = ts.get_k_data('000001', start='2019-01-01', index=True)['close']  # 取上证指数的收盘价
    data_close.to_csv('000001.csv', index=False) #将下载的数据转存为.csv格式保存
    data_close = pd.read_csv('000001.csv')  # 读取文件

    df_sh = ts.get_k_data('sh', start='2019-01-01', end=datetime.datetime.now().strftime('%Y-%m-%d'))
    print(df_sh.shape)

    data_close = data_close.astype('float32').values  # 转换数据类型
    plt.plot(data_close)
    plt.savefig('data.png', format='png', dpi=200)
    plt.close()

    # 将价格标准化到0~1
    max_value = np.max(data_close)
    min_value = np.min(data_close)
    data_close = (data_close - min_value) / (max_value - min_value)

    # dataset_x
    # 是形状为(样本数, 时间窗口大小)
    # 的二维数组,用于训练模型的输入
    # dataset_y
    # 是形状为(样本数, )
    # 的一维数组,用于训练模型的输出。
    dataset_x, dataset_y = create_dataset(data_close, DAYS_FOR_TRAIN)  # 分别是(1007,10,1)(1007,1)

    # 划分训练集和测试集,70%作为训练集
    train_size = int(len(dataset_x) * 0.7)

    train_x = dataset_x[:train_size]
    train_y = dataset_y[:train_size]

    # 将数据改变形状,RNN 读入的数据维度是 (seq_size, batch_size, feature_size)
    train_x = train_x.reshape(-1, 1, DAYS_FOR_TRAIN)
    train_y = train_y.reshape(-1, 1, 1)

    # 转为pytorch的tensor对象
    train_x = torch.from_numpy(train_x)
    train_y = torch.from_numpy(train_y)

    model = LSTM_Regression(DAYS_FOR_TRAIN, 8, output_size=1, num_layers=2)  # 导入模型并设置模型的参数输入输出层、隐藏层等

    model_total = sum([param.nelement() for param in model.parameters()])  # 计算模型参数
    print("Number of model_total parameter: %.8fM" % (model_total / 1e6))

    train_loss = []
    loss_function = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
    for i in range(200):
        out = model(train_x)
        loss = loss_function(out, train_y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss.append(loss.item())

        # 将训练过程的损失值写入文档保存,并在终端打印出来
        with open('log.txt', 'a+') as f:
            f.write('{} - {}\n'.format(i + 1, loss.item()))
        if (i + 1) % 1 == 0:
            print('Epoch: {}, Loss:{:.5f}'.format(i + 1, loss.item()))

    # 画loss曲线
    plt.figure()
    plt.plot(train_loss, 'b', label='loss')
    plt.title("Train_Loss_Curve")
    plt.ylabel('train_loss')
    plt.xlabel('epoch_num')
    plt.savefig('loss.png', format='png', dpi=200)
    plt.close()

    # torch.save(model.state_dict(), 'model_params.pkl')  # 可以保存模型的参数供未来使用
    t1 = time.time()
    T = t1 - t0
    print('The training time took %.2f' % (T / 60) + ' mins.')

    tt0 = time.asctime(time.localtime(t0))
    tt1 = time.asctime(time.localtime(t1))
    print('The starting time was ', tt0)
    print('The finishing time was ', tt1)




    # for test
    model = model.eval()  # 转换成评估模式
    # model.load_state_dict(torch.load('model_params.pkl'))  # 读取参数

    # 注意这里用的是全集 模型的输出长度会比原数据少DAYS_FOR_TRAIN 填充使长度相等再作图
    dataset_x = dataset_x.reshape(-1, 1, DAYS_FOR_TRAIN)  # (seq_size, batch_size, feature_size)
    dataset_x = torch.from_numpy(dataset_x)

    pred_test = model(dataset_x)  # 全量训练集
    # 的模型输出 (seq_size, batch_size, output_size)
    pred_test = pred_test.view(-1).data.numpy()
    pred_test = np.concatenate((np.zeros(DAYS_FOR_TRAIN), pred_test))  # 填充0 使长度相同
    assert len(pred_test) == len(data_close)

    plt.plot(pred_test, 'r', label='prediction')
    plt.plot(data_close, 'b', label='real')
    plt.plot((train_size, train_size), (0, 1), 'g--')  # 分割线 左边是训练数据 右边是测试数据的输出
    plt.legend(loc='best')
    plt.savefig('result.png', format='png', dpi=200)
    plt.close()

2.4 小问题:为什么采用tanh函数,不能都用sigmoid函数吗

先放上两个函数的图形:

  1. Sigmoid函数比Tanh函数收敛饱和速度慢

  1. Sigmoid函数比Tanh函数值域范围更窄

  1. tanh的均值是0,Sigmoid均值在0.5左右,均值在0的数据显然更便于数据处理

  1. tanh的函数变化敏感区间更大

  1. 对两者求导,发现tanh对计算的压力更小,直接是1-原函数的平方,不需要指数操作

使用该问的图请标明出处,创作不易,希望收获你的赞赞

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/717.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C语言/动态通讯录

本文使用了malloc、realloc、calloc等和内存开辟有关的函数。 文章目录 前言 二、头文件 三、主界面 四、通讯录功能函数 1.全代码 2.增加联系人 3.删除联系人 4.查找联系人 5.修改联系人 6.展示联系人 7.清空联系人 8.退出通讯录 总结 前言 为了使用通讯录时&#xff0c;可以…

Opencv项目实战:22 物体颜色识别并框选

目录 0、项目介绍 1、效果展示 2、项目搭建 3、项目代码展示与部分讲解 Color_trackbar.py bgr_detector.py test.py 4、项目资源 5、项目总结 0、项目介绍 本次项目要完成的是对物体颜色的识别并框选&#xff0c;有如下功能&#xff1a; &#xff08;1&#xff09;…

线程池的使用:如何写出高效的多线程程序?

目录1.线程池的使用2.编写高效的多线程程序Java提供了Executor框架来支持线程池的实现&#xff0c;通过Executor框架&#xff0c;可以快速地创建和管理线程池&#xff0c;从而更加方便地编写多线程程序。 1.线程池的使用 在使用线程池时&#xff0c;需要注意以下几点&#xff…

GDAL python教程基础篇(7)OGR空间计算

1.空间计算 地理数据处理&#xff08;geoprocessing&#xff09;计算函数&#xff1a; 多边形&#xff08;Polygon&#xff09;&#xff1a; 1、交&#xff1a;poly3.Intersection(poly2) 2、并&#xff1a;poly3.Union(poly2) 3、差&#xff1a;poly3.Difference(poly2) 4、补…

python打包成apk界面设计,python打包成安装文件

大家好&#xff0c;给大家分享一下如何将python程序打包成apk文件&#xff0c;很多人还不知道这一点。下面详细解释一下。现在让我们来看看&#xff01; 1、如何用python制作十分秒加减的apk 如何用python制作十分秒加减的apk&#xff1f;用法:. apk包放入apk文件目录,然后输入…

Linux基础命令大全(下)

♥️作者&#xff1a;小刘在C站 ♥️个人主页&#xff1a;小刘主页 ♥️每天分享云计算网络运维课堂笔记&#xff0c;努力不一定有收获&#xff0c;但一定会有收获加油&#xff01;一起努力&#xff0c;共赴美好人生&#xff01; ♥️夕阳下&#xff0c;是最美的绽放&#xff0…

走进哈希心房

目录 哈希的概念 哈希函数 哈希冲突和解决方法 闭散列 插入 查找 删除 开散列 插入 查找 删除 哈希表&#xff08;开散列&#xff09;整体代码 位图 位图模拟实现思路分析&#xff1a; 位图应用 布隆过滤器 本文介绍unordered系列的关联式容器&#xff0c;unor…

安卓手机也可以使用新必应NewBing

没有魔法安卓手机也可以使用新必应NewBing 目前知道的是安卓手机 安卓手机先安装一个猴狐浏览器 打开手机自带浏览器&#xff0c;搜索关键词&#xff1a;猴狐浏览器&#xff0c;找到官网 也可以直接复制这个网址 狐猴浏览器 lemurbrowser CoolAPK 我的手机是荣耀安卓手机…

【正点原子FPGA连载】 第三十三章基于lwip的tftp server实验 摘自【正点原子】DFZU2EG_4EV MPSoC之嵌入式Vitis开发指南

第三十三章基于lwip的tftp server实验 文件传输是网络环境中的一项基本应用&#xff0c;其作用是将一台电子设备中的文件传输到另一台可能相距很远的电子设备中。TFTP作为TCP/IP协议族中的一个用来在客户机与服务器之间进行文件传输的协议&#xff0c;常用于无盘工作站、路由器…

「ML 实践篇」分类系统:图片数字识别

目的&#xff1a;使用 MNIST 数据集&#xff0c;建立数字图像识别模型&#xff0c;识别任意图像中的数字&#xff1b; 文章目录1. 数据准备&#xff08;MNIST&#xff09;2. 二元分类器&#xff08;SGD&#xff09;3. 性能测试1. 交叉验证2. 混淆矩阵3. 查准率与查全率4. P-R 曲…

2023年腾讯云服务器配置价格表(轻量服务器、CVM云服务器、GPU云服务器)

目前腾讯云服务器分为轻量应用服务器、云服务器云服务器云服务器CVM和GPU云服务器&#xff0c;首先介绍一下这三种服务。 1、腾讯云云服务器&#xff08;Cloud Virtual Machine&#xff0c;CVM&#xff09;提供安全可靠的弹性计算服务。 您可以实时扩展或缩减计算资源&#xff…

【经验总结】10年的嵌入式开发老手,到底是如何快速学习和使用RT-Thread的?(文末赠书5本)

【经验总结】一位近10年的嵌入式开发老手&#xff0c;到底是如何快速学习和使用RT-Thread的&#xff1f; RT-Thread绝对可以称得上国内优秀且排名靠前的操作系统&#xff0c;在嵌入式IoT领域一直享有盛名。近些年&#xff0c;物联网产业的大热&#xff0c;更是直接将RT-Thread这…

python绘制图像中心坐标二维分布曲线

数据和代码如下所示&#xff1a; import pandas as pd import numpy as np import matplotlib.pyplot as plt import xlrd from scipy.stats import multivariate_normal from mpl_toolkits.mplot3d import Axes3D np.set_printoptions(suppressTrue)# 根据均值、标准差,求指定…

SuperMap iMobile for Android 地图开发(一)

第一步&#xff1a;创建 Android Studio 项目 第一步&#xff1a;创建 Android Studio 项目 Android Studio 有两种创建项目的方法。 第一种是在 Android Studio起始页选择“Start a new Android Studio Project”。 第二种是在 Android Studio 主页选择“File”–>“New P…

数仓建模—主题域和主题

主题域和主题 前面在这个专题的第一篇,也就是数仓建模—数仓初识中我们就提到了一个概念—主题,这个概念其实在数仓的定义中也有提到 数据仓库是一个面向主题的、集成的、相对稳定的、反映历史变化的数据集合,用于支持管理决策。 今天我们主要来探究一下,数仓的主题到底是…

Multi-Camera Color Correction via Hybrid Histogram Matching直方图映射

文章目录Multi-Camera Color Correction via Hybrid Histogram Matching1. 计算直方图&#xff0c; 累计直方图&#xff0c; 直方图均衡化2. 直方图规定化&#xff0c;直方图映射。3. 实验环节3.1 输入图像3.2 均衡化效果3.3 映射效果4. 针对3实验环节的伪影 做处理和优化&…

ChatGPT研究分析:GPT-4做了什么

前脚刚研究了一轮GPT3.5&#xff0c;OpenAI很快就升级了GPT-4&#xff0c;整体表现有进一步提升。追赶一下潮流&#xff0c;研究研究GPT-4干了啥。本文内容全部源于对OpenAI公开的技术报告的解读&#xff0c;通篇以PR效果为主&#xff0c;实际内容不多。主要强调的工作&#xf…

九种跨域方式实现原理(完整版)

前言 前后端数据交互经常会碰到请求跨域&#xff0c;什么是跨域&#xff0c;以及有哪几种跨域方式&#xff0c;这是本文要探讨的内容。 一、什么是跨域&#xff1f; 1.什么是同源策略及其限制内容&#xff1f; 同源策略是一种约定&#xff0c;它是浏览器最核心也最基本的安…

如何发布自己的npm包

一、什么是npm npm是随同nodejs一起安装的javascript包管理工具&#xff0c;能解决nodejs代码部署上的很多问题&#xff0c;常见的使用场景有以下几种&#xff1a; ①.允许用户从npm服务器下载别人编写的第三方包到本地使用。 ②.允许用户从npm服务器下载并安装别人编写的命令…

K_A18_001 基于STM32等单片机采集MQ2传感参数串口与OLED0.96双显示

K_A18_001 基于STM32等单片机采集MQ2传感参数串口与OLED0.96双显示一、资源说明二、基本参数参数引脚说明三、驱动说明IIC地址/采集通道选择/时序对应程序:四、部分代码说明1、接线引脚定义1.1、STC89C52RCMQ2传感参模块1.2、STM32F103C8T6MQ2传感参模块五、基础知识学习与相关…