SE通道注意力机制模块

简介

论文原址:https://arxiv.org/pdf/1709.01507.pdf

在深度学习领域,提升模型的表征能力一直是一个关键的研究方向。SE(Squeeze-and-Excitation)模块是一种引入通道注意力机制的方法,旨在让神经网络更加关注对当前任务重要的特征。本文将介绍SE模块的原理、实现以及在深度学习中的应用。

SE模块结构

SE模块主要的两个操作就是Squeeze与Excitation操作。

首先是Squeeze操作,通过聚合跨空间维度(H × W)的特征映射来产生通道描述符,怎么理解呢?

假设有一个输入的特征映射,它的维度是H × W × C,对于每个通道,执行全局平均池化操作,具体来说,对于第i个通道,计算该通道上所有空间位置的平均值。可以通过对每个通道的所有元素取平均来实现,其中每个元素表示相应通道上的平均值,在对输入特征映射中的所有通道,都执行上述的操作,就会得到C个平均值,形成一个长度为C的向量。由此,Squeeze操作将输入的H × W × C特征映射压缩成一个C维的向量,其中每个元素表示对应通道上的平均值。这个向量可以被看作是对整个特征映射在通道维度上的"描述符"。


接着是Excitation操作,它是SE模块中的第二步,其主要目的是对上一步Squeeze得到的通道描述符进行加权,以强调重要的通道。

接着是将Squeeze得到的C维向量输入到一个全连接层(FC层),其目的是学习通道权重。这个全连接层包括一个非线性激活函数,如ReLU,以引入非线性变换。通过学习,全连接层得到的通道权重经过一个Sigmoid激活函数,将其范围限制在0到1之间。学到的权重可以被视为每个通道的激活程度。最后,将原始的C维向量与这学到的通道权重相乘,得到加权后的向量。这个加权后的向量反映了每个通道在任务中的重要性和贡献程度。这一过程使得网络能够动态调整通道的关注度,从而更有效地利用信息来完成任务。

"""
Original paper addresshttps: https://arxiv.org/pdf/1709.01507.pdf
Code originates from https://github.com/moskomule/senet.pytorch/blob/master/senet/se_module.py
"""
import torch
from torch import nn

class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)


if __name__ == '__main__':
    input = torch.randn(1, 64, 64, 64)
    se = SELayer(channel=64, reduction=8)
    output = se(input)
    print(output.shape)

关于这一部分的理解,我也是看了许多博主的博客,很多都是一步带过,我也是综合着才理解到原来是这样的,最后会过头来再看代码似乎有一点明悟,我的理解可以与其对的上。

这里还涉及到了一个超参数reduction,指的是 SE 模块的缩减比例。据作者论文里面说的是16是一个效果比较好的值,它表示的是通道数缩减的比例为 1/16。这个比例用于控制全连接层的通道数缩减,从而影响 SE 模块的参数量和计算复杂度,通过调整 reduction 参数,可以灵活控制通道数的缩减程度。

SE-Resnet网络

由于我也是刚学完这部分的知识,应用上面还需要进行实验,所以,还是来看看原作者提供的torch版本的网络。

import torch.nn as nn
from torchvision.models import ResNet

__all__ = ["se_resnet18", "se_resnet34", "se_resnet50", "se_resnet101", "se_resnet152"]

class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class SEBasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None,
                 *, reduction=16):
        super(SEBasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes, 1)
        self.bn2 = nn.BatchNorm2d(planes)
        self.se = SELayer(planes, reduction)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.se(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class SEBottleneck(nn.Module):
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None,
                 *, reduction=16):
        super(SEBottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.se = SELayer(planes * 4, reduction)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        out = self.se(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


def se_resnet18(num_classes):
    model = ResNet(SEBasicBlock, [2, 2, 2, 2], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model


def se_resnet34(num_classes):
    model = ResNet(SEBasicBlock, [3, 4, 6, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model


def se_resnet50(num_classes):
    model = ResNet(SEBottleneck, [3, 4, 6, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model


def se_resnet101(num_classes):
    model = ResNet(SEBottleneck, [3, 4, 23, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model


def se_resnet152(num_classes):
    model = ResNet(SEBottleneck, [3, 8, 36, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model

if __name__=="__main__":
    import torchsummary
    import torch
    input = torch.ones(2, 3, 224, 224).cuda()
    net = se_resnet50(num_classes=4)
    net = net.cuda()
    out = net(input)
    print(out)
    print(out.shape)
    torchsummary.summary(net, input_size=(3, 224, 224))
    # Total params: 26,031,172

这里面基本上是与torchvision中的源码相同的,只是对原本的BasicBlock与Bottleneck在最后的卷积层后加入了SE模块。

ResNet与SE-ResNet分类性能比较实验

这一部分我是在自己的数据集上进行分类的,一共四个类,每个类别大约在300张左右,ResNet添加了官方的预训练权重,SE-Resnet是重头开始训练。(本来是预计都跑50轮的,但是在跑ResNet的时候忘记了)

下面是ResNet50的训练损失记录:

下面是SE_ResNet50的训练损失记录:

这里还记录了se_resnet的错误率记录,因为是后加的功能,所以resnet50没有。

这里出现了过拟合的情况,不过也没关系,本身这个数据量在这里,这里我也只是想要测试它的性能效果而已。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/345561.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

我每天如何使用 ChatGPT

我们都清楚互联网的运作方式——充斥着各种“爆款观点”,极端分裂的意见,恶搞和无知现象屡见不鲜。 最近,大家对于人工智能(AI)特别是大语言模型(LLMs)和生成式 AI(GenAI&#xff0…

【趣味游戏-08】20240123点兵点将点到谁就是谁(列表倒置reverse)

背景需求: 上个月,看到大4班一个孩子在玩“点兵点将点到谁就是谁”的小游戏,他在桌上摆放两排奥特曼卡片,然后点着数“点兵点将点到谁就是谁”,第10次点击的卡片,拿起来与同伴的卡片进行交换。他是从第一排…

Unity-Arduino Bluetooth Plugin蓝牙插件使用时需要注意的一些事项(附插件下载链接)

一些参考链接 1.Android 无法扫描蓝牙设备踩坑 2.权限相关 1-首先要明确你的蓝牙设备是经典蓝牙还是低功耗(BLE)蓝牙: 转载:Android蓝牙开发—经典蓝牙和BLE(低功耗)蓝牙的区别 2.如果是BLE蓝牙,需要打勾…

what is `ContentCachingRequestWrapper` does?

ContentCachingRequestWrapper 是 Spring Framework 中提供的一种包装类,它扩展了 HttpServletRequestWrapper 类,用于缓存请求体的内容。 通常在处理 HTTP 请求时,原生的 HttpServletRequest 对象中的输入流 (getInputStream()) 只能被读取一…

SpringBoot-多数据源切换和事物处理(免费)

作者原始文章: SpringBoot-多数据源切换和事物处理 最新内容和改动请看上面的文章 安装 <dependency><groupId>com.gitee.huanminabc</groupId><artifactId>dynamic-datasource</artifactId><version>1.0.3-RELEASE</version> <…

【经验分享】豆瓣小组的文章/帖子怎么删除?

#豆瓣小组的文章/帖子怎么删除&#xff1f;# 第一步&#xff1a; 手机登录豆瓣app ↓ 点右下角“我” ↓ 然后在页面点击我的小组 ↓ 点我发布的 ↓ ↓ 再任意点开一个帖子 ↓ 在文章和帖子的右上角有一个笔状的图标&#xff0c;切记不是右上角的横三点… ↓ ↓ 最后点下边的…

git 对象压缩及垃圾对象清理

git 对象压缩及垃圾对象清理 这篇文章让我们来看看 git 的对象压缩机制&#xff0c;前面的几篇文章我们提到&#xff0c;在执行 git add 命令会会把文件先通过 zlib 压缩后放入到「暂存区」&#xff0c;我们先看看这个步骤&#xff1a; 我们这个实例中有一个 1.28m 的 index.…

6.php开发-个人博客项目Tp框架路由访问安全写法历史漏洞

目录 知识点 php框架——TP URL访问 Index.php-放在控制器目录下 ​编辑 Test.php--要继承一下 带参数的—————— 加入数据库代码 --不过滤 --自己写过滤 --手册&#xff08;官方&#xff09;的过滤 用TP框架找漏洞&#xff1a; 如何判断网站是thinkphp&#x…

​比特币大跌的 2 个原因

撰文&#xff1a;秦晋 原文来自Techub News&#xff1a;​比特币大跌的 2 个原因 比特币迎来大跌&#xff01;1 月 23 日凌晨&#xff0c;比特币跌破 40000 美元&#xff0c;为去年 12 月 4 日以来首次&#xff0c;日内跌超 3%。这是自 1 月 10 日美国证监会审批通过 11 只比…

Conda python运行的包和环境管理 入门

Conda系列&#xff1a; 翻译: Anaconda 与 miniconda的区别Miniconda介绍以及安装 Conda 是一个功能强大的命令行工具&#xff0c;用于在 Windows、macOS 和 Linux 上运行的包和环境管理。 本 conda 入门指南介绍了启动和使用 conda 创建环境和安装包的基础知识。 1. 准备…

学习推荐!!HTML5+CSS3从入门到精通

获取方式&#xff1a; 《HTML 5 CSS3从入门到精通》 《HTML5CSS3从入门到精通目录》 第1章 Web开发新时代 第2章 从HTML、XHTML到HTML5 第3章 创建HTML5文档 第4章 实战HTML5表单 第5章 实战HTML5画布 第6章 HTML5音频与视频 第7章 Web存储 第8章 离线应用 第9章 Workers多线…

K8S的HPA

horiztal Pod Autoscaling&#xff1a;pod的水平自动伸缩&#xff0c;这是k8s自带的模块&#xff0c;它是根据Pod占用cpu比率到达一定的阀值&#xff0c;会触发伸缩机制 Replication controller 副本控制器&#xff1a;控制pod的副本数 Deployment controller 节点控制器&…

从开发、部署到维护:SAAS与源代码小程序的全流程对比

在数字化时代&#xff0c;小程序已成为企业开展业务的重要工具。然而&#xff0c;小程序开发过程中存在多种形式&#xff0c;其中SAAS版本小程序和源代码小程序是最常见的两种。乔拓云SaaS系统作为业界领先的SaaS服务平台&#xff0c;为企业提供高效、便捷的小程序解决方案。与…

echarts 去掉x轴或y轴中的刻度线(分割x轴数值的线)

解决方法 将 xAxis 或者 yAxis 对象下的 axisTick 属性配置 show: false&#xff0c;代码如下&#xff1a; xAxis: {type: category,data: [Mon, Tue, Wed, Thu, Fri, Sat, Sun],//添加以下配置axisTick: { show: false} },效果

【大数据】Flink 系统架构

Flink 系统架构 1.Flink 组件1.1 JobManager1.2 ResourceManager1.3 TaskManager1.4 Dispatcher 2.应用部署2.1 框架模式2.2 库模式 3.任务执行4.高可用设置4.1 TaskManager 故障4.2 JobManager 故障 Flink 是一个用于状态化并行流处理的分布式系统。它的搭建涉及多个进程&…

大厂咋做多系统数据同步方案的?

1 背景 业务线与系统越来越多&#xff0c;系统或业务间数据同步需求也越频繁。当前互联网业务系统大多MySQL数据存储与处理方案&#xff1a; 随信息时代爆炸&#xff0c;大数据量场景下慢慢凸显短板&#xff0c;如&#xff1a;需对大量数据全文检索&#xff0c;对大量数据组合…

java web mvc-02-struts2

拓展阅读 Spring Web MVC-00-重学 mvc mvc-01-Model-View-Controller 概览 web mvc-03-JFinal web mvc-04-Apache Wicket web mvc-05-JSF JavaServer Faces web mvc-06-play framework intro web mvc-07-Vaadin web mvc-08-Grails Struts2 Apache Struts是一个用于创…

Mybatis四大组件

一、Mybatis四大组件 SqlSessionFactoryBuild、SqlSessionFactory、SqlSession、Mapper。 二、SqlSession四大对象 Executor、StatementHandler、ParameterHandler、ResultSetHandler。 这里阐述一下上图的流程 Exeutor发起sql执行任务 1、先调用statementHandler中的pre…

【昕宝爸爸小模块】深入浅出之为什么POI的SXSSFWorkbook占用内存更小

➡️博客首页 https://blog.csdn.net/Java_Yangxiaoyuan 欢迎优秀的你&#x1f44d;点赞、&#x1f5c2;️收藏、加❤️关注哦。 本文章CSDN首发&#xff0c;欢迎转载&#xff0c;要注明出处哦&#xff01; 先感谢优秀的你能认真的看完本文&…

C++ -- 入门(引用)

1.引用 1.1引用的概念 引用不是新定义一个变量&#xff0c;而是给已存在变量取了一个别名&#xff0c;编译器不会为引用变量开辟内存空间&#xff0c;它和它引用的变量共用同一块内存空间。 比如&#xff1a;李逵&#xff0c;在家称为"铁牛"&#xff0c;江湖上人称&q…