双重注意力模块 DoubleAttention | A2-Nets: Double Attention Networks

在这里插入图片描述

论文名称: A 2 A^2 A2-Nets: Double Attention Networks》

论文地址:https://arxiv.org/pdf/1810.11579.pdf


学习捕捉远距离关系对于图像/视频识别是基础性的。现有的CNN模型通常依靠增加深度来建模这些关系,这在很大程度上效率低下。在这项工作中,我们提出了“双重注意力块”,这是一种新颖的组件,它可以从输入图像/视频的整个时空空间聚合和传播有信息的全局特征,使得后续的卷积层可以高效地访问整个空间的特征。该组件设计了一个双重注意力机制的两个步骤,第一步通过二阶注意力池将整个空间的特征聚集到一个紧凑的集合中,第二步通过另一个注意力机制自适应地选择和分配特征到每个位置。所提出的双重注意力块易于采用,并可以方便地插入到现有的深度神经网络中。我们进行了大量的消融研究和实验证明其性能。在图像识别任务中,使用我们的双重注意力块装备的ResNet-50ImageNet-1k数据集上胜过了更大的ResNet-152架构,参数数量减少了40%以上,FLOPs也减少了。在动作识别任务中,我们提出的模型在KineticsUCF-101数据集上取得了最先进的结果,并具有比最近的研究工作更高的效率。捕捉长距离关系是图像/视频识别的基础。现有的卷积神经网络 (CNN) 模型通常依赖于增加深度来建模这种关系,这种方法效率低下。在这项工作中,我们提出了 “双重注意力块”,这是一种新组件,它从输入图像/视频的整个时空中聚合和传播有用的全局特征,使后续的卷积层能够有效地访问整个空间的特征。该组件采用双重注意力机制,分两个步骤进行。第一步通过二阶注意力池化将整个空间的特征聚集到一个紧凑的集合中,第二步通过另一种注意力机制自适应地选择和分配特征到每个位置。所提出的双重注意力块易于使用,并且可以方便地插入现有的深度神经网络中。

我们在图像和视频识别任务上进行了广泛的消融研究和实验,以评估其性能。在图像识别任务中,装备了双重注意力块的 ResNet-50ImageNet-1k 数据集上表现优于更大的 ResNet-152 架构,同时参数数量减少了超过40%FLOPs 也更少。在动作识别任务中,我们的模型在 KineticsUCF-101 数据集上取得了最新的结果,效率显著高于近期的其他工作。


问题背景

在图像和视频识别领域,深度卷积神经网络(CNN)面临着有效捕获长距离关系的挑战。传统的CNN模型通常依赖增加网络深度来处理这些长距离关系,这会导致效率低下和计算成本增加。此外,局部感受野的局限性也可能导致信息传播不充分,从而影响网络的性能。因此,需要一种能够在保持低计算开销的同时有效捕获长距离关系的方法。


核心概念

本文提出了“双重注意力块”(Double Attention Block),也称为A2-Net。这是一种新的组件,能够从整个空间或时空中聚合和传播信息,帮助后续的卷积层更有效地访问全局特征。双重注意力块通过两步注意力机制实现其目标:第一步使用二阶注意力池化从整个空间中聚合特征,第二步通过另一个注意力机制自适应地选择和分配特征到每个位置。这种设计使得卷积层能够有效地感知整个空间,从而提高图像和视频识别的性能。


模块的操作步骤

在这里插入图片描述

所提出的双重注意力块的计算图。所有卷积核的大小均为 1 × 1 × 1。我们将这个双重注意力块插入现有的卷积神经网络中,例如残差网络,以构成 A2-Net


双重注意力块的操作步骤包括两个主要部分:特征聚合和特征分配。

  • 在特征聚合部分,模块使用二阶注意力池化从整个空间中选择关键特征。与传统的平均池化或最大池化不同,二阶注意力池化能够捕获和保留更复杂的关系。
  • 在特征分配部分,模块根据每个位置的局部特征自适应地分配特征,而不是像SENet那样将相同的全局特征分配到每个位置。这样,每个位置可以根据其需要接收不同的特征,从而增强整体性能。

文章贡献

本文的主要贡献在于提出了一种新的注意力机制,能够同时捕获长距离特征间的依赖关系,并以较低的计算和内存开销实现高效的特征分配。双重注意力块可以方便地插入现有的深度神经网络中,提高其性能。此外,通过在ImageNet-1kKineticsUCF-101等数据集上的实验,作者证明了双重注意力块在图像和视频识别任务上的有效性。


实验结果与应用

在实验结果部分,本文通过对比不同网络架构、调整参数和测试各种场景,证明了双重注意力块在提高网络性能方面的有效性。在ImageNet-1k分类任务中,使用双重注意力块的ResNet-50超越了更大的ResNet-152,并显著降低了计算成本。在视频识别任务中,双重注意力块在KineticsUCF-101数据集上取得了最佳性能,表明其在不同视觉任务中的广泛适用性。


对未来工作的启示

双重注意力块的成功展示了通过高效注意力机制增强深度神经网络性能的潜力。未来的研究可以探索将双重注意力块与其他类型的神经网络结合,或将其用于其他任务,如自然语言处理和音频分析。此外,研究人员可以进一步优化这种机制,以提高其在移动设备和资源受限环境中的性能。


代码

import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F


class DoubleAttention(nn.Module):

    def __init__(self, in_channels, c_m, c_n, reconstruct=True):
        super().__init__()
        self.in_channels = in_channels
        self.reconstruct = reconstruct
        self.c_m = c_m
        self.c_n = c_n
        self.convA = nn.Conv2d(in_channels, c_m, 1)
        self.convB = nn.Conv2d(in_channels, c_n, 1)
        self.convV = nn.Conv2d(in_channels, c_n, 1)
        if self.reconstruct:
            self.conv_reconstruct = nn.Conv2d(c_m, in_channels, kernel_size=1)
        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode="fan_out")
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        b, c, h, w = x.shape
        assert c == self.in_channels
        A = self.convA(x)  # b,c_m,h,w
        B = self.convB(x)  # b,c_n,h,w
        V = self.convV(x)  # b,c_n,h,w
        tmpA = A.view(b, self.c_m, -1)
        attention_maps = F.softmax(B.view(b, self.c_n, -1), dim=-1)
        attention_vectors = F.softmax(V.view(b, self.c_n, -1), dim=-1)
        # step 1: feature gating
        global_descriptors = torch.bmm(
            tmpA, attention_maps.permute(0, 2, 1)
        )  # b.c_m,c_n
        # step 2: feature distribution
        tmpZ = global_descriptors.matmul(attention_vectors)  # b,c_m,h*w
        tmpZ = tmpZ.view(b, self.c_m, h, w)  # b,c_m,h,w
        if self.reconstruct:
            tmpZ = self.conv_reconstruct(tmpZ)

        return tmpZ


if __name__ == "__main__":
    input = torch.randn(64, 256, 8, 8)
    model = DoubleAttention(in_channels=256, c_m=128, c_n=128, reconstruct=True)
    output = model(input)
    print(output.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/579329.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数字旅游打造个性化旅行体验,科技让旅行更精彩:借助数字技术,旅行者可以定制专属旅行计划,享受个性化的旅行体验

目录 一、引言 二、数字旅游的兴起与发展 三、数字技术助力个性化旅行体验 1、智能推荐系统:精准匹配旅行者需求 2、定制化旅行计划:满足个性化需求 3、实时互动与分享:增强旅行体验 四、科技提升旅行便捷性与安全性 1、移动支付与电…

boot https ssl 使用http协议访问报错

在springboot中配置ssl以后, 再次使用http访问对应的接口就会报错 可以考虑如下设置,将http访问的端口重定向到https对应的端口 import org.apache.catalina.Context; import org.apache.catalina.connector.Connector; import org.apache.tomcat.util…

图像处理ASIC设计方法 笔记18 轮廓跟踪算法的硬件加速方案

目录 1排除伪孤立点(断裂链表)方法1 限制链表的长度方法2 增加判断条件排除断裂链表方法3 排除不必要跟踪的轮廓(推荐用这个方法) P129 轮廓跟踪算法的硬件加速方案 1排除伪孤立点(断裂链表) 如果图像中某…

探索开源的容器引擎--------------Docker容器操作

目录 一、Docker 容器操作 1.1容器创建 1.2查看容器的运行状态 1.3启动容器 1.4创建并启动容器 1.4.1当利用 docker run 来创建容器时, Docker 在后台的标准运行过程是: 1.4.2在后台持续运行 docker run 创建的容器 1.4.3创建容器并持续运行容器…

Pytorch基础:torch.load_state_dict()方法在加载时不会检查类型

相关阅读 Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html?spm1001.2014.3001.5482 笔者在使用torch.nn.module的load_state_dict中出现了一个问题,一个被注册的张量在加载后居然没有变化,一开始以为是加载出现了问题&#…

Kafka 3.x.x 入门到精通(07)——Java应用场景——SpringBoot集成

Kafka 3.x.x 入门到精通(07)——Java应用场景——SpringBoot集成 4. Java应用场景——SpringBoot集成4.1 创建SpringBoot项目4.1.1 创建SpringBoot项目4.1.2 修改pom.xml文件4.1.3 在resources中增加application.yml文件 4.2 编写功能代码4.2.1 创建配置…

debian配置BIND DNS服务器

前言 局域网内有很多台主机,IP难以记忆。 而修改hosts文件又难以做到配置共享和统一,需要一台内网的DNS服务器。 效果展示 这里添加了一个域名hello.dog,将其指向为192.168.1.100。 同时,外网的域名不会受到影响,…

基于粒子滤波器的电池剩余使用寿命计算matlab仿真

目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 4.1 粒子滤波器基础 4.2 电池剩余使用寿命建模与预测 4.3 粒子滤波器在电池寿命预测中的应用 5.完整工程文件 1.课题概述 基于粒子滤波器的电池剩余使用寿命计算。根据已知的数据,预测未来…

前端框架编译器之模板编译

编译原理概述 编译原理:是计算机科学的一个分支,研究如何将 高级程序语言 转换为 计算机可执行的目标代码 的技术和理论。 高级程序语言:Python、Java、JavaScript、TypeScript、C、C、Go 等。计算机可执行的目标代码:机器码、汇…

高级IO|从封装epoll服务器到实现Reactor服务器|Part1

从封装epoll_server到实现reactor服务器(part1) 项目复习:从封装epoll_server到实现reactor服务器(part1)EPOLL模式服务器初步 select, poll, epoll的优缺点epoll的几个细节封装epoll_server基本框架先写好创建监听套接字和创建epoll模型可以Accept了吗&#xff1f…

鸿蒙OpenHarmony【轻量系统 运行】 (基于Hi3861开发板)

运行 联网配置 由于Hi3861为WLAN模组,您可以在版本编译及烧录后,通过如下操作,使开发板实现联网功能。 保持Windows工作台和Hi3861 WLAN模组的连接状态,确认串口终端显示正常。 复位Hi3861 WLAN模组,终端界面显示“…

网络攻击日益猖獗,安全防护刻不容缓

“正在排队登录”、“账号登录异常”、“断线重连”......伴随着社交软件用户的一声声抱怨,某知名社交软件的服务器在更新上线2小时后,遭遇DDoS攻击,导致用户无法正常登录。在紧急维护几小时后,这款软件才恢复正常登录的情况。 这…

视频通话实时换脸:支持训练面部模型 | 开源日报 No.235

iperov/DeepFaceLive Stars: 19.7k License: GPL-3.0 DeepFaceLive 是一个用于 PC 实时流媒体或视频通话的人脸换装工具。 可以使用训练好的人脸模型从网络摄像头或视频中交换面部。提供多个公共面部模型,包括 Keanu Reeves、Mr. Bean 等。支持自己训练面部模型以…

基于数据挖掘的斗鱼直播数据可视化分析系统

温馨提示:文末有 CSDN 平台官方提供的学长 QQ 名片 :) 1. 项目简介 随着网络直播平台的兴起,斗鱼直播作为其中的佼佼者,吸引了大量用户和观众。为了更好地理解和分析斗鱼直播中的数据,本项目介绍了一个基于数据挖掘的斗鱼直播数据…

【图解计算机网络】简单易懂的https原理解析

简单易懂的https原理解析 https与http的区别混合加密对称加密非对称加密混合加密解析混合加密问题 摘要算法数字证书数字证书原理为什么通过CA证书可以解决中间人攻击的问题呢? https握手流程 https与http的区别 http是明文传输的,非常不安全&#xff0…

呆马科技——智慧应急执法监管平台

在当今社会,安全生产的重要性日益凸显。对于各级政府和企事业单位,当务之急是如何高效地对突发事件进行执法管理。平台应运而生,旨在通过信息化、智能化技术,提升安全管理的效率与准确性。 一、平台特点 整合各类平台的信息资源&…

添加github SSH Key

添加github SSH Key 使用 SSH 协议,您可以连接远程服务器和服务并对其进行身份验证。使用 SSH 密钥,您可以连接到 GitHub,而无需在每次访问时提供您的用户名和个人访问令牌。您还可以使用 SSH 密钥来签署提交。 #3224333333qq.com替换为你自己…

6.NVIC中断配置(ST的精简ARM中断体系)

void NVIC_SetPriorityGrouping(uint32_t PriorityGroup)//设置优先级分组,整个项目共用一个分组 uint32_t NVIC_EncodePriority (uint32_t PriorityGroup, uint32_t PreemptPriority, uint32_t SubPriority) //计算优先级编码值,(组号&…