YOLO多模态融合 | 从 DEA 到 DEFA:动态卷积+交叉注意力的创新融合

本教程基线代码为开源项目 YOLOFuse

在这里插入图片描述

请注意:并非在所有数据集上都能带来性能提升。DEFA 模块是我基于自身思路改进的——在您的数据集上是否有效,还需您自行实验验证,无法保证一定会有所增益。

一、背景与动机

在多模态目标检测场景中,RGB(可见光)与 IR(红外)两路数据往往具有互补信息。原 DEA(DEYOLO 中的 Dynamic Enhanced Attention)模块通过两个子模块 DECA(Dual-Enhancement Convolutional Attention)和 DEPA(Dual-Enhancement Pyramid Attention)对特征进行通道与空间上的交互融合,并在末端以简单相加的方式输出融合结果。尽管 DEA 在一定程度上加强了跨模态特征的交互,但也存在以下局限:

  • 融合策略固定:DECA 与 DEPA 中的卷积核、池化方式均为预设静态结构,难以自适应不同场景下的空间感受野需求。
  • 注意力机制单一:缺乏全局上下文信息建模,无法充分抓取长程依赖。
  • 融合权重耦合:通道间权重一旦计算,与空间上下文仍存在耦合,融合灵活性不足。

为此,提出 DEFA 模块,引入动态卷积、自适应多头交叉注意力与门控融合机制,旨在更精准地捕捉跨模态互补信息,并提升融合后特征的表达能力。


二、DEFA 核心设计与改进思路

2.1 动态核参数学习

传统静态卷积核对不同尺度、不同场景下的特征响应具有局限;DEFA 首先在两路特征的全局池化基础上,通过一系列 1×1 卷积与非线性激活,生成维度为 (k^2) 的动态卷积权重分布(Softmax 归一化),并在后续的 unfold/fold 操作中,将这一组可学习的卷积核系数,应用于融合后的中间特征,实现自适应的空间感受野调整。

2.2 交叉注意力交互

为进一步建模 RGB ↔ IR 间的长程依赖,DEFA 将两路特征重排为序列(B×(H·W)×C),并调用 PyTorch 的 MultiheadAttention 实现跨路交叉注意力:

vi_attn = rearrange(vi_feat, 'b c h w -> b (h w) c')
ir_attn = rearrange(ir_feat, 'b c h w -> b (h w) c')
fused_seq, _ = self.cross_attn(vi_attn, ir_attn, ir_attn)
fused_feat = rearrange(fused_seq, 'b (h w) c -> b c h w', h=H)

这样,IR 特征可为 RGB 提供全局上下文引导,反之亦然,显著增强跨模态的信息流动。

2.3 门控融合策略

融合后特征既要保留原始模态信息,又要引入对方的补充,故设计两通道门控网络:

gate = nn.Sequential(Conv2d(2C → C/r), ReLU,Conv2d(C/r → 2), Sigmoid
)
w_vi, w_ir = gate(concat([vi_feat, ir_feat]))
out_vi = w_ir * dynamic_feat + vi_feat
out_ir = w_vi * dynamic_feat + ir_feat

两个尺度不同的门控权重 wvi,wirw_{vi}, w_{ir}wvi,wir 分别控制融合特征对原始 RGB/IR 的增强比例,既可抑制噪声,也可动态放大互补信息。

2.4 特征增强模块

最后,融合输出再经过一次标准卷积 + BatchNorm + SiLU,用于平滑融合特征并做最后的非线性映射,提升特征表达能力。


三、DEFA 模块实现细节

class DEFA(nn.Module):def __init__(self, channels=512, reduction=16, num_heads=8, dynamic_k=5):super().__init__()# 1. 动态核参数学习self.dynamic_conv = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(channels, 4*channels, 1), ReLU(),nn.Conv2d(4*channels, dynamic_k**2, 1),nn.Softmax(dim=1))# 2. 交叉注意力self.cross_attn = MultiheadAttention(embed_dim=channels, num_heads=num_heads, batch_first=True)# 3. 门控融合self.gate = nn.Sequential(Conv2d(2*channels, channels//reduction, 3, padding=1),ReLU(),Conv2d(channels//reduction, 2, 3, padding=1),Sigmoid())# 4. 特征增强self.enhance = nn.Sequential(Conv2d(channels, channels, 3, padding=1),nn.BatchNorm2d(channels),nn.SiLU())self.dynamic_k = dynamic_kdef forward(self, x):vi, ir = x# 动态卷积权重kw = self.dynamic_conv(vi + ir)               # [B, k^2, 1, 1]# 交叉注意力融合vi_seq = rearrange(vi, 'b c h w -> b (h w) c')ir_seq = rearrange(ir, 'b c h w -> b (h w) c')fused_seq, _ = self.cross_attn(vi_seq, ir_seq, ir_seq)fused = rearrange(fused_seq, 'b (h w) c -> b c h w', h=vi.size(2))# 动态卷积操作U = F.unfold( fused, kernel_size=self.dynamic_k, padding=self.dynamic_k//2 )B, CK2, N = U.shapekw_expand = kw.view(B, self.dynamic_k**2, 1).expand(-1, -1, N)dyn_feat = torch.einsum('bkn,bckn->bcn', kw_expand, U.view(B, -1, self.dynamic_k**2, N))dyn_feat = F.fold(dyn_feat, output_size=vi.shape[2:], kernel_size=1)# 门控融合g = self.gate(torch.cat([vi, ir], dim=1))     # [B,2,H,W]w_vi, w_ir = g[:,0:1], g[:,1:2]out_vi = w_ir * dyn_feat + viout_ir = w_vi * dyn_feat + ir# 最终特征增强return self.enhance(out_vi), self.enhance(out_ir)

四、相对于 DEA 的优势

  1. 自适应感受野:动态卷积核可根据当前 RGB+IR 特征自动调整,替代固定卷积核的局限。
  2. 全局长程依赖:多头交叉注意力注入了全局上下文,使模态间交互更充分。
  3. 灵活门控融合:双通道门控直接控制融合强度,更好抑制冗余信息。

五、YOLOFuse 中添加 DEFA 模块的方法

  1. YOLOFuse/ultralytics/nn/modules/layers 路径下新建文件DEFA.py放入以下代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import MultiheadAttention
from einops import rearrangeclass DEFA(nn.Module):"""Dynamic Enhanced Fusion Attention"""def __init__(self, channels=512, reduction=16, num_heads=8, dynamic_k=5):super().__init__()# 1. 动态核参数学习self.dynamic_conv = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(channels, 4 * channels, 1),nn.ReLU(inplace=True),nn.Conv2d(4 * channels, dynamic_k**2, 1),nn.Softmax(dim=1))# 2. 交叉注意力机制self.cross_attn = MultiheadAttention(embed_dim=channels,num_heads=num_heads,batch_first=True)# 3. 门控融合(hidden 最少为 1)hidden = max(channels // reduction, 1)self.gate = nn.Sequential(nn.Conv2d(2 * channels, hidden, 3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(hidden, 2, 3, padding=1),nn.Sigmoid())# 4. 特征增强self.enhance = nn.Sequential(nn.Conv2d(channels, channels, 3, padding=1),nn.BatchNorm2d(channels),nn.SiLU(inplace=True))self.dynamic_k = dynamic_kdef forward(self, x):vi_feat, ir_feat = x  # 两路输入特征# --- 动态核参数学习 ---kernel_weights = self.dynamic_conv(vi_feat + ir_feat)  # [B, k^2, 1, 1]b, k2, _, _ = kernel_weights.size()k = self.dynamic_k# --- 交叉注意力交互 ---vi_seq = rearrange(vi_feat, 'b c h w -> b (h w) c')ir_seq = rearrange(ir_feat, 'b c h w -> b (h w) c')fused_seq, _ = self.cross_attn(vi_seq, ir_seq, ir_seq)fused_feat = rearrange(fused_seq, 'b (h w) c -> b c h w', h=vi_feat.size(2))# --- 动态卷积融合 ---U = F.unfold(fused_feat,kernel_size=k,padding=k // 2)  # [B, C * k^2, N]B, CK2, N = U.shape# 重新 reshape 权重并扩展到 Nkw = kernel_weights.view(B, k * k, 1).expand(-1, -1, N)# 爱因斯坦求和融合dynamic_feat = torch.einsum('bkn,bckn->bcn', kw, U.view(B, -1, k * k, N))dynamic_feat = F.fold(dynamic_feat,output_size=vi_feat.shape[2:],kernel_size=1,stride=1)  # [B, C, H, W]# --- 门控融合 ---gates = self.gate(torch.cat([vi_feat, ir_feat], dim=1))  # [B, 2, H, W]w_vi, w_ir = gates[:, 0:1], gates[:, 1:2]out_vi = w_ir * dynamic_feat + vi_featout_ir = w_vi * dynamic_feat + ir_feat# --- 融合两路输出 & 特征增强 ---fused = out_vi + out_ir  # 将两路融合return self.enhance(fused)if __name__ == "__main__":# 测试维度 (B=2, C=512, H=80, W=80)vi = torch.randn(2, 512, 80, 80)ir = torch.randn(2, 512, 80, 80)model = DEFA(channels=512, reduction=16, num_heads=8, dynamic_k=5)out = model((vi, ir))print(f"Output shape: {out.shape}")  # 输出应为 (2, 512, 80, 80)
  1. YOLOFuse/ultralytics/nn/tasks.py 中导包:
from ultralytics.nn.modules.layers.DEFA import DEFA

并以 nn.BatchNorm2d 为参照物添加以下代码

        elif m is DEFA:c1, c2 = ch[f[0]], args[0]if c2 != nc:c2 = make_divisible(min(c2, max_channels) * width, 8)args = [c1, *args[1:]]

在这里插入图片描述

  1. 新建一个新的 yaml 文件,使用这个 yaml 就可以开始训练了~
# Parameters
nc: 80 # number of classes
ch: 6
scales: # model compound scaling constants# [depth, width, max_channels]n: [0.33, 0.25, 1024]s: [0.33, 0.50, 1024]m: [0.67, 0.75, 768]l: [1.00, 1.00, 512]x: [1.00, 1.25, 512]# DEYOLO backbone
backbone:- [-1, 1, IdentityInput, []] # 0- [-1, 1, ModalitySelector, [1]] # 1 RGB- [-2, 1, ModalitySelector, [2]] # 2 IR# [from, repeats, module, args]- [1, 1, Conv, [64, 3, 2]] # 3-P1/2- [-1, 1, Conv, [128, 3, 2]] # 4-P2/4- [-1, 3, C2f_BiFocus, [128, True]]- [-1, 1, Conv, [256, 3, 2]] # 6-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]] # 8-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]] # 10-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]] # 12- [2, 1, Conv, [64, 3, 2]] # 13-P1/2- [-1, 1, Conv, [128, 3, 2]] # 14-P2/4- [-1, 3, C2f_BiFocus, [128, True]]- [-1, 1, Conv, [256, 3, 2]] # 16-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]] # 18-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]] # 20-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]] # 22# DEYOLO head
head:- [[7, 17], 1, DEFA, [256, 80]] # 23- [[9, 19], 1, DEFA, [512, 40]] # 24- [[12, 22], 1, DEFA, [1024, 20]] # 25- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 24], 1, Concat, [1]] # cat backbone P4- [-1, 3, C2f, [512]] # 28- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 23], 1, Concat, [1]] # cat backbone P3- [-1, 3, C2f, [256]] # 31 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 28], 1, Concat, [1]] # cat head P4- [-1, 3, C2f, [512]] # 34 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 25], 1, Concat, [1]] # cat head P5- [-1, 3, C2f, [1024]] # 37 (P5/32-large)- [[31, 34, 37], 1, Detect, [nc]] # Detect(P3, P4, P5)
  1. 修改训练脚本YOLO() 里面的路径为 yaml文件路径
from ultralytics import YOLOif __name__ == "__main__":model = YOLO("YOLOFuse/ultralytics/cfg/models/fuse/DEFA.yaml")model.train(data="ultralytics/cfg/datasets/LLVIP.yaml",ch=6, # 多模态时设置为 6 ,单模态时设置为 3imgsz=640,epochs=100,batch=256,close_mosaic=0,workers=16,device="0,1,2,3,4,5,6,7",optimizer="SGD",patience=0,amp=False,cache=True, # disk 硬盘,速度稍快精度可复现;ram/True 内存,速度快但精度不可复现project="runs/train",name="DEFA",resume=False,fraction=1, # 只用全部数据的 ?% 进行训练 (0.1-1))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/2622.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于SEP3203微处理器的嵌入式最小硬件系统设计

目录 1 引言 2 嵌入式最小硬件系统 3 SEP3202简述 4 最小系统硬件的选择和单元电路的设计 4.1 电源电路 4.2 晶振电路 4.3 复位及唤醒电路 4.5 存储器 4.5.1 FLASH存储 4.5.2 SDRAM 4.6 串行接口电路设计 4.7 JTAG模块 4.8 扩展功能(LED) …

PCIe RAS学习专题(3):AER内核处理流程梳理

目录 一、AER内核处理整体流程梳理 二、AER代码重要部分梳理 1、AER初始化阶段 2、中断上半部 aer_irq 3、中断下半部 aer_isr 3.1、aer_isr_one_error 3.2、find_source_device 3.3、aer_process_err_devices 3.4、handle_error_source 3.5、pcie_do_recovery 整体逻…

Window延迟更新10000天配置方案

1.点击"开始"菜单,搜索"注册表编辑器",点击"打开"。2.找到"\HKEY LOCAL MACHINE\SOFTWARE\Microsoft\WindowsUpdate\Ux\Settings"路径。3.右面空白处右键新建一个32位值,命名为FlightSettingsMaxPau…

TCP/IP 哲学:端到端的 Postel 定律

实际上这是互联网哲学,但 TCP/IP 是互联网的事实标准,也是互联网的唯一实例,因此 TCP/IP 等同于互联网。 我写过很多 TCP/IP 发展史的随笔,于宏观,我希望理解互联网何以至此,于微观,希望理解 TC…

Linux下使用原始socket收发数据包

在Linux系统中,使用非原始的socket,可以收发TCP或者UDP等网络层数据包。如果要处理网络层以下的数据包,比如ICMP、ARP等,或者更底层,比如链路层数据包,就得使用原始socket了。 创建socket 创建socket要使用…

cocosCreator2.4 Android 输入法遮挡

这里是 调用显示系统的输入法,然后在 Cocos2dxEditBox.java 创建UI,用于处理输入,这里可以看到会ui 会被系统的输入法遮挡,无法点击,是因为 计算ui位置时没有算上刘海区域,需要处理一下: private int getTo…

7.18 Java基础 |

以下内容,参考Java 教程 | 菜鸟教程,下边是我边看边记的内容,以便后续复习使用。 多态: 继承,接口就是多态的具体体现方式。生物学上,生物体或物质可以具有许多不同的形式或者阶段。 多态分为运行时多态&…

【Lua】闭包可能会导致的变量问题

先思考下面这个问题:local function counter()local count 0return function()count count 1return countend endlocal a counter() local b counter()print(a()) --> ? print(a()) --> ? print(b()) --> ? print(a()) --> ?输出结果&#xff…

网络基础12--可靠性概述及要求

一、可靠性基础概念定义可靠性(Availability) MTBF / (MTBF MTTR)MTBF(平均无故障时间):衡量系统稳定性的指标(如1年)。MTTR(平均修复时间):衡量故障响应与…

【Dv3Admin】菜单管理集成阿里巴巴自定义矢量图标库

图标选择是后台管理系统中高频功能。相比用 Element UI、Ant Design 等自带的 icon 集,阿里巴巴 iconfont.cn 支持上传和管理自定义图标,并生成矢量字体,便于统一维护和扩展。 本文目标是支持自定义 iconfont 图标的展示和选择,并…

有n棍棍子,棍子i的长度为ai,想要从中选出3根棍子组成周长尽可能长的三角形。请输出最大的周长,若无法组成三角形则输出0。

题目描述: 有n棍棍子,棍子i的长度为ai,想要从中选出3根棍子组成周长尽可能长的三角形。请输出最大的周长,若无法组成三角形则输出0。 算法为O(nlogn) 初始理解题目 首先,我们需要清楚地理解题目要求: 输入…

企业级网络综合集成实践:VLAN、Trunk、STP、路由协议(OSPF/RIP)、PPP、服务管理(TELNET/FTP)与安全(ACL)

NE综合实验4 一、实验拓扑二、实验需求 按照图示配置IP地址。Sw7和sw8之间的直连链路配置链路聚合。公司内部业务网段为vlan10和vlan20,vlan10是市场部,vlan20是技术部,要求对vlan进行命名以便区分识别;pc10属于vlan10&#xff0c…