CBAM注意力机制(结构图加逐行代码注释讲解)

学CBAM前建议先学会SEnet(因为本篇涉及SEnet的重合部分会略加带过)->传送门

⒈结构图

下面这个是自绘的,有些许草率。。。

因为CBAM机制是由通道和空间两部分组成的,所以有这两个模块(左边是通道注意力机制,右边是空间注意力机制)

下面这两个是官方论文里的:

⒉机制流程讲解

SEnet只关注了通道注意力机制而忽略了空间上的一些简单特征,相比之下,CBAM将通道注意力机制和空间注意力机制进行一个结合,对输入进来的特征层,分别进行通道注意力机制的处理和空间注意力机制的处理,而是是先通道后空间,也就是第一张结构图表达的意思。

①首先是通道机制:

对于输入特征层,分别作全局最大池化和全局平均池化,输出结果分别送入一个共享全连接层(官方源码在这里和SEnet的全连接层一模一样),为什么叫共享全连接层?因为最大池化和平均池化的两条路线用的是这同一个全连接层。然后对两个结果(maxout和avgout)做加法,最后进行归一化操作,获得通道上的权重矩阵。

②然后是空间机制:

对于输入特征层,在每一个特征点的通道上取最大值和平均值,(这里和通道机制的最大池化和平均池化完全不同,通道机制里是在H、W两个维度求最大或平均,空间机制是在C一个维度上求最大和平均。)然后对两个结果(maxout和avgout)做拼接,也就是maxout的1*H*W与avgout的1*H*W进行拼接,得到2*H*W的张量,因此紧接着下一步就要进行一个7*7的卷积(conv)将通道压缩回1,最后还是进行归一化操作,获得空间上的权重矩阵。

③整体上:

对于输入特征层,输入特征层先乘上通道机制的输出权重(channel_out),然后再乘上空间上的输出权重(spatial_out)

⒊源码(pytorch框架实现)及逐行解释

import torch
from torch import nn
from torchsummary import summary
 
class ChannelModule(nn.Module):
    def __init__(self, inputs, ratio=16):
        super(ChannelModule, self).__init__()
        _, c, _, _ = inputs.size()
        self.maxpool = nn.AdaptiveMaxPool2d(1)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.share_liner = nn.Sequential(
            nn.Linear(c, c // ratio),
            nn.ReLU(),
            nn.Linear(c // ratio, c)
        )
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, inputs):
        x = self.maxpool(inputs).view(inputs.size(0), -1)#nc
        maxout = self.share_liner(x).unsqueeze(2).unsqueeze(3)#nchw
        y = self.avgpool(inputs).view(inputs.size(0), -1)
        avgout = self.share_liner(y).unsqueeze(2).unsqueeze(3)
        return self.sigmoid(maxout + avgout)
 
 
class SpatialModule(nn.Module):
    def __init__(self):
        super(SpatialModule, self).__init__()
        self.maxpool = torch.max
        self.avgpool = torch.mean
        self.concat = torch.cat
        self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, inputs):
        maxout, _ = self.maxpool(inputs, dim=1, keepdim=True)#n1hw
        avgout = self.avgpool(inputs, dim=1, keepdim=True)#n1hw
        outs = self.concat([maxout, avgout], dim=1)#n2hw
        outs = self.conv(outs)#n1hw
        return self.sigmoid(outs)
 
 
class CBAM(nn.Module):
    def __init__(self, inputs):
        super(CBAM, self).__init__()
        self.channel_out = ChannelModule(inputs)
        self.spatial_out = SpatialModule()
 
    def forward(self, inputs):
        outs = self.channel_out(inputs) * inputs
        return self.spatial_out(outs) * outs

 解释:

①依赖包和SEnet解释的一样。

②整体上看,将通道机制和空间机制分别封装成类,再封装一个CBAM类来对这两个机制调用,其中用到的__init__构造方法(python称魔术方法)和foward函数(前向传播过程),这些模板和上面介绍SEnet时是一模一样的。

先来看通道机制:

class ChannelModule(nn.Module):#继承nn模块的Module类
    def __init__(self, inputs, ratio=16):#self必写,inputs接收输入特征张量,ratio是通道衰减因子
        super(ChannelModule, self).__init__()#调用父类构造
        _, c, _, _ = inputs.size()#获取通道数
        self.maxpool = nn.AdaptiveMaxPool2d(1)#nn模块的自适应二维最大池化
        self.avgpool = nn.AdaptiveAvgPool2d(1)#nn模块的自适应二维平均池化
        self.share_liner = nn.Sequential(
            nn.Linear(c, c // ratio),
            nn.ReLU(),
            nn.Linear(c // ratio, c)
        )#这个共享全连接的3层和SEnet的一模一样,这里借助Sequential这个容器把这3个层整合在一起,方便forward函数去执行,直接调用share_liner(x)相当于直接执行了里面这3层
        self.sigmoid = nn.Sigmoid()#nn模块的Sigmoid函数
 
    def forward(self, inputs):
        x = self.maxpool(inputs).view(inputs.size(0), -1)#对于输入特征张量,做完最大池化后再重塑形状,view的第一个参数inputs.size(0)表示第一维度,显然就是n;-1表示会自适应的调整剩余的维度,在这里就将原来的(n,c,1,1)调整为了(n,c*1*1),后面才能送入全连接层(fc层)
        maxout = self.share_liner(x).unsqueeze(2).unsqueeze(3)#做完全连接后,再用unsqueeze解压缩,也就是还原指定维度,这里用了两次,分别还原2维度的h,和3维度的w
        y = self.avgpool(inputs).view(inputs.size(0), -1)
        avgout = self.share_liner(y).unsqueeze(2).unsqueeze(3)#y走的平均池化路线的代码和x是一样的解释
        return self.sigmoid(maxout + avgout)#最后相加两个结果并作归一化

再来看空间机制:(重复的模板就不再反复赘述了)

class SpatialModule(nn.Module):
    def __init__(self):
        super(SpatialModule, self).__init__()
        self.maxpool = torch.max
        self.avgpool = torch.mean
        #和通道机制不一样!这里要进行的是在C这一个维度上求最大和平均,分别用的是torch库里的max方法和mean方法
        self.concat = torch.cat#torch的cat方法,用于拼接两个张量
        self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)#nn模块的二维卷积,其中的参数分别是:输入通道(2),输出通道(1),卷积核大小(7*7),步长(1),灰度填充(3)
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, inputs):
        maxout, _ = self.maxpool(inputs, dim=1, keepdim=True)#maxout接收特征点的最大值很好理解,为什么还要一个占位符?因为torch.max不仅返回张量最大值,还会返回索引,索引用不着所以直接忽略,dim=1表示在维度1(也就是nchw的c)上求最大值,keepdim=True表示要保持原来张量的形状
        avgout = self.avgpool(inputs, dim=1, keepdim=True)#torch.mean则只返回张量的平均值,至于参数的解释和上面是一样的
        outs = self.concat([maxout, avgout], dim=1)#torch.cat方法,传入一个列表,将列表中的张量在指定维度,这里是维度1(也就是nchw的c)拼接,即n*1*h*w拼接n*1*h*w得到n*2*h*w
        outs = self.conv(outs)#卷积压缩上面的n*2*h*w,又得到n*1*h*w
        return self.sigmoid(outs)

 最后看整体:

class CBAM(nn.Module):
    def __init__(self, inputs):
        super(CBAM, self).__init__()
        self.channel_out = ChannelModule(inputs)#获得通道权重
        self.spatial_out = SpatialModule()#获得空间权重
 
    def forward(self, inputs):
        outs = self.channel_out(inputs) * inputs #先乘上通道权重
        return self.spatial_out(outs) * outs #在乘完通道权重的基础上再乘上空间权重

⒋测试结果

大问题没有,但还是少了一些关键层,尤其是空间机制那里的拼接maxout和avgout,通道变为2再用卷积压缩回1的过程都没体现。。。只能说summary确实不太好使,或者说我没用对?网络层简写导致的?(最不可能是这个原因,因为我拿官方的源码测试也是summary出这些结果)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/165415.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【MATLAB】全网唯一的11种信号分解+模糊熵(近似熵)联合算法全家桶

有意向获取代码,请转文末观看代码获取方式~ 大家吃一顿火锅的价格便可以拥有18种信号分解算法,绝对不亏,知识付费是现今时代的趋势,而且都是我精心制作的教程,有问题可随时反馈~也可单独获取某一算法的代码&#xff0…

【LeetCode刷题-树】--654.最大二叉树

654.最大二叉树 给定一个不重复的整数数组 nums 。 最大二叉树 可以用下面的算法从 nums 递归地构建: 1.创建一个根节点,其值为 nums 中的最大值。 2.递归地在最大值 左边 的 子数组前缀上 构建左子树。 3.递归地在最大值 右边 的 子数组后缀上 构建右子树。 返回…

【LeetCode刷题-树】--998.最大二叉树II

998.最大二叉树II /*** Definition for a binary tree node.* public class TreeNode {* int val;* TreeNode left;* TreeNode right;* TreeNode() {}* TreeNode(int val) { this.val val; }* TreeNode(int val, TreeNode left, TreeNode right) {* …

8.1 Windows驱动开发:内核文件读写系列函数

在应用层下的文件操作只需要调用微软应用层下的API函数及C库标准函数即可,而如果在内核中读写文件则应用层的API显然是无法被使用的,内核层需要使用内核专有API,某些应用层下的API只需要增加Zw开头即可在内核中使用,例如本章要讲解…

SQL单表复杂查询where、group by、order by、limit

1.1SQL查询代码如下&#xff1a; select job as 工作类别,count(job) as 人数 from tb_emp where entrydate <2015-01-01 group by job having count(job) > 2 order by count(job) limit 1,1where entrydate <‘2015-01-01’ 表示查询日期小于2015-01-01的记录…

Python大数据之linux学习总结——day11_ZooKeeper

ZooKeeper ZK概述 ZooKeeper概念: Zookeeper是一个分布式协调服务的开源框架。本质上是一个分布式的小文件存储系统 ZooKeeper作用: 主要用来解决分布式集群中应用系统的一致性问题。 ZooKeeper结构: 采用树形层次结构&#xff0c;ZooKeeper树中的每个节点被称为—Znode。且树…

随机过程-张灏

文章目录 导论随机过程相关 学习视频来源&#xff1a;https://www.bilibili.com/video/BV18p4y1u7NP?p1&vd_source5e8340c2f2373cdec3870278aedd8ca4 将持续更新—— 第一次更新&#xff1a;2023-11-19 导论 教材&#xff1a;《随机过程及其应用》陆大絟.张颢 参考&…

X3DAudio1_7.dll丢失原因,X3DAudio1_7.dll丢失怎样解决分享

X3DAudio1_7.dll是一款由微软公司开发的音频处理库&#xff0c;主要用于实现三维音频效果。这个库主要应用于游戏开发、多媒体应用等领域&#xff0c;它可以使得音频更加真实、自然地表现出空间感。如果在使用过程中遇到X3DAudio1_7.dll丢失的问题&#xff0c;可以尝试以下五个…

指南:关于帮助中心需要注意的一些细节

在现代商业环境中&#xff0c;帮助中心已经成为企业提供客户支持和解决问题的重要方式之一。然而&#xff0c;建立一个高效的帮助中心并不简单。除了选择合适的软件平台和工具之外&#xff0c;还需要注意一些细节&#xff0c;以确保能够真正帮助客户并提高客户满意度。 | 1.设计…

辅助解决小白遇到的电脑各种问题

写这个纯属是为了让电脑小白知道一些电脑上的简单操作&#xff0c;勿喷&#xff01;&#xff01;&#xff01; 一&#xff1a;当小白遇到电脑程序不完全退出怎么办&#xff1f; 使用软件默认的退出方式 此处拿百度网盘举例&#xff1a; 用户登录网盘后&#xff1a; 如果直接点…

spring常见面试题总结

1、spring是什么 Spring&#xff1a;是一个轻量级的IOC和AOP的java开发框架&#xff0c;为了简化企业级开发而生。核心就是控制反转和面向切面编程。 IOC&#xff1a;控制反转&#xff08;Inverse of Control&#xff09;&#xff0c;以前项目都是在哪儿用到对象 在哪儿new&a…

Vite - 静态资源处理 - json文件导入

直接就说明白了 vite 中对json 文件直接当作一个模块来解析。 可以直接导入使用&#xff01; 可以直接导入使用&#xff01; 可以直接导入使用&#xff01;json文件中的key&#xff0c;直接被作为一个属性&#xff0c;可以单独被导入。 因此&#xff0c;导入json文件有两种方式…

Rust与其他语言对比:优势在哪里?

大家好&#xff01;我是lincyang。 今天&#xff0c;我们将深入探讨Rust语言与其他编程语言比较的优势&#xff0c;并通过具体的代码示例和性能数据来加深理解。 Rust与其他语言的比较 1. 内存安全性 Rust&#xff1a;采用所有权系统&#xff0c;编译器在编译时检查内存安全…

自动驾驶学习笔记(十)——Cyber通信

#Apollo开发者# 学习课程的传送门如下&#xff0c;当您也准备学习自动驾驶时&#xff0c;可以和我一同前往&#xff1a; 《自动驾驶新人之旅》免费课程—> 传送门 《Apollo Beta宣讲和线下沙龙》免费报名—>传送门 文章目录 前言 Cyber通信 编写代码 编译程序 运行…

AI绘画使用Stable Diffusion(SDXL)绘制三星堆风格的图片

一、前言 三星堆文化是一种古老的中国文化&#xff0c;它以其精湛的青铜铸造技术闻名&#xff0c;出土文物中最著名的包括青铜面具、青铜人像、金杖、玉器等。这些文物具有独特的艺术风格&#xff0c;显示了高度的工艺水平和复杂的社会结构。 青铜面具的巨大眼睛和突出的颧骨&a…

asp.net mvc点餐系统餐厅管理系统

1. 主要功能 ① 管理员、收银员、厨师的登录 ② 管理员查看、添加、删除菜品类型 ③ 管理员查看、添加、删除菜品&#xff0c;对菜品信息进行简介和封面的修改 ④ 收银员浏览、搜索菜品&#xff0c;加入购物车后进行结算&#xff0c;生成订单 ⑤ 厨师查看待完成菜品信息…

漫谈广告机制设计 | 万剑归宗:聊聊广告机制设计与收入提升的秘密(3)

​书接上文漫谈广告机制设计 | 万剑归宗&#xff1a;聊聊广告机制设计与收入提升的秘密&#xff08;2&#xff09;&#xff0c;我们聊到囚徒困境是完全信息静态博弈&#xff0c;参与人存在占优策略&#xff0c;最终达到占优均衡&#xff0c;并且是对称占优均衡。接下来我们继续…

Jmeter——结合Allure展示测试报告

在平时用jmeter做测试时&#xff0c;生成报告的模板&#xff0c;不是特别好。大家应该也知道allure报告&#xff0c;页面美观。 先来看效果图&#xff0c;报告首页&#xff0c;如下所示&#xff1a; 报告详情信息&#xff0c;如下所示&#xff1a; 运行run.py文件&#xff0c;…

多功能神器,强劲升级,太极2.x你值得拥有!

嗨&#xff0c;大家好&#xff0c;今天给大家分享一个好用好玩的软件。那就是太极2.x软件&#xff0c;最近在1.0版本上进行了全新升级&#xff0c;升级后的功能更强更稳定&#xff0c;轻度用户使用基本功能就已经足够了&#xff0c;我们一起来看看吧&#xff01; 首页 首页左…

8.5 Windows驱动开发:内核注册表增删改查

注册表是Windows中的一个重要的数据库&#xff0c;用于存储系统和应用程序的设置信息&#xff0c;注册表是一个巨大的树形结构&#xff0c;无论在应用层还是内核层操作注册表都有独立的API函数可以使用&#xff0c;而在内核中读写注册表则需要使用内核装用API函数&#xff0c;如…
最新文章