【Block总结】MAB,多尺度注意力块|即插即用

文章目录

  • 一、论文信息
  • 二、创新点
  • 三、方法
    • MAB模块解读
      • 1、MAB模块概述
      • 2、MAB模块组成
      • 3、MAB模块的优势
  • 四、效果
  • 五、实验结果
  • 六、总结
  • 代码

一、论文信息

  • 标题: Multi-scale Attention Network for Single Image Super-Resolution
  • 作者: Yan Wang, Yusen Li, Gang Wang, Xiaoguang Liu
  • 机构: 南开大学
  • 发表会议: CVPR 2024 Workshops
  • 论文链接: arXiv
  • GitHub代码库: icandle/MAN
    在这里插入图片描述

二、创新点

  • 多尺度大核注意力(MLKA): 结合了多尺度机制与大核卷积,能够有效捕捉不同尺度的信息,避免了传统方法中常见的“块状伪影”问题。

  • 门控空间注意力单元(GSAU): 通过引入门控机制,优化了空间注意力的计算,去除了不必要的线性层,从而提高了信息聚合的效率和准确性。

  • 灵活的网络结构: 通过堆叠不同数量的MLKA和GSAU模块,构建出多种复杂度的网络,以实现性能与计算量之间的平衡。

三、方法

  1. 网络架构: Multi-scale Attention Network (MAN)由三个主要模块组成:

    • 浅层特征提取模块(SF): 负责初步的特征提取。
    • 深层特征提取模块(DF): 基于多个多尺度注意力块(MAB),进一步提取丰富的特征。
    • 高质量图像重建模块: 将提取的特征用于最终的图像重建。
  2. 多尺度注意力块(MAB):

    • MLKA模块: 结合大核注意力、多个尺度机制和门控聚合,建立不同尺度之间的相关性。
    • GSAU模块: 整合空间注意力和门控机制,简化前馈网络。
  3. MLKA的功能:

    • 大核注意力: 通过分解卷积建立长距离关系。
    • 多尺度机制: 增强固定LKA以学习全尺度信息的注意力图。
    • 门控聚合: 动态调整注意力图以避免伪影。

MAB模块解读

在这里插入图片描述

1、MAB模块概述

MAB(Multi-scale Attention Block)是Multi-scale Attention Network (MAN)中的核心组件,旨在通过结合多尺度大核注意力(MLKA)和门控空间注意力单元(GSAU)来提升图像超分辨率的性能。MAB模块的设计旨在有效捕捉图像中的局部和全局特征,同时避免传统卷积网络中常见的“块状伪影”问题。

2、MAB模块组成

MAB模块主要由以下两个部分构成:

  1. 多尺度大核注意力(MLKA):

    • 功能: MLKA通过引入多尺度机制,结合大核卷积,能够在不同尺度上提取丰富的特征信息。
    • 结构:
      • 首先,MLKA使用点卷积(Point-wise convolution)调整通道数。
      • 然后,将特征分成三组,每组使用不同尺寸的大核卷积(如7×7、21×21、35×35)进行处理,膨胀率分别设置为(2,3,4)。
      • 为了避免膨胀卷积带来的“块状伪影”,MLKA引入了门控聚合机制,通过逐元素乘法将深度卷积的输出与对应组的LKA输出相结合,从而动态调整注意力图的输出。
  2. 门控空间注意力单元(GSAU):

    • 功能: GSAU旨在增强特征表示能力,通过结合空间注意力和门控机制,优化信息聚合过程。
    • 结构:
      • GSAU通常由两个分支组成,其中一个分支使用深度卷积对特征进行加权,另一个分支则通过空间自注意力机制捕捉空间上下文信息。
      • 这种设计减少了不必要的线性层,降低了计算复杂度,同时增强了特征的表达能力。

3、MAB模块的优势

  • 多尺度特征提取: 通过MLKA,MAB能够在多个尺度上提取特征,增强了模型对不同图像细节的敏感性。

  • 减少伪影: 通过门控聚合机制,MAB有效地减少了由于膨胀卷积引起的块状伪影,提升了图像重建的质量。

  • 高效的特征表示: GSAU的引入使得模型能够更好地聚合空间信息,提升了特征的表达能力,进而提高了超分辨率的效果。

四、效果

  • 性能提升: 实验结果表明,MAN在多个超分辨率基准测试中表现优异,能够与当前最先进的模型(如SwinIR)相媲美,同时在计算效率上也有显著改善。

  • 避免伪影: 通过MLKA和GSAU的结合,模型有效减少了图像重建中的伪影现象,提升了视觉效果。

五、实验结果

  • 基准测试: 论文中使用了多个数据集(如Set5、Set14、BSD100、Urban100、Manga109)进行测试,结果显示MAN在PSNR和SSIM指标上均优于传统的超分辨率模型,尤其是在高倍数放大(如×4)时表现突出。
数据集上采样因子MAN PSNR (dB)与SwinIR比较
Set5×238.42相当
×334.91略低
×432.87良好
Set14×234.44相近
×330.92略低
×429.09良好
BSD100×232.53相近
×329.65略低
×427.90良好
Urban100×233.80相近
×334.45略低
×433.73良好
Manga109×240.02相近
×335.21略低
×431.22良好

六、总结

Multi-scale Attention Network (MAN)通过结合多尺度大核注意力和门控空间注意力机制,成功提升了单幅图像超分辨率重建的性能和效率。该研究不仅解决了传统方法中的一些局限性,还为未来的超分辨率模型设计提供了新的思路。MAN在多个基准测试中的优异表现,证明了其在实际应用中的潜力,尤其是在需要高质量图像重建的场景中。

代码

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as Fclass LayerNorm(nn.Module):r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.The ordering of the dimensions in the inputs. channels_last corresponds to inputs withshape (batch_size, height, width, channels) while channels_first corresponds to inputswith shape (batch_size, channels, height, width)."""def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):super().__init__()self.weight = nn.Parameter(torch.ones(normalized_shape))self.bias = nn.Parameter(torch.zeros(normalized_shape))self.eps = epsself.data_format = data_formatif self.data_format not in ["channels_last", "channels_first"]:raise NotImplementedErrorself.normalized_shape = (normalized_shape,)def forward(self, x):if self.data_format == "channels_last":return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)elif self.data_format == "channels_first":u = x.mean(1, keepdim=True)s = (x - u).pow(2).mean(1, keepdim=True)x = (x - u) / torch.sqrt(s + self.eps)x = self.weight[:, None, None] * x + self.bias[:, None, None]return xclass SGAB(nn.Module):def __init__(self, n_feats, drop=0.0, k=2, squeeze_factor=15, attn='GLKA'):super().__init__()i_feats = n_feats * 2self.Conv1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0)self.DWConv1 = nn.Conv2d(n_feats, n_feats, 7, 1, 7 // 2, groups=n_feats)self.Conv2 = nn.Conv2d(n_feats, n_feats, 1, 1, 0)self.norm = LayerNorm(n_feats, data_format='channels_first')self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)def forward(self, x):shortcut = x.clone()# Ghost Expandx = self.Conv1(self.norm(x))a, x = torch.chunk(x, 2, dim=1)x = x * self.DWConv1(a)x = self.Conv2(x)return x * self.scale + shortcutclass GroupGLKA(nn.Module):def __init__(self, n_feats, k=2, squeeze_factor=15):super().__init__()i_feats = 2 * n_featsself.n_feats = n_featsself.i_feats = i_featsself.norm = LayerNorm(n_feats, data_format='channels_first')self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)# Multiscale Large Kernel Attentionself.LKA7 = nn.Sequential(nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3),nn.Conv2d(n_feats // 3, n_feats // 3, 9, stride=1, padding=(9 // 2) * 4, groups=n_feats // 3, dilation=4),nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))self.LKA5 = nn.Sequential(nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3),nn.Conv2d(n_feats // 3, n_feats // 3, 7, stride=1, padding=(7 // 2) * 3, groups=n_feats // 3, dilation=3),nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))self.LKA3 = nn.Sequential(nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3),nn.Conv2d(n_feats // 3, n_feats // 3, 5, stride=1, padding=(5 // 2) * 2, groups=n_feats // 3, dilation=2),nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))self.X3 = nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3)self.X5 = nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3)self.X7 = nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3)self.proj_first = nn.Sequential(nn.Conv2d(n_feats, i_feats, 1, 1, 0))self.proj_last = nn.Sequential(nn.Conv2d(n_feats, n_feats, 1, 1, 0))def forward(self, x, pre_attn=None, RAA=None):shortcut = x.clone()x = self.norm(x)x = self.proj_first(x)a, x = torch.chunk(x, 2, dim=1)a_1, a_2, a_3 = torch.chunk(a, 3, dim=1)a = torch.cat([self.LKA3(a_1) * self.X3(a_1), self.LKA5(a_2) * self.X5(a_2), self.LKA7(a_3) * self.X7(a_3)],dim=1)x = self.proj_last(x * a) * self.scale + shortcutreturn x# MABclass MAB(nn.Module):def __init__(self, n_feats):super().__init__()self.LKA = GroupGLKA(n_feats)self.LFE = SGAB(n_feats)def forward(self, x, pre_attn=None, RAA=None):# large kernel attentionx = self.LKA(x)# local feature extractionx = self.LFE(x)return xif __name__ == "__main__":dim=96 # 通道要被3整除# 如果GPU可用,将模块移动到 GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 输入张量 (batch_size, channels,height, width)x = torch.randn(2,dim,40,40).to(device)# 初始化 MAB模块block = MAB(dim)print(block)block = block.to(device)# 前向传播output = block(x)print("输入:", x.shape)print("输出:", output.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/68808.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【深度学习】DeepSeek模型介绍与部署

原文链接:DeepSeek-V3 1. 介绍 DeepSeek-V3,一个强大的混合专家 (MoE) 语言模型,拥有 671B 总参数,其中每个 token 激活 37B 参数。 为了实现高效推理和成本效益的训练,DeepSeek-V3 采用了多头潜在注意力 (MLA) 和 De…

深度学习深度解析:从基础到前沿

引言 深度学习作为人工智能的一个重要分支,通过模拟人脑的神经网络结构来进行数据分析和模式识别。它在图像识别、自然语言处理、语音识别等领域取得了显著成果。本文将深入探讨深度学习的基础知识、主要模型架构以及当前的研究热点和发展趋势。 基础概念与数学原理…

如何实现滑动列表功能

文章目录 1 概念介绍2 使用方法3 示例代码 我们在上一章回中介绍了沉浸式状态栏相关的内容,本章回中将介绍SliverList组件.闲话休提,让我们一起Talk Flutter吧。 1 概念介绍 我们在这里介绍的SliverList组件是一种列表类组件,类似我们之前介…

OpenEuler学习笔记(十七):OpenEuler搭建Redis高可用生产环境

在OpenEuler上搭建Redis高可用生产环境,通常可以采用Redis Sentinel或Redis Cluster两种方式,以下分别介绍两种方式的搭建步骤: 基于Redis Sentinel的高可用环境搭建 安装Redis 配置软件源:可以使用OpenEuler的默认软件源&#…

前沿课题推荐:提升水下导航精度的多源数据融合与算法研究

随着海洋探测技术的迅猛发展,水下地形匹配导航逐渐成为国际研究的热点领域。在全球范围内,水下导航技术的精确性对于科学探索、资源勘探及国防安全等方面都至关重要。我国在这一领域的研究与应用需求日益增长,亟需通过先进的技术手段提升水下…

浅析CDN安全策略防范

CDN(内容分发网络)信息安全策略是保障内容分发网络在提供高效服务的同时,确保数据传输安全、防止恶意攻击和保护用户隐私的重要手段。以下从多个方面详细介绍CDN的信息安全策略: 1. 数据加密 数据加密是CDN信息安全策略的核心之…

three.js+WebGL踩坑经验合集(6.1):负缩放,负定矩阵和行列式的关系(2D版本)

春节忙完一轮,总算可以继续来写博客了。希望在春节假期结束之前能多更新几篇。 这一篇会偏理论多一点。笔者本没打算在这一系列里面重点讲理论,所以像相机矩阵推导这种网上已经很多优质文章的内容,笔者就一笔带过。 然而关于负缩放&#xf…

HTB:Administrator[WriteUP]

目录 连接至HTB服务器并启动靶机 信息收集 使用rustscan对靶机TCP端口进行开放扫描 将靶机TCP开放端口号提取并保存 使用nmap对靶机TCP开放端口进行脚本、服务扫描 使用nmap对靶机TCP开放端口进行漏洞、系统扫描 使用nmap对靶机常用UDP端口进行开放扫描 使用nmap对靶机…

一文讲解JVM中的G1垃圾收集器

接上一篇博文,这篇博文讲下JVM中的G1垃圾收集器 G1在JDK1.7时引入,在JDK9时取代了CMS成为默认的垃圾收集器; G1把Java堆划分为多个大小相等的独立区域Region,每个区域都可以扮演新生代(Eden和Survivor)或老…

力扣第149场双周赛

文章目录 题目总览题目详解找到字符串中合法的相邻数字重新安排会议得到最多空余时间I3440.重新安排会议得到最多空余时间II 第149场双周赛 题目总览 找到字符串中合法的相邻数字 重新安排会议得到最多空余时间I 重新安排会议得到最多空余时间II 变成好标题的最少代价 题目…

25届 信息安全领域毕业设计选题88例:前沿课题

目录 前言 毕设选题 开题指导建议 更多精选选题 选题帮助 最后 前言 大家好,这里是海浪学长毕设专题! 大四是整个大学期间最忙碌的时光,一边要忙着准备考研、考公、考教资或者实习为毕业后面临的升学就业做准备,一边要为毕业设计耗费大量精力。学长给大家整理…

【算法设计与分析】实验6:n皇后问题的回溯法设计与求解

目录 一、实验目的 二、实验环境 三、实验内容 四、核心代码 五、记录与处理 六、思考与总结 七、完整报告和成果文件提取链接 一、实验目的 针对n皇后问题开展分析、建模、评价,算法设计与优化,并进行编码实践。 掌握回溯法求解问题的思想&#…

如何为用户设置密码

[rootxxx ~]# passwd aa #交互式的为用户设置密码 或者 [rootxxx ~]# echo 123 | passwd --stdin aa #不交互式的为用户设置密码 (适用于批量的为用户更改密码,比如一次性为100个用户初始化密码)

【Vaadin flow 实战】第5讲-使用常用UI组件绘制页面元素

vaadin flow官方提供的UI组件文档地址是 https://vaadin.com/docs/latest/components这里,我简单实战了官方提供的一些免费的UI组件,使用案例如下: Accordion 手风琴 Accordion 手风琴效果组件 Accordion 手风琴-测试案例代码 Slf4j PageT…

深入理解Java引用传递

先看一段代码: public static void add(String a) {a "new";System.out.println("add: " a); // 输出内容:add: new}public static void main(String[] args) {String a null;add(a);System.out.println("main: " a);…

Elasticsearch的开发工具(Dev Tools)

目录 说明1. **Console**2. **Search Profiler**3. **Grok Debugger**4. **Painless Lab**总结 说明 Elasticsearch的开发工具(Dev Tools)在Kibana中提供了多种功能强大的工具,用于调试、优化和测试Elasticsearch查询和脚本。以下是关于Cons…

【机器学习】自定义数据集 使用scikit-learn中svm的包实现svm分类

一、支持向量机(support vector machines. ,SVM)概念 1. SVM 绪论 支持向量机(SVM)的核心思想是找到一个最优的超平面,将不同类别的数据点分开。SVM 的关键特点包括: ① 分类与回归: SVM 可以用于分类&a…

C++并行化编程

C并行化编程 C 简介 C 是一种静态类型的、编译式的、通用的、大小写敏感的、不规则的编程语言,支持过程化编程、面向对象编程和泛型编程。 C 被认为是一种中级语言,它综合了高级语言和低级语言的特点。 C 是由 Bjarne Stroustrup 于 1979 年在新泽西州美…

记6(人工神经网络

目录 1、M-P神经元2、感知机3、Delta法则4、前馈型神经网络(Feedforward Neural Networks)5、鸢尾花数据集——单层前馈型神经网络:6、多层神经网络:增加隐含层7、实现异或运算(01、10为1,00、11为0)8、线性…

网工_HDLC协议

2025.01.25:网工老姜学习笔记 第9节 HDLC协议 9.1 HDLC高级数据链路控制9.2 HDLC帧格式(*控制字段)9.2.1 信息帧(承载用户数据,0开头)9.2.2 监督帧(帮助信息可靠传输,10开头&#xf…