Haar小波下采样模块

论文原址:Haar wavelet downsampling: A simple but effective downsampling module for semantic segmentation - ScienceDirect

原文代码:HWD/HWD.py at main · apple1986/HWD (github.com)

介绍 

深度卷积神经网络 (DCNN) 通常采用标准的下采样操作,例如最大池化、平均池化和跨步卷积,这可能会导致信息丢失。丢失的信息,如边界和纹理,对于语义分割可能是必不可少的。为了缓解这个问题,一般有下面四种方法:

  1. 通过跳过连接到解码器子网(如U-Net、LCU-Net、CENet、LinkNet和RefineNet )。
  2. 提取具有空间金字塔池化或扩展卷积的多尺度特征图到融合模块中(如DeepLab、PSPNet、PCPLP-Net、BiSenet和ICNet)。
  3. 向编码器提供多模态图像(如DiSegNet、MMADT、CANet和CCFFNet)。
  4. 增加先验信息。轮廓增强关注模块,旨在从CT图像中提取边界和形状线索,以细化分割区域。

这些方法的主要目的是通过基于多尺度、先验指导、多模态等各种策略提供更多的学习信息或特征,帮助下采样特征与分割标签之间建立良好的关系。

因此,是否可以设计一个保留信息的下采样模块,使DCNNs中尽可能多地保留信息进行语义分割?这就是作者的想法。 

下采样模块

最大池化与平均池化

池化过程类似于卷积过程。在这个示意图中,我们看到对一个 4x4 的特征图邻域进行操作,使用了一个 2x2 的滤波器,步长为2进行扫描。这个过程被称为最大池化(Max Pooling),其中选择邻域内的最大值并输出到下一层。

常用的 max pooling 参数是 S=2、f=2,其效果是将特征图的高度和宽度减半,而通道数保持不变。

如上图所示,描述的是对一个 4x4 的特征图邻域内的数值进行操作。使用了一个 2x2 的滤波器,步长为2进行扫描,计算邻域内数值的平均值并将其输出到下一层。这种操作被称为平均池化(Mean Pooling)。

"""
Copyright (c) 2023, Auorui.
All rights reserved.

The Torch implementation of average pooling and maximum pooling has been compared with the official Torch implementation
"""
import torch
import torch.nn as nn

__all__ = ["MaxPool2d", "AvgPool2d"]

class MaxPool2d(nn.Module):
    """
    池化层计算公式:
        output_size = [(input_size−kernel_size) // stride + 1]
    """
    def __init__(self, kernel_size, stride):
        super(MaxPool2d, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride

    def max_pool2d(self, input_tensor, kernel_size, stride):
        batch_size, channels, height, width = input_tensor.size()
        output_height = (height - kernel_size) // stride + 1
        output_width = (width - kernel_size) // stride + 1
        output_tensor = torch.zeros(batch_size, channels, output_height, output_width)

        for i in range(output_height):
            for j in range(output_width):
                # 获取输入张量中与池化窗口对应的部分
                window = input_tensor[:, :,
                         i * stride: i * stride + kernel_size, j * stride: j * stride + kernel_size]
                output_tensor[:, :, i, j] = torch.max(window.reshape(batch_size, channels, -1), dim=2)[0]
        return output_tensor

    def forward(self, input_tensor):
        return self.max_pool2d(input_tensor, kernel_size=self.kernel_size, stride=self.stride)


class AvgPool2d(nn.Module):
    """
    池化层计算公式:
        output_size = [(input_size−kernel_size) // stride + 1]
    """
    def __init__(self, kernel_size, stride):
        super(AvgPool2d, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride

    def avg_pool2d(self, input_tensor, kernel_size, stride):
        batch_size, channels, height, width = input_tensor.size()
        output_height = (height - kernel_size) // stride + 1
        output_width = (width - kernel_size) // stride + 1
        output_tensor = torch.zeros(batch_size, channels, output_height, output_width)

        for i in range(output_height):
            for j in range(output_width):
                # 获取输入张量中与池化窗口对应的部分
                window = input_tensor[:, :,
                         i * stride: i * stride + kernel_size, j * stride:j * stride + kernel_size]
                output_tensor[:, :, i, j] = torch.mean(window.reshape(batch_size, channels, -1), dim=2)
        return output_tensor

    def forward(self, input_tensor):
        return self.avg_pool2d(input_tensor, kernel_size=self.kernel_size, stride=self.stride)


if __name__=="__main__":
    # input_data = torch.rand((1, 3, 3, 3))
    input_data = torch.Tensor([[[[0.3939, 0.8964, 0.3681],
                               [0.5134, 0.3780, 0.0047],
                               [0.0681, 0.0989, 0.5962]],
                              [[0.7954, 0.4811, 0.3329],
                               [0.8804, 0.3986, 0.3561],
                               [0.2797, 0.3672, 0.6508]],
                              [[0.6309, 0.1340, 0.0564],
                               [0.3101, 0.9927, 0.5554],
                               [0.0947, 0.2305, 0.8299]]]])

    print(input_data.shape)

    kernel_size = 3
    stride = 1
    MaxPool2d1 = nn.MaxPool2d(kernel_size, stride)
    output_data_with_torch_max = MaxPool2d1(input_data)
    AvgPool2d1 = nn.AvgPool2d(kernel_size, stride)
    output_data_with_torch_avg = AvgPool2d1(input_data)
    AvgPool2d2 = AvgPool2d(kernel_size, stride)
    output_data_with_torch_Avg = AvgPool2d2(input_data)
    MaxPool2d2 = MaxPool2d(kernel_size, stride)
    output_data_with_torch_Max = MaxPool2d2(input_data)
    # output_data_with_max = max_pool2d(input_data, kernel_size, stride)
    # output_data_with_avg = avg_pool2d(input_data, kernel_size, stride)

    print("\ntorch.nn pooling Output:")
    print(output_data_with_torch_max,"\n",output_data_with_torch_max.size())
    print(output_data_with_torch_avg,"\n",output_data_with_torch_avg.size())
    print("\npooling Output:")
    print(output_data_with_torch_Max,"\n",output_data_with_torch_Max.size())
    print(output_data_with_torch_Avg,"\n",output_data_with_torch_Avg.size())
    # 直接使用bool方法判断会因为浮点数的原因出现偏差
    print(torch.allclose(output_data_with_torch_max,output_data_with_torch_Max))
    print(torch.allclose(output_data_with_torch_avg,output_data_with_torch_Avg))
    # tensor([[[[0.8964]],       # output_data_with_max
    #          [[0.8804]],
    #          [[0.9927]]]])
    # tensor([[[[0.3686]],       # output_data_with_avg
    #           [[0.5047]],
    #           [[0.4261]]]])

在这里,简单地与PyTorch官方的实现进行了比对,成功的进行复现。

跨步卷积

import torch
import torch.nn as nn

class StridedConvolution(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, is_relu=True):
        super(StridedConvolution, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.is_relu = is_relu

    def forward(self, x):
        x = self.conv(x)
        if self.is_relu:
            x = self.relu(x)
        return x

if __name__ == '__main__':
    input_data = torch.rand((1, 3, 64, 64))
    strided_conv = StridedConvolution(3, 64)
    output_data = strided_conv(input_data)
    print("Input shape:", input_data.shape)
    print("Output shape:", output_data.shape)

对输入进行跨步卷积,并根据 is_relu 参数选择是否添加ReLU激活函数。在构建卷积神经网络时经常被用于下采样步骤,以减小特征图的尺寸。

Haar小波下采样

这一部分就直接参考的作者的代码,与池化不同的是,这里它是要指定输入输出几个通道。

"""
Haar Wavelet-based Downsampling (HWD)

Original address of the paper: https://www.sciencedirect.com/science/article/abs/pii/S0031320323005174
Code reference: https://github.com/apple1986/HWD/tree/main
"""
import torch
import torch.nn as nn
from pytorch_wavelets import DWTForward

class HWDownsampling(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(HWDownsampling, self).__init__()
        self.wt = DWTForward(J=1, wave='haar', mode='zero')
        self.conv_bn_relu = nn.Sequential(
            nn.Conv2d(in_channel * 4, out_channel, kernel_size=1, stride=1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        yL, yH = self.wt(x)
        y_HL = yH[0][:, :, 0, ::]
        y_LH = yH[0][:, :, 1, ::]
        y_HH = yH[0][:, :, 2, ::]
        x = torch.cat([yL, y_HL, y_LH, y_HH], dim=1)
        x = self.conv_bn_relu(x)
        return x


if __name__ == '__main__':
    downsampling_layer = HWDownsampling(3, 64)
    input_data = torch.rand((1, 3, 64, 64))
    output_data = downsampling_layer(input_data)
    print("Input shape:", input_data.shape)
    print("Output shape:", output_data.shape)

Haar小波变换是一种基于小波的信号处理方法,它将信号分解成低频和细节高频两个部分。在图像处理中,Haar小波通常用于图像压缩和特征提取,代码中使用的DWTForward模块中离散小波变换,通过选择 yH 中的不同方向上的高频分量,构建了新的特征图。将原始低频分量 yL 与新构建的高频分量拼接在一起。最后通过一个包含卷积、批归一化和ReLU激活函数的序列处理最终的特征图。

实验验证

这是作者论文中做的实验,这样看起来,似乎HWD在细节上确实是比池化和跨步卷积效果要好。

这里因为我也用我自己的数据进行了实验:

最大池化效果

平均池化效果

跨步卷积效果 

HDW效果

从肉眼上来看,HDW的效果确实要比其他的效果要好一些。

下面是我做实验的代码,感兴趣的可以在自己的数据上面进行实验,我觉得用于交通和医学上应该会有比较好的效果。

import cv2
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn as nn
from pytorch_wavelets import DWTForward

class StridedConvolution(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, is_relu=True):
        super(StridedConvolution, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.is_relu = is_relu

    def forward(self, x):
        x = self.conv(x)
        if self.is_relu:
            x = self.relu(x)
        return x

class HWDownsampling(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(HWDownsampling, self).__init__()
        self.wt = DWTForward(J=1, wave='haar', mode='zero')
        self.conv_bn_relu = nn.Sequential(
            nn.Conv2d(in_channel * 4, out_channel, kernel_size=1, stride=1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        yL, yH = self.wt(x)
        y_HL = yH[0][:, :, 0, ::]
        y_LH = yH[0][:, :, 1, ::]
        y_HH = yH[0][:, :, 2, ::]
        x = torch.cat([yL, y_HL, y_LH, y_HH], dim=1)
        x = self.conv_bn_relu(x)
        return x

class DeeperCNN(nn.Module):
    def __init__(self):
        super(DeeperCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.batch_norm1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        # self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
        # self.pool1 = HWDownsampling(16, 16)
        self.pool1 = StridedConvolution(16, 16, is_relu=True)

        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.batch_norm2 = nn.BatchNorm2d(32)
        # self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
        # self.pool2 = HWDownsampling(32, 32)
        self.pool2 = StridedConvolution(32, 32, is_relu=True)

        self.conv6 = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = self.pool1(self.relu(self.batch_norm1(self.conv1(x))))
        print(x.shape)
        x = self.pool2(self.relu(self.batch_norm2(self.conv2(x))))
        print(x.shape)
        x = self.conv6(x)
        return x

image_path = r'D:\PythonProject\Crack_classification_training_script\data\base\val\crack\2416.png'
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

transform = transforms.Compose([transforms.ToTensor()])
input_image = transform(image).unsqueeze(0)
import numpy as np
model = DeeperCNN()
output = model(input_image)
print("Output shape:", output.shape)

input_image = input_image.squeeze(0).permute(1, 2, 0).numpy()
output_image = output.squeeze(0).permute(1, 2, 0).detach().numpy()
output_image = output_image / output_image.max()
output_image = np.clip(output_image, 0, 1)

plt.subplot(1, 2, 1)
plt.imshow(input_image)
plt.title('Input Image')

plt.subplot(1, 2, 2)
plt.imshow(output_image)
plt.title('Output Image')

plt.show()

总结 

在论文当中,作者也做了大量的消融实验去证实这个下采样模块的有效性,建议大家去看看原著作,或许会有更多的收获。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/340578.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

CPMS靶场练习

关键:找到文件上传点,分析对方验证的手段 首先查看前端发现没有任何上传的位置,找到网站的后台,通过弱口令admin 123456可以进入 通过查看网站内容发现只有文章列表可以进行文件上传;有两个图片上传点 图片验证很严格…

《WebKit 技术内幕》学习之六(2): CSS解释器和样式布局

2 CSS解释器和规则匹配 在了解了CSS的基本概念之后,下面来理解WebKit如何来解释CSS代码并选择相应的规则。通过介绍WebKit的主要设施帮助理解WebKit的内部工作原理和机制。 2.1 样式的WebKit表示类 在DOM树中,CSS样式可以包含在“style”元素中或者使…

SpringBoot异常处理和单元测试

学习目标 Spring Boot 异常处理Spring Boot 单元测试 1.SpringBoot异常处理 1.1.自定义错误页面 SpringBoot默认的处理异常的机制:SpringBoot 默认的已经提供了一套处理异常的机制。一旦程序中出现了异常 SpringBoot 会向/error 的 url 发送请求。在 springBoot…

移动开发行业——鸿蒙OS NEXT开出繁花

1月18日,华为宣布HarmonyOS NEXT开发者预览版开放申请,根据官方注解,这个版本的鸿蒙系统有个更通俗易懂的名字——“星河版”,也被称为“纯血”鸿蒙。 根据官方解释,之所以取名星河版,寓意鸿蒙OS NEXT就像…

28、web攻防——通用漏洞SQL注入HTTP头XFFCOOKIEPOST请求

文章目录 $_GET:接收get请求,传输少量数据,URL是有长度限制的; $_POST:接收post请求; $_COOKIE:接收cookie,用于身份验证; $_REQUEST:收集通过 GET 、POST和C…

线性代数:矩阵运算(加减、数乘、乘法、幂、除、转置)

目录 加减 数乘 矩阵与矩阵相乘 矩阵的幂 矩阵转置 方阵的行列式 方阵的行列式,证明:|AB| |A| |B| 加减 数乘 矩阵与矩阵相乘 矩阵的幂 矩阵转置 方阵的行列式 方阵的行列式,证明:|AB| |A| |B|

JVM的组成部分(类加载器、运行时数据区、执行引擎、本地库接口)

目录 JVM作用 JVM构成 1.类加载器 类加载子系统: 类加载器的分类: 双亲委派机制: 2.运行时数据区 程序计数器 虚拟机栈 本地方法栈 堆 方法区 3.执行引擎 4.本地库接口 JVM作用 jvm是将字节码文件加载到虚拟机中,…

特征融合篇 | YOLOv8 引入长颈特征融合网络 Giraffe FPN

在本报告中,我们介绍了一种名为DAMO-YOLO的快速而准确的目标检测方法,其性能优于现有的YOLO系列。DAMO-YOLO是在YOLO的基础上通过引入一些新技术而扩展的,这些技术包括神经架构搜索(NAS)、高效的重参数化广义FPN(RepGFPN)、带有AlignedOTA标签分配的轻量级头部以及蒸馏增…

RV1103与FPGA通过MIPI CSI-2实现视频传输,实现网络推流

RV1103与FPGA通过MIPI CSI-2实现视频传输,实现网络推流。 一:图像格式 支持图像格式如下: [0]: NV16 (Y/CbCr 4:2:2) Size: Stepwise 64x64 - 2304x1296 with step 8/8 [1]: NV61 (Y/CrCb 4:2:2) Size: Stepwise 64x64 - 2304x1296 with …

POI及EasyExcel学习笔记

POI及EasyExcel学习笔记 组件、工具 POI-Excel概述 Apache POI 结构: HSSF - 提供读写[Microsoft Excel](https://baike.baidu.com/item/Microsoft Excel)格式档案的功能。XSSF - 提供读写Microsoft Excel OOXML格式档案的功能。HWPF &am…

美团收银餐饮版培训教程

硬件连接方式及介绍: 双屏收银机 收银一体机 双屏收银机连接图 收银一体机连接图 前台打印机 后厨打印机 标签打印机 前台打印机连接图 后厨打印机连接图 其它收银机配件 软件前期设置 1、机器联网 点开桌面的设置,点击更多,点击以太网,最上…

Linux与windows互相传输文件之rzsz命令

文章目录 关于rzsz安装软件使用命令方法一:直接拖拽方法二:直接在终端输入rz 关于rzsz 这个工具用于 windows 机器和远端的 Linux 机器通过 XShell 传输文件 安装完毕之后可以通过拖拽的方式将文件上传过去 首先看一下我们的机器可以使用网络吗&#xff…

QT upd测试

QT upd测试 本次测试将服务器和客户端写在了一个工程下&#xff0c;代码如下 widget.h #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include<QUdpSocket> #include<QTimer>QT_BEGIN_NAMESPACE namespace Ui { class Widget; } QT_END_NAMESPACE…

C++的流库

1.流的概念 “流”&#xff0c;即“流动”的意思&#xff0c;是物质从一处向另一处流动的过程。在计算机这边通常是指对一种有序连续且具有方向性的数据的抽象描述。 C 中的流一般指两个过程的统一&#xff1a; 信息从外部输入设备&#xff08;键盘&#xff09;向计算机内部…

网页设计-用户体验

Use Cases (用例) 用例是用户如何在网站上执行任务的书面描述&#xff0c;从用户的角度描述了系统响应请求时的行为。每个用例都是用户实现目标的一系列简单的步骤。简言之&#xff0c;用例是一种用于描述系统如何满足用户需求的方法。 用例的好处 1. 明确需求&#xff1a; Use…

【测试入门】测试用例经典设计方法 —— 因果图法

01、因果图设计测试用例的步骤 1、分析需求 阅读需求文档&#xff0c;如果User Case很复杂&#xff0c;尽量将它分解成若干个简单的部分。这样做的好处是&#xff0c;不必在一次处理过程中考虑所有的原因。没有固定的流程说明究竟分解到何种程度才算简单&#xff0c;需要测试…

Flutter 滚动布局:sliver模型

一、滚动布局 Flutter中可滚动布局基本都来自Sliver模型&#xff0c;原理和安卓传统UI的ListView、RecyclerView类似&#xff0c;滚动布局里面的每个子组件的样式往往是相同的&#xff0c;由于组件占用内存较大&#xff0c;所以在内存上我们可以缓存有限个组件&#xff0c;滚动…

每天五分钟计算机视觉:掌握迁移学习使用技巧

本文重点 随着深度学习的发展,迁移学习已成为一种流行的机器学习方法,它能够将预训练模型应用于各种任务,从而实现快速模型训练和优化。然而,要想充分利用迁移学习的优势,我们需要掌握一些关键技巧。本文将介绍这些技巧,帮助您更好地应用迁移学习技术。 迁移学习的关键…

HCIP-BGP选路实验

一.实验拓扑图 二.详细配置 R1 interface GigabitEthernet0/0/0 ip address 12.1.1.1 255.255.255.0interface LoopBack0 ip address 1.1.1.1 255.255.255.0interface LoopBack1 ip address 10.1.1.1 255.255.255.0bgp 1 router-id 1.1.1.1 peer 12.1.1.2 as-number 2ipv4-fa…

Unity New Input System 及其系统结构和源码浅析【Unity学习笔记·第十二】

转载请注明出处&#xff1a;&#x1f517;https://blog.csdn.net/weixin_44013533/article/details/132534422 作者&#xff1a;CSDN|Ringleader| 主要参考&#xff1a; 官方文档&#xff1a;Unity官方Input System手册与API官方测试用例&#xff1a;Unity-Technologies/InputS…
最新文章