YOLOv5/v7 添加注意力机制,30多种模块分析③,GCN模块,DAN模块

目录

    • 一、注意力机制介绍
      • 1、什么是注意力机制?
      • 2、注意力机制的分类
      • 3、注意力机制的核心
    • 二、GCN 模块
      • 1、GCN 模块的原理
      • 2、实验结果
      • 3、应用示例
    • 三、DAN模块
      • 1、DAN模块的原理
      • 2、实验结果
      • 3、应用示例

大家好,我是哪吒。

🏆本文收录于,目标检测YOLO改进指南。

本专栏均为全网独家首发,内附代码,可直接使用,改进的方法均是2023年最近的模型、方法和注意力机制。每一篇都做了实验,并附有实验结果分析,模型对比。


在机器学习和自然语言处理领域,随着数据的不断增长和任务的复杂性提高,传统的模型在处理长序列或大型输入时面临一些困难。传统模型无法有效地区分每个输入的重要性,导致模型难以捕捉到与当前任务相关的关键信息。为了解决这个问题,注意力机制(Attention Mechanism)应运而生。

一、注意力机制介绍

1、什么是注意力机制?

注意力机制(Attention Mechanism)是一种在机器学习和自然语言处理领域中广泛应用的重要概念。它的出现解决了模型在处理长序列或大型输入时的困难,使得模型能够更加关注与当前任务相关的信息,从而提高模型的性能和效果。

本文将详细介绍注意力机制的原理、应用示例以及应用示例。

2、注意力机制的分类

类别描述
全局注意力机制(Global Attention)在计算注意力权重时,考虑输入序列中的所有位置或元素,适用于需要全局信息的任务。
局部注意力机制(Local Attention)在计算注意力权重时,只考虑输入序列中的局部区域或邻近元素,适用于需要关注局部信息的任务。
自注意力机制(Self Attention)在计算注意力权重时,根据输入序列内部的关系来决定每个位置的注意力权重,适用于序列中元素之间存在依赖关系的任务。
Bahdanau 注意力机制全局注意力机制的一种变体,通过引入可学习的对齐模型,对输入序列的每个位置计算注意力权重。
Luong 注意力机制全局注意力机制的另一种变体,通过引入不同的计算方式,对输入序列的每个位置计算注意力权重。
Transformer 注意力机制自注意力机制在Transformer模型中的具体实现,用于对输入序列中的元素进行关联建模和特征提取。

3、注意力机制的核心

注意力机制的核心思想是根据输入的上下文信息来动态地计算每个输入的权重。这个过程可以分为三个关键步骤:计算注意力权重、对输入进行加权和输出。首先,计算注意力权重是通过将输入与模型的当前状态进行比较,从而得到每个输入的注意力分数。这些注意力分数反映了每个输入对当前任务的重要性。对输入进行加权是将每个输入乘以其对应的注意力分数,从而根据其重要性对输入进行加权。最后,将加权后的输入进行求和或者拼接,得到最终的输出。注意力机制的关键之处在于它允许模型在不同的时间步或位置上关注不同的输入,从而捕捉到与任务相关的信息。

🏆YOLOv5/v7 添加注意力机制,30多种模块分析①,SE模块,SK模块

🏆YOLOv5/v7 添加注意力机制,30多种模块分析②,BAM模块,CBAM模块

二、GCN 模块

1、GCN 模块的原理

GCN(Global Context Network)模块是一种用于计算机视觉领域的深度学习模型中的注意力机制。它由 Tsinghua 大学的 Cao et al. 在 2019 年提出,旨在通过给神经网络提供全局上下文信息来提高图像分类、分割、检测等任务的性能。

GCN模块的核心思想是利用自适应的全局平均池化(Adaptive Global Average Pooling),根据每个通道的重要性对其进行加权,将全局范围内的并行卷积特征映射融合成一个全局语义向量,从而增强模型对局部和全局特征的感知能力。

GCN模块的具体实现如下所示:

import torch.nn as nn
import torch.nn.functional as F

class ContextBlock2d(nn.Module):

    def __init__(self, in_channels, ratio, pooling_type='att', fusion_types=('channel_add', )):
        super(ContextBlock2d, self).__init__()
        assert pooling_type in ['avg', 'att']
        assert isinstance(fusion_types, (list, tuple))
        
        self.in_channels = in_channels
        self.ratio = ratio
        self.planes = int(in_channels // ratio)
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types

        if 'channel_add' in fusion_types:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.in_channels, self.in_channels, kernel_size=1),
                nn.LayerNorm([self.in_channels, 1, 1])
            )
        
        if 'channel_mul' in fusion_types:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.in_channels, self.in_channels, kernel_size=1),
                nn.LayerNorm([self.in_channels, 1, 1]),
                nn.Sigmoid()
            )

        if 'spatial' in fusion_types:
            self.spatial_conv = nn.Sequential(
                nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1])
            )
        
        if pooling_type == 'att':
            self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1)

    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        input_x = x
        # [N, C, H * W]
        input_x = input_x.view(batch, channel, height * width)
        # [N, C, 1, 1]
        spatial_output = F.adaptive_avg_pool2d(x, output_size=(1, 1))
        # [N, C, 1, 1]
        spatial_output = self.spatial_conv(spatial_output)
        # [N, C, 1, 1]
        spatial_output = F.relu(spatial_output, inplace=True)
        # [N, C, 1, 1]
        spatial_output = F.interpolate(spatial_output, size=(height, width), mode='nearest')
        # [N, 1, H, W]
        output = F.softmax(spatial_output, dim=1)
        return output

    def forward(self, x):
        batch, channel, height, width = x.size()
        
        # calculate the input tensor for calculating correlation matrix
        input_x = x
        
        if self.pooling_type == 'avg':
            # N x C x 1 x 1
            context_mask = F.adaptive_avg_pool2d(x, output_size=(1, 1))
        elif self.pooling_type == 'att':
            # N x C x 1 x 1
            context_mask = F.relu(self.conv_mask(x))
        
        # N x C x 1 x 1
        context_mask = F.interpolate(context_mask, size=(height, width), mode='nearest')

        # N x C x H x W
        context = x * context_mask
        if 'channel_mul' in self.fusion_types:
            # N x C x 1 x 1
            avg_context = torch.sum(context, dim=(2, 3), keepdim=True) / (height * width)
        # N x C x H x W
        context = context * avg_context
         if 'channel_add' in self.fusion_types:
        # N x C x H x W
        channel_add_term = self.channel_add_conv(context)
        # N x C x H x W
        context = context + channel_add_term
    
        output = context

        if 'spatial' in self.fusion_types:
        # N x 1 x H x W
        spatial_attention = self.spatial_pool(x)
        # N x C x H x W
        output = output * spatial_attention
    
        return output

在这个实现中,ContextBlock2d 类接受输入张量 x,并提供了以下四种融合策略:

  • channel_add:通过一个卷积层和 LayerNorm 实现的通道级别加法操作。

  • channel_mul:通过一个卷积层、Sigmoid 激活函数和 LayerNorm 实现的通道级别乘法操作。

  • spatial:通过自适应平均池化和一组卷积层实现的空间级别的特征融合。

  • att:通过一组卷积层和 Softmax 函数实现的注意力机制。

GCN模块的实现中用到了以下技巧:

  • 自适应平均池化:对于不同大小的输入,使用自适应的池化核大小以得到固定大小的输出。
  • Sigmoid 激活函数:对于 channel_mul 融合策略,使用 Sigmoid 函数将权重限制在 (0, 1) 范围内。
  • LayerNorm:对于 channel_add 和 channel_mul 融合策略,使用 LayerNorm 对特征图进行归一化操作。
  • Softmax 函数:对于 att 融合策略,使用 Softmax 函数计算注意力值。

2、实验结果

在Kinetics验证集上,使用R50作为骨干的Slow-only基线下,GCNet和NLNet的结果如下:

methodTop-1 AccTop-5 Acc#params(M)FLOPs(G)
baseline74.9491.9032.4539.29
+5 NL75.9592.2939.8159.60
+5 SNL75.7692.4436.1339.32
+5 GC75.8592.2534.3039.31
+all GC76.0092.3442.4539.35

3、应用示例

在 YOLOv5 中,GCNet 模块被应用于 CSPDarknet53 特征提取器中,以增强模型的感受野和上下文信息。具体来说,GCNet 模块是通过在通道维度上进行全局上下文编码来实现的。

下面是在 YOLOv5 中使用 GCNet 模块的应用示例:

import torch.nn as nn
from models.common import Conv

class GCNet(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(GCNet, self).__init__()
        self.conv1x1 = Conv(in_channels, in_channels // reduction, 1)
        self.conv3x3 = Conv(in_channels, in_channels // reduction, 3, padding=1)
        self.global_context = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            Conv(in_channels, in_channels // reduction, 1),
            nn.ReLU(inplace=True),
            Conv(in_channels // reduction, in_channels, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        feat1 = self.conv1x1(x)
        feat2 = self.conv3x3(x)
        gc = self.global_context(feat2)
        feat = feat1 * gc.expand_as(feat1) + feat2
        return feat

class CSPBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_blocks, darknet_lite=False):
        super(CSPBlock, self).__init__()
        self.conv1 = Conv(in_channels, out_channels, 1)
        self.conv2 = Conv(in_channels, out_channels, 1)
        self.conv3 = Conv(out_channels * 2, out_channels, 1)
        self.conv4 = Conv(out_channels * 2, out_channels * 2, 3, padding=1, groups=out_channels * 2)
        self.conv5 = Conv(out_channels * 2, out_channels, 1)
        self.layers = nn.Sequential(*[
            ResBlock(out_channels, darknet_lite) for _ in range(num_blocks)
        ])
        self.gc_block = nn.Sequential(
            Conv(out_channels, out_channels // 2, 1),
            GCNet(out_channels // 2),
            Conv(out_channels, out_channels, 1)
        )

    def forward(self, x):
        feat1 = self.conv1(x)
        feat2 = self.conv2(x)
        feat2 = self.layers(feat2)
        feat2 = self.gc_block(feat2)
        feat2 = torch.cat([feat2, feat1], dim=1)
        feat2 = self.conv3(feat2)
        feat2 = self.conv4(feat2)
        feat2 = self.conv5(feat2)
        return feat2

在上述代码中,我们首先定义了一个 GCNet 类和一个 CSPBlock 类。GCNet 类是实现全局上下文编码的模块,而 CSPBlock 是 YOLOv5 中的一个基本块。

在 CSPBlock 中,我们使用 GCNet 模块来增强模型的感受野和上下文信息。具体来说,我们将 GCNet 模块放置于 CSPBlock 的末尾,并将其输入特征图和经过卷积操作的另一份特征图进行拼接,最后再通过几个卷积层输出特征图。这样做可以使模型更好地处理不同尺度的物体。

三、DAN模块

1、DAN模块的原理

在这里插入图片描述

DANet(Dual Attention Network)模块是一种新型的注意力机制,广泛应用于计算机视觉领域中的图像分割任务。该模块由位置注意力模块和通道注意力模块组成,能够自适应地对输入图像中的关键区域进行加强,从而提高了图像分割的精度。

在这里插入图片描述

在位置注意力模块中,通过构建一个全局上下文信息嵌入层来获取位置感知初始特征,然后使用空间转换网络(Spatial Transform Network,STN)来自适应地调整这些特征的空间位置,以使其更好地匹配目标对象的形状和大小。进一步地,通过使用一个门控方案,位置注意力模块可以选择性地增强或抑制每个特征通道的激活,以便更好地突出目标对象。

在这里插入图片描述

在通道注意力模块中,首先提取特征图的全局信息,并通过一个门控方案将其与特征图中每个通道的激活相乘,以得到通道加权的响应。接着,在通道特征响应映射(Channel Feature Response Map,CFRM)中,使用类似于SENet(Squeeze and Excitation Network)的方式来生成通道注意力图。最后,将通道注意力图与特征图相乘,以获得增强后的特征响应。

2、实验结果

在PASCAL VOC 2012和Cityscapes等多个数据集上进行的大量实验表明,DANet模块相对于其他主流方法具有更好的图像分割性能。例如,在Cityscapes数据集中,使用DANet模块的分割网络在Mean IoU指标上达到了81.5%,优于其他方法。

在这里插入图片描述

3、应用示例

以下是一个使用DANet模块的应用示例片段,其中包括了DANet模块的定义和在YOLOv5中的应用:

import torch.nn as nn

class DANet(nn.Module):
    def __init__(self, in_channels):
        super(DANet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels // 4, kernel_size=1)
        self.conv2 = nn.Conv2d(in_channels, in_channels // 4, kernel_size=1)
        self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.softmax = nn.Softmax(dim=-1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        proj_query = self.conv1(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.conv2(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.conv3(x).view(batch_size, -1, width * height)

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, channels, height, width)
        out = self.gamma * out + x

        return out

# 在YOLOv5中使用DANet模块
class YOLOv5(nn.Module):
    def __init__(self):
        super(YOLOv5, self).__init__()
        # ... 略去其他层的定义
        self.da_conv1 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(negative_slope=0.1),
            DANet(in_channels=512),  # 添加DANet模块
        )
        # ... 略去其他层的定义

    def forward(self, x):
        # ... 略去前面的网络部分
        x = self.da_conv1(x)
        # ... 略去后续的网络部分
        return x

在上述代码中,我们首先定义了一个DANet模块。在该模块中,将输入张量x通过三个卷积层,得到了三个特征张量:proj_queryproj_key和proj_value。其中,proj_queryproj_key用来计算注意力权重矩阵(attention),而proj_value作为输出特征的候选。

在DANet模块的forward函数中,我们首先计算了proj_query和proj_key之间的点积,得到了能量矩阵energy,然后对energy进行softmax操作,得到了注意力权重矩阵attention。最后,我们将proj_valueattention进行矩阵乘法操作,得到了加权后的输出特征张量out。

在YOLOv5模型中,我们通过在卷积层之间插入DANet模块来提高特征表示能力。例如,在上述代码中,我们定义了一个包含DANet模块的卷积层da_conv1,并在其中使用了DANet模块来增强特征表示能力。

参考论文:

  1. https://arxiv.org/abs/1904.11492
  2. https://arxiv.org/abs/1809.02983

在这里插入图片描述

🏆本文收录于,目标检测YOLO改进指南。

本专栏均为全网独家首发,🚀内附代码,可直接使用,改进的方法均是2023年最近的模型、方法和注意力机制。每一篇都做了实验,并附有实验结果分析,模型对比。

🏆华为OD机试(JAVA)真题(A卷+B卷)

每一题都有详细的答题思路、详细的代码注释、样例测试,订阅后,专栏内的文章都可看,可加入华为OD刷题群(私信即可),发现新题目,随时更新,全天CSDN在线答疑。

🏆哪吒多年工作总结:Java学习路线总结,搬砖工逆袭Java架构师。

🏆往期回顾:

1、YOLOv5/v7 添加注意力机制,30多种模块分析①,SE模块,SK模块

2、YOLOv5/v7 添加注意力机制,30多种模块分析②,BAM模块,CBAM模块

3、YOLOv5结合BiFPN,如何替换YOLOv5的Neck实现更强的检测能力?

4、YOLOv5结合BiFPN:BiFPN网络结构调整,BiFPN训练模型训练技巧

5、YOLOv7升级换代:EfficientNet骨干网络助力更精准目标检测

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/27770.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

AMC12和高考数学哪个更难?知识点有哪些不同?

AMC12和高考数学哪个更难?知识点有哪些不同?今天小编给大家来详细介绍一下! 难度对比 从难度上看,高考数学的计算量更大,并且知识点比AMC10/12超前,需要用到极限和微积分的知识。 反观AMC10/12不需要用到…

数据结构与算法之美 | 栈

栈结构:后进者先出,先进者后出 栈是一种“操作受限”的线性表 当某个数据集合只涉及在一端插入和删除数据,并且满足后进先出、先进后出的特性,这时我们就应该首选“栈”这种数据结构 栈的实现 使用数组实现:顺序栈…

初探图神经网络——GNN

title: 图神经网络(GNN) date: tags: 随笔知识点 categories:[学习笔记] 初探图神经网络(GNN) 文章来源:https://distill.pub/2021/gnn-intro/ 前言:说一下为什么要写这篇文章,因为自己最近一直听说“图神经网络”,但是一直不了…

pycharm使用之torch_sparse安装

正式安装之前要先查看一下torch的版本 一、查看torch版本 1、winR ,输入cmd 2、输入python 3、 输入import torch,然后输入torch.__version__,最后回车 可以看到我的torch版本是1.10.0 二、下载合适的torch_sparse版本 1、打开链接 https…

接口反应慢优化

遇到某个功能,页面转圈好久,需要优化 1.F12 查看接口时间 2.看参数 总共耗时9.6s Waiting for sercer response 时间是2秒 Content Download 7秒 慢在Content Download F12查看接口响应 显示Failed to load response data:Request content was e…

spark入门 高可用部署HA(五)

一、standalone基于修改部署 https://blog.csdn.net/weixin_43205308/article/details/131070277?spm1001.2014.3001.5501 二、安装ZOOKEEPER zookeeper 安装下载与集群 三、修改conf下的spark-env.sh vim conf/spark-env.sh注释以下内容(根据自己环境修改&am…

visual studio 2022,ADO.NET 实体数据模型添加 sqlite数据库对象

文章目录 前言前期环境博客github 文档解析文件安装说明文件下载省流版nuget环境配置成功标志sqlite连接测试 前言 我们知道ADO.NET 实体数据模型特别适合动态开发数据库。因为ADO.NET可以使用DB First 开发 我们在开发一个程序的时候,经常会动态更新数据库字段&a…

算法模板(3):搜索(4):高等图论

高等图论 有向图的强连通分量 相关概念 强连通分量:Strongly Connected Component (SCC).对于一个有向图顶点的子集 S S S,如果在 S S S 内任取两个顶点 u u u 和 v v v,都能找到一条 u u u 到 v v v 的路径,那么称 S S…

C++多态和文件读写

C黑马,每天1.5倍速2个视频(1小时),看到9月1日完成314个视频 目录 🔑多态 🌳基本语法 🌳原理剖析 🌳案例1 -- 计算器类 🌳纯虚函数和抽象类 🌳案例2 --…

redis知识复习

redis知识复习 redis基础知识一. redis的认识1. 非关系型数据库 与 传统数据库 的区别2. 安装redis并设置自启动3. 熟悉命令行客户端4. 熟悉图形化工具RDM 二. redis的命令与数据结构1. 数据结构介绍2. redis通用命令(熟练掌握) 三. redis的Java客户端1.…

SpringBoot整合Flyway实现数据库的初始化和版本管理

文章目录 一、Flyway1、介绍2、业务痛点3、个人理解 二、SpringBoot整合flyway1、整合2、SQL文件命名3、版本号校验算法4、工作流程5、注意事项 一、Flyway 1、介绍 Flyway 是一款开源的数据库版本管理工具。它可以很方便的在命令行中使用,或者在Java应用程序中引入…

【MySQL】数据表的基本操作

目录 1. 创建表 2. 创建表案例 2.1 创建一个users表 2.2 查看表结构 2.3 修改表 3. 删除表 MySQL🌷 1. 创建表 语法: CREATE TABLE table_name (field1 datatype,field2 datatype,field3 datatype ) character set 字符集 collate 校验规则 engine 存储…

chatgpt赋能python:如何升级Python的pip版本

如何升级Python的pip版本 如果你使用Python来进行程序开发,那么你一定需要用到pip,它是Python的包管理器,用于安装和管理各种Python库。 不过,一旦你开始使用pip,你可能会遇到一个问题:你的pip版本可能会…

几种技巧让大模型(ChatGPT、文心一言)帮你提高写代码效率!

代码神器 自从大模型推出来之后,似乎没有什么工作是大模型不能做的。特别是在文本生成、文案写作、代码提示、代码生成、代码改错等方面都表现出不错的能力。下面我将介绍运用大模型写代码的几种方式,帮助程序员写出更好的代码!(…

利用AI点亮副业变现:5个变现实操案例的启示

AI变现副业实操案例 宝宝起名服务AI科技热点号头像壁纸职业头像收徒:萌娃头像定制头像平台挂载 小说推广号流量营销号百家号AI共创计划公众号流量主 知识付费知识星球小报童: 整体思维导图: 在这里先分享五个实操案例: 宝宝起名服务AI科技热…

cvte 前端一面 凉经

cvte 前端一面 凉经 原文面试题地址:https://www.nowcoder.com/discuss/353159272857018368?sourceSSRsearch 1. vuex原理 和vuerouter的原理差不多 2. vuerouter的原理 ​ 首先在main.js中,import router from ‘./router’ 引入在router文件夹下面…

学习WooCommerce跨境电商社交媒体营销

WooCommerce 长期以来一直为电子商务店主提供多样化的服务。大约 500 万家商店啓用安装了免费的 WooCommerce 插件。 官方 WooCommerce 插件从 WordPress.org 下载了161,908,802次,并且还在增加。 超过5,106,506 个网站正在使用 WooCommerce。 本文网址: https…

一文搞懂什么是Docker

一、什么是Docker 微服务虽然具备各种各样的优势,但服务的拆分通用给部署带来了很大的麻烦。 分布式系统中,依赖的组件非常多,不同组件之间部署时往往会产生一些冲突。在数百上千台服务中重复部署,环境不一定一致,会遇…

LVS+Keepalived 群集

目录 一、keepalived概述 1.keepalived工作原理 2.keepalived体系主要模块及其作用 3.判断服务器主备,及如何配置浮动IP 二、keepalived的抢占与非抢占模式 三、部署LVSkeepalived 1.配置负载调度器(主备相同) 1.1配置keepalived&…

NVM安装教程

我是小荣,给个赞鼓励下吧! NVM安装教程 简介 nvm 是node.js的版本管理器,设计为按用户安装,并按 shell 调用。nvm适用于任何符合 POSIX 的 shell(sh、dash、ksh、zsh、bash),特别是在这些平台…