Swin-UMamba—基于 Mamba 的 UNet 和基于 ImageNet 的预训练—论文精读和代码实践

Swin-UMamba

  • 期刊分析
    • 摘要
    • 贡献
    • 方法
      • Swin-UMamba整体框架
      • 1. 基于 Mamba 的 VSS 块
      • 2. 集成基于 ImageNet 的预训练
      • 3. Swin-UMamba解码器
      • 4. Swin-UMamba†:带有基于 Mamba 的解码器的 Swin-UMamba
    • 实验
    • 代码实践
  • 可借鉴参考

期刊分析

Swin-UMamba只是名字中含有swin,和swin没有一点关系,代码中也无swin相关内容
发布 arXiv 2024.2.5
代码:https://github.com/JiarunLiu/Swin-UMamba

摘要

准确的医学图像分割需要集成多尺度信息从局部特征到全局依赖性。然而,现有方法对远程全局信息进行建模具有挑战性,其中卷积神经网络(CNN)受到其局部感受野的限制,而视觉变换器(ViT)则受到其注意力机制的高二次复杂度的影响。最近,基于 Mamba 的模型因其在长序列建模方面令人印象深刻的能力而受到极大关注。多项研究表明,这些模型在各种任务中都可以优于流行的视觉模型,提供更高的准确性、更低的内存消耗和更少的计算负担。然而,现有的基于 Mamba 的模型大多是从头开始训练的,并没有探索预训练的力量,而预训练已被证明对于数据高效的医学图像分析非常有效。本文介绍了一种基于 Mamba 的新型模型 Swin-UMamba,该模型专为医学图像分割任务而设计,利用了基于 ImageNet 的预训练的优势。我们的实验结果揭示了基于 ImageNet 的训练在增强基于 Mamba 的模型性能方面的重要作用。与 CNN、ViT 和最新的基于 Mamba 的模型相比,Swin-UMamba 表现出了巨大的优越性能。值得注意的是,在腹部 MRI、肠镜检查和显微镜检查数据集上,Swin-UMamba 的平均得分比最接近的 U-Mamba 好 3.58%。


贡献

  1. 据我们所知,我们是首次尝试发现预训练的基于 Mamba 的网络在医学图像分割任务中的影响。我们的实验验证了基于 ImageNet 的预训练在基于 Mamba 的网络的医学图像分割中发挥着重要作用,有时至关重要。
  2. 我们提出了一种基于 Mamba 的新网络,名为 Swin-UMamba,用于医学图像分割,该网络是专门为统一预训练模型的功能而设计的。此外,我们提出了一种变体结构 SwinUMamba†,具有更少的网络参数和更低的 FLOPs,同时保持有竞争力的性能。
  3. 我们的结果表明,Swin-UMamba 和 Swin-UMamba† 都可以优于以前的分割模型,包括 CNN、ViT 和最新的基于 Mamba 的模型,并且具有显着的优势,凸显了基于 ImageNet 的预训练和所提出的架构在医学图像分割任务中的有效性。

方法

Swin-UMamba整体框架

在这里插入图片描述
其中蓝色框中就是加载的预训练权重!

1. 基于 Mamba 的 VSS 块

VMamba中解释图:
在这里插入图片描述
VM-UNet中解释图:
在这里插入图片描述

1. VSS块在VM-UNet中有着详细的解释,不同与Vit的扫描方式,其按照四种方向进行拆分并合并。
2. 本文提出的Swin-UMamba是在编码器部分采用了四层VSS;而提出的SwinUMamba†是在解码器部分也采用了VSS

2. 集成基于 ImageNet 的预训练

上面先介绍了VSS基础模块,接着介绍Swin-UMamba中如何使用VSS的,以及如何将预训练模型插入到网络中
1. Swin-UMamba的编码器可以分为5个阶段。第一阶段是茎阶段**(stem stage)**。它包含一个用于 2×下采样的卷积层,内核为 7×7,填充大小为 3,步幅大小为 2。在卷积层之后采用了 实例归一化。 Swin-UMamba 的第一阶段与 VMamba 不同,因为我们更喜欢渐进的下采样过程,其中每个阶段进行 2× 下采样。第二阶段使用补丁大小为 2 × 2 的补丁嵌入层,将特征分辨率保持在原始图像的 1/4×,这与 VMamba 中的嵌入特征相同。后续阶段遵循 VMamba-Tiny 的设计,其中每个阶段由用于 2×下采样的补丁合并层和用于高级特征提取的多个 VSS 块组成。与 ViT 不同,由于 VSS 块的因果性质,我们没有采用 Swin-UMamba 中的位置嵌入[24]
2. stage-2到stage-5的VSS块的数量分别为{2,2,9,2}。每个阶段之后的特征尺寸都会以二次方的方式增加。阶段,结果为 D = {48, 96, 192, 384, 768}。我们使用 VMamba-Tiny 中的 ImageNet 预训练权重来初始化 VSS 块和补丁合并层,如图 1 所示。值得注意的是,由于补丁大小和输入通道的差异,补丁嵌入块没有使用预训练权重进行初始化。

因为是按层进行预训练权重加载的,这和平时那种直接加载整个网络的权重不同,大大增加了理解代码的负担🐕

3. Swin-UMamba解码器

上面的1介绍了VSS块,2介绍了Swin-UMamba的编码器和预训练权重使用,这里的3介绍了解码器部分。
1. 带有残差连接的额外卷积块来处理跳过连接特征,这里作者是使用了monai框架中的UnetResBlock实现的,暂时没太搞懂。
2. 每个尺度的额外分割头用于深度监督。

4. Swin-UMamba†:带有基于 Mamba 的解码器的 Swin-UMamba

上面的1介绍了VSS块,2介绍了Swin-UMamba的编码器和预训练权重使用,3介绍了解码器部分,这里的4介绍的是Swin-Umamba变体
在这里插入图片描述
1. 我们使用 4×4 补丁嵌入层,直接将输入图像从 H ×W ×C 投影到形状为 H/4 ×H/4 × 96 的特征图中,遵循 VMamba [24]。值得注意的是,Swin-UMamba† 中的最后一个补丁扩展块是 4×上采样操作,镜像 4× 补丁嵌入层。剩余的patch扩展层进行2×上采样操作。源自输入图像的跳过连接和 Swin-UMamba 中的 2× 下采样特征被删除,因为它们没有相应的解码块。此外,深度监督应用于 1×、1/4 ×、1/8 × 和 1/16 × 的分辨率,并为每个尺度结合了额外的分割头(即,将高维特征映射到 K 的 1 × 1 卷积)。结合所有这些修改,网络参数的数量从 40M 减少到 27M,并且 AbdomenMRI 数据集上的 FLOPs 从 58.4G 减少到 15.0G。

实验

我们评估了 Swin-UMamba 在三个不同的医学图像分割数据集上的性能和可扩展性,包括器官分割仪器分割细胞分割。这些数据集是在各种分辨率和图像模式中选择的,可以深入了解模型在不同医学成像场景中的功效和适应性。

文中指出:我们的主要目标是评估预训练模型对医学图像分割的影响,而不是仅仅追求最先进的 (SOTA) 性能。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
1. 实验结果如表1、表2和表3所示,在AbdomenMRI(腹部MRI器官分割)、Endoscopy(内窥镜手术器械分割)和Microscopy(细胞分割)三个数据集上,Swin-UMamba和Swin-UMamba†相比于baseline方法均有显著的提升。
2. 作者发现ImageNet预训练对Swin-UMamba和Swin-UMamba†都起到了非常重要的作用,包括更高的分割精度、更稳定的收敛、减轻过拟合问题、数据效率和较低的计算开销
例如,相较于不使用ImageNet预训练,使用ImageNet预训练能够为Swin-UMamba在Endoscopy上的DSC带来13.08%的显著提升,这可能是因为Endoscopy数据集较小,导致模型更容易过拟合。
3. Swin-UMamba仅需baseline算法1/10的训练迭代次数就能够在AbdomenMRI数据上收敛,这有利于节省训练期间的计算开销。
4. ImageNet预训练有时对模型收敛稳定性起到至关重要的作用,**Swin-UMamba†**在没有使用预训练模型时很难在AbdomenMRI数据上正常收敛,而当使用预训练模型后,**Swin-UMamba†**又能够以较少的参数量和FLOPs超过所有baseline算法,这也进一步佐证了预训练对Mamba-based模型的重要性。

代码实践

  1. 环境安装
conda create -n your_env_name python=3.10.13
conda activate your_env_name
conda install cudatoolkit==11.8 -c nvidia
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
conda install packaging
pip install causal-conv1d==1.1.1
pip install mamba-ssm

  1. 代码实践


import re
import time
import math
import numpy as np
from functools import partial
from typing import Optional, Union, Type, List, Tuple, Callable, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from einops import rearrange, repeat
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"

from monai.networks.blocks.dynunet_block import UnetOutBlock
from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrUpBlock


class PatchEmbed2D(nn.Module):
    r""" Image to Patch Embedding
    Args:
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """
    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs):
        super().__init__()
        if isinstance(patch_size, int):
            patch_size = (patch_size, patch_size)
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        x = self.proj(x).permute(0, 2, 3, 1)
        if self.norm is not None:
            x = self.norm(x)
        return x


class PatchMerging2D(nn.Module):
    r""" Patch Merging Layer.
    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        B, H, W, C = x.shape

        SHAPE_FIX = [-1, -1]
        if (W % 2 != 0) or (H % 2 != 0):
            print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True)
            SHAPE_FIX[0] = H // 2
            SHAPE_FIX[1] = W // 2

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C

        if SHAPE_FIX[0] > 0:
            x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
            x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
            x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
            x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
        
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, H//2, W//2, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x


class SS2D(nn.Module):
    def __init__(
        self,
        d_model,
        d_state=16,
        d_conv=3,
        expand=2,
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        dropout=0.,
        conv_bias=True,
        bias=False,
        device=None,
        dtype=None,
        **kwargs,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank

        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
        self.conv2d = nn.Conv2d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            groups=self.d_inner,
            bias=conv_bias,
            kernel_size=d_conv,
            padding=(d_conv - 1) // 2,
            **factory_kwargs,
        )
        self.act = nn.SiLU()

        self.x_proj = (
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 
        )
        self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner)
        del self.x_proj

        self.dt_projs = (
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
        )
        self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank)
        self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner)
        del self.dt_projs
        
        self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N)
        self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N)

        self.selective_scan = selective_scan_fn

        self.out_norm = nn.LayerNorm(self.d_inner)
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
        self.dropout = nn.Dropout(dropout) if dropout > 0. else None

    @staticmethod
    def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs):
        dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)

        # Initialize special dt projection to preserve variance at initialization
        dt_init_std = dt_rank**-0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
        dt = torch.exp(
            torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            dt_proj.bias.copy_(inv_dt)
        # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
        dt_proj.bias._no_reinit = True
        
        return dt_proj

    @staticmethod
    def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
        # S4D real initialization
        A = repeat(
            torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=d_inner,
        ).contiguous()
        A_log = torch.log(A)  # Keep A_log in fp32
        if copies > 1:
            A_log = repeat(A_log, "d n -> r d n", r=copies)
            if merge:
                A_log = A_log.flatten(0, 1)
        A_log = nn.Parameter(A_log)
        A_log._no_weight_decay = True
        return A_log

    @staticmethod
    def D_init(d_inner, copies=1, device=None, merge=True):
        # D "skip" parameter
        D = torch.ones(d_inner, device=device)
        if copies > 1:
            D = repeat(D, "n1 -> r n1", r=copies)
            if merge:
                D = D.flatten(0, 1)
        D = nn.Parameter(D)  # Keep in fp32
        D._no_weight_decay = True
        return D

    def forward_core(self, x: torch.Tensor):
        B, C, H, W = x.shape
        L = H * W
        K = 4

        x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
        xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)

        x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
        dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
        dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)

        xs = xs.float().view(B, -1, L) # (b, k * d, l)
        dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
        Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
        Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
        Ds = self.Ds.float().view(-1) # (k * d)
        As = -torch.exp(self.A_logs.float()).view(-1, self.d_state)  # (k * d, d_state)
        dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)

        out_y = self.selective_scan(
            xs, dts, 
            As, Bs, Cs, Ds, z=None,
            delta_bias=dt_projs_bias,
            delta_softplus=True,
            return_last_state=False,
        ).view(B, K, -1, L)
        assert out_y.dtype == torch.float

        inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
        wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
        invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)

        return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y

    def forward(self, x: torch.Tensor, **kwargs):
        B, H, W, C = x.shape

        xz = self.in_proj(x)
        x, z = xz.chunk(2, dim=-1) # (b, h, w, d)

        x = x.permute(0, 3, 1, 2).contiguous()
        x = self.act(self.conv2d(x)) # (b, d, h, w)
        y1, y2, y3, y4 = self.forward_core(x)
        assert y1.dtype == torch.float32
        y = y1 + y2 + y3 + y4
        y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
        y = self.out_norm(y)
        y = y * F.silu(z)
        out = self.out_proj(y)
        if self.dropout is not None:
            out = self.dropout(out)
        return out


class VSSBlock(nn.Module):
    def __init__(
        self,
        hidden_dim: int = 0,
        drop_path: float = 0,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
        attn_drop_rate: float = 0,
        d_state: int = 16,
        **kwargs,
    ):
        super().__init__()
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = SS2D(d_model=hidden_dim, dropout=attn_drop_rate, d_state=d_state, **kwargs)
        self.drop_path = DropPath(drop_path)

    def forward(self, input: torch.Tensor):
        x = input + self.drop_path(self.self_attention(self.ln_1(input)))
        return x


class VSSLayer(nn.Module):
    """ A basic layer for one stage.
    Args:
        dim (int): Number of input channels.
        depth (int): Number of blocks.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(
        self, 
        dim, 
        depth, 
        attn_drop=0.,
        drop_path=0., 
        norm_layer=nn.LayerNorm, 
        downsample=None, 
        use_checkpoint=False, 
        d_state=16,
        **kwargs,
    ):
        super().__init__()
        self.dim = dim
        self.use_checkpoint = use_checkpoint

        self.blocks = nn.ModuleList([
            VSSBlock(
                hidden_dim=dim,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer,
                attn_drop_rate=attn_drop,
                d_state=d_state,
            )
            for i in range(depth)])
        
        if True: # is this really applied? Yes, but been overriden later in VSSM!
            def _init_weights(module: nn.Module):
                for name, p in module.named_parameters():
                    if name in ["out_proj.weight"]:
                        p = p.clone().detach_() # fake init, just to keep the seed ....
                        nn.init.kaiming_uniform_(p, a=math.sqrt(5))
            self.apply(_init_weights)

        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None


    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        
        if self.downsample is not None:
            x = self.downsample(x)

        return x


class VSSMEncoder(nn.Module):
    def __init__(self, patch_size=4, in_chans=3, depths=[2, 2, 9, 2], 
                 dims=[96, 192, 384, 768], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2,
                 norm_layer=nn.LayerNorm, patch_norm=True,
                 use_checkpoint=False, **kwargs):
        super().__init__()
        self.num_layers = len(depths)
        if isinstance(dims, int):
            dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]
        self.embed_dim = dims[0]
        self.num_features = dims[-1]
        self.dims = dims

        self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim,
            norm_layer=norm_layer if patch_norm else None)

        # WASTED absolute position embedding ======================
        self.ape = False
        if self.ape:
            self.patches_resolution = self.patch_embed.patches_resolution
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, *self.patches_resolution, self.embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        self.layers = nn.ModuleList()
        self.downsamples = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = VSSLayer(
                dim=dims[i_layer],
                depth=depths[i_layer],
                d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109
                drop=drop_rate, 
                attn_drop=attn_drop_rate,
                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                norm_layer=norm_layer,
                downsample=None,
                use_checkpoint=use_checkpoint,
            )
            self.layers.append(layer)
            if i_layer < self.num_layers - 1:
                self.downsamples.append(PatchMerging2D(dim=dims[i_layer], norm_layer=norm_layer))

        self.apply(self._init_weights)

    def _init_weights(self, m: nn.Module):
        """
        out_proj.weight which is previously initilized in VSSBlock, would be cleared in nn.Linear
        no fc.weight found in the any of the model parameters
        no nn.Embedding found in the any of the model parameters
        so the thing is, VSSBlock initialization is useless
        
        Conv2D is not intialized !!!
        """
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    def forward(self, x):
        x_ret = []
        x_ret.append(x)

        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for s, layer in enumerate(self.layers):
            x = layer(x)
            x_ret.append(x.permute(0, 3, 1, 2))
            if s < len(self.downsamples):
                x = self.downsamples[s](x)

        return x_ret


class SwinUMamba(nn.Module):
    def __init__(
        self,
        in_chans=3,
        out_chans=1,
        feat_size=[48, 96, 192, 384],
        drop_path_rate=0,
        layer_scale_init_value=1e-6,
        hidden_size: int = 768,
        norm_name = "instance",
        res_block: bool = True,
        spatial_dims=2,
        deep_supervision: bool = False,
    ) -> None:
        super().__init__()

        self.hidden_size = hidden_size
        self.in_chans = in_chans
        self.out_chans = out_chans
        self.drop_path_rate = drop_path_rate
        self.feat_size = feat_size
        self.layer_scale_init_value = layer_scale_init_value

        self.stem = nn.Sequential(
              nn.Conv2d(in_chans, feat_size[0], kernel_size=7, stride=2, padding=3),
              nn.InstanceNorm2d(feat_size[0], eps=1e-5, affine=True),
        )
        self.spatial_dims = spatial_dims
        self.vssm_encoder = VSSMEncoder(patch_size=2, in_chans=feat_size[0])
        self.encoder1 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=self.in_chans,
            out_channels=self.feat_size[0],
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.encoder2 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=self.feat_size[0],
            out_channels=self.feat_size[1],
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.encoder3 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=self.feat_size[1],
            out_channels=self.feat_size[2],
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.encoder4 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=self.feat_size[2],
            out_channels=self.feat_size[3],
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=res_block,
        )

        self.encoder5 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=self.feat_size[3],
            out_channels=self.hidden_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=res_block,
        )

        self.decoder5 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=self.hidden_size,
            out_channels=self.feat_size[3],
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder4 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=self.feat_size[3],
            out_channels=self.feat_size[2],
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder3 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=self.feat_size[2],
            out_channels=self.feat_size[1],
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder2 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=self.feat_size[1],
            out_channels=self.feat_size[0],
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder1 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=self.feat_size[0],
            out_channels=self.feat_size[0],
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=res_block,
        )

        # deep supervision support
        self.deep_supervision = deep_supervision
        self.out_layers = nn.ModuleList()
        for i in range(4):
            self.out_layers.append(UnetOutBlock(
                spatial_dims=spatial_dims, 
                in_channels=self.feat_size[i], 
                out_channels=self.out_chans
            ))

    def forward(self, x_in):
        x1 = self.stem(x_in)
        vss_outs = self.vssm_encoder(x1)
        enc1 = self.encoder1(x_in)
        enc2 = self.encoder2(vss_outs[0])
        enc3 = self.encoder3(vss_outs[1])
        enc4 = self.encoder4(vss_outs[2])
        enc_hidden = self.encoder5(vss_outs[3])
        dec3 = self.decoder5(enc_hidden, enc4)
        dec2 = self.decoder4(dec3, enc3)
        dec1 = self.decoder3(dec2, enc2)
        dec0 = self.decoder2(dec1, enc1)
        dec_out = self.decoder1(dec0)

        if self.deep_supervision:
            feat_out = [dec_out, dec1, dec2, dec3]
            out = []
            for i in range(4):
                pred = self.out_layers[i](feat_out[i])
                out.append(pred)
        else:
            out = self.out_layers[0](dec_out)

        return out

    @torch.no_grad()
    def freeze_encoder(self):
        for name, param in self.vssm_encoder.named_parameters():
            if "patch_embed" not in name:
                param.requires_grad = False

    @torch.no_grad()
    def unfreeze_encoder(self):
        for param in self.vssm_encoder.parameters():
            param.requires_grad = True


# def load_pretrained_ckpt(
#     model, 
#     ckpt_path="./data/pretrained/vmamba/vmamba_tiny_e292.pth"
# ):
    
#     print(f"Loading weights from: {ckpt_path}")
#     skip_params = ["norm.weight", "norm.bias", "head.weight", "head.bias", 
#                    "patch_embed.proj.weight", "patch_embed.proj.bias", 
#                    "patch_embed.norm.weight", "patch_embed.norm.weight"]    

#     ckpt = torch.load(ckpt_path, map_location='cpu')
#     model_dict = model.state_dict()
#     for k, v in ckpt['model'].items():
#         if k in skip_params:
#             print(f"Skipping weights: {k}")
#             continue
#         kr = f"vssm_encoder.{k}"
#         if "downsample" in kr:
#             i_ds = int(re.findall(r"layers\.(\d+)\.downsample", kr)[0])
#             kr = kr.replace(f"layers.{i_ds}.downsample", f"downsamples.{i_ds}")
#             assert kr in model_dict.keys()
#         if kr in model_dict.keys():
#             assert v.shape == model_dict[kr].shape, f"Shape mismatch: {v.shape} vs {model_dict[kr].shape}"
#             model_dict[kr] = v
#         else:
#             print(f"Passing weights: {k}")
        
#     model.load_state_dict(model_dict)

#     return model


# def get_swin_umamba_from_plans(
#     plans_manager: int,
#     dataset_json: dict,
#     configuration_manager: None,
#     num_input_channels: int,
#     deep_supervision: bool = True,
#     use_pretrain: bool = True
# ):
    
#     label_manager = plans_manager.get_label_manager(dataset_json)

#     model = SwinUMamba(
#         in_chans=num_input_channels,
#         out_chans=3,
#         feat_size=[48, 96, 192, 384],
#         deep_supervision=deep_supervision,
#         hidden_size=768,
#     )

#     if use_pretrain:
#         model = load_pretrained_ckpt(model)

#     return model

x = torch.randn(12, 3, 256, 256)
net = SwinUMamba(3, 1)
print(net(x).shape)

可借鉴参考

  1. 阅读U-Mamba
    Jun Ma, Feifei Li, and Bo Wang. U-mamba: Enhancing long-range dependency for biomedical image segmentation. arXiv preprint arXiv:2401.04722, 2024.
  2. 阅读Vmamba
    Yue Liu, Yunjie Tian, Yuzhong Zhao, Hongtian Yu, Lingxi Xie, Yaowei Wang, Qixiang Ye, and Yunfan Liu. Vmamba: Visual state space model. arXiv preprint arXiv:2401.10166, 2024.
  3. 阅读Segmamba
    Zhaohu Xing, Tian Ye, Yijun Yang, Guang Liu, and Lei Zhu. Segmamba: Longrange sequential modeling mamba for 3d medical image segmentation. arXiv preprint arXiv:2401.13560, 2024.
  4. 阅读Visionmamba
    Lianghui Zhu, Bencheng Liao, Qian Zhang, Xinlong Wang, Wenyu Liu, and Xinggang Wang. Vision mamba: Efficient visual representation learning with bidirectional state space model. arXiv preprint arXiv:2401.09417, 2024.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/438348.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

为什么不从独立服务器中转向云或其他方案呢?

传统的专用服务器&#xff0c;如香港服务器租赁、重庆服务器租赁等&#xff0c;是最强大、最稳定的业务托管类型之一。您将获得比任何其他托管计划更多的管理权限和卓越的性能&#xff0c;并且您可以控制整个服务器上的几乎所有内容。 当然&#xff0c;专用服务器也是在线业务…

HTML5:七天学会基础动画网页9

在进行接下来的了解之前我们先来看一下3d的xyz轴&#xff0c;下面图中中间的平面就相当于电脑屏幕&#xff0c;z轴上是一个近大远小的效果。 3d转换属性 transform 2D或3D转换 transform-origin 改变旋转点位置 transform-style 嵌套元素在3D空间如何显 …

Eclipse调试技巧 条件断点 监视

实验代码 import java.util.Scanner;public class Test {// 判断n是否为质数public static boolean isPrime(int n){if (n < 2)return false;for (int i 2; i < n; i){if (n % i 0)return false;}return true;}public static void main(String[] args){Scanner scanne…

类初步认识与对象

一&#xff0c;对于面向对象的认识 Java是一门面向对象的语言&#xff0c;一切都可以称为对象。将一个大象装进冰箱&#xff0c;甭管步骤多复杂&#xff0c;大象便是对象&#xff1b;将牛奶放进冰箱&#xff0c;牛奶便是对象&#xff1b;你我均是对像。 再比如洗一个衣服&…

JavaScript——流程控制(程序结构)

JavaScript——流程控制&#xff08;程序结构&#xff09; 流程控制就是来控制我们的代码按照什么结构顺序来执行。更倾向于一种思想结构。 流程控制分为三大结构&#xff1a;顺序结构、分支结构、循环结构 1、顺序结构 ​ 代码从上往下依次执行&#xff0c;从A到B执行&#x…

Java毕业设计 基于SpringBoot vue 疫苗咨询与预约系统

Java毕业设计 基于SpringBoot vue 疫苗咨询与预约系统 SpringBoot vue 疫苗咨询与预约系统 功能介绍 用户前端&#xff1a;首页 图片轮播 疫苗信息 条件查询 疫苗详情 点我收藏 评论 接种疫苗 疫情资讯 资讯详情 资讯评论 论坛交流 发布帖子 公告信息 公告详情 留言反馈 登录…

arm架构服务器使用Virtual Machine Manager安装的kylin v10虚拟机

本文中使用Virtual Machine Manager安装kylin v10的虚拟机 新建虚拟机 新建虚拟机 选择镜像&#xff0c;下一步 设置内存和CPU&#xff0c;下一步 选择或创建自定义存储&#xff08;默认存储位置的磁盘空间可能不够用&#xff09; 点击管理&#xff0c;打开选择存储卷页…

[linux]shell脚本语言:变量、测试、控制语句以及函数的全面详解

一、shell的概述 1、shell本质是脚本文件&#xff1a;完成批处理。 shell脚本是一种脚本语言&#xff0c;我们只需使用任意文本编辑器&#xff0c;按照语法编写相应程序&#xff0c;增加可执行权限&#xff0c;即可在安装shell命令解释器的环境下执行。shell 脚本主要用于帮助开…

【软件测试】如何申请专利?

一、专利类型 在软件测试领域&#xff0c;可以申请发明专利、实用新型专利和外观设计专利。其中&#xff0c;发明专利是最常见的专利类型&#xff0c;它保护的是软件测试方法、系统和装置等技术方案。 二、申请专利的条件 申请专利需要满足新颖性、创造性和实用性三个条件。…

饮料换购 刷题笔记

直接开个计数器mask 每当饮料现存数-1&#xff1b; cnt;且mask; 一旦mask达到3 饮料现存数 计数器清零3 代码 #include <iostream> #include<cstdio> #include<algorithm> #include<cstring> using namespace std; int main(){ int n; …

【AIGC】如何提高Prompt准确度

前言 随着人工智能的迅猛进展&#xff0c;AIGC&#xff08;通用人工智能聊天工具&#xff09;已成为多个行业中不可或缺的自然语言处理技术。Prompt作为AIGC系统的一项关键功能&#xff0c;在工具的有效运作中发挥了举足轻重的作用。本篇文章将深入探讨Prompt与AIGC之间的紧密…

迭代器失效问题(C++)

迭代器失效就是迭代器指向的位置已经不是原来的含义了&#xff0c;或者是指向的位置是非法的。以下是失效的几种情况&#xff1a; 删除元素&#xff1a; 此处发生了迭代器的失效&#xff0c;因为erase返回的是下一个元素的位置的迭代器&#xff0c;所以在删除1这个元素的时候&…

SAP Parallel Accounting(平行分类账业务)配置及操作手册(超详细的说明和测试)

SAP Parallel Accounting(平行分类账业务)配置及操作手册 1、Overview 为了适应不同的会计准则&#xff0c;SAP在新总账中启用了多分类账&#xff0c;&#xff08;其作用简单来说就是&#xff0c;同时一笔记账&#xff0c;会产生多个账套的凭证。&#xff09;分类账可以对应一…

Python之Web开发中级教程----搭建SSH环境

Python之Web开发中级教程----搭建SSH环境 SSH 的全称是 “安全的 Shell(Secure Shell)”&#xff0c;它功能强大、效率高&#xff0c;这个主流的网络协议用于在两个远程终端之间建立连接。让我们不要忘记它名称的“安全”部分&#xff0c;SSH 会加密所有的通信流量&#xff0c…

C语言从入门到精通 第十二章(程序的编译及链接)

写在前面&#xff1a; 本系列专栏主要介绍C语言的相关知识&#xff0c;思路以下面的参考链接教程为主&#xff0c;大部分笔记也出自该教程。除了参考下面的链接教程以外&#xff0c;笔者还参考了其它的一些C语言教材&#xff0c;笔者认为重要的部分大多都会用粗体标注&#xf…

【学习笔记】数据结构与算法06 - 堆:上堆、下堆、Top-K问题以及代码实现

知识来源&#xff1a;https://www.hello-algo.com/chapter_heap/heap/#4 文章目录 2.5 堆2.5.1 堆&#xff08;优先队列2.5.1.1 堆的常用操作 2.5.2 堆的存储与表示2.5.2.1 访问堆顶元素2.5.2.2 入堆时间复杂度 2.5.2.3 堆顶元素出堆时间复杂度 2.5.3 堆的常见应用2.5.4 建堆问…

每日OJ题_牛客_井字棋

目录 牛客_井字棋 解析代码 牛客_井字棋 井字棋__牛客网 解析代码 class Board {public:bool checkWon(vector<vector<int> > board) {// 当前玩家是否胜出&#xff01;&#xff01;&#xff01;不是有玩家胜出int row board.size(), col board[0].size();fo…

vue 常用的 UI 组件库之一:Vuetify组件库

Vuetify是一个基于Vue.js 的Material Design组件库&#xff0c;它提供了一套完整的、预构建的、可自定义的、响应式的组件&#xff0c;以便开发者可以快速构建美观且功能强大的Web应用程序。Vuetify遵循Material Design设计指南&#xff0c;提供了一系列易于使用的组件&#xf…

【STM32】HAL库 CubeMX教程---基本定时器 定时

目录 一、基本定时器的作用 二、常用型号的TIM时钟频率 三、CubeMX配置 四、编写执行代码 实验目标&#xff1a; 通过CUbeMXHAL&#xff0c;配置TIM6&#xff0c;1s中断一次&#xff0c;闪烁LED。 一、基本定时器的作用 基本定时器&#xff0c;主要用于实现定时和计数功能…

Leetcode : 147. 对链表进行插入排序

给定单个链表的头 head &#xff0c;使用 插入排序 对链表进行排序&#xff0c;并返回 排序后链表的头 。 插入排序 算法的步骤: 插入排序是迭代的&#xff0c;每次只移动一个元素&#xff0c;直到所有元素可以形成一个有序的输出列表。 每次迭代中&#xff0c;插入排序只从输…