[PyTorch][chapter 46][LSTM -1]

前言:

           长短期记忆网络(LSTM,Long Short-Term Memory)是一种时间循环神经网络,是为了解决一般的RNN(循环神经网络)存在的长期依赖问题而专门设计出来的。

目录:

  1.      背景简介
  2.      LSTM Cell
  3.      LSTM 反向传播算法
  4.      为什么能解决梯度消失
  5.       LSTM 模型的搭建


一  背景简介:

       1.1  RNN

         RNN 忽略o_t,L_t,y_t 模型可以简化成如下

      

       

          图中Rnn Cell 可以很清晰看出在隐藏状态h_t=f(x_t,h_{t-1})

            得到 h_t后:

              一方面用于当前层的模型损失计算,另一方面用于计算下一层的h_{t+1}

    由于RNN梯度消失的问题,后来通过LSTM 解决 

       1.2 LSTM 结构

        


二  LSTM  Cell

   LSTMCell(RNNCell) 结构

          

          前向传播算法 Forward

         2.1   更新: forget gate 忘记门

             f_t=\sigma(W_fh_{t-1}+U_{t}x_t+b_f)

             将值朝0 减少, 激活函数一般用sigmoid

             输出值[0,1]

         2.2 更新: Input gate 输入门

                i_t=\sigma(W_ih_{t-1}+U_ix_t+b_i)

                决定是不是忽略输入值

    

           2.3 更新: 候选记忆单元

                    a_t=\widetilde{c_t}=tanh(W_a h_{t-1}+U_ax_t+b_a)

           2.4 更新: 记忆单元

               c_t=f_t \odot c_{t-1}+i_t \odot a_t

             2.5  更新: 输出门

                决定是否使用隐藏值

                 o_t=\sigma(W_oh_{t-1}+U_ox_t+b_0)  

           2.6. 隐藏状态

                h_t=o_t \odot tanh(c_t)

           2.7  模型输出

                  \hat{y_t}=\sigma(Vh_t+b)

LSTM 门设计的解释一:

 输入门 ,遗忘门,输出门 不同取值组合的时候,记忆单元的输出情况


三  LSTM 反向传播推导

      3.1 定义两个\delta_t

             \delta_h^t=\frac{\partial L}{\partial h_t}

            \delta_c^t=\frac{\partial L}{\partial C_t}

    3.2  定义损失函数

            损失函数L(t)分为两部分: 

             时刻t的损失函数 l(t)

             时刻t后的损失函数L(t+1)

              L(t)=\left\{\begin{matrix} l(t)+L(t+1), if: t<T\\ l(t), if: t=T \end{matrix}\right.

      3.3 最后一个时刻\tau

              

 这里面要注意这里的o^{\tau}= Vh_{\tau}+c

    证明一下第二项,主要应用到微分的两个性质,以及微分和迹的关系:

   

   dl= tr((\frac{\partial L^{\tau}}{\partial h^{\tau}})^Tdh^{\tau})  ... 公式1: 微分和迹的关系

       =tr((\delta_h^{\tau})^Tdh^{\tau})

     因为

    h^{\tau}=o^{\tau} \odot tanh(c^{\tau})

   dh_T=o^{\tau}\odot(d(tanh (c^{\tau})))

           =o^{\tau} \odot (1-tanh^2(c^{\tau})) \odot dc^{\tau}

     带入上面公式1:

      dl= tr((\delta_h^{\tau})^T (o^{\tau}\odot(1-tanh^2(c^{\tau}))\odot dc^{\tau})

           =tr((\delta_h^{\tau} \odot o^{\tau} \odot(1-tanh^2(c^{\tau}))^Tdc^{\tau})

    所以

3.4   链式求导过程

       求导结果:

 

  这里详解一下推导过程:

  这是一个符合函数求导:先把h 写成向量形成

h=\begin{bmatrix} o_1*tanh(c_1)\\ o_2*tanh(c_2) \\ .... \\ o_n*tanh(c_n) \end{bmatrix}

 ------------------------------------------------------------   

 第一项: 

             

         h_{t+1}=o_{t+1}\odot tanh(c_{t+1})

         o_{t+1}=\sigma(W_oh_t+U_ox_{t+1}+b_0)

        设 a_{t+1}=W_oh_t+U_ox_{t+1}+b_0

           则    \frac{\partial h_{t+1}}{\partial h_{t}}=\frac{\partial h_{t+1}}{\partial o_{t+1}}\frac{\partial o_{t+1}}{\partial a_{t+1}}\frac{\partial a_{t+1}}{\partial h_{t}}

 

            其中:(利用矩阵求导的定义法 分子布局原理)

                    \frac{\partial h_{t+1}}{\partial o_{t+1}}=diag(tanh(c^{t+1})) 是一个对角矩阵

                  o=\begin{bmatrix} \sigma(a_1)\\ \sigma(a_2) \\ .... \\ \sigma(a_n) \end{bmatrix}

                 \frac{\partial o_{t+1}}{\partial a_{t+1}}=diag(o_{t+1}\odot(1-o_{t+1}))

                 \frac{\partial a_{t+1}}{\partial h_{t}}=W_o

                 几个连乘起来就是第一项

               

第二项

    c_{t+1}=f_{t+1}\odot c_t+i_{t+1}\odot a_{t+1}

   f_{t+1}=\sigma(W_fh_t+U_tx_{t+1}+b_f)

   i_{t+1}=\sigma(W_ih_t+U_i x_{t+1}+b_i)

  a_{t+1}=tanh(W_a h_t +U_ax_t +b_a)

参考:

   h=\begin{bmatrix} o_1*tanh(c_1)\\ o_2*tanh(c_2) \\ .... \\ o_n*tanh(c_n) \end{bmatrix}

其中:

\frac{\partial h_{t+1}}{\partial c^{t+1}}=diag(o^{t+1}\odot (1-tanh^2(c^{t+1}))

\frac{\partial h_{t+1}}{\partial h_{t}}=\frac{\partial h_{t+1}}{\partial c_{t+1}}\frac{\partial c_{t+1}}{\partial f_{t+1}}\frac{\partial f_{t+1}}{\partial h_{t}}

 \frac{\partial c_{t+1}}{\partial f_{t+1}}=diag(c^{t})

 \frac{\partial a_{t+1}}{\partial h_{t}}=diag(f_t \odot(1-f_t))W_f

其它也是相似,就有了上面的求导结果


四  为什么能解决梯度消失

    

     4.1 RNN 梯度消失的原理

                ,复旦大学邱锡鹏书里面 有更加详细的解释,通过极大假设:

在梯度计算中存在梯度的k 次方连乘 ,导致 梯度消失原理。

    4.2  LSTM 解决梯度消失 解释1:

            通过上面公式发现梯度计算中是加法运算,不存在连乘计算,

            极大概率降低了梯度消失的现象。

    4.3  LSTM 解决梯度 消失解释2:

              记忆单元c  作用相当于ResNet的残差部分.  

   比如f_{t}=1,\hat{c_t}=0 时候,\frac{\partial c_t}{\partial c_{t-1}}=1,不会存在梯度消失。

       


五 模型的搭建

   

    我们最后发现:

    O_t,C_t,H_t 的维度必须一致,都是hidden_size

    通过C_t,则 I_t,F_t,\tilde{c} 最后一个维度也必须是hidden_size

    

# -*- coding: utf-8 -*-
"""
Created on Thu Aug  3 15:11:19 2023

@author: chengxf2
"""

# -*- coding: utf-8 -*-
"""
Created on Wed Aug  2 15:34:25 2023

@author: chengxf2
"""

import torch
from torch import nn
from d21 import torch as d21


def normal(shape,devices):
    
    data = torch.randn(size= shape, device=devices)*0.01
    
    return data


def get_lstm_params(input_size, hidden_size,categorize_size,devices):
    


    
    #隐藏门参数
    W_xf= normal((input_size, hidden_size), devices)
    W_hf = normal((hidden_size, hidden_size),devices)
    b_f = torch.zeros(hidden_size,devices)
    
    #输入门参数
    W_xi= normal((input_size, hidden_size), devices)
    W_hi = normal((hidden_size, hidden_size),devices)
    b_i = torch.zeros(hidden_size,devices)
    

    
    #输出门参数
    W_xo= normal((input_size, hidden_size), devices)
    W_ho = normal((hidden_size, hidden_size),devices)
    b_o = torch.zeros(hidden_size,devices)
    
    #临时记忆单元
    W_xc= normal((input_size, hidden_size), devices)
    W_hc = normal((hidden_size, hidden_size),devices)
    b_c = torch.zeros(hidden_size,devices)
    
    #最终分类结果参数
    W_hq = normal((hidden_size, categorize_size), devices)
    b_q = torch.zeros(categorize_size,devices)
    
    
    params =[
        W_xf,W_hf,b_f,
        W_xi,W_hi,b_i,
        W_xo,W_ho,b_o,
        W_xc,W_hc,b_c,
        W_hq,b_q]
    
    for param in params:
        
        param.requires_grad_(True)
        
    return params

def init_lstm_state(batch_size, hidden_size, devices):
    
    cell_init = torch.zeros((batch_size, hidden_size),device=devices)
    hidden_init = torch.zeros((batch_size, hidden_size),device=devices)
    
    return (cell_init, hidden_init)


def lstm(inputs, state, params):
    [
        W_xf,W_hf,b_f,
        W_xi,W_hi,b_i,
        W_xo,W_ho,b_o,
        W_xc,W_hc,b_c,
        W_hq,b_q] = params    
    
    (H,C) = state
    outputs= []
    
    for x in inputs:
        
        #input gate
        I = torch.sigmoid((x@W_xi)+(H@W_hi)+b_i)
        F = torch.sigmoid((x@W_xf)+(H@W_hf)+b_f)
        O = torch.sigmoid((x@W_xo)+(H@W_ho)+b_o)
        
        C_tmp = torch.tanh((x@W_xc)+(H@W_hc)+b_c)
        C = F*C+I*C_tmp
        
        H = O*torch.tanh(C)
        Y = (H@W_hq)+b_q
        
        outputs.append(Y)
        
    return torch.cat(outputs, dim=0),(H,C)
        
    

def main():
    batch_size,num_steps =32, 35
    train_iter, cocab= d21.load_data_time_machine(batch_size, num_steps)

    

if __name__ == "__main__":
    
     main()


 参考

 

CSDN

https://www.cnblogs.com/pinard/p/6519110.html

57 长短期记忆网络(LSTM)【动手学深度学习v2】_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/62870.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【windows】windows上如何使用linux命令?

前言 windows上的bat命令感觉不方便&#xff0c;想在windows上使用linux命令。 有人提供了轮子&#xff0c;本文简单介绍一些该轮子的安装与使用&#xff0c;希望能够帮助到和我有一起需求的网友。 我的答案是busybox。 1.安装busybox.exe 在这个网站上安装busybox busyb…

Windows下安装Kafka(图文记录详细步骤)

Windows下安装Kafka Kafka简介一、Kafka安装前提安装Kafka之前&#xff0c;需要安装JDK、Zookeeper、Scala。1.1、JDK安装&#xff08;version&#xff1a;1.8&#xff09;1.1.1、JDK官网下载1.1.2、JDK网盘下载1.1.3、JDK安装 1.2、Zookeeper安装1.2.1、Zookeeper官网下载1.2.…

FPGA纯verilog实现 LZMA 数据压缩,提供工程源码和技术支持

目录 1、前言2、我这儿已有的FPGA压缩算法方案3、FPGA LZMA数据压缩功能和性能4、FPGA LZMA 数据压缩设计方案输入输出接口描述数据处理流程LZ检索器数据同步LZMA 压缩器 为输出LZMA压缩流添加文件头 5、vivado仿真6、福利&#xff1a;工程代码的获取 1、前言 说到FPGA的应用&…

计算机二级Python基本操作题-序号45

1. 键盘输入一组水果名称并以空格分隔&#xff0c;共一行。 示例格式如下&#xff1a; 苹果 芒果 草莓 芒果 苹果 草莓 芒果 香蕉 芒果 草莓 统计各类型的数量&#xff0c;从数量多到少的顺序输出类型及对应数量&#xff0c;以英文冒号分隔&#xff0c;每个类型行。输出结果保存…

学生管理系统(升级版)

import java.util.ArrayList; import java.util.Random; import java.util.Scanner;public class Demo_学生管理系统 {public static void main(String[] args) {ArrayList<User> list new ArrayList<>();Scanner sc new Scanner(System.in);while (true) {Syste…

Hive创建外部表详细步骤

① 在hive中执行HDFS命令&#xff1a;创建/data目录 hive命令终端输入&#xff1a; hive> dfs -mkdir -p /data; 或者在linux命令终端输入&#xff1a; hdfs dfs -mkdir -p /data; ② 在hive中执行HDFS命令&#xff1a;上传/emp.txt至HDFS的data目录下&#xff0c;并命名为…

励志长篇小说《周兴和》书连载之八 处心积虑揽工程

处心积虑揽工程 如何去揽工程&#xff0c;周兴和其实早就谋划好了。 那一天&#xff0c;周兴和与爱人王琼华的姐夫严忠伦、大舅子王安全提了10来只大公鸡&#xff0c;背着两只狗腿&#xff0c;以及城里人喜欢的干豇豆等山区土特产&#xff0c;坐车来到省城成都。下了车&#x…

zookeeper --- 基础篇

一、zookeeper简介 1.1、什么是zookeeper zookeeper官网&#xff1a;https://zookeeper.apache.org/ 大数据生态系统里的很多组件的命名都是某种动物或者昆虫&#xff0c;他是用来管 Hadoop&#xff08;大象&#xff09;、Hive(蜜蜂)、Pig(小 猪)的管理员。顾名思义就是管理…

侧边栏的打开与收起

侧边栏的打开与收起 <template><div class"box"><div class"sideBar" :class"showBox ? : controller-box-hide"><div class"showBnt" click"showBox!showBox"><i class"el-icon-arrow-r…

『PostgreSQL』在 PostgreSQL中创建只读权限和读写权限的账号

&#x1f4e3;读完这篇文章里你能收获到 理解在 PostgreSQL 数据库中创建账号的重要性以及如何进行账号管理掌握在 PostgreSQL 中创建具有只读权限和读写权限的账号的步骤和方法学会使用 SQL 命令来创建账号、为账号分配适当的权限以及控制账号对数据库的访问级别了解如何确保…

css position: sticky;实现上下粘性布局,中间区域滚动

sticky主要解决的问题 1、使用absolute和fixed中间区域需要定义高度2、使用absolute和fixed底部需要写padding-bottom 避免列表被遮挡住一部分&#xff08;底部是浮窗的时候&#xff0c;需要动态的现实隐藏&#xff09; <!DOCTYPE html> <html lang"en"&…

初学者自学python哪本书好,python教程自学全套

大家好&#xff0c;小编来为大家解答以下问题&#xff0c;python怎么自学,可以达到什么程度&#xff0c;初学者自学python哪本书好&#xff0c;现在让我们一起来看看吧&#xff01; 前言 Python是一个非常适合自学&#xff0c;0基础的话从入门到精通也只需要花3-4个月PYTHON库“…

Arch Linux 使用桥接模式上网

如果我们想要将虚拟机与物理主机同一网段&#xff0c;并且像物理机器一样被其他设备访问&#xff0c;则需要以桥接模式上网&#xff0c;这个时候&#xff0c;物理主机就必须配置为使用网桥上网了。 注意&#xff1a;这里我们使用了 NetworkManager 网络管理工具中的 nmcli 来进…

虚拟机重启网络服务失败 Failed to start LSB:Bring up/down networking.

许久没有打开虚拟机了&#xff0c;今天一开打发现无法ping通网络 使用 ip addr 也获取不到ip信息 重启网络服务提示我 使用 systemctl status network.service 命令查看 出现以下报错 百度各种解决方案无效&#xff0c;才发现我为了加快电脑开机速度&#xff0c;把虚拟机的一些…

【源码分析】Nacos服务端如何更新以及保存注册表信息?

文章目录 我们知道服务注册到Nacos之后&#xff0c;Nacos是需要对这些服务实例信息进行保存的&#xff0c;那么Nacos是如何保存的呢&#xff1f; 首先我们先分析Nacos的注册表的结构。 我们知道Nacos有namespace&#xff0c;group&#xff0c;cluster三个分级&#xff0c;他们都…

Llama 2 with langchain项目详解(一)

Llama 2 with langchain项目详解(一) 2023年2月25日,美国Meta公司发布了Llama 1开源大模型。随后,于2023年7月18日,Meta公司发布了Llama 2开源大模型,该系列包括了70亿、130亿和700亿等不同参数规模的模型。相较于Llama 1,Llama 2的训练数据增加了40%,上下文长度提升至…

uniapp 返回上一页并刷新

如要刷新的是mine页面 在/pages/mine/improveInfo页面修改信息&#xff0c;点击保存后跳转到个人中心&#xff08;/pages/mine/index&#xff09;页面并刷新更新数据 点击保存按钮时执行以下代码&#xff1a; wx.switchTab({url: /pages/mine/index }) // 页面重载 let pages …

出现Error: Cannot find module ‘compression-webpack-plugin‘错误

错误&#xff1a; 解决&#xff1a;npm install --save-dev compression-webpack-plugin1.1.12 版本问题

powerdesigner各种字体设置;preview字体设置;sql字体设置

1.设置左侧菜单&#xff1a; 步骤如下&#xff1a; tools —> general options —> fonts —> defalut UI font ,选择字体样式及大小即可&#xff0c;同下图。 2.设置preview字体大小&#xff08;sql预览&#xff09; 步骤如下&#xff1a; tools —> general o…

无涯教程-Perl - chop函数

描述 此函数从EXPR,LIST的每个元素或$_(如果未指定值)中删除最后一个字符。 语法 以下是此函数的简单语法- chop VARIABLEchop( LIST )chop返回值 此函数返回从EXPR中删除的字符,并且在列表context中,从LIST的最后一个元素中删除该字符。 例 以下是显示其基本用法的示例…
最新文章