Haar小波下采样模块

论文原址:Haar wavelet downsampling: A simple but effective downsampling module for semantic segmentation - ScienceDirect

原文代码:HWD/HWD.py at main · apple1986/HWD (github.com)

介绍 

深度卷积神经网络 (DCNN) 通常采用标准的下采样操作,例如最大池化、平均池化和跨步卷积,这可能会导致信息丢失。丢失的信息,如边界和纹理,对于语义分割可能是必不可少的。为了缓解这个问题,一般有下面四种方法:

  1. 通过跳过连接到解码器子网(如U-Net、LCU-Net、CENet、LinkNet和RefineNet )。
  2. 提取具有空间金字塔池化或扩展卷积的多尺度特征图到融合模块中(如DeepLab、PSPNet、PCPLP-Net、BiSenet和ICNet)。
  3. 向编码器提供多模态图像(如DiSegNet、MMADT、CANet和CCFFNet)。
  4. 增加先验信息。轮廓增强关注模块,旨在从CT图像中提取边界和形状线索,以细化分割区域。

这些方法的主要目的是通过基于多尺度、先验指导、多模态等各种策略提供更多的学习信息或特征,帮助下采样特征与分割标签之间建立良好的关系。

因此,是否可以设计一个保留信息的下采样模块,使DCNNs中尽可能多地保留信息进行语义分割?这就是作者的想法。 

下采样模块

最大池化与平均池化

池化过程类似于卷积过程。在这个示意图中,我们看到对一个 4x4 的特征图邻域进行操作,使用了一个 2x2 的滤波器,步长为2进行扫描。这个过程被称为最大池化(Max Pooling),其中选择邻域内的最大值并输出到下一层。

常用的 max pooling 参数是 S=2、f=2,其效果是将特征图的高度和宽度减半,而通道数保持不变。

如上图所示,描述的是对一个 4x4 的特征图邻域内的数值进行操作。使用了一个 2x2 的滤波器,步长为2进行扫描,计算邻域内数值的平均值并将其输出到下一层。这种操作被称为平均池化(Mean Pooling)。

"""
Copyright (c) 2023, Auorui.
All rights reserved.The Torch implementation of average pooling and maximum pooling has been compared with the official Torch implementation
"""
import torch
import torch.nn as nn__all__ = ["MaxPool2d", "AvgPool2d"]class MaxPool2d(nn.Module):"""池化层计算公式:output_size = [(input_size−kernel_size) // stride + 1]"""def __init__(self, kernel_size, stride):super(MaxPool2d, self).__init__()self.kernel_size = kernel_sizeself.stride = stridedef max_pool2d(self, input_tensor, kernel_size, stride):batch_size, channels, height, width = input_tensor.size()output_height = (height - kernel_size) // stride + 1output_width = (width - kernel_size) // stride + 1output_tensor = torch.zeros(batch_size, channels, output_height, output_width)for i in range(output_height):for j in range(output_width):# 获取输入张量中与池化窗口对应的部分window = input_tensor[:, :,i * stride: i * stride + kernel_size, j * stride: j * stride + kernel_size]output_tensor[:, :, i, j] = torch.max(window.reshape(batch_size, channels, -1), dim=2)[0]return output_tensordef forward(self, input_tensor):return self.max_pool2d(input_tensor, kernel_size=self.kernel_size, stride=self.stride)class AvgPool2d(nn.Module):"""池化层计算公式:output_size = [(input_size−kernel_size) // stride + 1]"""def __init__(self, kernel_size, stride):super(AvgPool2d, self).__init__()self.kernel_size = kernel_sizeself.stride = stridedef avg_pool2d(self, input_tensor, kernel_size, stride):batch_size, channels, height, width = input_tensor.size()output_height = (height - kernel_size) // stride + 1output_width = (width - kernel_size) // stride + 1output_tensor = torch.zeros(batch_size, channels, output_height, output_width)for i in range(output_height):for j in range(output_width):# 获取输入张量中与池化窗口对应的部分window = input_tensor[:, :,i * stride: i * stride + kernel_size, j * stride:j * stride + kernel_size]output_tensor[:, :, i, j] = torch.mean(window.reshape(batch_size, channels, -1), dim=2)return output_tensordef forward(self, input_tensor):return self.avg_pool2d(input_tensor, kernel_size=self.kernel_size, stride=self.stride)if __name__=="__main__":# input_data = torch.rand((1, 3, 3, 3))input_data = torch.Tensor([[[[0.3939, 0.8964, 0.3681],[0.5134, 0.3780, 0.0047],[0.0681, 0.0989, 0.5962]],[[0.7954, 0.4811, 0.3329],[0.8804, 0.3986, 0.3561],[0.2797, 0.3672, 0.6508]],[[0.6309, 0.1340, 0.0564],[0.3101, 0.9927, 0.5554],[0.0947, 0.2305, 0.8299]]]])print(input_data.shape)kernel_size = 3stride = 1MaxPool2d1 = nn.MaxPool2d(kernel_size, stride)output_data_with_torch_max = MaxPool2d1(input_data)AvgPool2d1 = nn.AvgPool2d(kernel_size, stride)output_data_with_torch_avg = AvgPool2d1(input_data)AvgPool2d2 = AvgPool2d(kernel_size, stride)output_data_with_torch_Avg = AvgPool2d2(input_data)MaxPool2d2 = MaxPool2d(kernel_size, stride)output_data_with_torch_Max = MaxPool2d2(input_data)# output_data_with_max = max_pool2d(input_data, kernel_size, stride)# output_data_with_avg = avg_pool2d(input_data, kernel_size, stride)print("\ntorch.nn pooling Output:")print(output_data_with_torch_max,"\n",output_data_with_torch_max.size())print(output_data_with_torch_avg,"\n",output_data_with_torch_avg.size())print("\npooling Output:")print(output_data_with_torch_Max,"\n",output_data_with_torch_Max.size())print(output_data_with_torch_Avg,"\n",output_data_with_torch_Avg.size())# 直接使用bool方法判断会因为浮点数的原因出现偏差print(torch.allclose(output_data_with_torch_max,output_data_with_torch_Max))print(torch.allclose(output_data_with_torch_avg,output_data_with_torch_Avg))# tensor([[[[0.8964]],       # output_data_with_max#          [[0.8804]],#          [[0.9927]]]])# tensor([[[[0.3686]],       # output_data_with_avg#           [[0.5047]],#           [[0.4261]]]])

在这里,简单地与PyTorch官方的实现进行了比对,成功的进行复现。

跨步卷积

import torch
import torch.nn as nnclass StridedConvolution(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, is_relu=True):super(StridedConvolution, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1)self.relu = nn.ReLU(inplace=True)self.is_relu = is_reludef forward(self, x):x = self.conv(x)if self.is_relu:x = self.relu(x)return xif __name__ == '__main__':input_data = torch.rand((1, 3, 64, 64))strided_conv = StridedConvolution(3, 64)output_data = strided_conv(input_data)print("Input shape:", input_data.shape)print("Output shape:", output_data.shape)

对输入进行跨步卷积,并根据 is_relu 参数选择是否添加ReLU激活函数。在构建卷积神经网络时经常被用于下采样步骤,以减小特征图的尺寸。

Haar小波下采样

这一部分就直接参考的作者的代码,与池化不同的是,这里它是要指定输入输出几个通道。

"""
Haar Wavelet-based Downsampling (HWD)Original address of the paper: https://www.sciencedirect.com/science/article/abs/pii/S0031320323005174
Code reference: https://github.com/apple1986/HWD/tree/main
"""
import torch
import torch.nn as nn
from pytorch_wavelets import DWTForwardclass HWDownsampling(nn.Module):def __init__(self, in_channel, out_channel):super(HWDownsampling, self).__init__()self.wt = DWTForward(J=1, wave='haar', mode='zero')self.conv_bn_relu = nn.Sequential(nn.Conv2d(in_channel * 4, out_channel, kernel_size=1, stride=1),nn.BatchNorm2d(out_channel),nn.ReLU(inplace=True),)def forward(self, x):yL, yH = self.wt(x)y_HL = yH[0][:, :, 0, ::]y_LH = yH[0][:, :, 1, ::]y_HH = yH[0][:, :, 2, ::]x = torch.cat([yL, y_HL, y_LH, y_HH], dim=1)x = self.conv_bn_relu(x)return xif __name__ == '__main__':downsampling_layer = HWDownsampling(3, 64)input_data = torch.rand((1, 3, 64, 64))output_data = downsampling_layer(input_data)print("Input shape:", input_data.shape)print("Output shape:", output_data.shape)

Haar小波变换是一种基于小波的信号处理方法,它将信号分解成低频和细节高频两个部分。在图像处理中,Haar小波通常用于图像压缩和特征提取,代码中使用的DWTForward模块中离散小波变换,通过选择 yH 中的不同方向上的高频分量,构建了新的特征图。将原始低频分量 yL 与新构建的高频分量拼接在一起。最后通过一个包含卷积、批归一化和ReLU激活函数的序列处理最终的特征图。

实验验证

这是作者论文中做的实验,这样看起来,似乎HWD在细节上确实是比池化和跨步卷积效果要好。

这里因为我也用我自己的数据进行了实验:

最大池化效果

平均池化效果

跨步卷积效果 

HDW效果

从肉眼上来看,HDW的效果确实要比其他的效果要好一些。

下面是我做实验的代码,感兴趣的可以在自己的数据上面进行实验,我觉得用于交通和医学上应该会有比较好的效果。

import cv2
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn as nn
from pytorch_wavelets import DWTForwardclass StridedConvolution(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, is_relu=True):super(StridedConvolution, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1)self.relu = nn.ReLU(inplace=True)self.is_relu = is_reludef forward(self, x):x = self.conv(x)if self.is_relu:x = self.relu(x)return xclass HWDownsampling(nn.Module):def __init__(self, in_channel, out_channel):super(HWDownsampling, self).__init__()self.wt = DWTForward(J=1, wave='haar', mode='zero')self.conv_bn_relu = nn.Sequential(nn.Conv2d(in_channel * 4, out_channel, kernel_size=1, stride=1),nn.BatchNorm2d(out_channel),nn.ReLU(inplace=True),)def forward(self, x):yL, yH = self.wt(x)y_HL = yH[0][:, :, 0, ::]y_LH = yH[0][:, :, 1, ::]y_HH = yH[0][:, :, 2, ::]x = torch.cat([yL, y_HL, y_LH, y_HH], dim=1)x = self.conv_bn_relu(x)return xclass DeeperCNN(nn.Module):def __init__(self):super(DeeperCNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)self.batch_norm1 = nn.BatchNorm2d(16)self.relu = nn.ReLU()self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)# self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)# self.pool1 = HWDownsampling(16, 16)self.pool1 = StridedConvolution(16, 16, is_relu=True)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)self.batch_norm2 = nn.BatchNorm2d(32)# self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)# self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)# self.pool2 = HWDownsampling(32, 32)self.pool2 = StridedConvolution(32, 32, is_relu=True)self.conv6 = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1)def forward(self, x):x = self.pool1(self.relu(self.batch_norm1(self.conv1(x))))print(x.shape)x = self.pool2(self.relu(self.batch_norm2(self.conv2(x))))print(x.shape)x = self.conv6(x)return ximage_path = r'D:\PythonProject\Crack_classification_training_script\data\base\val\crack\2416.png'
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)transform = transforms.Compose([transforms.ToTensor()])
input_image = transform(image).unsqueeze(0)
import numpy as np
model = DeeperCNN()
output = model(input_image)
print("Output shape:", output.shape)input_image = input_image.squeeze(0).permute(1, 2, 0).numpy()
output_image = output.squeeze(0).permute(1, 2, 0).detach().numpy()
output_image = output_image / output_image.max()
output_image = np.clip(output_image, 0, 1)plt.subplot(1, 2, 1)
plt.imshow(input_image)
plt.title('Input Image')plt.subplot(1, 2, 2)
plt.imshow(output_image)
plt.title('Output Image')plt.show()

总结 

在论文当中,作者也做了大量的消融实验去证实这个下采样模块的有效性,建议大家去看看原著作,或许会有更多的收获。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/640285.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

springboot中一些注解

springboot中一些注解 1:项目启动时会去扫描启动的注解,一般是启动时就想要被加载的方法: 2:springBoot中MSApplication启动类的一些其他注解: EnableAsync:这是一个Spring框架的注解,它用于开启方法异步调用的功能。当…

2017年认证杯SPSSPRO杯数学建模B题(第一阶段)岁月的印记全过程文档及程序

2017年认证杯SPSSPRO杯数学建模 跨年龄人脸识别模型的建立与分析 B题 岁月的印记 原题再现: 对同一个人来说,如果没有过改变面容的疾病、面部外伤或外科手术等经历,年轻和年老时的面容总有很大的相似性。人们在生活中也往往能够分辨出来两…

2.【SpringBoot3】用户模块接口开发

文章目录 开发模式和环境搭建开发模式环境搭建 1. 用户注册1.1 注册接口基本代码编写1.2 注册接口参数校验 2. 用户登录2.1 登录接口基本代码编写2.2 登录认证2.2.1 登录认证引入2.2.2 JWT 简介2.2.3 登录功能集成 JWT2.2.4 拦截器 3. 获取用户详细信息3.1 获取用户详细信息基本…

一周时间,开发了一款封面图生成工具

介绍 这是一款封面图的制作工具,根据简单的配置即可生成一张好看的封面图,目前已有七款主题可以选择。做这个工具的初衷来自平时写文章,都为封面图发愁,去图片 网站上搜索很难找到满意的,而且当你要的图如果要搭配上文…

【JavaEE进阶】 关于⽇志框架(SLF4J)

文章目录 🌳SLF4j🌲⻔⾯模式(外观模式)🚩⻔⾯模式的定义🚩⻔⾯模式的优点 🍃关于SLF4J框架🚩不引⼊⽇志⻔⾯🚩引⼊⽇志⻔⾯ ⭕总结 🌳SLF4j SLF4J不同于其他⽇志框架,它不是⼀个真正…

构建高效外卖系统:技术实践与代码示例

外卖系统在现代社会中扮演着重要的角色,为用户提供了便捷的用餐解决方案。在这篇文章中,我们将探讨构建高效外卖系统的技术实践,同时提供一些基础的代码示例,帮助开发者更好地理解和应用这些技术。 1. 技术栈选择 构建外卖系统…

BP蓝图映射到C++笔记1

教程链接:示例1:CompleteQuest - 将蓝图转换为C (epicgames.com) 1.常用的引用需要记住,如图所示。 2.蓝图中可以调用C函数,也可以实现C函数 BlueprintImplementableEvent:C只创建,不实现,在蓝图中实现 B…

C++提高编程---模板---类模板

目录 一、类模板 1.模板 2.类模板的作用 3.语法 4.声明 二、类模板和函数模板的区别 三、类模板中成员函数的创建时机 四、类模板对象做函数参数 五、类模板与继承 六、类模板成员函数类外实现 七、类模板分文件编写 八、类模板与友元 九、类模板案例 一、类模板 …

软件测试的需求人才越来越多,为什么大家还是不太愿意走软件测试的道路?

🔥 交流讨论:欢迎加入我们一起学习! 🔥 资源分享:耗时200小时精选的「软件测试」资料包 🔥 教程推荐:火遍全网的《软件测试》教程 📢欢迎点赞 👍 收藏 ⭐留言 &#x1…

【动态规划】【C++算法】801. 使序列递增的最小交换次数

作者推荐 【动态规划】【广度优先搜索】【状态压缩】847 访问所有节点的最短路径 本文涉及知识点 动态规划汇总 数组 LeetCode801使序列递增的最小交换次数 我们有两个长度相等且不为空的整型数组 nums1 和 nums2 。在一次操作中,我们可以交换 nums1[i] 和 num…

路飞项目--03

二次封装Response模块 # drf提供的Response,前端想接收到的格式 {code:xx,msg:xx} 后端返回,前端收到: APIResponse(tokneasdfa.asdfas.asdf)---->{code:100,msg:成功,token:asdfa.asdfas.asdf} APIResponse(code101,msg用户不存在) ---…

学习笔记-李沐动手学深度学习(一)(01-07,概述、数据操作、tensor操作、数学基础、自动求导)

个人随笔 第三列是 jupyter记事本 官方github上啥都有(代码、jupyter记事本、胶片) https://github.com/d2l-ai 多体会 【梯度指向的是值变化最大的方向】 符号 维度 (弹幕说)2,3,4越后面维度越低 4…

Java 面向对象案例 02 (黑马)

代码: public class foodTest {public static void main(String[] args) {//1、构建一个数组food[] arr new food[3];//2、创建三个商品对象food f1 new food("apple","123",3.2,500);food f2 new food("pear","456",4…

临时工说:AI 人工智能化对于DBA 的工作的影响

这开头还是介绍一下群,如果感兴趣PolarDB ,MongoDB ,MySQL ,PostgreSQL ,Redis, Oceanbase, Sql Server等有问题,有需求都可以加群群内,可以解决你的问题。加群请联系 liuaustin3 ,(共1900人左右 1 2 3 4 5&#xf…

ChatGPT:关于 OpenAI 的 GPT-4工具,你需要知道的一切

ChatGPT:关于 OpenAI 的 GPT-4工具,你需要知道的一切 什么是GPT-3、GPT-4 和 ChatGPT?ChatGPT 可以做什么?ChatGPT-4 可以做什么?ChatGPT 的费用是多少?GPT-4 与 GPT-3.5 有何不同?ChatGPT 如何…

开源堡垒机JumpServer本地安装并配置公网访问地址

文章目录 前言1. 安装Jump server2. 本地访问jump server3. 安装 cpolar内网穿透软件4. 配置Jump server公网访问地址5. 公网远程访问Jump server6. 固定Jump server公网地址 前言 JumpServer 是广受欢迎的开源堡垒机,是符合 4A 规范的专业运维安全审计系统。JumpS…

ONLYOFFICE服务器无法连接,请联系管理员问题解决

1、现象 部署好了nextcloud和onlyoffice后,新建文本文档报错ONLYOFFICE服务器无法连接,请联系管理员。 用快捷键“F12”进入控制台,点开错误提示栏,找到有“api.js“文件,“https://ONLYOFFICED的地址/web-apps/apps/…

书法AI全自动切字+识别算法2.0版发布,草书篆书行书楷书识别准确率超过90%,覆盖书法单字30万张

我们开发的业界识别最准覆盖作品最全的书法AI小程序上线了 书法AI全自动切字识别算法2.0版发布,草书篆书行书楷书识别准确率超过90%,准确率甩百度OCR一条街,覆盖书法单字30万张,遥遥领先同行 我们还可为客户提供书法AI全自动切字a…

借助文档控件Aspose.Words,将 Word DOC/DOCX 转换为 TXT

在文档处理领域,经常需要将 Word 文档转换为更简单的纯文本格式。无论是出于数据提取、内容分析还是兼容性原因,将 Word(.doc、.docx)文件转换为纯文本(.txt)的能力对于开发人员来说都是一项宝贵的技能。在…

87230系列USB连续波功率探头

01 87230 USB连续波功率探头 产品综述: 87230/87231/87232/87233系列USB功率探头是一款基于USB2.0全速/高速自适应接口的二极管检波式功率探头,内部采用高性能处理芯片,通过各种校准和补偿技术,使得探头具有频率范围宽、功率动…