RepVGG学习笔记

📣 论文下载地址:https://arxiv.org/abs/2101.03697
📣 官方源码(Pytorch实现):https://github.com/DingXiaoH/RepVGG

0 前言

🐾 R e p V G G RepVGG RepVGG最大的创新之处是:结构重参数化。结构重参数化是一个新的概念,在训练和推理阶段采用不同的策略。比如:在训练阶段,使用一个类似 R e s N e t ResNet ResNet的多分支模型,而在推理时转化成 V G G VGG VGG的单路模型。

1 结构重参数化

🐾 结构重参数化就是在训练阶段采用多分支模型,而在推理测试阶段使用单分支模型,可以理解为如下图所示,训练时模型的基本组件为 A A A,而在测试阶段模型的基本组件为 B B B
在这里插入图片描述
🐾 那么关键的问题来了,训练的多分支模型是如何转化到测试的单分支模型的呢?👍 👍 👍 全文参考 强力推荐:结构重参数化主要分为两步,第一步主要是将 C o n v 2 D Conv2D Conv2D算子和 B N BN BN算子融合以及将只有 B N BN BN的分支转换成一个 C o n v 2 D 3 × 3 Conv2D 3\times3 Conv2D3×3分支;第二步将多个分支上的 C o n v 2 D 3 × 3 Conv2D 3\times3 Conv2D3×3卷积分支融合成单路结构的一个 3 × 3 3\times3 3×3卷积。

1.1 结构重参数化第一步(将 C o n v 2 D Conv2D Conv2D算子和 B N BN BN算子融合以及将只有 B N BN BN的分支转换成一个 C o n v 2 D Conv2D Conv2D算子)

  • Conv2D + BN ——> Conv2D
    BN ——> Conv2D
    卷积计算方式: y i = W i ∗ x i + b y_i=W_i*x_i + b yi=Wixi+b,在模型训练阶段,若忽略偏置的影响,则只有公式: y i = W i ∗ x i y_i=W_i*x_i yi=Wixi。标准化 B N BN BN的计算公式为: y i = x i − u i ( δ i ) 2 + ϵ × γ i + β i y_i = \frac{x_i - u_i}{\sqrt{(\delta_i)^2+\epsilon}}\times \gamma_i+\beta_i yi=(δi)2+ϵ xiui×γi+βi,将 B N BN BN层的计算公式按类似卷积计算公式展开就为: y i = γ i δ i 2 + ϵ x i + ( β i − u i ⋅ γ i δ i 2 + ϵ ) y_i= \frac{\gamma_i}{\sqrt{\delta_i^2+\epsilon}}x_i+(\beta_i - \frac{u_i \cdot \gamma_i}{\sqrt{\delta_i^2+\epsilon}}) yi=δi2+ϵ γixi+(βiδi2+ϵ uiγi),在所有的模型中,都是卷积层后接标准化(归一化)层,那么就有卷积层的输出是 B N BN BN层的输入,也就是说在忽略卷积偏置的条件下, y i = W i ∗ x i y_i=W_i*x_i yi=Wixi就是 B N BN BN计算公式中 x i x_i xi的输入。
    上步骤说到卷积在忽略偏置的条件下,计算公式为: y i = W i ∗ x i y_i=W_i*x_i yi=Wixi,就输入特征图和卷积核来说,卷积计算的理论过程是对应位置相乘后相加,卷积核移位,再同理计算下一个值。如下图所示:
    在这里插入图片描述

据上图所示,得到 y 1 y_1 y1位置的计算值为: w 1 ∙ x 1 + w 2 ∙ x 2 + w 3 ∙ x 3 + w 4 ∙ x 5 + w 5 ∙ x 6 + w 6 ∙ x 7 + w 7 ∙ x 9 + w 8 ∙ x 10 + w 9 ∙ x 11 w_1∙x_1+w_2∙x_2+w_3∙x_3+w_4∙x_5+w_5∙x_6+w_6∙x_7+w_7∙x_9+w_8∙x_{10}+w_9∙x_{11} w1x1+w2x2+w3x3+w4x5+w5x6+w6x7+w7x9+w8x10+w9x11,以该点位置为例,该值输入到 B N BN BN层中,计算公式为: ( w 1 ∙ x 1 + w 2 ∙ x 2 + w 3 ∙ x 3 + w 4 ∙ x 5 + w 5 ∙ x 6 + w 6 ∙ x 7 + w 7 ∙ x 9 + w 8 ∙ x 10 + w 9 ∙ x 11 ) ⋅ γ 1 δ 1 2 + ϵ + ( β 1 − u 1 ⋅ γ 1 δ 1 2 + ϵ ) (w_1∙x_1+w_2∙x_2+w_3∙x_3+w_4∙x_5+w_5∙x_6+w_6∙x_7+w_7∙x_9+w_8∙x_{10}+w_9∙x_{11}) \cdot \frac{\gamma_1}{\sqrt{\delta_1^2+\epsilon}}+(\beta_1 - \frac{u_1 \cdot \gamma_1}{\sqrt{\delta_1^2+\epsilon}}) (w1x1+w2x2+w3x3+w4x5+w5x6+w6x7+w7x9+w8x10+w9x11)δ12+ϵ γ1+(β1δ12+ϵ u1γ1)。到这里 C o n v 2 D Conv2D Conv2D就和 B N BN BN层融合在一个计算公式内了。通过公式也可以发现,该公式就是加了偏置的卷积计算公式,所以就可以融合为一个卷积分支了。(理解到这里就算差不多了,但最关键的还是代码了,接着看~~)

🐾 根据学习发现,原论文中的 c o n v 2 d + B N conv2d + BN conv2d+BN或者 B N BN BN,都是转换为 3 × 3 3\times3 3×3卷积,然后再将转化的多个 3 × 3 3\times3 3×3卷积进行融合,为一个 3 × 3 3\times3 3×3卷积,才算结束。那么就有这几种转化需要理解:

  • 1 × 1 卷积 1\times1 卷积 1×1卷积 ——> 3 × 3 卷积 3\times3 卷积 3×3卷积
    原始 1 × 1 1\times1 1×1卷积,只需要将卷积核设置 p a d = 1 pad=1 pad=1的0填充,就可以将 1 × 1 1\times1 1×1卷积转换为 3 × 3 3\times3 3×3卷积。
    如下图所示:
    在这里插入图片描述
  • B N 标准化 BN标准化 BN标准化 ——> 3 × 3 卷积 3\times3 卷积 3×3卷积
    只有 B N BN BN的模型结构,需要构建一个卷积层,然后再根据 C o n v 2 D + B N Conv2D + BN Conv2D+BN转换公式进行转换,得到 3 × 3 3\times3 3×3卷积层。

1.2 结构重参数化第二步(多分支的 3 × 3 3\times3 3×3卷积融合成一个 3 × 3 3\times3 3×3卷积)

在这里插入图片描述

根据图例,融合后的3分支卷积的计算过程可以理解为: y 1 = w 1 ∗ x 1 + b 1 、 y 2 = w 2 ∗ x 1 + b 2 、 y 3 = w 3 ∗ x 1 + b 3 y_1=w_1*x_1+b_1、y_2=w_2*x_1+b_2、y_3=w_3*x_1+b_3 y1=w1x1+b1y2=w2x1+b2y3=w3x1+b3默认输入都是 x 1 x_1 x1,多分支再进行融合可计算为: y i = ( w 1 + w 2 + w 3 ) ∗ x 1 + ( b 1 + b 2 + b 3 ) y_i=(w_1+w_2+w_3)*x_1+(b_1+b_2+b_3) yi=(w1+w2+w3)x1+(b1+b2+b3)
至此,整个的融合过程就结束了。

2 代码分析

🐾 下面就根据大佬的博文的代码来详细剖析下, C o n v 2 D + B N Conv2D + BN Conv2D+BN 转换为 3 × 3 3\times3 3×3卷积。

from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn

def main():
    torch.random.manual_seed(0)
    f1 = torch.randn(1, 2, 3, 3)
    module = nn.Sequential(OrderedDict(
        conv=nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=1, padding=1, bias=False),
        bn=nn.BatchNorm2d(num_features=2)
    ))

    module.eval()

    with torch.no_grad():
        output1 = module(f1)
        print(output1)

    # fuse conv + bn
    kernel = module.conv.weight            # Conv2D 卷积核
    running_mean = module.bn.running_mean  # BN 均值
    running_var = module.bn.running_var    # BN 方差
    gamma = module.bn.weight               # BN  gamma 可学习参数
    beta = module.bn.bias                  # BN  beta  可学习参数
    eps = module.bn.eps                    # BN  eps   可学习参数
    std = (running_var + eps).sqrt()       # BN  根号下方差加eps(很小的数)
    t = (gamma / std).reshape(-1, 1, 1, 1)  # [ch] -> [ch, 1, 1, 1]
    kernel = kernel * t
    bias = beta - running_mean * gamma / std
    fused_conv = nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=1, padding=1, bias=True)
    fused_conv.load_state_dict(OrderedDict(weight=kernel, bias=bias))

    with torch.no_grad():
        output2 = fused_conv(f1)
        print(output2)

    np.testing.assert_allclose(output1.numpy(), output2.numpy(), rtol=1e-03, atol=1e-05)
    print("convert module has been tested, and the result looks good!")


if __name__ == '__main__':
    main()

上述代码中可以看到是定义了一个 n n . S e q u e n t i a l nn.Sequential nn.Sequential序列容器,里面包含 3 × 3 3\times3 3×3卷积和 B N BN BN归一化,此时卷积是忽略偏置系数的,而这个卷积可以理解为是定义在模型训练阶段的,然后就是通过公式: ( x 1 1 ⋅ k 1 1 + x 2 2 ⋅ k 1 1 + x 3 3 ⋅ k 3 3 + x 4 4 ⋅ k 4 4 ) ⋅ γ 1 δ 1 2 + ϵ + ( β 1 − u 1 ⋅ γ 1 δ 1 2 + ϵ ) (x_1^1 \cdot k_1^1+x_2^2 \cdot k_1^1+x_3^3 \cdot k_3^3+x_4^4 \cdot k_4^4) \cdot \frac{\gamma_1}{\sqrt{\delta_1^2+\epsilon}}+(\beta_1 - \frac{u_1 \cdot \gamma_1}{\sqrt{\delta_1^2+\epsilon}}) (x11k11+x22k11+x33k33+x44k44)δ12+ϵ γ1+(β1δ12+ϵ u1γ1) 来进行二者的融合。根据代码再结合此公式就很好理解了,融合后的 3 × 3 3\times3 3×3卷积的偏置系数为: β 1 − u 1 ⋅ γ 1 δ 1 2 + ϵ \beta_1 - \frac{u_1 \cdot \gamma_1}{\sqrt{\delta_1^2+\epsilon}} β1δ12+ϵ u1γ1,需要注意的是融合后的卷积是有偏置系数的,而且是在模型推理阶段。然后就是卷积核的参数,其计算的公式为:待融合的卷积权重参数 × γ 1 δ 1 2 + ϵ \times \frac{\gamma_1}{\sqrt{\delta_1^2+\epsilon}} ×δ12+ϵ γ1,可以发现融合过程是严格按着上述公式计算得来的。新的卷积就融合了 C o n v 2 d D + B N Conv2dD + BN Conv2dD+BN的所有参数信息。(值得注意的就是:融合前的卷积是不带有偏置的,而融合后的卷积要设置偏置为 T r u e True True

📝 📝 📝

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/16405.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2023五一数学建模竞赛选题人数公布

数据来源自,各个平台人数投票统计,仅供参考。 具体数值比例为: 题号人数A504B1174C1905 目前,五一数模竞赛C题半成品论文基本完成制作(累计35页,10000字),注:蓝色字体…

three.js学习 06 - 结合GSAP(补间动画)设置各种动画效果(运动效果与双击暂停动画等效果)

1. GSAP简介 GSAP👍🏼是前端业内非常有名的一个动效库,有大量的优秀的网站都在使用它。它不仅能在原生JS的环境下使用,也能配合各种当前流行的框架进行使用。 通过使用它,非常多原本实现起来很有难度的交互动画效果&a…

一文吃透Http协议

Http 协议 1. 初始 Http Http 协议 , 是应用层最为广泛使用的协议 , Http 就是浏览器和服务器之间的桥梁. Http 是基于 TCP 协议实现的 , 通常我们输入搜索框中的网址 (URL) , 浏览器就会根据这个 URL 构造出一个 Http 请求 , 发送给服务器. 服务器就会返回一个 Http 响应(包…

基于空间矢量脉宽调制(SVPWM)的并网逆变器研究(Simulink)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

Doris(23):Doris的函数—字符串函数

1 append_trailing_char_if_absent(VARCHAR str, VARCHAR trailing_char) 如果s字符串非空并且末尾不包含c字符,则将c字符附加到末尾。 trailing_char只包含一个字符,如果包含多个字符,将返回NULL select append_trailing_char_if_absent(a,c);select append_trailing_cha…

RabbitMQ 工作队列模式 Work Queue Demo

工作队列模式,一个消息只能有一个消费者消费 生产者发送20条消息 消费者有两个 第一个消费 睡一秒取一个 第二个睡2秒取 public class WorkConsumerTest1 {public static void main(String[] args) throws IOException, TimeoutException {//1 创建连接工厂ConnectionFactor…

SpringCloud01

SpringCloud01 微服务入门案例 实现步骤 导入数据 实现远程调用 MapperScan("cn.itcast.order.mapper") SpringBootApplication public class OrderApplication {public static void main(String[] args) {SpringApplication.run(OrderApplication.class, args);}…

ETL工具 - Kettle 转换算子介绍

一、Kettle 转换算子 上篇文章对 Kettle 中的输入输出算子进行了介绍,本篇文章继续对转换算子进行讲解。 下面是上篇文章的地址: ETL工具 - Kettle 输入输出算子介绍 转换是ETL里面的T(Transform),主要做数据转换&am…

快解析动态域名解析,实现外网访问内网数据库

今天跟大家分享一下如何借助快解析动态域名解析,在两种特定网络环境下,实现外网访问内网mysql数据库。 第1种网络环境:路由器分配的是动态公网IP,且有路由器登录管理权限。如何实现外网访问内网mysql数据库? 针对这种…

如何自制云平台,并实现远程访问控制?

除了阿里、腾讯各种云,计算机大神们都想自己搭建IoT云平台。今天小编跟大家分享一种用UbuntuEMQXNode-RED方式自制IoT云平台的方法,并实现无公网IP随时访问远程数据! 第一步 Step1搭建EMQX服务器 1.搭建IoT平台需要一个服务器,这…

大公司为什么禁止SpringBoot项目使用Tomcat?

前言 在SpringBoot框架中,我们使用最多的是Tomcat,这是SpringBoot默认的容器技术,而且是内嵌式的Tomcat。同时,SpringBoot也支持Undertow容器,我们可以很方便的用Undertow替换Tomcat,而Undertow的性能和内…

thinkphp6结合layui增删改查综合案列

文章目录 技术栈实现代码实现数据库 本案例适合新手,特别是杠刚入门thinkphp和layui,但又不是特别熟悉这类 主要实现登录退出功能,用户模块的增删改查功能,分页功能是layui表单自带功能 效果图 左侧的菜单栏我没有写对应的页面&am…

最好的物联网教程:软硬结合——从零打造物联网

在大学里不同专业有着不同的追求:机械类与强电类专业学生追求的是 “机电合一” ,既懂机械又懂电气,整个电气机械自动化便能打通。弱电类专业学生追求的是 “软硬结合” ,既懂硬件又懂软件,整个电子产品便能打通。我作…

微服务保护 笔记分享【黑马笔记】

微服务保护 1.初识Sentinel 1.1.雪崩问题及解决方案 1.1.1.雪崩问题 微服务中,服务间调用关系错综复杂,一个微服务往往依赖于多个其它微服务。 如图,如果服务提供者I发生了故障,当前的应用的部分业务因为依赖于服务I&#xff…

文件的使用

文章目录 1.概念1.1定义:1.2分类1.2.1程序文件1.2.2数据文件1.概念2.存储方式 1.3文件名 2.文件的使用2.1文件指针2.2开闭函数2.3顺序读写2.3.1何为读写2.3.2读写函数1.字符输出fputc(输出到文件 写到文件中)2.字符输入fgetc(输入…

Spring 的创建和使用

目录 一. 创建 Spring项目 二. 存储 Bean 对象到Spring中 1. 添加Spring配置文件 2. 创建一个 Bean 对象 3. 将 Bean 存储到 Spring 容器中 三. 从 Spring 中获取并使用 Bean 对象 1. 创建 Spring 上下文 1.1 使用 ApplicationContext 作为Spring上下文 1.2 使用 Bea…

线性回归模型(7大模型)

线性回归模型(7大模型) 线性回归是人工智能领域中最常用的统计学方法之一。在许多不同的应用领域中,线性回归都是非常有用的,例如金融、医疗、社交网络、推荐系统等等。 在机器学习中,线性回归是最基本的模型之一&am…

深入理解 Linux 内核

Linux 内核系列文章 Linux 内核设计与实现 深入理解 Linux 内核 Linux 设备驱动程序 Linux设备驱动开发详解 文章目录 Linux 内核系列文章前言一、内存寻址1、内存地址2、硬件中的分段(1)段选择符 3、Linux 中的分段(1)Linux GDT&…

IPsec中IKE与ISAKMP过程分析(快速模式-消息1)

IPsec中IKE与ISAKMP过程分析(主模式-消息1)_搞搞搞高傲的博客-CSDN博客 IPsec中IKE与ISAKMP过程分析(主模式-消息2)_搞搞搞高傲的博客-CSDN博客 IPsec中IKE与ISAKMP过程分析(主模式-消息3)_搞搞搞高傲的博客…

九款顶级AI工具推荐

ChatGPT OpenAI开发的最强对话系统 地址:chat.openai.com ChatGPT能够在同一个会话期间内回答上下文相关的后续问题。其在短时间内引爆全球的原因在于,在网友们晒出的截图中,ChatGPT不仅能流畅地与用户对话,甚至能写诗、撰文、编…