手写一个RNN前向传播以及反向传播

前向传播

根据公式

st = tanh (Uxt + Wst-1 + ba)

ot = softmax(Vst + by )

m = 3 词的个数   n = 5

import numpy as np
import tensorflow as tf
# 单个cell 的前向传播过程
# 两个输入,x_t,s_prev,parameters
def rnn_cell_forward(x_t,s_prev,parameters):
    """
    单个cell 的前向传播过程
    :param x_t: 当前T时刻的序列输入
    :param s_prev: 上一个cell的隐藏层状态输入
    :param parameters: cell中参数,字典
    :return: 隐层输出 s_next,out_pred,cache
    """
    # 取出参数
    U = parameters["U"]
    W = parameters["W"]
    V = parameters["V"]
    ba = parameters["ba"]
    by = parameters["by"]
    # 根据公式计算
    # 隐层输出计算
    s_next = np.tanh(np.dot(U,x_t) + np.dot(W,s_prev) + ba)
    # 计算cell的输出
    out_pred = tf.nn.softmax(np.dot(V,s_next) + by)
    # 记录每层的值,用于反向传播计算使用
    cache = (s_next,s_prev,x_t,parameters)

    return s_next,out_pred,cache
if __name__ == '__main__':
    # forward
    np.random.seed(1)
    # 定义该cell的输入
    x_t = np.random.randn(3, 1,)
    s_prev = np.random.randn(5, 1)
    # 定义参数
    W = np.random.randn(5, 5)
    U = np.random.randn(5, 3)
    V = np.random.randn(3, 5)
    ba = np.random.randn(5, 1)
    by = np.random.randn(3, 1)
    parameters = {"U": U, "W": W, "V": V, "ba": ba, "by": by}
    s_next, out_pred, caches = rnn_cell_forward(x_t, s_prev, parameters)
    print("s_next = ", s_next)
    print("s_next.shapr = ", s_next.shape)
    print("out_pred =", out_pred)
    print("out_pred.shape = ",out_pred.shape)

单个cell反向传播

根据图我们能够知道需要计算的梯度变量有哪些

ds_next:表示当前cell的损失对输出s的导数

dtanh:表示当前cel的损失对激活函数的导数

dx_t:表示当前cell的损失对输入xt的导数。

dU:表示当前cell的损失对U的导数

ds_prev:表示当前cell的损失对上一个cell的输入的导数

dW:表示当前cell的损失对W的导数

dba:表示当前cell的损失对dba的导数

表示公式:

def rnn_cell_forward(x_t,s_prev,parameters):
    """
    单个cell 的前向传播过程
    :param x_t: 当前T时刻的序列输入
    :param s_prev: 上一个cell的隐藏层状态输入
    :param parameters: cell中参数,字典
    :return: 隐层输出 s_next,out_pred,cache
    """
    # 取出参数
    U = parameters["U"]
    W = parameters["W"]
    V = parameters["V"]
    ba = parameters["ba"]
    by = parameters["by"]
    # 根据公式计算
    # 隐层输出计算
    s_next = np.tanh(np.dot(U,x_t) + np.dot(W,s_prev) + ba)
    # 计算cell的输出
    out_pred = tf.nn.softmax(np.dot(V,s_next) + by)
    # 记录每层的值,用于反向传播计算使用
    cache = (s_next,s_prev,x_t,parameters)

    return s_next,out_pred,cache
def rnn_cell_backward(ds_next, cache):
    """
    对单个cell进行反向传播
    :param ds_next: 当前隐层输出结果相对于损失的导数
    :param cache: 每个cell的缓存
    :return:gradients
    """

    # 获取缓存值
    (s_next, s_prev, x_t, parameters) = cache
    print(type(parameters))

    # 获取参数
    U = parameters["U"]
    W = parameters["W"]
    # V = parameters["V"]
    # ba = parameters["ba"]
    # by = parameters["by"]

    # 计算tanh的梯度通过对s_next
    dtanh = (1 - s_next ** 2) * ds_next

    # 计算U的梯度值
    dx_t = np.dot(U.T, dtanh)

    dU = np.dot(dtanh, x_t.T)

    # 计算W的梯度值
    ds_prev = np.dot(W.T, dtanh)
    dW = np.dot(dtanh, s_prev.T)

    # 计算b的梯度
    dba = np.sum(dtanh,axis=1,keepdims= 1)

    # 梯度字典
    gradients = {"dtanh" : dtanh,"dx_t": dx_t, "ds_prev": ds_prev, "dU": dU, "dW": dW, "dba": dba}

    return gradients

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/579444.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

每日OJ题_DFS回溯剪枝⑧_力扣494. 目标和

目录 力扣494. 目标和 解析代码(path设置成全局) 解析代码(path设置全局) 力扣494. 目标和 494. 目标和 难度 中等 给你一个非负整数数组 nums 和一个整数 target 。 向数组中的每个整数前添加 或 - ,然后串联…

SpringBoot + Vue实现Github第三方登录

前言:毕业设计终于好了,希望能有空多写几篇 1. 获取Github账号的Client ID和Client secrets 首先点击这个链接进入Github的OAuth Apps页面,页面展示如下: 之后我们可以创建一个新的apps: 填写资料: 创建之后就可以获…

WebGIS面试题(第六期)-GeoServer

WebGIS面试题(第六期) 以下题目仅为部分题目,全部题目在公众号 {GISer世界} ,答案仅供参考!!! 因为本人之前做过相关项目用到了GeoServer,因此在简历上写了熟悉GeoServer。所以在相关面试中都有问到,所以我…

【项目】仿muduo库One Thread One Loop式主从Reactor模型实现高并发服务器(Http板块)

【项目】仿muduo库One Thread One Loop式主从Reactor模型实现高并发服务器(Http板块) 一、思路图二、Util板块1、Splite板块(分词)(1)代码(2)测试及测试结果i、第一种测试ii、第二种…

[论文阅读] 3D感知相关论文简单摘要

Adaptive Fusion of Single-View and Multi-View Depth for Autonomous Driving 提出了一个单、多视图融合深度估计系统,它自适应地集成了高置信度的单视图和多视图结果 动态选择两个分支之间的高置信度区域执行融合 提出了一个双分支网络,即一个以单…

查看笔记本电池容量/健康状态

1. 打开命令行提示符 快捷键“win R”后输入“cmd” 2. 在命令提示符中输入命令 “powercfg /batteryreport" 并回车 3. 查看文件 最后就可以看到笔记本的电池使用报告了

Promises: JavaScript异步编程的救星

Promises: JavaScript异步编程的救星 Promises(承诺)是JavaScript中处理异步操作的一种机制,它提供了一种更优雅和可读性更高的方式来处理异步代码。Promises的实现原理基于一种称为"Promise/A"规范的约定,该规范定义了…

[蓝桥杯2024]-Reverse:rc4解析(对称密码rc4)

无壳 查看ida 这里应该运行就可以得flag,但是这个程序不能直接点击运行 按照伪代码写exp 完整exp: keylist(gamelab) content[0xB6,0x42,0xB7,0xFC,0xF0,0xA2,0x5E,0xA9,0x3D,0x29,0x36,0x1F,0x54,0x29,0x72,0xA8, 0x63,0x32,0xF2,0x44,0x8B,0x85,0x…

visual studio2022,开发CMake项目添加rabbitmq库,连接到远程计算机并进行开发于调试

1.打开visual studio installer 。安装“用于 Windows 的 C CMake 工具” 2.新建CMake项目 3.点击VS的“工具”—>"选项“—>“跨平台”—>”连接管理器“,添加远程计算机。用来将VS编辑的代码传到服务器进行编译–连接—运行(调试)。 …

BIO、NIO与AIO

一 BIO 同步并阻塞(传统阻塞型),服务器实现模式为一个连接一个线程,即客户端有连接请求时服务器端就需要启动一个线程进行处理. BIO(Blocking I/O,阻塞I/O)模式是一种网络编程中的I/O处理模式。在BIO模式中&#xf…

鸿蒙内核源码分析(任务调度篇) | 任务是内核调度的单元

任务即线程 在鸿蒙内核中,广义上可理解为一个任务就是一个线程 官方是怎么描述线程的 基本概念 从系统的角度看,线程是竞争系统资源的最小运行单元。线程可以使用或等待CPU、使用内存空间等系统资源,并独立于其它线程运行。 鸿蒙内核每个…

[蓝桥杯2024]-PWN:fd解析(命令符转义,标准输出重定向)

查看保护 查看ida 这里有一次栈溢出,并且题目给了我们system函数。 这里的知识点没有那么复杂 完整exp: from pwn import* pprocess(./pwn) pop_rdi0x400933 info0x601090 system0x400778payloadb"ca\\t flag 1>&2" print(len(paylo…

SAP PP学习笔记07 - 作业手顺(工艺路线Routing)

上一章讲了BOM的相关知识。 SAP PP学习笔记07 - 简单BOM,派生BOM,多重BOM,批量修改工具 CEWB_sap半成品有多个bom-CSDN博客 本章来讲作业手顺(工艺路线Routing)的相关知识。 1,作业手顺(工艺路线 Routing…

四、线段、矩形、圆、椭圆、自定义多边形、边缘轮廓和文本绘制(OpenCvSharp)

功能实现: 对指定图片上进行绘制线段、矩形、圆、椭圆、自定义多边形、边缘轮廓以及自定义文本 一、布局 用到了一个pictureBox和八个button 二、引入命名空间 using System; using System.Collections.Generic; using System.Drawing; using System.Windows.F…

Dockerfile镜像构建实战

一、构建Apache镜像 cd /opt/ #建立工作目录 mkdir /opt/apache cd apache/vim Dockerfile #基于的基础镜像 FROM centos:7 #维护镜像的用户信息 MAINTAINER this is apache image <cyj> #镜像操作指令安装Apache软件 RUN yum install -y httpd #开启80端口 EXPOSE 80 #…

远程桌面连接不上个别服务器的问题分析与解决方案

在日常的IT运维工作中&#xff0c;远程桌面连接&#xff08;RDP&#xff0c;Remote Desktop Protocol&#xff09;是我们经常使用的工具之一&#xff0c;用于管理和维护远程服务器。然而&#xff0c;有时我们可能会遇到无法连接到个别服务器的情况。针对这一问题&#xff0c;我…

《Kafka 3.x.x 入门到精通》

Kafka 3.x.x 入门到精通 Kafka是一个由Scala和Java语言开发的&#xff0c;经典高吞吐量的分布式消息发布和订阅系统&#xff0c;也是大数据技术领域中用作数据交换的核心组件之一。以高吞吐&#xff0c;低延迟&#xff0c;高伸缩&#xff0c;高可靠性&#xff0c;高并发&#x…

【论文浅尝】Porting Large Language Models to Mobile Devices for Question Answering

Introduction 移动设备上的大型语言模型(LLM)增强了自然语言处理&#xff0c;并支持更直观的交互。这些模型支持高级虚拟助理、语言翻译、文本摘要或文本中关键术语的提取(命名实体提取)等应用。 LLMs的一个重要用例也是问答&#xff0c;它可以为大量的用户查询提供准确的和上…

LeetCode 热题 100 题解:二叉树部分(1 ~ 5)

题目一&#xff1a;二叉树的中序遍历&#xff08;No. 948&#xff09; 94. 二叉树的中序遍历 - 力扣&#xff08;LeetCode&#xff09; 题目难度&#xff1a;简单 给定一个二叉树的根节点 root &#xff0c;返回 它的 中序 遍历 。 示例 1&#xff1a; 输入&#xff1a;roo…

【Django】初识Django快速上手

Django简介 Django是一个高级的、开源的Python Web框架&#xff0c;旨在快速、高效地开发高质量的Web应用程序 https://developer.mozilla.org/zh-CN/docs/Learn/Server-side/Django/Introduction 安装Django pip install Django如果要知道安装的Django的版本&#xff0c;可…