YOLOv5改进 | 添加ECA注意力机制 + 更换主干网络之ShuffleNetV2

前言:Hello大家好,我是小哥谈。本文给大家介绍一种轻量化部署改进方式,即在主干网络中添加ECA注意力机制和更换主干网络之ShuffleNetV2,希望大家学习之后,能够彻底理解其改进流程及方法~!🌈 

     目录

🚀1.基础概念

🚀2.添加位置

🚀3.添加步骤

🚀4.改进方法

💥💥步骤1:common.py文件修改

💥💥步骤2:yolo.py文件修改

💥💥步骤3:创建自定义yaml文件

💥💥步骤4:修改自定义yaml文件

💥💥步骤5:验证是否加入成功

💥💥步骤6:修改默认参数

🚀1.基础概念

ECA注意力机制:

ECA注意力机制是一种用于提升卷积神经网络特征表示能力的方法。它通过嵌入式通道注意力模块,在保持高效性的同时,引入了通道注意力机制。具体来说,ECA注意力机制在通道维度上增加了注意力机制,以提升特征表示的能力。与SE注意力机制不同的是,ECA注意力机制只包含一个操作——excitation,而没有squeeze操作。这使得ECA注意力机制更加轻量级,适用于计算资源有限的场景。

ECA的结构主要分为两个部分:通道注意力模块和嵌入式通道注意力模块。

🍀(1)通道注意力模块

通道注意力模块是ECA的核心组成部分,它的目标是根据通道之间的关系,自适应地调整通道特征的权重。该模块的输入是一个特征图(Feature Map),通过全局平均池化得到每个通道的全局平均值,然后通过一组全连接层来生成通道注意力权重。这些权重被应用于输入特征图的每个通道,从而实现特征图中不同通道的加权组合。最后,通过一个缩放因子对调整后的特征进行归一化,以保持特征的范围。

🍀(2)嵌入式通道注意力模块

嵌入式通道注意力模块是ECA的扩展部分,它将通道注意力机制嵌入到卷积层中,从而在卷积操作中引入通道关系。这种嵌入式设计能够在卷积操作的同时,进行通道注意力的计算,减少了计算成本。具体而言,在卷积操作中,将输入特征图划分为多个子特征图,然后分别对每个子特征图进行卷积操作,并在卷积操作的过程中引入通道注意力。最后,将这些卷积得到的子特征图进行合并,得到最终的输出特征图。

ShuffleNetV2网络:

ShuffleNetV2是一种轻量级的神经网络模型,它是ShuffleNetV1的改进版本。ShuffleNetV2主要采用了两种技术通道分离组卷积。通道分离是指将输入的通道分成两个部分,分别进行不同的计算,然后再将它们合并在一起。这种方法可以减少计算量,提高模型的效率。组卷积是指将卷积操作分成多个小组,每个小组只处理一部分通道,然后再将它们合并在一起。这种方法可以减少参数量,提高模型的泛化能力。


🚀2.添加位置

本文的改进是基于YOLOv5-6.0版本,关于其网络结构具体如下图所示:

本文的改进是在主干网络中添加ECA注意力机制更换主干网络之ShuffleNetV2,具体添加位置如下图所示:

所以,本节课改进后的网络结构图具体如下图所示:


🚀3.添加步骤

针对本文的改进,具体步骤如下所示:👇

步骤1:common.py文件修改

步骤2:yolo.py文件修改

步骤3:创建自定义yaml文件

步骤4:修改自定义yaml文件

步骤5:验证是否加入成功

步骤6:修改默认参数


🚀4.改进方法

💥💥步骤1:common.py文件修改

common.py中添加ECA注意力机制模块ShuffleNetV2模块,所要添加模块的代码如下所示,将其复制粘贴到common.py文件末尾的位置。

ECA注意力机制代码:

# ECA
class ECA(nn.Module):
    """Constructs a ECA module.
    Args:
        channel: Number of channels of the input feature map
        k_size: Adaptive selection of kernel size
    """
    def __init__(self, c1, c2, k_size=3):
        super(ECA, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # feature descriptor on the global spatial information
        y = self.avg_pool(x)
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        # Multi-scale information fusion
        y = self.sigmoid(y)
        return x * y.expand_as(x)

ShuffleNetV2模块代码:

# 更换主干网络之shuffleNetV2
def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()
    channels_per_group = num_channels // groups
    # reshape
    x = x.view(batchsize, groups,
               channels_per_group, height, width)
    x = torch.transpose(x, 1, 2).contiguous()
    # flatten
    x = x.view(batchsize, -1, height, width)
    return x
class CBRM(nn.Module):
    def __init__(self, c1, c2):
        super(CBRM, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(c1, c2, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(c2),
            nn.ReLU(inplace=True),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    def forward(self, x):
        return self.maxpool(self.conv(x))

class ShuffleNetV2(nn.Module):
    def __init__(self, ch_in, ch_out, stride):
        super(ShuffleNetV2, self).__init__()
        if not (1 <= stride <= 2):
            raise ValueError('illegal stride value')
        self.stride = stride
        branch_features = ch_out // 2
        assert (self.stride != 1) or (ch_in == branch_features << 1)
        if self.stride > 1:
            self.branch1 = nn.Sequential(
                self.depthwise_conv(ch_in, ch_in, kernel_size=3, stride=self.stride, padding=1),
                nn.BatchNorm2d(ch_in),

                nn.Conv2d(ch_in, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(branch_features),
                nn.ReLU(inplace=True),
            )
        self.branch2 = nn.Sequential(
            nn.Conv2d(ch_in if (self.stride > 1) else branch_features,
                      branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
            self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
            nn.BatchNorm2d(branch_features),
            nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
        )
    @staticmethod
    def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
        return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
    def forward(self, x):
        if self.stride == 1:
            x1, x2 = x.chunk(2, dim=1)  # 按照维度1进行split
            out = torch.cat((x1, self.branch2(x2)), dim=1)
        else:
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
        out = channel_shuffle(out, 2)
        return out

💥💥步骤2:yolo.py文件修改

首先在yolo.py文件中找到parse_model函数这一行,加入ECACBAMShuffleNetV2。具体如下图所示:

💥💥步骤3:创建自定义yaml文件

models文件夹中复制yolov5s.yaml,粘贴并重命名为yolov5s_ECA_ShuffleNetV2.yaml具体如下图所示:

💥💥步骤4:修改自定义yaml文件

本步骤是修改yolov5s_ECA_ShuffleNetV2.yaml,根据改进后的网络结构图进行修改。

由下面这张图可知,当添加ECA注意力机制和更换主干网络之ShuffleNetV2之后,后面的层数会发生相应的变化,需要修改相关参数。

备注:层数从0开始计算,比如第0层、第1层、第2层......🍉 🍓 🍑 🍈 🍌 🍐  

综上所述,修改后的完整yaml文件如下所示:

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50 # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  # Shuffle_Block: [out, stride]
  [[ -1, 1, CBRM, [ 32 ] ], # 0-P2/4
   [ -1, 1, ShuffleNetV2, [ 128, 2 ] ],  # 1-P3/8
   [ -1, 3, ShuffleNetV2, [ 128, 1 ] ],  # 2
   [ -1, 1, ShuffleNetV2, [ 256, 2 ] ],  # 3-P4/16
   [ -1, 7, ShuffleNetV2, [ 256, 1 ] ],  # 4
   [ -1, 1, ShuffleNetV2, [ 512, 2 ] ],  # 5-P5/32
   [ -1, 3, ShuffleNetV2, [ 512, 1 ] ],  # 6
   [-1, 1, ECA, [512]],  # 7
   [-1, 1, SPPF, [1024, 5]],  # 8
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 12

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 2], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 16 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 13], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 19 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 9], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 22 (P5/32-large)

   [[16, 19, 22], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

💥💥步骤5:验证是否加入成功

yolo.py文件里,将配置改为我们刚才自定义的yolov5s_ECA_ShuffleNetV2.yaml

修改1,位置位于yolo.py文件165行左右,具体如图所示:

修改2,位置位于yolo.py文件363行左右,具体如下图所示:

配置完毕之后,点击“运行”,结果如下图所示:

由运行结果可知,与我们前面更改后的网络结构图相一致,证明添加成功了!✅ 

说明:由上图可以看出,添加ECA注意力机制和更换主干网络之ShuffleNetV2之后,参数量大大减少,所以,该种改进方式适合于轻量化部署。

💥💥步骤6:修改默认参数

train.py文件中找到parse_opt函数,然后将第二行 '--cfg的default改为 'models / yolov5s_ECA_ShuffleNetV2.yaml',然后就可以开始进行训练了。🎈🎈🎈 

结束语:关于更多YOLOv5学习知识,可参考专栏:《YOLOv5:从入门到实战》🍉 🍓 🍑 🍈 🍌 🍐

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/214029.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

使用idea如何快速的搭建ssm的开发环境

文章目录 唠嗑部分言归正传1、打开idea&#xff0c;点击新建项目2、填写信息3、找到pom.xml先添加springboot父依赖4、添加其他依赖5、编写启动类、配置文件6、连接创建数据库、创建案例表7、安装MybatisX插件8、逆向工程9、编写controller10、启动项目、测试 结语 唠嗑部分 小…

剪切空间与归一化设备坐标【NDC】

有了投影变换的知识&#xff0c;我们现在可以讨论剪切空间&#xff08;Clip Space&#xff09;和 归一化设备坐标&#xff08;NDC&#xff1a;Normalized Device Coordinates&#xff09;。 为了理解这些主题&#xff0c;我们还需要深入了解齐次坐标的有趣世界。 NSDT工具推荐&…

解决:UnboundLocalError: local variable ‘js’ referenced before assignment

解决&#xff1a;UnboundLocalError: local variable ‘js’ referenced before assignment 文章目录 解决&#xff1a;UnboundLocalError: local variable js referenced before assignment背景报错问题报错翻译报错位置代码报错原因解决方法今天的分享就到此结束了 背景 在使…

Python、Stata、SPSS怎么学?推荐一波学习资料

1.Python学习推荐书目 关于Python机器学习&#xff0c;推荐学习杨维忠、张甜所著的&#xff0c;清华大学出版社出版的《Python机器学习原理与算法实现》&#xff0c;以及张甜、杨维忠所编著的&#xff0c;清华大学出版社出版的《Python数据科学应用从入门到精通》&#xff0c;…

柯桥英语口语学习,日常生活用语军大衣用英语怎么说?

那么军大衣跟羽绒服用英语怎么说呢&#xff1f; 跟商英君一起学习一下吧&#xff01; 01 "军大衣"用英语怎么说&#xff1f; 军大衣在英语表达中 也有专门的词汇 即military coat 或 military style cotton coats military有“军人、军事;军事的、军用的…”的…

【Java Web学习笔记】3 - JavaScript入门

项目代码 https://github.com/yinhai1114/JavaWeb_LearningCode/tree/main/javascript 零、JavaScript引出 JavaScript 教程 官方文档 1. JavaScript能改变HTML内容&#xff0c;能改变HTML属性&#xff0c;能改变HTML样式(CSS),能完成页面的数据验证。 <!DOCTYPE html>…

notepad ++ 用法大全【程序员必会高级用法】

目录 1&#xff1a;notepad 介绍 2&#xff1a; 快捷键 3&#xff1a; notepad 实用插件 1&#xff1a;notepad 介绍 notepad是一款免费且开源的文本编辑器&#xff0c;可运行在Windows系统上。它支持多种编程语言&#xff0c;包括C、C、Java、Python等等。Notepad具有许多实…

1949-2021年全国31省铁路里程数据

1949-2021年全国31省铁路里程数据 1、时间&#xff1a;1949-2021年 2、指标&#xff1a;时间、省份、铁路里程 3、范围&#xff1a;包括31省 4、数据缺失情况说明&#xff1a;西藏2005年之前存在缺失&#xff0c;其余30省份1978-2020年无缺失 5、来源&#xff1a;各省统计…

Python生产者消费者模型

额滴名片儿 &#x1f388; 博主&#xff1a;一只程序猿子 &#x1f388; 博客主页&#xff1a;一只程序猿子 博客主页 &#x1f388; 个人介绍&#xff1a;爱好(bushi)编程&#xff01; &#x1f388; 创作不易&#xff1a;如喜欢麻烦您点个&#x1f44d;或者点个⭐&#xff01…

反序列化漏洞详解(三)

目录 一、wakeup绕过 二、引用 三、session反序列化漏洞 3.1 php方式存取session格式 3.2 php_serialize方式存取session格式 3.3 php_binary方式存取session格式 3.4 代码演示 3.5 session例题获取flag 四、phar反序列化漏洞 4.1 phar常识 4.2 代码演示 4.3 phar例…

KDE环境文件夹user-dirs为英文

KDE环境文件夹user-dirs 修改KDE主页文件夹为英文 该文件路径 ~/.config/user-dirs.dirs打开后会发现里面的内容如下 # This file is written by xdg-user-dirs-update # If you want to change or add directories, just edit the line youre # interested in. All local …

XSS漏洞原理

XSS漏洞介绍&#xff1a; 跨站脚本攻击XSS(Cross Site Scripting)&#xff0c;为了不和层叠样式表(Cascading Style Sheets, CSS)的缩写混淆&#xff0c;故将跨站脚本攻击缩写为XSS。恶意攻击者往Web页面里插入恶意Script代码&#xff0c;当用户浏览该页面时&#xff0c;嵌入We…

SpringMVC常用注解和用法总结

目标&#xff1a; 1. 熟悉使用SpringMVC中的常用注解 目录 前言 1. Controller 2. RestController 3. RequestMapping 4. RequestParam 5. PathVariable 6. SessionAttributes 7. CookieValue 前言 SpringMVC是一款用于构建基于Java的Web应用程序的框架&#xff0c;它通…

3dMax拼图生成工具Puzzle2D使用教程

Puzzle2D for 3dsMax拼图生成工具使用教程 Puzzle2D简介&#xff1a; 2D拼图随机生成器&#xff08;英文&#xff1a;Puzzle2D&#xff09; &#xff0c;是一款由#沐风课堂#用MAXScript脚本语言开发的3dsMax建模小工具&#xff0c;可以随机创建2D可编辑样条线拼图图形。可批量…

3、在链式存储结构上建立一棵二叉排序树。

3、在链式存储结构上建立一棵二叉排序树。 分析&#xff1a; &#xff08;1&#xff09;定义二叉排序树的结点。 &#xff08;2&#xff09;插入操作&#xff1a;在建立二叉排序树的过程中&#xff0c;需要一个插入操作&#xff0c;用于将新的元素插入到树中。 插入操作的核心思…

openGauss学习笔记-140 openGauss 数据库运维-例行维护-例行维护表

文章目录 openGauss学习笔记-140 openGauss 数据库运维-例行维护-例行维护表140.1 相关概念140.2 操作步骤140.3 维护建议 openGauss学习笔记-140 openGauss 数据库运维-例行维护-例行维护表 为了保证数据库的有效运行&#xff0c;数据库必须在插入/删除操作后&#xff0c;基于…

P-Tuning v2论文概述

P-Tuning v2论文概述 P-Tuning v2论文概述前言微调的限制性P-Tuning的缺陷P-Tuning v2 摘要论文十问NLU任务优化点实验数据集预训练模型实验结果消融实验 结论 P-Tuning v2论文概述 前言 微调的限制性 微调&#xff08;fine-tuning&#xff09;是一种在预训练模型基础上进行目…

【Docker】容器数据持久化及容器互联

一、Docker容器的数据管理 1.1、什么是数据卷 数据卷是经过特殊设计的目录&#xff0c;可以绕过联合文件系统&#xff08;UFS&#xff09;&#xff0c;为一个或者多个容器提供访问&#xff0c;数据卷设计的目的&#xff0c;在于数据的永久存储&#xff0c;它完全独立于容器的…

获取所有的 font-awesome图标, 用于本地选择使用

访问 font-awesome 首页 Font Awesome 4.7.0 675款图标,Font Awesome,奥森图标,Font Awesome 4.7.0,Font Awesome中文站,Font Awesome IE7兼容处理,Font Awesome图标搜索,Font Awesome中文站-ThinkCMF 调出控制台 , 执行下面的脚步 // 获取所有的图标元素 var icons docu…

《YOLOv5原创自研》专栏介绍 CSDN独家改进创新实战专栏目录

YOLOv5原创自研 https://blog.csdn.net/m0_63774211/category_12511931.html &#x1f4a1;&#x1f4a1;&#x1f4a1;全网独家首发创新&#xff08;原创&#xff09;&#xff0c;适合paper &#xff01;&#xff01;&#xff01; &#x1f4a1;&#x1f4a1;&#x1f4a1;…
最新文章