即插即用模块:Convolutional Triplet注意力模块(论文+代码)

目录

一、摘要

二、创新点总结

三、代码详解


论文:https://arxiv.org/pdf/2010.03045v2

代码:https://github. com/LandskapeAI/triplet-attention

一、摘要

由于注意机制具有在通道或空间位置之间建立相互依赖关系的能力,近年来在各种计算机视觉任务中得到了广泛的研究和应用。在本文中,我们研究了轻量级但有效的注意机制,并提出了三重注意,这是一种利用三分支结构捕获跨维交互来计算注意权重的新方法。对于输入张量,三元组注意力通过旋转操作建立维度间依赖关系,然后进行残差变换,并以可忽略不计的计算开销对通道间和空间信息进行编码。我们的方法简单高效,可以作为附加模块轻松插入经典骨干网。我们证明了我们的方法在各种具有挑战性的任务上的有效性,包括在ImageNet-1k上的图像分类和在MSCOCO和PASCAL VOC数据集上的目标检测。此外,我们通过视觉检查GradCAM和GradCAM++结果,提供了对三重注意力性能的广泛洞察。对我们方法的经验评估支持了我们在计算注意力权重时捕获跨维度依赖关系的重要性的直觉。

用三个分支捕获跨维交互的三重注意的抽象表示。给定输入张量,三重态注意力通过旋转输入张量,然后进行残差变换来捕获维间依赖关系。

二、创新点总结

不同注意模块的比较:

(a)挤压激励(SE)模块;

(b)卷积块注意模块(CBAM);

(c)全局上下文模块;

(d)(我们的)三重注意力模块。

特征映射表示为特征维度,例如:

C × H × W表示通道号为C、高H、宽W的特征映射。⊗表示矩阵乘法,⊙表示广播元素明智乘法,⊕表示广播元素明智加法。

图3 三重注意力的图解,它由三个分支。顶部分支负责计算跨通道维度C和空间W的注意力权重,中间捕捉的是C和H之间的权重,类似的底部分支获取的是H和W的空间相关性。在前两个分支中,我们采用旋转操作来建立通道维度与空间维度中的任一个之间的连接,最后通过简单的平均来合计权重。

跨维度交互:计算通道注意力的传统方法包括计算singular权重,通常是输入张量中每个通道的标量,然后使用奇异权重统一缩放这些特征映射。虽然这个计算通道注意力的过程被证明是非常轻量级的和非常成功的,但是在考虑这种方法有一个重要的确实。为了计算这些通道的奇异权重,通过执行全局平均池化,输入张量被空间分解为每一个通道一个像素。这导致了空间信息的主要损失,因此当计算对这些单像素通道的关注时,通道尺寸和空间尺寸之间的相互依赖性不存在。CBAM引入了空间注意作为通道注意力的补充模块,简而言之,空间注意力告诉’通道的什么地方聚焦,通道注意力告诉’聚焦在哪个通道。然而,这个过程的缺点是通道注意力和空间注意力是分离的,并且彼此独立地计算。因此不考虑两者之间的任何关系。受到建立空间注意力方式的启发,我们提出了跨维度交互的概念,通过捕捉输入张量的空间维度和通道维度之间的交互来解决这个缺点。我们在三重注意力中引入了跨维度相互作用,通过三个分支分别获得张量(C,H)、(C,W)和(H,W)维之间的依赖关系。

Z-pool : 这里的Z池化层负责将第0个维度缩减为两个维度,方法是将该维度上的平均池化和最大池化要素串联起来。这使得该层能够保留实际张量的丰富表示,同时缩小其深度,以使进一步的计算变得轻量级。在数学上可以用如下公式:

其中0d是发生最大和平均池化的操作的第0维度。例如,一个形状张量为(CxHxW)最后可以生成一个形状张量(2xHxW)的张量。

Triplet Attention:给定上述定义的操作,我们将三重注意定义为一个三分支模块,它接受一个输入张量并输出一个相同形状的细化张量。给定一个输入张量X ∈ R(CxHxW),我们首先把它传递给提出三重注意模块中的每一个。

在第一个分支中,我们构建了高度维度和通道维度之间的交互。为此,输入X沿着H轴逆时针旋转90度。这个旋转张量表示为形状(WxHxC),X1然后通过一个Z-pool,随后被简化为形状为(2 x H x C) ,X1然后通过内核大小为7X7的标准卷积层,随后就是批量归一化层,其提供维度的中间输出(1 x H x C)。然后通过张量穿过sigmod激活层(σ)来生成最终的注意力权重。随后将生成的注意力权重应用于X1,然后沿H轴顺时针旋转90°,以保持x的原始形状输入。

同样的,在第二个分支中,我们沿着W轴逆时针旋转90°。旋转张量X2可以用(H x C x W)表示,并通过一个Z池化层。因此张量被简化为形状x2为(2xCxW)。X2 通过 由核大小为k x k 定义的标准卷积层,随后是批量归一化层,其输出形状的张量(1 x C x W)。然后通过使该张量通过sigmod激活层来获得注意力权重,然后简单地应用于X2,并且输出随后沿着W轴顺时针旋转90°,以保持与输入X相同的形状。

对于最后一个分支,输入张量X的通道被Z池化为两个。然后,该形状的简化张量X3(2 x H x W)通过由核大小k定义的标准卷积层,随后是批量归一化层。输出通过sigmod激活层(σ)生成形状注意力权重(1 x H x W),然后应用于输入X。然后,由三个分支中的每一个生成的形状的精细张量(C x H x W)通过简单平均来聚集。

总之输入张量X ∈ R(C x H x W)的三重注意力中获得的精细注意力应用张量y的过程可以由以下等式表示:

其中σ代表sigmod激活函数;ψ1、ψ2和ψ3代表三重注意的三个分支中由核大小k定义的标准二维卷积层。简单的来说y可以变为:

其中ω1、ω2和ω3是在三重注意中计算的三个交叉维度注意权重。等式中的y1和y2 在上述等式中,代表90°顺时针旋转以保持(C × H × W)的原始输入形状。

三、代码详解

代码:

import torch
import torch.nn as nn

# 定义一个基础的卷积模块
class BasicConv(nn.Module):
    def __init__(
        self,
        in_planes,      # 输入通道数
        out_planes,     # 输出通道数
        kernel_size,    # 卷积核大小
        stride=1,       # 步长
        padding=0,      # 填充
        dilation=1,     # 空洞率
        groups=1,       # 分组卷积的组数
        relu=True,      # 是否使用ReLU激活函数
        bn=True,        # 是否使用批标准化
        bias=False,     # 卷积是否添加偏置
    ):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        # 定义卷积层
        self.conv = nn.Conv2d(
            in_planes,
            out_planes,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
        )
        # 可选的批标准化层
        self.bn = (
            nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True)
            if bn
            else None
        )
        # 可选的ReLU激活层
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

# 定义一个通道池化模块
class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat(
            (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1
        )

# 定义一个空间门控模块
class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(
            2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False
        )

    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = torch.sigmoid_(x_out)
        return x * scale

# 定义一个三元注意力模块
class TripletAttention(nn.Module):
    def __init__(
        self,
        gate_channels,         # 门控通道数
        reduction_ratio=16,    # 缩减比率
        pool_types=["avg", "max"],  # 池化类型
        no_spatial=False,      # 是否禁用空间门控
    ):
        super(TripletAttention, self).__init__()
        self.ChannelGateH = SpatialGate()
        self.ChannelGateW = SpatialGate()
        self.no_spatial = no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()

    def forward(self, x):
        x_perm1 = x.permute(0, 2, 1, 3).contiguous()
        x_out1 = self.ChannelGateH(x_perm1)
        x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()
        x_perm2 = x.permute(0, 3, 2, 1).contiguous()
        x_out2 = self.ChannelGateW(x_perm2)
        x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()
        if not self.no_spatial:
            x_out = self.SpatialGate(x)
            x_out = (1 / 3) * (x_out + x_out11 + x_out21)
        else:
            x_out = (1 / 2) * (x_out11 + x_out21)
        return x_out

论文解读:大佬

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/610209.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Web功能测试之表单、搜索测试

初入职场接触功能测试老是碰到以下情况不知道怎么写测试用例: 一个界面很多搜索条件怎么写用例? 下拉框测试如何考虑测试点? 上传要考虑哪些验证点?...... 所以这篇主要是整理关于web测试之表单、搜索测试的相关要点。 一、表…

小程序(三)

十三、自定义组件 (二)数据方法声明位置 在js文件中 A、数据声明位置:data中 B、方法声明位置methods中,这点和普通页面不同! Component({/*** 组件的属性列表*/properties: {},/*** 组件的初始数据*/data: {isCh…

离线使用evaluate

一、目录 步骤demorouge-n 含义 二、实现 步骤 离线使用evaluate: 1. 下载evaluate 文件:https://github.com/huggingface/evaluate/tree/main2. 离线使用 路径/evaluate-main/metrics/rougedemo import evaluate离线使用evaluate: 1. 下载evaluate 文件&…

# 从浅入深 学习 SpringCloud 微服务架构(十五)

从浅入深 学习 SpringCloud 微服务架构(十五) 一、SpringCloudStream 的概述 在实际的企业开发中,消息中间件是至关重要的组件之一。消息中间件主要解决应用解耦,异步消息,流量削锋等问题,实现高性能&…

Java中使用RediSearch进行高效数据检索

RediSearch是一款构建在Redis上的搜索引擎,它为Redis数据库提供了全文搜索、排序、过滤和聚合等高级查询功能。通过RediSearch,开发者能够在Redis中实现复杂的数据搜索需求,而无需依赖外部搜索引擎。本文将介绍如何在Java应用中集成并使用Red…

3. 多层感知机算法和异或门的 Python 实现

前面介绍过感知机算法和一些简单的 Python 实践,这些都是单层实现,感知机还可以通过叠加层来构建多层感知机。 2. 感知机算法和简单 Python 实现-CSDN博客 1. 多层感知机介绍 单层感知机只能表示线性空间,多层感知机就可以表示非线性空间。…

切割链表 问题的讲解和实现(带哨兵位)

一:题目 二:思路讲解 先 将小于x的放进一个新的链表,将≥x的也放进另一个新链表。 然后 第一个新链表的尾节点的next链接到第二个新链表的哨兵节点的next,因为本身不存在哨兵位。 最后 链接完成后的链表的最后一个节点的next一…

python数据分析——数据的选择和运算

数据的选择和运算 前言一、数据选择NumPy的数据选择一维数组元素提取示例 多维数组行列选择、区域选择示例 花式索引与布尔值索引布尔索引示例一示例二 花式索引示例一示例二 Pandas数据选择Series数据获取DataFrame数据获取列索引取值示例一示例二 取行方式示例loc() 方法示例…

天地图2024版正式启用!首次开放多时相影像专题,可直接查看历史影像

近日,国家基础地理信息中心正式发布了天地图2024版。 新版本更新2米分辨率遥感影像794万平方公里、优于1米分辨率遥感影像655万平方公里,在线服务2米分辨率遥感影像覆盖全部陆地国土、优于1米遥感影像覆盖陆地国土达到98%,同时新增多时相遥感…

航空科技:探索飞机引擎可视化技术的新视界

随着航空技术的飞速发展,飞机引擎作为航空器最为关键的部件之一,其性能直接影响到飞机的安全性、经济性和环保性。因此,飞机引擎可视化技术的应用日益成为航空行业研究和发展的热点。 通过图扑将复杂的飞机引擎结构和工作原理以直观、生动的…

Linux中的日志系统简介

在的Linux系统上使用的日志系统一般为rsyslogd。rsyslogd守护进程既能接收用户进程输出的日志,又能接收内核日志。用户进程是通过调用syslog函数生成系统日志的。该函数将日志输出到一个UNIX本地域socket类型(AF_UNIX)的文件/dev/log中&#…

动联再掀创新风潮!P92 Max智能POS机惊艳发布

当下,智能支付与零售行业正经历着深刻变革,移动支付、无人支付等新型支付方式在我国广泛应用,显著优化了消费者的支付体验,同时也为零售行业带来新的发展契机。动联,凭借其在身份认证领域的深厚技术底蕴与创新精神&…

CRM客户关系管理系统源码部署/售后更新/搭建上线维护

基于ThinkPHPFastAdmin开发的CRM客户关系管理系统,专门为企业销售团队量身定制的工具,它能够有效地管理跟进客户,提高销售业绩!提供无加密源代码,可以自行根据不同企业的需求进行开发定制。Uniapp版本(高级授权)支持编…

微服务保护-学习笔记

微服务保护 1 Sentinel入门1.1 雪崩问题1.1.1 什么是雪崩问题1.1.2 如何解决雪崩设置超时时间仓壁模式熔断降级限流(预防措施) 1.2 服务保护技术对比1.3 Sentinel介绍与安装Sentinel 特征Sentinel 安装 1.4 Sentinel整合微服务 2 流量控制3 熔断隔离和降…

【通信协议解析】WiFi协议解析

WiFi协议解析 一、发展由来 Wi-Fi,又称“无线网络”,是Wi-Fi联盟的商标,一个基于IEEE 802.11标准的无线局域网技术。“Wi-Fi”常写作“WiFi”或“Wifi”,但是这些写法并没有被Wi-Fi联盟认可。Wi-Fi产品经由Wi-Fi联盟的一家独立授…

MT7516A-ASEMI变频器专用MT7516A

编辑:ll MT7516A-ASEMI变频器专用MT7516A 型号:MT7516A 品牌:ASEMI 封装:MT-5 最大重复峰值反向电压:1600V 最大正向平均整流电流(Vdss):75A 功率(Pd):大功率 芯片个数:5 引…

数据结构与算法学习笔记八-二叉树的顺序存储表示法和实现(C语言)

目录 前言 1.数组和结构体相关的一些知识 1.数组 2.结构体数组 2.二叉树的顺序存储表示法和实现 1.定义 2.初始化 3.先序遍历二叉树 4.中序遍历二叉树 5.后序遍历二叉树 6.完整代码 前言 二叉树的非递归的表示和实现。 1.数组和结构体相关的一些知识 1.数组 在C语…

百倍潜力股Aleo即将上线,布局正当时!牛市来时,你得有币!

前言 在加密货币市场,2024年被众多市场专家预测为迎来新一轮牛市的关键年份。这一预测背后,潜藏着多种可能推动牛市的因素。其中,下一次比特币(BTC)的减半事件,以及2024年 BTC 现货ETF的推出,都…

二叉树-堆

树 在数据库中,树是一种数据结构,用于组织和存储数据,使得可以高效地进行插入、删除和查找操作。它通常用于表示层次关系或者有序集合。 基本概念 节点:树结构中的每个元素都称为节点。 根节点:树的最顶端节点。 子…

智能奶柜:健康生活新风尚

智能奶柜:健康生活新风尚 在快节奏的都市生活中,健康与便利成为了现代人的双重追求。而在这两者交汇之处,智能奶柜应运而生,它不仅是科技与生活的完美融合,更是日常营养补给的智慧之选。 清晨的第一缕温暖 —— 新鲜…
最新文章