pytorch框架:pytorch的钩子

说明

  • 在深度学习中,"钩子"通常指的是在模型训练或推理过程中插入的一些回调函数或处理程序,以执行额外的操作或监控模型的行为。这些钩子可以分为两种类型:张量钩子和模块钩子。
  1. 张量钩子(Tensor Hooks)

    张量钩子是与模型中的具体张量(tensor)相关联的。通过在张量上注册钩子,可以在张量的计算中执行自定义的操作,例如记录梯度、修改张量的值等。这对于调试、可视化和梯度的处理非常有用。在PyTorch中,可以使用register_hook方法来添加张量钩子。

    例子:

    def tensor_hook(grad):
        # 自定义操作,可以在这里处理梯度信息
        print("梯度信息:", grad)
    
    # 注册张量钩子
    tensor.register_hook(tensor_hook)
    
  2. 模块钩子(Module Hooks)

    模块钩子是与模型中的具体模块(layer、block等)相关联的。通过在模块上注册钩子,可以在模块的前向或后向传播中执行自定义操作,例如获取模块的输出、记录模块的参数等。在PyTorch中,可以使用register_forward_hookregister_backward_hook方法来添加模块钩子。

    例子:

    def forward_hook(module, input, output):
        # 自定义前向传播操作
        print("输入:", input)
        print("输出:", output)
    
    def backward_hook(module, grad_input, grad_output):
        # 自定义反向传播操作
        print("梯度输入:", grad_input)
        print("梯度输出:", grad_output)
    
    # 注册模块钩子
    module.register_forward_hook(forward_hook)
    module.register_backward_hook(backward_hook)
    

张量钩子

我们要看的第一种类型

  • 紫色为梯度,e处的梯度如果没有特殊指定的话,默认值为1
    在这里插入图片描述

  • 一旦我们向后调用e点(执行到e),通过这些节点的梯度的整个计算,对我们来说是不可接近的

  • 当它们流过时,我们无法真正检查梯度,或者如果我们想改变他们,我们只能看到梯度是什么,输出到叶注释

在这里插入图片描述

  • 这就是张量上的钩子的用武之地,它们允许我们在梯度向后流过图形时检查它们,并有可能改变它们

  • 当你加钩子的时候,在这里我们加入第一个钩子,我们称C点寄存器钩,我们通过它接受梯度的函数,可选地返回一个新的梯度,如果你不从这个函数返回任何东西,就用和之前一样的梯度把它传递下去

  • 所以当我们注册这个钩子的时候,它首先被添加到这个c张量上的向后钩子上,这是一本有序的字典,所以你把钩子加到张量上的顺序很重要,因为在向后的图表中,它们会按这个顺序被调用

  • 接下来我们再注册一个钩子,这次我们只是给它传递一个lambda函数,我就打印出一个渐变,所以它不会改变梯度
    我把它打印出来
    它将继续使用向后图中的前一个渐变
    你可以在这里看到它在向后钩子上添加了lambda函数

  • 接下来我们叫C保持,如果要在中间节点上存储渐变,所以在这个例子中,A,B和D是叶节点

  • 默认情况下,它们将是唯一获得存储到它们的渐变的节点,通过这些累加梯度节点,如果我们想要一个渐变存储在中间节点上

  • 接下来我们将创建d张量,然后我们在d张量上注册一个钩子,这里它只是一个lambda函数
    再加一百,因为它返回一个梯度,它将替换传递给它的梯度
    在这里插入图片描述

  • 所以需要注意的是,向中间节点和叶节点添加钩子是有区别的

  • 向叶子节点添加钩子时,它只是把它添加到它的向后钩子有序字典中

  • 但是当您向中间节点添加钩子时,第一次将钩子添加到它的向后钩子顺序字典中,您还通知与此张量关联的向后图中的节点,在这种情况下,你要加上这个张量的后钩到这个Pre钩子的向后节点列表
    在这里插入图片描述

模块钩子

在这里插入图片描述

  • 我们现在来看看模块上的钩子,这些会更容易理解,首先呢,一个典型的模块将有一个正向方法
  • 这里我们只接受三个输入,我们将它们相加并返回输出

在这里插入图片描述

  • 模块钩子是添加一个函数,该函数在这个前向方法之前被调用,或者在这个前向方法之后调用的函数

CG

  • 【动态图Hook机制解释】PyTorch Hooks Explained - In-depth Tutorial

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/396507.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

读十堂极简人工智能课笔记07_模拟与情感

1. 数码式考察 1.1. 制作计算机动画或游戏 1.1.1. 想怎么制作都可以 1.2. 计算机模拟 1.2.1. 目标是建造一个虚拟的实验室,其行为与现实完全一致,只是某些变量由我们来控制 1.3. 对现实世界进行建模并不容易,需要非常谨慎地收集和使用数…

微信小程序 搜索框实现模糊搜索(带模拟数据,js,wxml,wxss齐全)

最近在做一个小程序的页面,搜索框困扰了我很久,今天终于把搜索框给做了出来,记录一下过程 我主要使用的就是wx的if,当我输入框用户点击的时候,我前面的显示界面添加上false属性,然后我搜索页面显示出true的…

gRPC 备查

简介 HTTP/2 HTTP/2 的三个概念 架构 使用流程 gRPC 的接口类型 1.单一RPC 2.服务器流式RPC 3.客户端式流式RPC 4.双向流式RPC

【ARMv8M Cortex-M33 系列 8 -- RT-Thread 移植 posix pthread】

文章目录 RT-Thread POSIX PthreadRT-Thread Pthread 相关宏定义RT-Thread libc 初始化RT-Thread Pthread 测试 RT-Thread POSIX Pthread pthread是POSIX(Portable Operating System Interface)标准定义的一套线程相关的API,全称为POSIX Thr…

PDF控件Spire.PDF for .NET【安全】演示:如何在 PDF 中添加签名字段

Spire.PDF for .NET 是一款独立 PDF 控件,用于 .NET 程序中创建、编辑和操作 PDF 文档。使用 Spire.PDF 类库,开发人员可以新建一个 PDF 文档或者对现有的 PDF 文档进行处理,且无需安装 Adobe Acrobat。 E-iceblue 功能类库Spire 系列文档处…

Golang - 使用CentOS 7 安装Golang环境

文章目录 操作步骤 操作步骤 为在CentOS 7上安装Go语言环境,可以按照以下步骤进行操作: 下载Go语言包: 从官方网站 https://golang.org/dl/ 下载适用于Linux的Go语言包。 解压缩Go语言包: 使用以下命令解压缩下载的Go语言包 […

刷题Day3

🌈个人主页:小田爱学编程 🔥 系列专栏:刷题日记 🏆🏆关注博主,随时获取更多关于IT的优质内容!🏆🏆 😀欢迎来到小田代码世界~ 😁 喜欢…

给label-studio 配置sam(segment anything)ml 记录

给label-studio 配置sam(segment anything)ml 后端记录 配置ml后台下载代码下载模型文件创建环境模型转换后端服务启动 配置label-studio 前端配置模型后端连接配置标注模板标注界面使用 参考链接 配置ml后台 下载代码 git clone https://github.com/H…

机器学习---规则学习(一阶规则学习、归纳逻辑程序设计)

1. 一阶规则学习 “一阶”的目的:描述一类物体的性质、相互关系,比如利用一阶关系来挑“ 更好的”瓜,但实际应用 中很难量化颜色、 …、敲声的属性值。一般情况下可以省略全称量词。 命题逻辑:属性-值数据 色泽程度&#xff1a…

2.19学习总结

1.中位数 2.统计和 3.铺设道路 4.岛屿个数 5.冶炼金属 6.飞机降落 7.接龙数列 中位数https://www.luogu.com.cn/problem/P1168 题目描述 给定一个长度为 �N 的非负整数序列 �A,对于前奇数项求中位数。 输入格式 第一行一个正整数 &#xfff…

Spring Boot与LiteFlow:轻量级流程引擎的集成与应用含完整过程

点击下载《Spring Boot与LiteFlow:轻量级流程引擎的集成与应用含完整过程》添加链接描述 1. 前言 本文旨在介绍Spring Boot与LiteFlow的集成方法,详细阐述LiteFlow的原理、使用流程、步骤以及代码注释。通过本文,读者将能够了解LiteFlow的特…

【LeetCode: 590. N 叉树的后序遍历 + DFS】

🚀 算法题 🚀 🌲 算法刷题专栏 | 面试必备算法 | 面试高频算法 🍀 🌲 越难的东西,越要努力坚持,因为它具有很高的价值,算法就是这样✨ 🌲 作者简介:硕风和炜,…

163邮箱发邮件

1、Jenkins安装Email Extension Plugin 2、网易邮箱里获取授权码:qa_jenkins_robot@163.com 开启POP3/SMTP 我已经配置过了,所以这里会有一个使用设备 3、配置Jenkins邮箱通知 Manage Jnekins-Configuration System Jenkins Location: Extended E-mail Notification: …

FL Studio21中文版本混音功能介绍

FL Studio 21的混音功能是其音乐制作能力中不可或缺的一部分,它为用户提供了强大的工具,以便他们可以对音轨进行细致的调整,确保音乐作品的最终呈现效果达到最佳。 FL Studio 21 Win-安装包下载如下: https://wm.makeding.com/iclk/?zonei…

图形渲染基础学习

原文链接:游戏开发入门(三)图形渲染_如果一个面只有三个像素进行渲染可以理解为是定点渲染吗?-CSDN博客 游戏开发入门(三)图形渲染笔记: 渲染一般分为离线渲染与实时渲染,游戏中我们用的都是…

指针的进阶(C语言)(上)

目录 前言 1、字符指针 2、指针数组 3、数组指针 3.1数组指针的定义 3.2 数组名VS&数组名 3.3数组指针的运用 前言 对于指针,我们已经有了初步认识(可以看我写的指针详解那一篇文章)。 简单总结一下基本概念: 1、指针就…

探索海洋世界,基于YOLOv5全系列【n/s/m/l/x】参数模型开发构建海洋场景下海洋生物检测识别分析系统

前面的博文中,开发实践过海底相关生物检测识别的项目,对于海洋场景下的海洋生物检测则很少有所涉及,这里本文的主要目的就是想要开发构建基于YOLOv5的海洋场景下的海洋生物检测识别系统。 前文相关的开发实践如下,感兴趣的话可以…

Django实战:部署项目 【资产管理系统】,Django完整项目学习研究(项目全解析,部署教程,非常详细)

导言 关于Django,我已经和大家分享了一些知识,考虑到一些伙伴需要在实际的项目中去理解。所以我上传了一套Django的项目学习源码,已经和本文章进行了绑定。大家可以自行下载学习,考虑到一些伙伴是初学者,几年前&#…

MySQL-DDL-数据库操作

目录 数据库操作查询所有数据库创建数据库使用数据库查询当前数据库删除数据库 数据库操作 DDL 英文全称是 Data Definition Language,数据定义语言,用来定义数据库对象(数据库、表)。 查询所有数据库 show databases;创建数据库 create database [ i…

C++面试宝典第30题:分发饼干

题目 假设你是一位非常棒的家长,想要给你的孩子们分发一些小饼干。但是,每个孩子最多只能给一块饼干。对每一个孩子i,都有一个胃口值gi,这是能让孩子们满足胃口的饼干的最小尺寸。对每一块饼干j,都有一个尺寸sj。如果sj >= gi,我们就可以将这个饼干j分配给孩子i,这个…