【论文笔记】MCANet: Medical Image Segmentation withMulti-Scale Cross-Axis Attention

        医疗图像分割任务中,捕获多尺度信息、构建长期依赖对分割结果有非常大的影响。该论文提出了 Multi-scale Cross-axis Attention(MCA)模块,融合了多尺度特征,并使用Attention提取全局上下文信息。

论文地址:MCANet: Medical Image Segmentation with Multi-Scale Cross-Axis Attention

代码地址:https://github.com/haoshao-nku/medical_seg

一、MCA(Multi-scale Cross-axis Attention)

MCA的结构如下,将E2/3/4通过concat连接起来(concat前先插值到同样分辨率),经过1x1的卷积后(压缩通道数来降低计算量),得到了包含多尺度信息的特征图F,然后在X和Y方向使用不同大小的卷积核进行卷积运算(比如1x11的卷积是x方向,11x1的是y方向,这里可以对着代码看,容易理解),将Q在X和Y方向交换后(这就是Cross-Axis),经过注意力模块后,将多个特征图相加,并融合E1,经过卷积后得到输出。该模块有以下特点:

1、注意力机制作用在多个不同尺度的特征图;

2、Multi-Scale x-Axis Convolution和Multi-Scale y-Axis Convolution分别关注不同轴的特征,在计算注意力时交叉计算,使得不同方向的特征都能被关注到。

MCA细节如下图,输入特征图进入x和y方向的路径,经过不同大小的卷积后进行融合,然后跨轴(x和y轴的Q交换)计算Attention,最后得到输出特征图。

二、代码

MCA的代码如下所示,总体来说比较简单:

from audioop import bias
from pip import main
import torch
import torch.nn as nn
import torch.nn.functional as F
import numbers
from mmseg.registry import MODELS
from einops import rearrange
from ..utils import resize
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmseg.models.decode_heads.decode_head import BaseDecodeHead


def to_3d(x):
    return rearrange(x, 'b c h w -> b (h w) c')

def to_4d(x,h,w):
    return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)

class BiasFree_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(BiasFree_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.normalized_shape = normalized_shape

    def forward(self, x):
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return x / torch.sqrt(sigma+1e-5) * self.weight

class WithBias_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(WithBias_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.normalized_shape = normalized_shape

    def forward(self, x):
        mu = x.mean(-1, keepdim=True)
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias

class LayerNorm(nn.Module):
    def __init__(self, dim, LayerNorm_type):
        super(LayerNorm, self).__init__()
        if LayerNorm_type =='BiasFree':
            self.body = BiasFree_LayerNorm(dim)
        else:
            self.body = WithBias_LayerNorm(dim)

    def forward(self, x):
        h, w = x.shape[-2:]
        return to_4d(self.body(to_3d(x)), h, w)
    
class Attention(nn.Module):
    def __init__(self, dim, num_heads,LayerNorm_type,):
        super(Attention, self).__init__()
        self.num_heads = num_heads   
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))   

        self.norm1 = LayerNorm(dim, LayerNorm_type)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1)      
        self.conv0_1 = nn.Conv2d(dim, dim, (1, 7), padding=(0, 3), groups=dim)
        self.conv0_2 = nn.Conv2d(dim, dim, (7, 1), padding=(3, 0), groups=dim)
        self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)
        self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)
        self.conv2_1 = nn.Conv2d(
            dim, dim, (1, 21), padding=(0, 10), groups=dim)
        self.conv2_2 = nn.Conv2d(
            dim, dim, (21, 1), padding=(10, 0), groups=dim)

    def forward(self, x):
        b,c,h,w = x.shape   
        x1 = self.norm1(x)
        attn_00 = self.conv0_1(x1)
        attn_01= self.conv0_2(x1)  
        attn_10 = self.conv1_1(x1)
        attn_11 = self.conv1_2(x1)
        attn_20 = self.conv2_1(x1)
        attn_21 = self.conv2_2(x1)   
        out1 = attn_00+attn_10+attn_20
        out2 = attn_01+attn_11+attn_21   
        out1 = self.project_out(out1)
        out2 = self.project_out(out2)  
        k1 = rearrange(out1, 'b (head c) h w -> b head h (w c)', head=self.num_heads)
        v1 = rearrange(out1, 'b (head c) h w -> b head h (w c)', head=self.num_heads)
        k2 = rearrange(out2, 'b (head c) h w -> b head w (h c)', head=self.num_heads)
        v2 = rearrange(out2, 'b (head c) h w -> b head w (h c)', head=self.num_heads)   
        q2 = rearrange(out1, 'b (head c) h w -> b head w (h c)', head=self.num_heads) 
        q1 = rearrange(out2, 'b (head c) h w -> b head h (w c)', head=self.num_heads)       
        q1 = torch.nn.functional.normalize(q1, dim=-1)
        q2 = torch.nn.functional.normalize(q2, dim=-1)
        k1 = torch.nn.functional.normalize(k1, dim=-1)
        k2 = torch.nn.functional.normalize(k2, dim=-1)          
        attn1 = (q1 @ k1.transpose(-2, -1))
        attn1 = attn1.softmax(dim=-1)   
        out3 = (attn1 @ v1) + q1      
        attn2 = (q2 @ k2.transpose(-2, -1))
        attn2 = attn2.softmax(dim=-1)   
        out4 = (attn2 @ v2) + q2                         
        out3 = rearrange(out3, 'b head h (w c) -> b (head c) h w', head=self.num_heads, h=h, w=w)
        out4 = rearrange(out4, 'b head w (h c) -> b (head c) h w', head=self.num_heads, h=h, w=w)       
        out =  self.project_out(out3)  + self.project_out(out4) + x
          
        return out

@MODELS.register_module()
class MCAHead(BaseDecodeHead):
    def __init__(self,in_channels,image_size,heads,c1_channels,
                 **kwargs):
        super(MCAHead, self).__init__(in_channels,input_transform = 'multiple_select',**kwargs)
        self.image_size = image_size
        self.decoder_level = Attention(in_channels[1],heads,LayerNorm_type = 'WithBias')
        self.align = ConvModule(
            in_channels[3],
            in_channels[0],
            1,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg)
        self.squeeze = ConvModule(
            sum((in_channels[1],in_channels[2],in_channels[3])),
            in_channels[1],
            1,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg)
        self.sep_bottleneck = nn.Sequential(
            DepthwiseSeparableConvModule(
                in_channels[1] + in_channels[0],
                in_channels[3],
                3,
                padding=1,
                norm_cfg=self.norm_cfg,
                act_cfg=self.act_cfg),
            DepthwiseSeparableConvModule(
                in_channels[3],
                in_channels[3],
                3,
                padding=1,
                norm_cfg=self.norm_cfg,
                act_cfg=self.act_cfg))             
    def forward(self, inputs):
        """Forward function."""
        inputs = self._transform_inputs(inputs)
        inputs = [resize(
                level,
                size=self.image_size,
                mode='bilinear',
                align_corners=self.align_corners
            ) for level in inputs]
        y1 = torch.cat([inputs[1],inputs[2],inputs[3]], dim=1)
        x = self.squeeze(y1)  
        x = self.decoder_level(x)
        x = torch.cat([x,inputs[0]], dim=1) 
        x = self.sep_bottleneck(x)
        
        output = self.align(x)  
        output = self.cls_seg(output)
        return output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/258992.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

融云观察:给 ChatGPT 加上声音和脸庞,AI 社交的多模态试验

(👆点击获取行业首款《社交泛娱乐出海作战地图》) 如果将短剧的爆火简单粗暴地归因为剧情上头、狗血反转和精妙卡点,那 GenAI 世界这一年来可以说是一部短剧 Live Show。关注【融云全球互联网通信云】了解更多 这厢 Open AI 宫斗…

别再唱衰Python了,未来十年Python的“霸榜”地位依旧!

直接说结论!不管唱反调的人怎么唱衰,Python 在下一个十年仍然十分重要,并且依旧会与时俱进。 我们都知道 Python 是一门了不起的编程语言,它改变了编程的游戏规则,将编程的格局提升到了一个完全不同的层次。 Python 的…

【C++初阶】学习string类的模拟实现

目录 前言:一、创建文件和类二、实现string类2.1 私有成员和构造函数2.2 析构函数2.3 拷贝构造函数2.3.1 写法12.3.2 写法2 2.4 赋值重载函数2.4.1 写法12.4.2 写法2 2.5 迭代器遍历访问2.6 下标遍历访问2.7 reserve2.8 resize2.9 判空和清理2.10 尾插2.10.1 尾插字…

HTML CSS 进度条

1 原生HTML标签 <meter>&#xff1a;显示已知范围的标量值或者分数值<progress>&#xff1a;显示一项任务的完成进度&#xff0c;通常情况下&#xff0c;该元素都显示为一个进度条 1.1 <meter> <html><head><style>meter{width:200px;}…

新能源车企年底冲刺KPI,只能抓住“价格战”做文章?

新能源汽车行业的价格战似乎看不到尽头。 自特斯拉吹响号角后&#xff0c;今年以来&#xff0c;业内已然开启了几轮颇具规模的价格战。 如今进入年底&#xff0c;价格战不仅没有消停&#xff0c;还愈打愈烈。据不完全统计&#xff0c;12月&#xff0c;已有20多家车企先后开启…

Nginx快速入门:访问日志access.log参数详解 |访问日志记录自定义请求头(三)

0. 引言 在企业的生产环境中&#xff0c;我们时常需要通过nginx的访问日志来统计流量、排查调用问题等&#xff0c;而nginx默认的日志格式所包含的信息远无法满足我们使用&#xff0c;因此常常需要对日志进行自定义&#xff0c;所以今天我们就来看如何自定义nginx的访问日志格…

anaconda 安装 使用 pytorch onnx onnxruntime

一&#xff1a;安装 如果不是 x86_64&#xff0c;需要去镜像看对应的版本 安装 Anaconda 输入命令 bash Anaconda3-2021.11-Linux-x86_64.sh 然后输入 yes 表示同意 确认安装的路径&#xff0c;一般直接回车安装在默认的 /home/你的名字/anaconda3 很快就安装完毕。输入 yes…

星融元中标华夏银行项目,助力金融数据中心可视网建设工作

近日&#xff0c;星融元成功入围华夏银行国产品牌网络流量汇聚分流器&#xff08;TAP&#xff09;设备供应商&#xff0c;在助力头部金融机构构建数据中心可视网络的建设工作中&#xff0c;星融元又一次获得全国性股份制银行客户的青睐。 华夏银行作为全国性股份制商业银行积极…

如何在Ubuntu系统中安装VNC并结合内网穿透实现远程访问桌面

文章目录 前言1. ubuntu安装VNC2. 设置vnc开机启动3. windows 安装VNC viewer连接工具4. 内网穿透4.1 安装cpolar【支持使用一键脚本命令安装】4.2 创建隧道映射4.3 测试公网远程访问 5. 配置固定TCP地址5.1 保留一个固定的公网TCP端口地址5.2 配置固定公网TCP端口地址5.3 测试…

3d云渲染动画、效果图的速度,对比本地电脑渲染速度区别

与使用个人电脑进行渲染相比&#xff0c;3D云渲染服务擁有其无可比拟的优势。云端的服务器配置通常超出个人电脑&#xff0c;具有更强大的运算力和多任务并行处理的能力&#xff0c;使得同时执行多个渲染作业成为可能。这一点在处理图形复杂度高和数据量巨大的渲染项目时尤为显…

CEC2013(python):五种算法(OOA、WOA、GWO、DBO、HHO)求解CEC2013(python代码)

一、五种算法简介 1、鱼鹰优化算法OOA 2、鲸鱼优化算法WOA 3、灰狼优化算法GWO 4、蜣螂优化算法DBO 5、哈里斯鹰优化算法HHO 二、5种算法求解CEC2013 &#xff08;1&#xff09;CEC2013简介 参考文献&#xff1a; [1] Liang J J , Qu B Y , Suganthan P N , et al. Pro…

图片编辑文字用什么软件?带你了解这5个

图片编辑文字用什么软件&#xff1f;在当今数字化的时代&#xff0c;图片编辑已经成为我们日常生活中不可或缺的一部分。有时候&#xff0c;我们需要在图片上添加文字&#xff0c;以增强图片的视觉效果或传达特定的信息。那么&#xff0c;有哪些可以在图片上编辑文字的软件呢&a…

Java数据结构-模拟ArrayList集合思想,手写底层源码(1),底层数据结构是数组,编写add添加方法,正序打印和倒叙打印

package com.atguigu.structure; public class Demo02_arrayList {public static void main(String[] args) {MyGenericArrayListV1 arrayListV1 new MyGenericArrayListV1();//arr.add(element:100,index:1);下标越界&#xff0c;无法插入//初始化&#xff08;第一次添加&…

Spring Cloud Gateway请求路径修改指南:详解ServerWebExchange的完美解决方案及代码示例

&#x1f337;&#x1f341; 博主猫头虎 带您 Go to New World.✨&#x1f341; &#x1f984; 博客首页——猫头虎的博客&#x1f390; &#x1f433;《面试题大全专栏》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33a; &a…

python实现贪吃蛇游戏

文章目录 1、项目说明2、项目预览3、开发必备4、贪吃蛇代码实现4.1、窗口和基本参数实现4.2、绘制背景4.3、绘制墙壁4.4、绘制贪吃蛇4.5、绘制食物4.6、实现长度信息显示4.7、定义游戏暂停界面4.8、定义贪吃蛇死亡界面4.9、实现贪吃蛇碰撞效果4.10、实现添加食物功能4.11、实现…

jQuery —— 自定义四位数验证弹框

在提交表单发送请求前&#xff0c;想要校验下&#xff0c;但不想用第三方插件&#xff0c;就自己写了个自定义数字校验码弹框&#xff0c;更稳定些&#xff0c;样式有点low&#xff0c;记录下。 没什么硬性要求的话&#xff0c;可以使用第三方插件&#xff0c;会方便许多样式也…

SQL学习笔记+MySQL+SQLyog工具教程

文章目录 1、前言2、SQL基本语言及其操作2.1、CREATE TABLE – 创建表2.2、DROP TABLE – 删除表2.3、INSERT – 插入数据2.4、SELECT – 查询数据2.5、SELECTDISTINCT – 去除重复值后查询数据2.6、SELECTWHERE – 条件过滤2.7、AND & OR – 运算符2.8、ORDER BY – 排序2…

科研院校和研究所都在用功率放大器做哪些实验

科研院校和研究所在科研工作中常常使用功率放大器进行实验。功率放大器是一种电子设备&#xff0c;其主要功能是将输入信号的功率增加到预定的输出功率水平&#xff0c;并保持信号的波形不失真。它在各个学科领域都有广泛的应用&#xff0c;包括通信、无线电、雷达、生物医学等…

Mac安装Nginx

一起学习 1、确认你的电脑是否安装homebrew&#xff0c;打开电脑终端 输入&#xff1a; /usr/bin/ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)"2、确认homebrew是否安装成功&#xff0c;在终端输入&#xff1a; br…

Linux中使用HTTP协议进行网络通信的示例——你的“网络信使”

大家好&#xff0c;今天我们要聊聊在Linux中如何使用HTTP协议进行网络通信。听起来有点高大上&#xff0c;但其实并不难&#xff0c;让我们一起来看看&#xff01; 首先&#xff0c;我们要明白HTTP协议是什么。HTTP&#xff0c;全名为超文本传输协议&#xff08;Hypertext Tra…
最新文章