【一起撸个DL框架】4 反向传播求梯度

CSDN个人主页:清风莫追
欢迎关注本专栏:《一起撸个DL框架》

文章目录

  • 4 反向传播求梯度🥥
    • 4.1 简介
    • 4.2 导数与梯度
    • 4.3 链式法则
    • 4.4 示例:y=2x+1的梯度

4 反向传播求梯度🥥

4.1 简介

上一篇:【一起撸个DL框架】3 前向传播

前面我们已经介绍了前向传播,而本节即将介绍的反向传播中的自动微分机制,可以说是深度学习框架的一个核心功能。因为计算图中的参数正是按照着梯度的指引来更新的。

4.2 导数与梯度

说到“梯度”与“导数”这两个概念,有些同学可能已经有些模糊了。在一元函数的情况下,两者几乎可以混为一谈,然而在多元函数的情况下梯度与导数的概念是有区别的。例如二元函数 f ( x , y ) f(x,y) f(x,y)

  • 它沿着二维平面内的每一个方向都会有一个方向导数,方向导数的结果是一个数值,代表沿着该方向的变化率
  • 而它只有一个梯度,梯度的结果是一个向量 ▽ f ( x , y ) = ( ∂ f ∂ x , ∂ f ∂ y ) \triangledown f(x,y)=(\frac{\partial f}{\partial x},\frac{\partial f}{\partial y}) f(x,y)=(xf,yf) ,它同时指示了两个信息:变化率最大的方向和相应的变化率。

4.3 链式法则

简单描述就是“嵌套函数的导,等于各层求导的乘积”。下面是数学的表达: z = f ( y ) , y = g ( x ) z=f(y),y=g(x) z=f(y),y=g(x) ,则 ∂ z ∂ x = ∂ z ∂ y ∂ y ∂ x \frac{\partial z}{\partial x}=\frac{\partial z}{\partial y} \frac{\partial y}{\partial x} xz=yzxy 。计算图就相当于多层的嵌套函数,线性函数可以表示成加法和乘法节点的嵌套。

problem: 梯度不是相对于函数而言的吗?例如 ▽ f ( w , x ) \triangledown f(w,x) f(w,x)。为什么会有“w的梯度”这种概念呢?

在计算图中求一个节点的梯度,只需要将结果节点对子节点的梯度子节点对自己的梯度乘起来就可以了。

在这里插入图片描述
图1:链式法则

4.4 示例:y=2x+1的梯度

为了实现反向传播,我们需要在节点类中加入几个方法。

class Node:
    def __init__(self, parent1=None, parent2=None) -> None:
        self.parent1 = parent1
        self.parent2 = parent2
        self.value = None

        self.grad = None  # 在其它结点求梯度时可能再次用到本结点的梯度
        self.children = []

        parents = [self.parent1, self.parent2]
        for parent in parents:
            if parent is not None:
                parent.children.append(self)

    def set_value(self, value):
        self.value = value
    def compute(self):
        '''抽象方法,在具体的类中重写'''
        pass
    def forward(self):
        for parent in [self.parent1, self.parent2]:
            if parent.value is None:
                parent.forward()
        self.compute()
        return self.value
    def get_parent_grad(self, parent):
        '''求本节点对于父节点的梯度,抽象方法'''
        pass
    def get_grad(self):
        '''求结果节点对本节点的梯度'''
        # 结果结点返回单位值,而不是self.value
        if not self.children:
            return 1
        if self.grad is not None:
            return self.grad
        else:
            self.grad = 0
            for i in range(len(self.children)):
                grad1 = self.children[i].get_parent_grad(parent=self)  # 子节点对自己的梯度
                grad2 = self.children[i].get_grad()  # 结果节点对子节点的梯度
                self.grad += grad1 * grad2
            return self.grad


class Varrible(Node):
    def __init__(self) -> None:
        super().__init__()

class Add(Node):
    def __init__(self, parent1=None, parent2=None) -> None:
        super().__init__(parent1, parent2)
    def compute(self):
        self.value = self.parent1.value + self.parent2.value
    def get_parent_grad(self, parent):
        return 1

class Mul(Node):
    def __init__(self, parent1=None, parent2=None) -> None:
        super().__init__(parent1, parent2)
    def compute(self):
        self.value = self.parent1.value * self.parent2.value
    def get_parent_grad(self, parent):
        '''从parent1,2改成parents时,需要重写'''
        if parent == self.parent1:
            return self.parent2.value
        elif parent == self.parent2:
            return self.parent1.value
        else:
            raise "get which is not a parent of mul node"
        

在之前的基础上,我们在Node类中新增了两个属性self.gradself.children,它们都将在get_grad()方法中发挥作用。

然后我们还在Node类中新增了两个方get_parent_grad()get_grad(),分别用来求子节点对于子节点的父节点(即本节点)的梯度,和结果节点对于本节点的梯度。get_parent_grad()是一个抽象方法,还需在Add类和Mul类中进行具体的实现,实现的方式比较简单,大家阅读源代码即可。

关于get_grad()方法,在本节点没有子节点时,就判断本节点为结果节点,梯度设置为单位值1。在本节点梯度已经存在时,就不再进行递归求值,而直接返回已经保存的梯度值。为什么要保存梯度值?一个节点可以有多个需要求梯度的父节点,这种情况下就会多次用到同一个节点的梯度,每次都递归到结果节点来求值显然是不必要的,于是我们使用空间来换时间。

在下面的示例代码中,我们求 w w w x x x的梯度,就两次用到了mul节点的梯度。

if __name__ == '__main__':
    # 搭建计算图: y=2x+1
    x1 = Varrible()
    w1 = Varrible()
    mul = Mul(x1, w1)
    b = Varrible()
    add = Add(mul, b)
    # 给参数赋值
    w1.set_value(2)
    b.set_value(1)
    # 使用计算图计算
    x1.set_value(int(input("请输入x:")))
    y = add.forward()
    print(f"y: {y}")
    # 反向传播求梯度
    w_grad = w1.get_grad()
    x_grad = x1.get_grad()
    print(f"w_grad: {w_grad}, x_grad: {x_grad}")

    '''
    请输入x:3
    y: 7
    w_grad: 3, x_grad: 2
    '''

然后我们再次搭建了函数 y = 2 x + 1 y=2x+1 y=2x+1的计算图,首先进行了前向传播,目的是检查计算图是否正确实现了函数y=2x+1的功能。然后进行反向传播求得了 w w w x x x的梯度。

这一节代码的结构已经逐渐变得复杂,一些细节的设计可能需要反复揣摩才能明白,同时代码也还存在一些有待改进的地方,例如Node类的__init__()方法中,每个节点只会有两个父节点的设定其实是不太合适的。


下一篇:【一起撸个DL框架】5 实现:自适应线性单元

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/17741.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【OpenCV】 2D-2D:对极几何算法原理

2D-2D匹配: 对极几何 SLAM十四讲笔记1 1.1 对极几何數學模型 考虑从两张图像上观测到了同一个3D点,如图所示**。**我们希望可以求解相机两个时刻的运动 R , t R,t R,t。 假设我们要求取两帧图像 I 1 , I 2 I_1,I_2 I1​,I2​之间的运动,设第一帧到第二帧的运动为…

全国快递物流 API 实现快递单号自动识别的原理解析

概述 全国快递物流 API 是一种提供快递物流单号查询的接口,涵盖了包括申通、顺丰、圆通、韵达、中通、汇通等600快递公司的数据。该 API 的目标是为快递公司、电商、物流平台等提供便捷、快速、准确的快递物流信息查询服务。 数据采集和处理 全国快递物流 API 的…

自定义控件 (?/N) - 颜料 Paint

参考来源 一、颜色 1.1 直接设置颜色 1.1.1 setColor( ) public void setColor(ColorInt int color) paint.setColor(Color.RED) paint.setColor(Color.parseColor("#009688")) 1.1.2 setARGB( ) public void setARGB(int a, int r, int g, int b) paint.se…

Packet Tracer – 研究 VLAN 实施

Packet Tracer – 研究 VLAN 实施 地址分配表 设备 接口 IP 地址 子网掩码 默认网关 S1 VLAN 99 172.17.99.31 255.255.255.0 不适用 S2 VLAN 99 172.17.99.32 255.255.255.0 不适用 S3 VLAN 99 172.17.99.33 255.255.255.0 不适用 PC1 NIC 172.17.10.2…

数字化转型导师坚鹏:数字化转型背景下的企业人力资源管理

企业数字化转型背景下的企业人力资源管理 课程背景: 很多企业存在以下问题: 不清楚企业数字化转型目前的发展阶段与重要应用? 不知道企业数字化转型给企业人力资源管理带来哪些机遇与挑战? 不知道企业数字化转型背景下如何…

SQL注入攻防入门详解

毕业开始从事winform到今年转到 web ,在码农届已经足足混了快接近3年了,但是对安全方面的知识依旧薄弱,事实上是没机会接触相关开发……必须的各种借口。这几天把sql注入的相关知识整理了下,希望大家多多提意见。 (对于…

系统集成项目管理工程师 下午 真题 及考点(2020年下半年)

文章目录 2020年下半年试题一:第10章 项目质量管理,规划质量管理过程的输入试题二:第9章 项目成本管理,典型:EAC ACETC AC(BAC-EV)/CPI BAC/CPI试题三:第18章 项目风险管理&#x…

吴恩达ChatGPT网课笔记Prompt Engineering——训练ChatGPT前请先训练自己

吴恩达ChatGPT网课笔记Prompt Engineering——训练ChatGPT前请先训练自己 主要是吴恩达的网课,还有部分github的prompt-engineering-for-developers项目,以及部分自己的经验。 一、常用使用技巧 prompt最好是英文的,如果是中文的prompt&am…

【网站架构】Nginx 4层、7层代理配置,正向代理、反向代理详解

大家好,欢迎来到停止重构的频道。 本期我们讨论网络代理。 在往期《大型网站 安全性》介绍过,出于网络安全的考虑,一般大型网站都需要做网络区域隔离,以防止攻击者直接操控服务器。 网站系统的应用及数据库都会放在这个网络安全…

【Python习题集6】类与对象

类与对象 一、实验内容二、实验总结 一、实验内容 1.设计一个Circle类来表示圆,这个类包含圆的半径以及求面积和周长的函数。在使用这个类创建半径为1~10的圆,并计算出相应的面积和周长。 半径为1的圆,面积: 3.14 周长: 6.28 半径为2的圆&am…

云原生介绍

本博客地址:https://security.blog.csdn.net/article/details/130540430 一、云原生的概念 云原生的整体概念思路是三统一,即统一基础平台、统一软件架构、统一开发流程。 基于统一的基础平台、软件架构以及开发流程,数字化转型和云化转型能…

详解:搭建常见问题(FAQ)的步骤?

许多的Web用户都更加偏向于可信赖的FAQ页面,以此作为快速查找更多信息的方法。因为用户时间的紧缺,并且想知道产品的功能和能够提供的服务。构造精巧的FAQ页面是提供人们寻求信息的绝妙方法,而且还可以提供更多的信息。这就是为什么FAQ页面对…

Chrome远程调试

最近接触到Chrome远程调试相关内容,记录一下。 场景:使用Chrome远程调试Chromium。当能够控制目标主机执行命令之后,可以在该主机上建立全局代理,然后在自己这一边开启浏览器监听,接着在目标机器上执行 chrome.exe --…

3.13 结构体嵌套、大小及位域

目录 结构体嵌套结构体 结构体的大小 位域 结构体嵌套结构体 含义 结构体中的成员可以是另一个结构体 语法 struct 结构体名 { struct 结构体名 成员名; }; 结构体中共同的变量可以单独放出来,单独封装一个结构体 结构体的大小 字节对齐 含义 …

Bark:基于转换器的文本到音频模型

Bark是由Suno创建的一个基于转换器的文本到音频模型。Bark可以生成高度逼真的多语言语音以及其他音频,包括音乐、背景噪音和简单的音效。该模型还可以产生非语言交流,如大笑、叹息和哭泣。为了支持研究社区,我们正在提供对预先训练的模型检查…

【c语言】字符串的基本概念 | 字符串存储原理

创作不易&#xff0c;本篇文章如果帮助到了你&#xff0c;还请点赞 关注支持一下♡>&#x16966;<)!! 主页专栏有更多知识&#xff0c;如有疑问欢迎大家指正讨论&#xff0c;共同进步&#xff01; 给大家跳段街舞感谢支持&#xff01;ጿ ኈ ቼ ዽ ጿ ኈ ቼ ዽ ጿ ኈ ቼ …

【Java EE】-博客系统(前端页面)

作者&#xff1a;学Java的冬瓜 博客主页&#xff1a;☀冬瓜的主页&#x1f319; 专栏&#xff1a;【JavaEE】 分享: 且视他人如盏盏鬼火&#xff0c;大胆地去走你的道路。——史铁生《病隙碎笔》 主要内容&#xff1a;博客系统 登陆页面&#xff0c;列表页面&#xff0c;详情页…

AArch32 AArch64 Registers map详细解析与索引

1、AArch32 Registers AArch32 系统寄存器索引。 例如第一个寄存器ACTLR点击后解析如下&#xff1a; 2、AArch64 Registers AArch64 系统寄存器索引。 3、AArch32 Operations AArch32 系统指令索引。 例如第一个寄存器ACTLR_EL1点击后解析如下&#xff1a; 4、AArch…

Inception模型实现孤立手语词的识别

实现孤立手语词的识别流程如下&#xff0c;在实际研究中&#xff0c;本章将着重研究第三阶段内容&#xff0c;也就是模型的设计与实现过程&#xff0c;目的是提高手语图像的识别准确率。 Inception模型实现 Inception模型是谷歌研究人员在2014年提出的一个深度卷…

一文把 JavaScript 中的 this 聊得明明白白

文章目录 1.this 是什么&#xff1f;2.this的指向2.1 全局上下文的 this 指向2.2 函数&#xff08;普通函数&#xff09;上下文中的 this 指向2.3 事件处理程序中的 this 指向2.4 以对象的方式调用时 this 的指向2.5 构造函数中的 this 指向2.6 在 类上下文中 this 的指向。2.7…
最新文章