自己动手实现BatchNorm(pytorch实现)

BatchNorm可以加速模型的收敛并且缓解梯度消失问题,是深度学习领域常用的一个技术

最近仔细学习了BatchNorm的原理,因此想自己动手实现一下它,加深理解

代码如下:

import torch
import torch.nn as nn


class MyBatchNorm(nn.Module):
    # def __init__(self, batch_norm2, dim):
    def __init__(self, dim):
        super().__init__()
        # 可训练参数 gamma和beta
        self.gamma = nn.Parameter(data=torch.randn((dim)))
        self.beta = nn.Parameter(data=torch.randn((dim)))
        # 全局的均值和方差
        self.mean_whole = torch.zeros((dim))
        self.var_whole = torch.zeros((dim))
        self.lba = 0.99
        # 防止除零错误
        self.eps = 1e-7
   
    def forward(self, x):
        # 检查形状
        if x.dim() == 4:
            x = x.reshape(x.size(0), x.size(1), -1)
        assert x.dim() == 3

        # 处于训练状态
        if self.training:
            # 首先计算每个通道的均值和方差
            # (b, c, d) -> (1, c, 1)
            mean_batch = torch.mean(x, dim=[0, 2], keepdim=True)
            var_batch = torch.var(x, dim=[0, 2], keepdim=True, unbiased=False)
            # 使用滑动平均办法计算全局均值和方差
            n = x.numel() / x.size(1)
            self.mean_whole = self.lba * self.mean_whole + (1 - self.lba) * mean_batch
            self.var_whole = self.lba * self.var_whole + (1 - self.lba) * var_batch * n / (n-1)
            # 然后归一化数据
            x = (x - mean_batch) / torch.sqrt((var_batch + self.eps))
        else:
            # 归一化数据
            x = (x - self.mean_whole[None, ..., None]) / torch.sqrt((self.var_whole[None, ..., None] + self.eps))

        # 放缩平移
        x = x * self.gamma[None, ..., None] + self.beta[None, ..., None]
        return x


x = torch.randn((2, 3, 4))

batch_norm = MyBatchNorm(dim=3)
batch_norm = batch_norm.train()

b = batch_norm(x)

print(b.shape)

参考资料:

1. 原理

https://zhuanlan.zhihu.com/p/34879333

2. 代码

https://zhuanlan.zhihu.com/p/337732517

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/606889.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MAC M1 配置 Git SSH

背景 换了新笔记本,本地想要克隆github 上的项目需要配置ssh 公钥到自己的github账户中,否则使用ssh 地址克隆会报错,如下。 gitgithub.com: Permission denied (publickey). fatal: Could not read from remote repository.操作 1. 生成s…

探索大型语言模型(LLM)的世界

​ 引言 大型语言模型(LLM)作为人工智能领域的前沿技术,正在重塑我们与机器的交流方式,在医疗、金融、技术等多个行业领域中发挥着重要作用。本文将从技术角度深入分析LLM的工作原理,探讨其在不同领域的应用&#xff0…

安卓使用Fiddler抓包 2024

简介 最近试了一下安卓使用fiddler 抓包,发现https包基本都会丢失。原因是Anandroid 7版本针对ssl安全性做了加强,不认可用户的证书。我们要做的就是把fiddler导出的证书进过处理后放置到系统证书目录下面,这样才能抓包https请求。 这里使用…

【Anaconda】升级Anaconda Navigator提示JSONDecoderError,删除.condarc文件后搞定

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、报错:JSONDecoderError二、错误原因三、解决问题总结 前言 提示:这里可以添加本文要记录的大概内容: 时间长未升级Ana…

AI 绘画神器 Fooocus 本地部署指南:简介、硬件要求、部署步骤、界面介绍

本文收录于《AI绘画从入门到精通》专栏,专栏总目录:点这里,订阅后可阅读专栏内所有文章。 大家好,我是水滴~~ 随着人工智能技术的飞速发展,AI 绘画逐渐成为创意领域的新宠。Fooocus 作为一款免费开源的 AI 绘画工具&am…

窜货溯源采买的目的

当品牌遇到窜货时,不管是线上还是线下渠道,快速的治理方法,就是找到窜货源头,对源头进行打击,这里面有一步很关键的操作便是买货,将货品买回后做溯源,通过产品本身或者外包装上的条码&#xff0…

【Java orm 框架比较】十 新增hammer_sql_db 框架对比

迁移到(https://gitee.com/wujiawei1207537021/spring-orm-integration-compare) orm框架使用性能比较 比较mybatis-plus、lazy、sqltoy、mybatis-flex、easy-query、mybatis-mp、jpa、dbvisitor、beetlsql、dream_orm、wood、hammer_sql_db 操作数据 …

[uniapp 地图组件] 小坑:translateMarker的回调函数,会调用2次

大概率是因为旋转和移动是两个动画,动画结束后都会分别调用此函数 即使你配置了 【不旋转】它还是会调用两次, 所以此处应该是官方的bug

未来娱乐新地标?气膜球幕影院的多维体验—轻空间

在中国,一座独特的娱乐场所正在崭露头角:气膜球幕影院。这个融合了气膜建筑与激光投影技术的创新场所,不仅令人惊叹,更带来了前所未有的科幻娱乐体验。让我们一起探索这个未来的娱乐空间,感受其中的多维魅力。 现场演出…

Linux(openEuler、CentOS8)企业内网DHCP服务器搭建(固定Mac获取指定IP)

----本实验环境为openEuler系统<以server方式安装>&#xff08;CentOS8基本一致&#xff0c;可参考本文&#xff09;---- 目录 一、知识点二、实验&#xff08;一&#xff09;为服务器配置网卡和IP&#xff08;二&#xff09;为服务器安装DHCP服务软件&#xff08;三&a…

DenseCLIP论文讲解

文章目录 简介方法总体框架 &#xff08;Language-Guided Dense Prediction&#xff09;上下文感知提示 &#xff08;Context-Aware Prompting&#xff09;应用实例 论文&#xff1a;DenseCLIP: Language-Guided Dense Prediction with Context-Aware Prompting 代码&#xff1…

Python3实现三菱PLC串口通讯(附源码和运行图)

基于PyQt5通过串口通信控制三菱PLC 废话不多说&#xff0c;直接上源码 """ # -*- coding:utf-8 -*- Project : Mitsubishi File : Main_Run.pyw Author : Administrator Time : 2024/05/09 下午 04:10 Description : PyQt5界面主逻辑 Software:PyCharm "…

一个注解完美实现分布式锁(AOP)

前言 学习过Spring的小伙伴都知道AOP的强大&#xff0c;本文将通过Redisson结合AOP&#xff0c;仅需一个注解就能实现分布式锁。 &#x1f36d; 不会使用aop和redisson的小伙伴可以参考&#xff1a; 【学习总结】使Aop实现自定义日志注解-CSDN博客 【学习总结】使用分布式锁和…

Vulstack红队评估(一)

文章目录 一、环境搭建1、网络拓扑2、web服务器(win7)配置3、域控&#xff08;winserver2008&#xff09;配置4、域内机器&#xff08;windows 2003&#xff09;配置5、调试网络是否通常 二、web渗透1、信息搜集2、端口扫描3、目录扫描4、弱口令5、phpmyadmin getshell日志gets…

AI时代:人工智能大模型引领科技创造新时代

目录 前言一. AI在国家战略中有着举足轻重的地位1.1 战略1.2 能源1.3 教育 二. AI在日常生活中扮演着重要角色2.1 医疗保健2.2 智能客服2.3 自动驾驶2.4 娱乐和媒体2.5 智能家居 三. AI的未来发展趋势 总结 前言 随着AI技术的进步&#xff0c;新一代的AI技术已经开始尝试摆脱依…

域名系统(DNS)、DNS 服务器和 IP 地址概念解释

​  域名系统、DNS服务器和IP地址是构成互联网基础设施的重要部分。它们共同协作&#xff0c;使得人们能够方便地使用各种网络服务&#xff0c;而无需去记住复杂的数字地址。那么&#xff0c;域名系统、DNS 服务器和 IP 地址又该如何理解?本文主要讲讲关于这几个名词的概念解…

表单设计器开源:助力提质增效的办公利器

在激烈的市场竞争之下&#xff0c;拥有过硬的技术和本领的企业&#xff0c;就能在市场中提升市场竞争力&#xff0c;斩获更多市场份额。作为提质增效的办公利器&#xff0c;低代码技术平台、表单设计器开源拥有理想的优势特点&#xff0c;如操作灵活、易维护、可视化界面等&…

夸克网盘拉新怎么做?分享网盘拉新攻略!

夸克网盘拉新怎么做&#xff1f;如何通过推广夸克网盘来赚佣金&#xff1f;相信大家应该都使用过夸克网盘&#xff0c;现在夸克网盘的拉新赚佣金活动开展的如火如荼&#xff0c;不少朋友通过夸克网盘拉新赚取收益&#xff0c;真的很香。还有一部分想要赚佣金但是不知道如何操作…

OmniReader Pro mac激活版:智慧阅读新选择,开启高效学习之旅

在追求知识的道路上&#xff0c;一款优秀的阅读工具是不可或缺的。OmniReader Pro作为智慧阅读的新选择&#xff0c;以其独特的功能和卓越的性能&#xff0c;为您开启高效学习之旅。 OmniReader Pro具备高效的文本识别和处理技术&#xff0c;能够快速准确地提取文档中的关键信息…

Python 中的 Unit testing 文件写入

在 Python 中进行单元测试时&#xff0c;有时候需要测试文件写入操作。为了模拟文件写入并进行单元测试&#xff0c;你可以使用 Python 的 unittest 模块&#xff0c;并结合 io.StringIO 或 tempfile 模块来模拟文件操作。 1、问题背景 在 Python 中&#xff0c;为 ConfigPars…