【论文阅读】RS-Mamba for Large Remote Sensing Image Dense Prediction(附Code)

论文作者提出了RS-Mamba(RSM)用于高分辨率遥感图像遥感的密集预测任务。RSM设计用于模拟具有线性复杂性的遥感图像的全局特征,使其能够有效地处理大型VHR图像。它采用全向选择性扫描模块,从多个方向对图像进行全局建模,从多个方向捕捉大的空间特征。

论文链接:https://arxiv.org/abs/2404.02668

code链接:https://github.com/walking-shadow/Official_Remote_Sensing_Mamba

2D全向扫描机制是本研究的主要创新点。作者考虑到遥感影像地物多方向的特点,在VMamba2D双向扫描机制的基础上增加了斜向扫描机制。

 以下是作者针对该部分进行改进的代码:

def antidiagonal_gather(tensor):
    # 取出矩阵所有反斜向的元素并拼接
    B, C, H, W = tensor.size()
    shift = torch.arange(H, device=tensor.device).unsqueeze(1)  # 创建一个列向量[H, 1]
    index = (torch.arange(W, device=tensor.device) - shift) % W  # 利用广播创建索引矩阵[H, W]
    # 扩展索引以适应B和C维度
    expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
    # 使用gather进行索引选择
    return tensor.gather(3, expanded_index).transpose(-1,-2).reshape(B, C, H*W)

def diagonal_gather(tensor):
    # 取出矩阵所有反斜向的元素并拼接
    B, C, H, W = tensor.size()
    shift = torch.arange(H, device=tensor.device).unsqueeze(1)  # 创建一个列向量[H, 1]
    index = (shift + torch.arange(W, device=tensor.device)) % W  # 利用广播创建索引矩阵[H, W]
    # 扩展索引以适应B和C维度
    expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
    # 使用gather进行索引选择
    return tensor.gather(3, expanded_index).transpose(-1,-2).reshape(B, C, H*W)

def diagonal_scatter(tensor_flat, original_shape):
    # 把斜向元素拼接起来的一维向量还原为最初的矩阵形式
    B, C, H, W = original_shape
    shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1)  # 创建一个列向量[H, 1]
    index = (shift + torch.arange(W, device=tensor_flat.device)) % W  # 利用广播创建索引矩阵[H, W]
    # 扩展索引以适应B和C维度
    expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
    # 创建一个空的张量来存储反向散布的结果
    result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype)
    # 将平铺的张量重新变形为[B, C, H, W],考虑到需要使用transpose将H和W调换
    tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2)
    # 使用scatter_根据expanded_index将元素放回原位
    result_tensor.scatter_(3, expanded_index, tensor_reshaped)
    return result_tensor

def antidiagonal_scatter(tensor_flat, original_shape):
    # 把反斜向元素拼接起来的一维向量还原为最初的矩阵形式
    B, C, H, W = original_shape
    shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1)  # 创建一个列向量[H, 1]
    index = (torch.arange(W, device=tensor_flat.device) - shift) % W  # 利用广播创建索引矩阵[H, W]
    expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
    # 初始化一个与原始张量形状相同、元素全为0的张量
    result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype)
    # 将平铺的张量重新变形为[B, C, W, H],因为操作是沿最后一个维度收集的,需要调整形状并交换维度
    tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2)
    # 使用scatter_将元素根据索引放回原位
    result_tensor.scatter_(3, expanded_index, tensor_reshaped)
    return result_tensor

class CrossScan(torch.autograd.Function):
    # ZSJ 这里是把图像按照特定方向展平的地方,改变扫描方向可以在这里修改
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        B, C, H, W = x.shape
        ctx.shape = (B, C, H, W)
        # xs = x.new_empty((B, 4, C, H * W))
        xs = x.new_empty((B, 8, C, H * W))
        # 添加横向和竖向的扫描
        xs[:, 0] = x.flatten(2, 3)
        xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
        xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
    
        # 提供斜向和反斜向的扫描
        xs[:, 4] = diagonal_gather(x)
        xs[:, 5] = antidiagonal_gather(x)
        xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1])

        return xs
    
    @staticmethod
    def backward(ctx, ys: torch.Tensor):
        # out: (b, k, d, l)
        B, C, H, W = ctx.shape
        L = H * W
        # 把横向和竖向的反向部分再反向回来,并和原来的横向和竖向相加
        # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
        y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
        # 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式
        # y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
        y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
        y_rb = y_rb.view(B, -1, H, W)

        # 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加
        y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, -1, L)
        # 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加
        y_da = diagonal_scatter(y_da[:, 0], (B,C,H,W)) + antidiagonal_scatter(y_da[:, 1], (B,C,H,W))

        y_res = y_rb + y_da
        # return y.view(B, -1, H, W)
        return y_res


class CrossMerge(torch.autograd.Function):
    @staticmethod
    def forward(ctx, ys: torch.Tensor):
        B, K, D, H, W = ys.shape
        ctx.shape = (H, W)
        ys = ys.view(B, K, D, -1)
        # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
        # y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)

        y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
        # 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式
        y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
        y_rb = y_rb.view(B, -1, H, W)

        # 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加
        y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, D, -1)
        # 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加
        y_da = diagonal_scatter(y_da[:, 0], (B,D,H,W)) + antidiagonal_scatter(y_da[:, 1], (B,D,H,W))

        y_res = y_rb + y_da
        return y_res.view(B, D, -1)
        # return y
    
    @staticmethod
    def backward(ctx, x: torch.Tensor):
        # B, D, L = x.shape
        # out: (b, k, d, l)
        H, W = ctx.shape
        B, C, L = x.shape
        # xs = x.new_empty((B, 4, C, L))
        xs = x.new_empty((B, 8, C, L))

        # 横向和竖向扫描
        xs[:, 0] = x
        xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3)
        xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
        # xs = xs.view(B, 4, C, H, W)

        # 提供斜向和反斜向的扫描
        xs[:, 4] = diagonal_gather(x.view(B,C,H,W))
        xs[:, 5] = antidiagonal_gather(x.view(B,C,H,W))
        xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1])

        # return xs
        return xs.view(B, 8, C, H, W)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/552946.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Vue的生命周期的详解

Vue的生命周期是每个使用Vue框架的前端人员都需要掌握的知识,以此作为记录。 Vue的生命周期就是vue实例从创建到销毁的全过程,也就是new Vue() 开始就是vue生命周期的开始。Vue 实例有⼀个完整的⽣命周期,也就是从开始创建、初始化数据、编译…

基于51单片机点滴输液控制系统LCD显示( proteus仿真+程序+设计报告+讲解视频)

基于51单片机点滴输液控制系统LCD显示 1. 主要功能:2. 讲解视频:3. 仿真设计4. 程序代码5. 设计报告6. 设计资料内容清单&&下载链接 基于51单片机点滴输液控制系统LCD显示( proteus仿真程序设计报告讲解视频) 仿真图proteus7.8及以上…

基于开源IM即时通讯框架MobileIMSDK:RainbowChat v11.5版已发布

关于MobileIMSDK MobileIMSDK 是一套专门为移动端开发的开源IM即时通讯框架,超轻量级、高度提炼,一套API优雅支持UDP 、TCP 、WebSocket 三种协议,支持iOS、Android、H5、小程序、Uniapp、标准Java平台,服务端基于Netty编写。 工…

09 SQL进阶 -- SQL高级处理 -- 窗口函数等

1. 窗口函数 1.1 窗口函数概念及基本的使用方法 窗口函数也称为 OLAP 函数。OLAP 是 OnLine AnalyticalProcessing 的简称,意思是对数据库数据进行实时分析处理。 为了便于理解,称之为窗口函数。常规的 SELECT 语句都是对整张表进行查询,而窗口函数可以让我们有选择的去某…

创建一个javascript公共方法的npm包,js-tool-big-box,发布到npm上,一劳永逸

前端javascript的公共方法太多了,时间日期的,数值的,字符串的,搞复制的,搞网络请求的,搞数据转换的,几乎就是每个新项目,有的拷一拷,没有的继续写,放个utils目…

每日一题---OJ题: 用队列实现栈

片头 嗨! 小伙伴们,大家好! 今天我们一起来看看这道OJ题---用队列实现栈,话不多说,我们马上开始~ emmm,题目看似有点长,我们一起来分析分析! 我们都知道,栈的特点是后进先出(Last In First Out,简称 LIFO ),队列的特点是先进先出(First In First Out 简称 FIFO),明明两者的性…

【记录】Python|Selenium 下载 PDF 不预览不弹窗(2024年)

版本: Chrome 124Python 12Selenium 4.19.0 版本与我有差异不要紧,只要别差异太大比如 Chrome 用 57 之前的版本了,就可以看本文。 如果你从前完全没使用过、没安装过Selenium,可以参考这篇博客《【记录】Python3|Sele…

求π的近似值(C语言)

一、N-S流程图&#xff1b; 二、运行结果&#xff1b; 三、源代码&#xff1b; # define _CRT_SECURE_NO_WARNINGS # include <stdio.h> # include <math.h>int main() {//初始化变量值&#xff1b;int symbol 1;double denominator 1.0, sum 0, term 1.0;//循…

Excel文件解析(Java)

一、概述 在应用程序的开发过程中&#xff0c;经常需要使用 Excel文件来进行数据的导入或导出。所以&#xff0c;在通过Java语言实现此类需求的时候&#xff0c;往往会面临着Excel文件的解析(导入&#xff09;或生成&#xff08;导出)。 在Java技术生态圈中&#xff0c…

JVM运行时内存溢出以及解决办法

JVM有哪些运行时数据区 JVM运行时数据区有程序计数器、本地方法栈虚拟机栈、堆、元数据区、直接内存。 其中只有程序计数器不是内存溢出&#xff0c;其他的都有可能会产生内存溢出。 栈内存溢出 当方法的调用深度过深&#xff0c;可能会导致栈内存溢出。 一般是发生在递归调…

Elasticsearch:如何将 MongoDB 数据引入 Elastic Cloud

作者&#xff1a;Hemendra Singh Lodhi Elastic Cloud 是由 Elastic 提供的基于云的托管服务。Elastic Cloud 允许客户在亚马逊网络服务 (AWS)、谷歌云平台 (GCP) 和微软 Azure 上部署、管理和扩展他们的 Elasticsearch 集群。 MongoDB 是一种流行的 NoSQL 文档导向数据库&am…

IDEA最好用插件推荐

1 背景 俗话说&#xff1a;“工欲善其事必先利其器”&#xff0c;本问介绍几款强大实用的 IDEA 插件&#xff0c;助力大家开发。 希望大家做一个聪明又努力的人&#xff0c;而不只是一个努力的人。 以下插件大都可以通过 IDEA 自带的插件管理中心安装&#xff0c;如果搜不到可以…

算法|最大堆、最小堆和堆排序的实现(JavaScript)

一些概念 堆&#xff1a;特殊的完全二叉树&#xff0c;具有特定性质的完全二叉树。大根堆&#xff1a;父节点 > 子节点小根堆&#xff1a;父节点 < 子节点 二叉堆也属于完全二叉树&#xff0c;所以可以用数组表示。 若下标从1开始&#xff0c;左节点为 2*i &#xff0…

类和对象-封装-设计案例1-立方体类

#include<bits/stdc.h> using namespace std; class Cube{public://设置长void setL(int l){m_Ll;} //获取长int getL(){return m_L;}//设置宽 void setW(int w){m_Ww;}//获取宽 int getW(){return m_W;}//设置高 void setH(int h){m_Hh;}//获取高int getH(){return m_H;…

【机器学习300问】72、神经网络的隐藏层数量和各层神经元节点数如何影响模型的表现?

评估深度学习的模型的性能依旧可以用偏差和方差来衡量。它们反映了模型在预测过程中与理想情况的偏离程度&#xff0c;以及模型对数据扰动的敏感性。我们简单回顾一下什么是模型的偏差和方差&#xff1f; 一、深度学习模型的偏差和方差 偏差&#xff1a;衡量模型预测结果的期望…

JAVAEE—UDP协议TCP协议/三次握手四次挥手

文章目录 UDP协议UDP协议的段格式UDP的传输过程校验和无连接 TCP协议TCP报文的格式段有连接可靠性确认应答超时重传如果ACK丢了呢&#xff1f; 序号和确认序号 连接的构建和断开连接的构建&#xff08;三次握手&#xff09;三次握手的作用为什么握手是三次&#xff0c;而不是四…

微信小程序的常用API ①

前言&#xff1a;什么是微信小程序的API&#xff1f; &#xff08;1&#xff09;微信小程序的API是由宿主环境提供的。通俗来说API是一种接口函数&#xff0c;把函数封装起来给开发者使用&#xff0c;这样好多功能都无需开发者去实现&#xff0c;直接调用即可。 &#xff08;…

工业电脑在ESOP工作站行业应用

ESOP工作站行业应用 项目背景 E-SOP是实现作业指导书电子化&#xff0c;并统一管理和集中控制的一套管理信息平台。信迈科技的ESOP终端是一款体积小巧功能齐全的高性价比工业电脑&#xff0c;上层通过网络与MES系统连接&#xff0c;下层连接显示器展示作业指导书。ESOP控制终…

Covalent Network(CQT)宣布推出面向 Cronos 生态的捐赠计划与 API 积分,为 Web3 创新赋能

为了促进 Web3 领域的创新&#xff0c;Covalent Network&#xff08;CQT&#xff09;宣布将其捐赠计划向 Cronos 生态系统中的开发者拓展。这一战略性举措&#xff0c;旨在通过向 Cronos 网络中基于 Covalent Network&#xff08;CQT&#xff09;API 构建的项目提供支持和资源&…

OpenHarmony实战开发-如何使用Navigation实现多设备适配。

介绍 在应用开发时&#xff0c;一个应用需要适配多终端的设备&#xff0c;使用Navigation的mode属性来实现一套代码&#xff0c;多终端适配。 效果图预览 使用说明 将程序运行在折叠屏手机或者平板上观看适配效果。 实现思路 本例涉及的关键特性和实现方案如下&#xff1a…
最新文章