YoloV8改进策略:卷积改进|DOConv轻量卷积,即插即用|适用各种场景

摘要

本文使用DOConv卷积,替换YoloV8的常规卷积,轻量高效,即插即用!改进方法非常简单。
在这里插入图片描述

DO-Conv(Depthwise Over-parameterized Convolutional Layer)是一种深度过参数化的卷积层,用于提高卷积神经网络(CNN)的性能。它的核心思想是在训练阶段使用额外的深度卷积来增强卷积层,其中每个输入通道与不同的二维核进行卷积。这两个卷积的组合构成了一个过度参数化,因为它增加了可学习的参数,而结果的线性操作可以用单个卷积层来表示。在推理阶段,DO-Conv可以融合到常规卷积层中,使得计算量与常规卷积层的计算量完全相同。

DO-Conv可以作为一种即插即用的模块,用于替代CNN中的常规卷积层,以提高在各种计算机视觉任务(如图像分类、语义分割和对象检测)上的性能。通过实验证明,使用DO-Conv不仅可以加速网络的训练过程,还能在多种计算机视觉任务中取得比使用传统卷积层更好的结果。

论文链接:https://arxiv.org/pdf/2006.12030.pdf

代码

# coding=utf-8
import math
import torch
import numpy as np
from torch.nn import init
from itertools import repeat
from torch.nn import functional as F
from torch._jit_internal import Optional
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import collections


class DOConv2d(Module):
    """
       DOConv2d can be used as an alternative for torch.nn.Conv2d.
       The interface is similar to that of Conv2d, with one exception:
            1. D_mul: the depth multiplier for the over-parameterization.
       Note that the groups parameter switchs between DO-Conv (groups=1),
       DO-DConv (groups=in_channels), DO-GConv (otherwise).
    """
    __constants__ = ['stride', 'padding', 'dilation', 'groups',
                     'padding_mode', 'output_padding', 'in_channels',
                     'out_channels', 'kernel_size', 'D_mul']
    __annotations__ = {'bias': Optional[torch.Tensor]}

    def __init__(self, in_channels, out_channels, kernel_size, D_mul=None, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
        super(DOConv2d, self).__init__()

        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)

        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
        if padding_mode not in valid_padding_modes:
            raise ValueError("padding_mode must be one of {}, but got padding_mode='{}'".format(
                valid_padding_modes, padding_mode))
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.padding_mode = padding_mode
        self._padding_repeated_twice = tuple(x for x in self.padding for _ in range(2))

        #################################### Initailization of D & W ###################################
        M = self.kernel_size[0]
        N = self.kernel_size[1]
        self.D_mul = M * N if D_mul is None or M * N <= 1 else D_mul
        self.W = Parameter(torch.Tensor(out_channels, in_channels // groups, self.D_mul))
        init.kaiming_uniform_(self.W, a=math.sqrt(5))

        if M * N > 1:
            self.D = Parameter(torch.Tensor(in_channels, M * N, self.D_mul))
            init_zero = np.zeros([in_channels, M * N, self.D_mul], dtype=np.float32)
            self.D.data = torch.from_numpy(init_zero)

            eye = torch.reshape(torch.eye(M * N, dtype=torch.float32), (1, M * N, M * N))
            d_diag = eye.repeat((in_channels, 1, self.D_mul // (M * N)))
            if self.D_mul % (M * N) != 0:  # the cases when D_mul > M * N
                zeros = torch.zeros([in_channels, M * N, self.D_mul % (M * N)])
                self.d_diag = Parameter(torch.cat([d_diag, zeros], dim=2), requires_grad=False)
            else:  # the case when D_mul = M * N
                self.d_diag = Parameter(d_diag, requires_grad=False)
        ##################################################################################################

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.W)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)
        else:
            self.register_parameter('bias', None)

    def extra_repr(self):
        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        if self.padding_mode != 'zeros':
            s += ', padding_mode={padding_mode}'
        return s.format(**self.__dict__)

    def __setstate__(self, state):
        super(DOConv2d, self).__setstate__(state)
        if not hasattr(self, 'padding_mode'):
            self.padding_mode = 'zeros'

    def _conv_forward(self, input, weight):
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
                            weight, self.bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

    def forward(self, input):
        M = self.kernel_size[0]
        N = self.kernel_size[1]
        DoW_shape = (self.out_channels, self.in_channels // self.groups, M, N)
        if M * N > 1:
            ######################### Compute DoW #################
            # (input_channels, D_mul, M * N)
            D = self.D + self.d_diag
            W = torch.reshape(self.W, (self.out_channels // self.groups, self.in_channels, self.D_mul))

            # einsum outputs (out_channels // groups, in_channels, M * N),
            # which is reshaped to
            # (out_channels, in_channels // groups, M, N)
            DoW = torch.reshape(torch.einsum('ims,ois->oim', D, W), DoW_shape)
        else:
            # in this case D_mul == M * N
            # reshape from
            # (out_channels, in_channels // groups, D_mul)
            # to
            # (out_channels, in_channels // groups, M, N)
            DoW = torch.reshape(self.W, DoW_shape)
        return self._conv_forward(input, DoW)


def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return x
        return tuple(repeat(x, n))

    return parse
_pair = _ntuple(2)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/566806.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Win10 搭建 YOLOv8 运行环境(20240423)

一、环境要求 1、Python&#xff0c;版本要求>3.7 2、PyTorch&#xff0c;版本要求>1.7。PyTorch 是一个开源的深度学习平台&#xff0c;为人工智能研究提供了一个灵活的、易于使用的工具集。YOLOv8 是基于 PyTorch 框架实现的&#xff0c;所以需要安装 PyTorch。 3、CUD…

【nginx】nginx启动显示80端口占用问题的解决方案

目录 &#x1f305;1. 问题描述 &#x1f30a;2. 解决方案 &#x1f305;1. 问题描述 在启动nginx服务的时候显示内容如下&#xff1a; sudo systemctl status nginx 问题出现原因&#xff1a; 根据日志显示&#xff0c;Nginx 服务启动失败&#xff0c;主要原因是无法绑定…

Ultralytics YOLOv8 英伟达™ Jetson®处理器部署

系列文章目录 前言 本综合指南提供了在英伟达 Jetson设备上部署Ultralytics YOLOv8 的详细攻略。此外&#xff0c;它还展示了性能基准&#xff0c;以证明YOLOv8 在这些小巧而功能强大的设备上的性能。 备注 本指南使用Seeed Studio reComputer J4012进行测试&#xff0c;它基于…

[Vue warn]: useModel() called with prop “xxx“ which is not declared

我们在使用vue3里面的defineModel的时候可能会出现这个问题&#xff0c;原因是我们使用的 kebab-case 形式的属性名&#xff0c;我也不知道是不是vue3设定这个api的时候设置的不支持&#xff0c;我没找到相关文档&#xff0c;不过我们把 kebab-case 的形式改为 驼峰命名法 或者…

【WEB前端2024】开源元宇宙:乔布斯3D纪念馆-第9课-摆件美化

【WEB前端2024】开源元宇宙&#xff1a;乔布斯3D纪念馆-第9课-摆件美化 使用dtns.network德塔世界&#xff08;开源的智体世界引擎&#xff09;&#xff0c;策划和设计《乔布斯超大型的开源3D纪念馆》的系列教程。dtns.network是一款主要由JavaScript编写的智体世界引擎&#…

架构师系列-MYSQL调优(七)- 索引单表优化案例

索引单表优化案例 1. 建表 创建表 插入数据 下面是一张用户通讯表的表结构信息,这张表来源于真实企业的实际项目中,有接近500万条数据. CREATE TABLE user_contacts (id INT(11) NOT NULL AUTO_INCREMENT,user_id INT(11) DEFAULT NULL COMMENT 用户标识,mobile VARCHAR(50…

李沐53_语言模型——自学笔记

语言模型 1.预测文本序列出现的概率 2.应用在做预训练模型 3.生成文本&#xff0c;给定前面几个词&#xff0c;不断生成后续文本 4.判断多个序列中哪个更常见 真实数据集的统计 《时光机器》数据集构建词表&#xff0c; 并打印前10个最常用的&#xff08;频率最高的&…

Zabbix监控系统:基础配置及部署代理服务器

目录 前言 一、自定义监控内容 1、在客户端创建自定义key 2、在服务端验证新建的监控项 3、在web界面创建自定义监控项模版 3.1 创建模版 3.2 创建应用集&#xff08;用于管理监控项&#xff09; 3.3 创建监控项 3.4 创建触发器 3.5 创建图形 3.6 将主机与模板关联…

Python | Leetcode Python题解之第43题字符串相乘

题目&#xff1a; 题解&#xff1a; class Solution:def multiply(self, num1: str, num2: str) -> str:if num1 "0" or num2 "0":return "0"m, n len(num1), len(num2)ansArr [0] * (m n)for i in range(m - 1, -1, -1):x int(num1[i…

【技术干货】润石红外额温枪方案芯片功能介绍

手持红外额温枪框图中&#xff0c;以电池采用9V为例&#xff0c;先通过一个高压LDO RS3002 把电池电压转为3V&#xff0c;供整个系统使用&#xff0c;包括为 MCU&#xff0c;背光灯&#xff0c;运放 等器件供电&#xff0c;然后再用一个低功耗LDO RS3236 从3V 降为1.5V&#…

Excel如何计算时间差

HOUR(B1-A1)&"小时 "&MINUTE(B1-A1)&"分钟 "&SECOND(B1-A1)&"秒"

10种常用的JS数组循环及其应用场景

文章的更新路线&#xff1a;JavaScript基础知识-Vue2基础知识-Vue3基础知识-TypeScript基础知识-网络基础知识-浏览器基础知识-项目优化知识-项目实战经验-前端温习题&#xff08;HTML基础知识和CSS基础知识已经更新完毕&#xff09; 正文 在JavaScript中&#xff0c;数组是一种…

如何在PostgreSQL中实现基于哈希的分区表以及其优势是什么

文章目录 一、基于哈希的分区表实现二、基于哈希的分区表优势 PostgreSQL是一个功能强大的开源关系型数据库管理系统&#xff0c;它支持多种分区策略&#xff0c;包括基于范围的分区、基于列表的分区以及基于哈希的分区。本文将重点讨论如何在PostgreSQL中实现基于哈希的分区表…

青否数字人直播带货源码有哪些功能?

青否数字人直播源码功能如下&#xff1a; 1、青否数字人克隆源码的克隆效果媲美真人 将真人录制的2-6分钟视频上传至克隆端后台&#xff0c;系统便会自动启动自动克隆。3-5小时后&#xff0c;即可生成一个与本人在形象、表情及动作上1&#xff1a;1的数字人。 2、在声音克隆上&…

Vue 3中的ref和toRefs:响应式状态管理利器

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

[MySQL]运算符

1. 算术运算符 (1). 算术运算符 : , -, *, / 或 DIV, % 或MOD. (2). 例 : (3). 注 : DUAL是伪表.可以看到4/2结果为小数&#xff0c;并不会截断小数部分.(可能与其他语言不同&#xff0c;比如java中&#xff0c;两个操作数如果是整数&#xff0c;则计算得到的也是整数&…

面试经典150题——二叉树展开为链表

​ 1. 题目描述 2. 题目分析与解析 2.1 思路一 因为题目中提到&#xff1a;展开后的单链表应该与二叉树 先序遍历 顺序相同&#xff0c;那么我们是不是就可以先先序遍历&#xff0c;然后按照先序遍历的节点一个一个赋值&#xff1f; 其实最简单的思路就是用一个结构按顺序存…

加速大数据分析:Apache Kylin使用心得与最佳实践详解

Apache Kylin 是一个开源的分布式分析引擎&#xff0c;提供了Hadoop之上的SQL接口和多维分析&#xff08;OLAP&#xff09;能力以支持大规模数据。它擅长处理互联网级别的超大规模数据集&#xff0c;并能够进行亚秒级的查询响应时间。Kylin 的主要使用场景包括大数据分析、交互…

Web前端安全问题分类综合以及XSS、CSRF、SQL注入、DoS/DDoS攻击、会话劫持、点击劫持等详解,增强生产安全意识

前端安全问题是指发生在浏览器、单页面应用、Web页面等前端环境中的各类安全隐患。Web前端作为与用户直接交互的界面&#xff0c;其安全性问题直接关系到用户体验和数据安全。近年来&#xff0c;随着前端技术的快速发展&#xff0c;Web前端安全问题也日益凸显。因此&#xff0c…

Altair:Python数据可视化库的魅力之旅

目录 一、引言 二、Altair概述 三、Altair的核心特性 1.声明式语法 2.丰富的图表类型 3.交互式与响应式 4.无缝集成 四、案例与代码实践 案例一&#xff1a;使用Altair绘制折线图 案例二&#xff1a;使用Altair绘制热力图 五、新手入门指南 1.安装与导入 2.数据准…
最新文章