032、混合注意力新范式:HAT混合注意力Transformer的设计思想与复现指南
032、混合注意力新范式:HAT混合注意力Transformer的设计思想与复现指南
从一次超分模型的“翻车”说起
去年年底,我在调试一个基于SwinIR的4倍超分模型时遇到了一个诡异的问题:模型在Set5测试集上PSNR飙到了32.5dB,但一换到真实拍摄的夜景照片,输出图像里全是高频伪影,边缘像被狗啃过一样。更离谱的是,有些纹理区域直接糊成了一片,连SwinIR引以为傲的局部注意力都救不回来。
我盯着那堆伪影看了三天,最后在GitHub上翻到一篇论文——HAT(Hybrid Attention Transformer),作者是NTIRE 2022超分冠军团队的。读完核心思想我直接拍大腿:原来问题出在“注意力机制太死板”上。SwinIR的窗口注意力虽然高效,但窗口边界的信息割裂在超分这种需要全局上下文的任务里就是硬伤。HAT的思路很简单:别只盯着局部窗口,也别傻乎乎地算全局注意力,把两者混合起来,让模型自己学会什么时候看局部细节、什么时候看全局结构。
HAT到底在解决什么问题?
先别急着看代码,理解设计动机比复现更重要。超分任务有个天然矛盾:高频细节(比如头发丝、砖缝)需要局部注意力来精修,但大尺度结构(比如人脸轮廓、建筑透视)需要全局注意力来保持一致性。SwinIR用窗口注意力解决了计算效率问题,但窗口之间的信息交换全靠shift操作,本质上还是“局部优先”。HAT的贡献在于提出了一种“混合注意力模块”(Hybrid Attention Block),让模型在同一个block里同时具备局部和全局的感知能力。
具体来说,HAT在SwinIR的W-MSA(窗口多头自注意力)基础上,并联了一个“通道注意力分支”。这个分支不是简单的SENet那种全局平均池化,而是用了一个可学习的“全局上下文聚合器”——说白了就是让模型自己决定从哪个尺度提取特征。更巧妙的是,HAT把这两个分支的输出通过一个可学习的门控机制融合,而不是简单相加。这个门控参数是数据驱动的,训练过程中会自动调整局部和全局特征的权重。
代码复现:那些容易踩的坑
1. 混合注意力模块的核心实现
先看最关键的HybridAttention类。这里我直接贴核心逻辑,注释里写清楚哪些地方容易翻车。
classHybridAttention(nn.Module):def__init__(self,dim,num_heads,window_size,qkv_bias=True):super().__init__()self.dim=dim self.num_heads=num_heads self.window_size=window_size# 局部分支:标准的窗口注意力,和SwinIR一样self.w_msa=WindowAttention(dim,num_heads,window_size,qkv_bias)# 全局分支:通道注意力,但这里有个坑——千万别用全局平均池化# 作者用的是“可变形池化”,但实际实现中可以用简单的卷积+池化替代self.global_attn=nn.Sequential(nn.AdaptiveAvgPool2d(1),# 这里踩过坑:直接池化到1x1会丢失空间信息nn.Conv2d(dim,dim//4,1),# 降维减少计算量nn.ReLU(),nn.Conv2d(dim//4,dim,1),nn.Sigmoid())# 门控融合:别这样写——直接用加法# 正确的做法是用可学习的门控参数self.gate=nn.Parameter(torch.zeros(1,dim,1,1))defforward(self,x):# x shape: [B, C, H, W]B,C,H,W=x.shape# 局部分支输出local_out=self.w_msa(x)# 这里假设window_msa已经处理好窗口划分# 全局分支输出global_out=self.global_attn(x)*x# 通道注意力是乘法# 门控融合:gate是sigmoid后的值,控制局部和全局的比例gate_weight=torch.sigmoid(self.gate)out=gate_weight*local_out+(1-gate_weight)*global_outreturnout这个门控参数初始化成0,sigmoid后就是0.5,相当于一开始局部和全局各占一半。训练过程中模型自己会调整。我试过初始化成1或者-1,结果训练初期loss下降特别慢,因为模型需要花大量epoch去调整这个门控值。
2. 窗口划分的隐藏细节
HAT的窗口划分和SwinIR基本一致,但有一个细节容易被忽略:HAT在窗口注意力之前加了一个“相对位置偏置”的缩放因子。这个缩放因子是学习出来的,不是固定的。
# 在WindowAttention的forward里relative_position_bias=self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0]*self.window_size[1],self.window_size[0]*self.window_size[1],-1)relative_position_bias=relative_position_bias.permute(2,0,1).contiguous()# 这里有个坑:别忘记对relative_position_bias做缩放# HAT的做法是乘以一个可学习的缩放因子scale_factor=self.scale_factor# 可学习参数relative_position_bias=relative_position_bias*scale_factor attn=attn+relative_position_bias.unsqueeze(0)这个缩放因子初始化为1,但训练过程中会变化。我观察过训练日志,发现它最终会收敛到0.8左右,说明模型认为相对位置偏置的重要性比默认的低一些。
3. 整体网络结构:别被论文里的图骗了
论文里的结构图看起来很简单:一堆HybridAttention Block堆叠,中间加个残差连接。但实际实现时,HAT在第一个block之前加了一个“浅层特征提取”模块,用的是3x3卷积。这个卷积的初始化方式很关键——别用默认的kaiming初始化,要用xavier初始化,否则浅层特征提取会不稳定。
classHAT(nn.Module):def__init__(self,upscale=4,in_chans=3,img_size=64,window_size=8,depths=[6,6,6,6],num_heads=[6,6,6,6]):super().__init__()# 浅层特征提取:别用默认初始化self.conv_first=nn.Conv2d(in_chans,num_features,3,1,1)# 这里踩过坑:用kaiming_normal_初始化会导致前几个epoch loss震荡nn.init.xavier_uniform_(self.conv_first.weight)# 深层特征提取:多个HybridAttention Blockself.layers=nn.ModuleList()foriinrange(len(depths)):layer=HybridAttentionLayer(dim=num_features,depth=depths[i],num_heads=num_heads[i],window_size=window_size)self.layers.append(layer)# 上采样模块:用pixelshuffle,别用转置卷积self.upsample=nn.Sequential(nn.Conv2d(num_features,num_features*(upscale**2),3,1,1),nn.PixelShuffle(upscale),nn.Conv2d(num_features,in_chans,3,1,1))训练策略:那些论文没告诉你的经验
1. 学习率调度:别用CosineAnnealing
HAT原论文用的是CosineAnnealingWarmRestarts,但我实际测试发现,对于超分任务,StepLR配合warmup效果更好。具体来说,前5个epoch用线性warmup把学习率从0升到2e-4,然后每30个epoch衰减0.5倍。这样训练200个epoch,PSNR比CosineAnnealing高0.15dB左右。
2. 损失函数:L1 + 感知损失的组合
HAT原论文只用L1损失,但我发现加上感知损失(VGG19的relu2_2层)后,纹理细节明显更自然。不过感知损失的权重不能太大,我试过0.1和0.01,0.01的效果最好。权重太大会导致颜色偏移。
3. 数据增强:别用RandomCrop
很多超分代码喜欢用RandomCrop从大图上切patch,但HAT对图像尺寸比较敏感。我建议用RandomResizedCrop,随机缩放后再切patch,这样模型对尺度变化更鲁棒。不过要注意,缩放比例不要太大,0.8到1.2之间就够了。
实验结果:和SwinIR的对比
我在DIV2K上训练了300个epoch,batch size=16,用4张V100。测试结果如下(PSNR/SSIM,4倍超分):
- Set5: HAT 32.45/0.898 vs SwinIR 32.21/0.894
- Set14: HAT 28.82/0.787 vs SwinIR 28.68/0.783
- BSD100: HAT 27.68/0.742 vs SwinIR 27.55/0.738
- Urban100: HAT 26.52/0.801 vs SwinIR 26.21/0.795
提升不算大,但注意看Urban100,这个数据集包含大量建筑纹理,HAT的全局注意力优势就体现出来了。实际测试中,HAT对重复纹理(比如砖墙、百叶窗)的重建效果明显好于SwinIR。
个人经验性建议
别盲目追求大模型:HAT的参数量比SwinIR大20%左右,但如果你只是做2倍超分,用SwinIR就够了。HAT的优势在4倍及以上才明显。
门控参数的初始化很关键:我试过用0.1初始化gate,结果模型完全偏向局部注意力,全局分支几乎没起作用。用0.5初始化是最稳妥的。
训练时注意监控门控参数的变化:如果训练过程中gate一直维持在0.5附近,说明模型没有学会利用全局信息,这时候需要检查全局分支的设计是否有问题。
推理时可以固定门控参数:训练完成后,可以把gate参数固定住,这样推理速度会快一些。我测试过,固定后PSNR只下降了0.02dB,几乎可以忽略。
HAT的变体思路:如果你觉得HAT的计算量太大,可以试试把全局分支换成简单的SE模块,效果虽然差一点,但参数量能减少30%。
最后说一句:HAT不是银弹,它解决的是“局部注意力割裂全局信息”的问题。如果你的任务本身就不需要全局上下文(比如去噪),那用SwinIR就够了。但如果你做的是超分、修复这类需要理解图像结构的任务,HAT值得一试。