CBAM解析及代码(Pytorch)

CBAM,全称Convolutional Block Attention Module,是一种注意力机制模块,用于增强卷积神经网络(CNN)的特征表达能力。该模块由通道注意力模块和空间注意力模块两部分组成,能够分别关注输入特征图的通道信息和空间信息,进而提升模型对于重要特征的关注度。

在通道注意力模块中,CBAM通过全局平均池化和最大池化操作捕获通道间的依赖关系,生成两个通道描述子。这两个描述子随后通过共享的全连接层和ReLU激活函数进行变换,再经过Sigmoid函数得到通道注意力权重。这些权重与原始特征图相乘,实现通道维度的特征重标定。

空间注意力模块则关注特征图的空间位置信息。它首先对特征图进行通道维度的平均池化和最大池化操作,生成两个空间描述子。这两个描述子经过一个卷积层进行融合,再通过Sigmoid函数得到空间注意力权重。这些权重与原始特征图相乘,实现对空间位置的特征重标定。

CBAM模块可以轻松地嵌入到现有的卷积神经网络架构中,如ResNet、VGG等,通过增强模型的注意力能力,提升其在图像分类、目标检测等任务上的性能。同时,CBAM还具有良好的可解释性,有助于理解模型在决策过程中的关注点,为深度学习模型的可视化和解释提供了有力的工具

一、通道注意力

相同视角下,取不同的池化值,然后就是通道的压缩及扩展,最后通过sigmoid得到最终权重。其实这里模式很像SE,但也不能照搬SE,所以用了个多分支。

class ChannelAttentionModule(nn.Module):
    def __init__(self, channel, reduction=16):
        super(ChannelAttentionModule, self).__init__()
        mid_channel = channel // reduction
        # 使用自适应池化缩减map的大小,保持通道不变
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  #(1) 表示输出的高度和宽度都被设置为 1。
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.shared_MLP = nn.Sequential(
            nn.Linear(in_features=channel, out_features=mid_channel),
            nn.ReLU(),
            nn.Linear(in_features=mid_channel, out_features=channel)
        )
        self.sigmoid = nn.Sigmoid()
        # self.act=SiLU()

    def forward(self, x):
        avgout = self.shared_MLP(self.avg_pool(x).view(x.size(0), -1)).unsqueeze(2).unsqueeze(3)
        maxout = self.shared_MLP(self.max_pool(x).view(x.size(0), -1)).unsqueeze(2).unsqueeze(3)
        return self.sigmoid(avgout + maxout)

二、空间注意力

 

 刚才那个是通道上的压缩及放缩,那这里就是空间特征图上,依然采用两种池化方式。

# 空间注意力模块
class SpatialAttentionModule(nn.Module):
    def __init__(self):
        super(SpatialAttentionModule, self).__init__()
        self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)
        # self.act=SiLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # map尺寸不变,缩减通道
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avgout, maxout], dim=1)
        out = self.sigmoid(self.conv2d(out))
        return out

该论文采用的创新,从两个不同的视角建立注意力,这个出发点还是不错的,但实际比如在小数据集上的效果怎么样,那么就需要你自己斟酌了。

三、CBAM_ResNet(Pytorch)

# ------------------------#
# CBAM模块的Pytorch实现
# ------------------------#

# 通道注意力模块
import torch.nn as nn
import torch
class ChannelAttentionModule(nn.Module):
    def __init__(self, channel, reduction=16):
        super(ChannelAttentionModule, self).__init__()
        mid_channel = channel // reduction
        # 使用自适应池化缩减map的大小,保持通道不变
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  #(1) 表示输出的高度和宽度都被设置为 1。
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.shared_MLP = nn.Sequential(
            nn.Linear(in_features=channel, out_features=mid_channel),
            nn.ReLU(),
            nn.Linear(in_features=mid_channel, out_features=channel)
        )
        self.sigmoid = nn.Sigmoid()
        # self.act=SiLU()

    def forward(self, x):
        avgout = self.shared_MLP(self.avg_pool(x).view(x.size(0), -1)).unsqueeze(2).unsqueeze(3)
        maxout = self.shared_MLP(self.max_pool(x).view(x.size(0), -1)).unsqueeze(2).unsqueeze(3)
        return self.sigmoid(avgout + maxout)


# 空间注意力模块
class SpatialAttentionModule(nn.Module):
    def __init__(self):
        super(SpatialAttentionModule, self).__init__()
        self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)
        # self.act=SiLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # map尺寸不变,缩减通道
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avgout, maxout], dim=1)
        out = self.sigmoid(self.conv2d(out))
        return out


# CBAM模块
class CBAM(nn.Module):
    def __init__(self, channel):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttentionModule(channel)
        self.spatial_attention = SpatialAttentionModule()

    def forward(self, x):
        out = self.channel_attention(x) * x
        out = self.spatial_attention(out) * out
        return out
from CBAM import CBAM
import torch
import torch.nn as nn
from torch.hub import load_state_dict_from_url
from torchvision.models import ResNet


def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class CBAMBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None,
                 *, reduction=16):
        # 参数列表里的 * 星号,标志着位置参数的就此终结,之后的那些参数,都只能以关键字形式来指定。
        super(CBAMBasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes, 1)
        self.bn2 = nn.BatchNorm2d(planes)
        self.cbam = CBAM(planes, reduction)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.cbam(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class CBAMBottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None,
                 *, reduction=16):
        # 参数列表里的 * 星号,标志着位置参数的就此终结,之后的那些参数,都只能以关键字形式来指定。
        super(CBAMBottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.cbam = CBAM(planes * 4, reduction)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        out = self.cbam(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


def cbam_resnet18(num_classes=1_000):
    model = ResNet(CBAMBasicBlock, [2, 2, 2, 2], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model


def cbam_resnet34(num_classes=1_000):
    model = ResNet(CBAMBasicBlock, [3, 4, 6, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model


def cbam_resnet50(num_classes=1_000, pretrained=False):
    model = ResNet(CBAMBottleneck, [3, 4, 6, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    if pretrained:
        model.load_state_dict(load_state_dict_from_url(
            "https://github.com/moskomule/senet.pytorch/releases/download/archive/seresnet50-60a8950a85b2b.pkl"))
    return model


def cbam_resnet101(num_classes=1_000):
    model = ResNet(CBAMBottleneck, [3, 4, 23, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model


def cbam_resnet152(num_classes=1_000):
    model = ResNet(CBAMBottleneck, [3, 8, 36, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model


if __name__ == "__main__":
    inputs = torch.randn(2, 3, 224, 224)
    model = cbam_resnet50(pretrained=False)
    # outputs = model(inputs)
    # print(outputs.size())
    print(model)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/482351.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

算法思想总结:模拟算法

一、模拟算法的总结 1、本质:比葫芦画瓢 2、特点:思路较简单,根据题目要求即可,代码量和细节较多 3、解决方法: (1) 模拟算法流程,在草稿纸上进行演算 (2)…

GAMMA数据处理问题(七)

phase_sim_orb报这个错是什么原因呢,说是我的hgt文件和模拟的干涉图行数不匹配,之前geocode生成hgt的参数不是在mli.par文件中看吗,为什么会出现行数不匹配的情况啊,难道不是par文件中里面看???…

【C++庖丁解牛】二叉搜索树(Binary Search Tree,BST)

🍁你好,我是 RO-BERRY 📗 致力于C、C、数据结构、TCP/IP、数据库等等一系列知识 🎄感谢你的陪伴与支持 ,故事既有了开头,就要画上一个完美的句号,让我们一起加油 目录 1. 二叉搜索树概念2. 二叉…

结构体内存对齐 offsetof 枚举 联合体

文章目录 结构体结构体内存对齐结构体嵌套结构体内存对齐的原因修改默认对齐数设置默认对齐数 #pragma pack() offsetof() 是宏 offset偏移量 of是谁的偏移量。计算结构体成员相对于结构体的起始位置偏移量是几。 结构体传参值传递地址传递 位段枚举联合 联合体 共用体联合体大…

【JS】深度学习JavaScript

💓 博客主页:从零开始的-CodeNinja之路 ⏩ 收录文章:【JS】深度学习JavaScript 🎉欢迎大家点赞👍评论📝收藏⭐文章 目录 一:JavaScript1.1 JavaScript是什么1.2 JS的引入方式1.3 JS变量1.4 数据类型1.5 …

LeetCode 热题 100 | 堆(二)

目录 1 什么是优先队列 1.1 优先队列与堆的关系 1.2 如何定义优先队列 1.3 如何使用优先队列 1.4 如何设置排序规则 2 347. 前 K 个高频元素 2.1 第 2 步的具体实现 2.2 举例说明 2.3 完整代码 3 215. 数组中的第 K 个最大元素 - v2 菜鸟做题,语…

【漏洞复现】科立讯通信指挥调度平台editemedia.php sql注入漏洞

漏洞描述 在20240318之前的福建科立讯通信指挥调度平台中发现了一个漏洞。该漏洞被归类为关键级别,影响文件/api/client/editemedia.php的未知部分。通过操纵参数number/enterprise_uuid可导致SQL注入。攻击可能会远程发起。 免责声明 技术文章仅供参考,任何个人和组织使…

2024公认口碑最好的洗地机有哪些?若看重清洁力,这四款最值得买

每当我们要清洁卫生时,是否总是感到腰酸背痛、疲劳不堪,甚至头昏眼花?地板是家中的重要门面,不容忽视的卫生焦点。如今,我们终于多了一位家务打扫的救星——家用洗地地机。一次操作,即可完成扫地除尘、地除…

Git 分布式版本控制系统基本概念和操作命令

目录 Git 基本概念 功能特点 工作流程 操作命令 新建代码库 配置 增删文件 代码提交 分支 标签 查看信息 远程同步 撤销 其他 小结 Git Git 是一个开源的分布式版本控制系统,用于跟踪文件的变更历史。它最初由 Linux Torvalds 设计,用于…

1+x中级题目练习复盘(八)

SQL 语句中进行 group by 分组时,可以不写 where 子句 在使用 select 语句进行查询分组时,如果希望去掉不满足条件的分组,使用 having 子句File 类的 isDirectory() 方法可以判断文件是否为目录 在使用 select 语句进行查询分组时&#xff0…

二.寄存器

1. 2. 例如:h即为high(高位),l即为low(低位) 3.一个字是两个字节 4.在写一条汇编指令或一个寄存器的名称时不区分大小写。 5.al,ah,ax在接受汇编指令时,并不相等&…

33-Java服务定位器模式 (Service Locator Pattern)

Java服务定位器模式 实现范例 服务定位器模式(Service Locator Pattern)用于想使用 JNDI 查询定位各种服务的时候考虑到为某个服务查找 JNDI 的代价很高,服务定位器模式充分利用了缓存技术在首次请求某个服务时,服务定位器在 JNDI…

十三、MySQL基于GTID的半同步复制

目录 一、MySQL半同步复制 一、三种复制方式比较 1、异步复制 2、同步复制 3、半同步复制 4、半同步复制比较 5、半同步复制的特点 二、搭建半同步复制 1、如果不清楚Plugin的目录,用如下查找: 2、所有数据库服务器,安装半同步插件…

【Go实现】实践GoF的23种设计模式:解释器模式

上一篇:【Go实现】实践GoF的23种设计模式:适配器模式 简单的分布式应用系统(示例代码工程):https://github.com/ruanrunxue/Practice-Design-Pattern–Go-Implementation 简介 解释器模式(Interpreter Pat…

【STM32嵌入式系统设计与开发】——6矩阵按键应用(4x4)

这里写目录标题 一、任务描述二、任务实施1、SingleKey工程文件夹创建2、函数编辑(1)主函数编辑(2)LED IO初始化函数(LED_Init())(3)开发板矩阵键盘IO初始化(ExpKeyBordInit())&…

JVM本地方法

本地方法接口 NAtive Method就是一个java调用非java代码的接口 本地方法栈(Native Method Statck) Java虚拟机栈用于管理Java方法的调用,而本地方法栈用于管理本地方法的调用。 本地方法栈,也是线程私有的。 允许被实现成固定或…

Matlab|基于分布式ADMM算法的考虑碳排放交易的电力系统优化调度研究

目录 1 主要内容 目标函数 计算步骤 节点系统 2 部分代码 3 程序结果 4 下载链接 1 主要内容 程序完全复现文献《A Distributed Dual Consensus ADMM Based on Partition for DC-DOPF with Carbon Emission Trading》,建立了一个考虑碳排放交易的最优模型&am…

【测试开发学习历程】MySQL分组查询与子查询 + MySQL表的联结操作

目录 1 MySQL分组查询与子查询 1.1 数据分组查询 1.2 过滤分组 1.3 分组结果排序 1.4 select语句中子句的执行顺序 1.5 子查询 2 MySQL表的联结操作 2.1 关系表 2.2 表联结 2.3 笛卡尔积 2.4 内部联结 2.5 外联结 2.6 自联结 2.7 组合查询 1 MySQL分组查询与子查询…

Java学习路线一条龙

说在前面 讲真,虽然我是正规计算机专业出身,但十多年来,Java这语言和它那一大堆配套的工具、框架,变化得太快了。 我也是一边学新的,一边扔旧的,忙得不可开交。 现在回想起来,走过的弯路、浪费…

2024年【危险化学品经营单位安全管理人员】新版试题及危险化学品经营单位安全管理人员模拟考试题

题库来源:安全生产模拟考试一点通公众号小程序 危险化学品经营单位安全管理人员新版试题考前必练!安全生产模拟考试一点通每个月更新危险化学品经营单位安全管理人员模拟考试题题目及答案!多做几遍,其实通过危险化学品经营单位安…
最新文章