YOLOv9独家原创改进|使用DySample超级轻量的动态上采样算子


专栏介绍:YOLOv9改进系列 | 包含深度学习最新创新,主力高效涨点!!!


一、DySample论文摘要

       尽管最近的基于内核的动态上采样器如CARAFE、FADE和SAPA取得了令人印象深刻的性能提升,但它们引入了大量的工作量,主要是由于时间消耗大的动态卷积和用于生成动态内核的额外子网络。 此外,FADE和SAPA对高分辨率特征的需求在一定程度上限制了它们的应用场景。为了解决这些问题,研究人员绕过了动态卷积,并从点采样的角度来表述上采样,这更加节省资源并可以用PyTorch中的标准内置函数轻松实现。与之前的基于内核的动态上采样相比,DySample不需要自定义的CUDA包,并且参数、FLOPs、GPU内存和延迟都要少得多。除了轻量级的特点之外,DySample在五个密集预测任务(语义分割、目标检测、实例分割、全景分割和单目深度估计)中都优于其他上采样器。DySample的应用领域也更广泛,可以适用于各类图像处理任务,有效提升图像处理的效率和质量。

适用检测目标:   通用上采样算子


二、DySample模块详解

        论文地址:  https://arxiv.org/abs/2308.15085

 2.1 模块简介

        DySample的主要思想:   点采样

 总结:一种新的超轻量化上采样算子,发表于ICCV2023

DySample模块的原理图


三、DySample模块使用教程

3.1 DySample模块的代码

try:
    from mmcv.cnn import build_activation_layer, build_norm_layer
    from mmcv.ops.modulated_deform_conv import ModulatedDeformConv2d
    from mmengine.model import constant_init, normal_init
except ImportError as e:
    pass


class DySample(nn.Module):
    def __init__(self, in_channels, scale=2, style='lp', groups=4, dyscope=False):
        super().__init__()
        self.scale = scale
        self.style = style
        self.groups = groups
        assert style in ['lp', 'pl']
        if style == 'pl':
            assert in_channels >= scale ** 2 and in_channels % scale ** 2 == 0
        assert in_channels >= groups and in_channels % groups == 0

        if style == 'pl':
            in_channels = in_channels // scale ** 2
            out_channels = 2 * groups
        else:
            out_channels = 2 * groups * scale ** 2

        self.offset = nn.Conv2d(in_channels, out_channels, 1)
        # normal_init(self.offset, std=0.001)
        if dyscope:
            self.scope = nn.Conv2d(in_channels, out_channels, 1)
            constant_init(self.scope, val=0.)

        self.register_buffer('init_pos', self._init_pos())

    def _init_pos(self):
        h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale
        return torch.stack(torch.meshgrid(h, h, indexing='ij')).transpose(1, 2).repeat(1, self.groups, 1).reshape(1, -1, 1, 1)

    def sample(self, x, offset):
        B, _, H, W = offset.shape
        offset = offset.view(B, 2, -1, H, W)
        coords_h = torch.arange(H) + 0.5
        coords_w = torch.arange(W) + 0.5
        coords = torch.stack(torch.meshgrid(coords_w, coords_h, indexing='ij')
                             ).transpose(1, 2).unsqueeze(1).unsqueeze(0).type(x.dtype).to(x.device)
        normalizer = torch.tensor([W, H], dtype=x.dtype, device=x.device).view(1, 2, 1, 1, 1)
        coords = 2 * (coords + offset) / normalizer - 1
        coords = F.pixel_shuffle(coords.view(B, -1, H, W), self.scale).view(
            B, 2, -1, self.scale * H, self.scale * W).permute(0, 2, 3, 4, 1).contiguous().flatten(0, 1)
        return F.grid_sample(x.reshape(B * self.groups, -1, H, W), coords, mode='bilinear',
                             align_corners=False, padding_mode="border").view(B, -1, self.scale * H, self.scale * W)

    def forward_lp(self, x):
        if hasattr(self, 'scope'):
            offset = self.offset(x) * self.scope(x).sigmoid() * 0.5 + self.init_pos
        else:
            offset = self.offset(x) * 0.25 + self.init_pos
        return self.sample(x, offset)

    def forward_pl(self, x):
        x_ = F.pixel_shuffle(x, self.scale)
        if hasattr(self, 'scope'):
            offset = F.pixel_unshuffle(self.offset(x_) * self.scope(x_).sigmoid(), self.scale) * 0.5 + self.init_pos
        else:
            offset = F.pixel_unshuffle(self.offset(x_), self.scale) * 0.25 + self.init_pos
        return self.sample(x, offset)

    def forward(self, x):
        if self.style == 'pl':
            return self.forward_pl(x)
        return self.forward_lp(x)

3.2 在YOlO v9中的添加教程

阅读YOLOv9添加模块教程或使用下文操作

        1. 将YOLOv9工程中models下common.py文件中增加以下代码。


try:
    from mmcv.cnn import build_activation_layer, build_norm_layer
    from mmcv.ops.modulated_deform_conv import ModulatedDeformConv2d
    from mmengine.model import constant_init, normal_init
except ImportError as e:
    pass


class DySample(nn.Module):
    def __init__(self, in_channels, scale=2, style='lp', groups=4, dyscope=False):
        super().__init__()
        self.scale = scale
        self.style = style
        self.groups = groups
        assert style in ['lp', 'pl']
        if style == 'pl':
            assert in_channels >= scale ** 2 and in_channels % scale ** 2 == 0
        assert in_channels >= groups and in_channels % groups == 0

        if style == 'pl':
            in_channels = in_channels // scale ** 2
            out_channels = 2 * groups
        else:
            out_channels = 2 * groups * scale ** 2

        self.offset = nn.Conv2d(in_channels, out_channels, 1)
        # normal_init(self.offset, std=0.001)
        if dyscope:
            self.scope = nn.Conv2d(in_channels, out_channels, 1)
            constant_init(self.scope, val=0.)

        self.register_buffer('init_pos', self._init_pos())

    def _init_pos(self):
        h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale
        return torch.stack(torch.meshgrid(h, h, indexing='ij')).transpose(1, 2).repeat(1, self.groups, 1).reshape(1, -1, 1, 1)

    def sample(self, x, offset):
        B, _, H, W = offset.shape
        offset = offset.view(B, 2, -1, H, W)
        coords_h = torch.arange(H) + 0.5
        coords_w = torch.arange(W) + 0.5
        coords = torch.stack(torch.meshgrid(coords_w, coords_h, indexing='ij')
                             ).transpose(1, 2).unsqueeze(1).unsqueeze(0).type(x.dtype).to(x.device)
        normalizer = torch.tensor([W, H], dtype=x.dtype, device=x.device).view(1, 2, 1, 1, 1)
        coords = 2 * (coords + offset) / normalizer - 1
        coords = F.pixel_shuffle(coords.view(B, -1, H, W), self.scale).view(
            B, 2, -1, self.scale * H, self.scale * W).permute(0, 2, 3, 4, 1).contiguous().flatten(0, 1)
        return F.grid_sample(x.reshape(B * self.groups, -1, H, W), coords, mode='bilinear',
                             align_corners=False, padding_mode="border").view(B, -1, self.scale * H, self.scale * W)

    def forward_lp(self, x):
        if hasattr(self, 'scope'):
            offset = self.offset(x) * self.scope(x).sigmoid() * 0.5 + self.init_pos
        else:
            offset = self.offset(x) * 0.25 + self.init_pos
        return self.sample(x, offset)

    def forward_pl(self, x):
        x_ = F.pixel_shuffle(x, self.scale)
        if hasattr(self, 'scope'):
            offset = F.pixel_unshuffle(self.offset(x_) * self.scope(x_).sigmoid(), self.scale) * 0.5 + self.init_pos
        else:
            offset = F.pixel_unshuffle(self.offset(x_), self.scale) * 0.25 + self.init_pos
        return self.sample(x, offset)

    def forward(self, x):
        if self.style == 'pl':
            return self.forward_pl(x)
        return self.forward_lp(x)

         2. 将YOLOv9工程中models下yolo.py文件中的第718行(可能因版本变化而变化)增加以下代码。

        elif m in (DySample,):
            args.insert(0, ch[f])

3.3 运行配置文件

# YOLOv9
# Powered bu https://blog.csdn.net/StopAndGoyyy
# parameters
nc: 80  # number of classes
#depth_multiple: 0.33  # model depth multiple
depth_multiple: 1  # model depth multiple
#width_multiple: 0.25  # layer channel multiple
width_multiple: 1  # layer channel multiple
#activation: nn.LeakyReLU(0.1)
#activation: nn.ReLU()

# anchors
anchors: 3

# YOLOv9 backbone
backbone:
  [
   [-1, 1, Silence, []],  
   
   # conv down
   [-1, 1, Conv, [64, 3, 2]],  # 1-P1/2

   # conv down
   [-1, 1, Conv, [128, 3, 2]],  # 2-P2/4

   # elan-1 block
   [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 3

   # avg-conv down
   [-1, 1, ADown, [256]],  # 4-P3/8

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 5

   # avg-conv down
   [-1, 1, ADown, [512]],  # 6-P4/16

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 7

   # avg-conv down
   [-1, 1, ADown, [512]],  # 8-P5/32

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 9
  ]

# YOLOv9 head
head:
  [
   # elan-spp block
   [-1, 1, SPPELAN, [512, 256]],  # 10

   # up-concat merge
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 7], 1, Concat, [1]],  # cat backbone P4

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 13

   # up-concat merge
   [-1, 1, DySample, []],
   [[-1, 5], 1, Concat, [1]],  # cat backbone P3

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]],  # 16 (P3/8-small)

   # avg-conv-down merge
   [-1, 1, ADown, [256]],
   [[-1, 13], 1, Concat, [1]],  # cat head P4

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 19 (P4/16-medium)

   # avg-conv-down merge
   [-1, 1, ADown, [512]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 22 (P5/32-large)
   
   
   # multi-level reversible auxiliary branch
   
   # routing
   [5, 1, CBLinear, [[256]]], # 23
   [7, 1, CBLinear, [[256, 512]]], # 24
   [9, 1, CBLinear, [[256, 512, 512]]], # 25
   
   # conv down
   [0, 1, Conv, [64, 3, 2]],  # 26-P1/2

   # conv down
   [-1, 1, Conv, [128, 3, 2]],  # 27-P2/4

   # elan-1 block
   [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 28

   # avg-conv down fuse
   [-1, 1, ADown, [256]],  # 29-P3/8
   [[23, 24, 25, -1], 1, CBFuse, [[0, 0, 0]]], # 30  

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 31

   # avg-conv down fuse
   [-1, 1, ADown, [512]],  # 32-P4/16
   [[24, 25, -1], 1, CBFuse, [[1, 1]]], # 33 

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 34

   # avg-conv down fuse
   [-1, 1, ADown, [512]],  # 35-P5/32
   [[25, -1], 1, CBFuse, [[2]]], # 36

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 37
   
   
   
   # detection head

   # detect
   [[31, 34, 37, 16, 19, 22], 1, DualDDetect, [nc]],  # DualDDetect(A3, A4, A5, P3, P4, P5)
  ]

3.4 训练过程


欢迎关注

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/426048.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

实时抓取SKU商品属性详细信息API数据接口(淘宝,某音)

item_sku-获取sku详细信息 taobao.item_sku详细信息 API公共参数 请求地址: https://api-gw.onebound.cn/taobao/item_sku 名称类型必须描述keyString是调用key(演示示例)secretString是调用密钥api_nameString是API接口名称(包括在请求地…

芯来科技发布最新NI系列内核,NI900矢量宽度可达512/1024位

参考:芯来科技发布最新NI系列内核,NI900矢量宽度可达512/1024位 (qq.com) 本土RISC-V CPU IP领军企业——芯来科技正式发布首款针对人工智能应用的专用处理器产品线Nuclei Intelligence(NI)系列,以及NI系列的第一款AI专用RISC-V处理器CPU IP…

机器人 标准DH与改进DH

文章目录 1 建立机器人坐标系1.1 连杆编号1.2 关节编号1.3 坐标系方向2 标准DH(STD)2.1 确定X轴方向2.2 建模步骤2.3 变换顺序2.4 变换矩阵3 改进DH(MDH)3.1 确定X轴方向3.2 建模步骤3.3 变换顺序3.4 变换矩阵4 标准DH与改进DH区别5 Matlab示例参考链接1 建立机器人坐标系 1.1…

现代化数据架构升级:毫末智行自动驾驶如何应对年增20PB的数据规模挑战?-OceanBase案例

毫末智行是一家致力于自动驾驶的人工智能技术公司,其前身是长城汽车智能驾驶前瞻分部,以零事故、零拥堵、自由出行和高效物流为目标,助力合作伙伴重塑和全面升级整个社会的出行及物流方式。 在自动驾驶领域中,是什么原因让毫末智行…

边缘智能网关:让环境监测更智能

在环境监测领域,边缘智能网关可用于区域环境的实时监测、分析和预警,例如河湖水位监测、雨雪监测、风沙/风速监测,通过实时采集并分析环境变化数据,能够有助于对于突发、急发的各种自然灾害进行快速预警和应对。 一、边缘智能网关…

Docker 创建容器并指定时区

目录 1. 通过环境变量设置时区(推荐)2. 挂载宿主机的时区文件到容器中3. 总结 要在 Docker 容器中指定时区,可以通过两种方式来实现: 1. 通过环境变量设置时区(推荐) 在 Docker 运行时,可以通…

Unity UI实现表格渲染

前言 最近有在用Unity做前端UI, 用到了实现表格数据渲染,也就是后台给的list渲染到表格中,查看了许多资料发现比较少,因此在这里记录一下吧,希望可以帮助到大家哦。 也是第一次使用Unity,先简单介绍一下&…

类构造完成,Bean注入之后执行方法

PostConstruct 容器执行之后执行 PreDestory 在容器销毁之前执行

redis进阶(一)

文章目录 前言一、Redis中的对象的结构体如下:二、压缩链表三、跳跃表 前言 Redis是一种key/value型数据库,其中,每个key和value都是使用对象表示的。 一、Redis中的对象的结构体如下: /** Redis 对象*/ typedef struct redisO…

今日arXiv最热大模型论文:谷歌最新研究,将LLM用于回归分析任务,显著超越传统模型

回归分析是一个强大的工具,能够准确预测系统或模型的结果指标,给定一组参数。然而,传统上这些方法仅适用于特定任务。本文研究者提出了OMNIPRED框架,这是一个训练语言模型作为通用端到端回归器的框架,它可以处理来自多…

SNAP:如何批量预处理Sentinel2 L2A数据集并输出为TIFF文件?

我的需求 我目前就是希望下载哨兵2号数据,然后在SNAP中进行批量提取真彩色波段并输出为TIFF文件。 数据集下载说明 目前哨兵网站似乎进行了一大波更新,连网站都换了,网址如下: https://dataspace.copernicus.eu/ 打开后界面如…

五千字 DDL、DML、DQL、DCL 超详解

SQL语句,根据其功能,主要分为四类:DDL、DML、DQL、DCL。 DDL (Data Definition Language) 数据定义语言,用来定义数据库对象(数据库,表, 字段) DML (Data Manipulation Languag) 数据操作语言,…

想从事数据方向职场小白看过来, 数据方面的一些英文解释

想从事数据方向职场小白看过来,一些英文名词解释 文章目录 想从事数据方向职场小白看过来,一些英文名词解释 英文类解释NoSQL:ESB:ACID :Data Vault:MDM:OLAP:SCD:SBA:MP…

从嵌入式Linux到嵌入式Android

最近开始投入Android的怀抱。说来惭愧,08年就听说这东西,当时也有同事投入去看,因为恶心Java,始终对这玩意无感,没想到现在不会这个嵌入式都快要没法搞了。为了不中年失业,所以只能回过头又来学。 首先还是…

Python算法100例-2.11 换分币

完整源代码项目地址,关注博主私信源代码后可获取 1.问题描述2.问题分析3.算法设计4.确定程序框架5.完整的程序6.运行结果 1.问题描述 将5元的人民币兑换成1元、5角和1角的硬币,共有多少种不同的兑换方法。 2.问题分析 根据该…

【框架】Spring 框架重点解析

Spring 框架重点解析 1. Spring 框架中的单例 bean 是线程安全的吗? 不是线程安全的 Spring 框架中有一个 Scope 注解,默认的值是 singleton,即单例的;因为一般在 Spring 的 bean 对象都是无状态的(在生命周期中不被…

嵌入式Qt 对话框及其类型 QDialog

一.对话框的概念 对话框是与用户进行简短交互的顶层窗口。 QDialog是Qt中所有对话框窗口的基类。 QDialog继承与QWidfet是一种容器类型的组件。 QDialog的意义: QDialog作为一种专业的交互窗口而存在。 QDialog不能作为子部部件嵌入其他容器中。 QDialog是定制…

【算法集训】基础算法:枚举

一、基本理解 枚举的概念就是把满足题目条件的所有情况都列举出来,然后一一判定,找到最优解的过程。 枚举虽然看起来麻烦,但是有时效率上比排序高,也是一个不错的方法、 二、最值问题 1、两个数的最值问题 两个数的最小值&…

力扣刷题:226.反转二叉树

题目: 给你一棵二叉树的根节点 root ,翻转这棵二叉树,并返回其根节点。 示例 1: 输入:root [4,2,7,1,3,6,9] 输出:[4,7,2,9,6,3,1]示例 2: 输入:root [2,1,3] 输出:[2…

业务真的需要微服务吗

业务真的需要微服务吗 要说过去十年最火热的软件体系是什么,个人认为莫过于“微服务架构“了。从一线互联网架构师,到刚接触计算机软件不久的学生几乎都或多或少的了解过”微服务“相关知识了,其中在最出名的微服务体系要数 spring cloud 了…
最新文章