【Block总结】掩码窗口自注意力 (M-WSA)

在这里插入图片描述

摘要

论文链接:https://arxiv.org/pdf/2404.07846
论文标题:Transformer-Based Blind-Spot Network for Self-Supervised Image Denoising
Masked Window-Based Self-Attention (M-WSA) 是一种新颖的自注意力机制,旨在解决传统自注意力方法在处理图像时的局限性,特别是在图像去噪和恢复任务中。M-WSA 通过引入掩码机制,确保在计算注意力时遵循盲点要求,从而避免信息泄露。

设计原理

  1. 窗口自注意力:M-WSA 基于窗口自注意力(Window Self-Attention, WSA)的概念,将输入图像划分为多个不重叠的窗口。在每个窗口内,计算自注意力以捕捉局部特征。这种方法的计算复杂度相对较低,适合处理高分辨率图像。

  2. 掩码机制:为了满足盲点要求,M-WSA 在计算注意力时应用了掩码。具体而言,掩码限制了每个像素只能关注其窗口内的特定像素,从而避免了对盲点信息的访问。这一设计确保了网络在去噪时不会泄露噪声信息。

  3. 扩张卷积模拟:M-WSA 的掩码设计模仿了扩张卷积的感受野,使得网络能够在保持计算效率的同时,捕捉到更大范围的上下文信息。这种方法有效地扩展了网络的感受野,增强了特征提取能力。
    在这里插入图片描述

优势

  • 高效性:通过限制注意力计算在窗口内,M-WSA 显著降低了计算复杂度,使其适用于大规模图像处理任务。

  • 信息保护:掩码机制确保了盲点信息不被泄露,从而提高了去噪效果,特别是在处理具有空间相关噪声的图像时。

  • 灵活性:M-WSA 可以与其他网络架构结合使用,增强其在各种视觉任务中的表现,尤其是在自我监督学习和图像恢复领域。

实验结果

在多个真实世界的图像去噪数据集上进行的实验表明,M-WSA 显著提高了去噪性能,超越了传统的卷积网络和其他自注意力机制。这一结果表明,M-WSA 在处理复杂噪声模式时具有良好的适应性和有效性。

代码

Masked Window-Based Self-Attention (M-WSA) 通过结合窗口自注意力和掩码机制,为图像去噪和恢复任务提供了一种有效的解决方案。其设计不仅提高了计算效率,还确保了信息的安全性,展示了在自我监督学习中的广泛应用潜力。代码:

import torch
import torch.nn as nn
from einops import rearrange
from torch import einsumdef to(x):return {'device': x.device, 'dtype': x.dtype}def expand_dim(t, dim, k):t = t.unsqueeze(dim=dim)expand_shape = [-1] * len(t.shape)expand_shape[dim] = kreturn t.expand(*expand_shape)def rel_to_abs(x):b, l, m = x.shaper = (m + 1) // 2col_pad = torch.zeros((b, l, 1), **to(x))x = torch.cat((x, col_pad), dim=2)flat_x = rearrange(x, 'b l c -> b (l c)')flat_pad = torch.zeros((b, m - l), **to(x))flat_x_padded = torch.cat((flat_x, flat_pad), dim=1)final_x = flat_x_padded.reshape(b, l + 1, m)final_x = final_x[:, :l, -r:]return final_xdef relative_logits_1d(q, rel_k):b, h, w, _ = q.shaper = (rel_k.shape[0] + 1) // 2logits = einsum('b x y d, r d -> b x y r', q, rel_k)logits = rearrange(logits, 'b x y r -> (b x) y r')logits = rel_to_abs(logits)logits = logits.reshape(b, h, w, r)logits = expand_dim(logits, dim=2, k=r)return logitsclass RelPosEmb(nn.Module):def __init__(self,block_size,rel_size,dim_head):super().__init__()height = width = rel_sizescale = dim_head ** -0.5self.block_size = block_sizeself.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)def forward(self, q):block = self.block_sizeq = rearrange(q, 'b (x y) c -> b x y c', x=block)rel_logits_w = relative_logits_1d(q, self.rel_width)rel_logits_w = rearrange(rel_logits_w, 'b x i y j-> b (x y) (i j)')q = rearrange(q, 'b x y d -> b y x d')rel_logits_h = relative_logits_1d(q, self.rel_height)rel_logits_h = rearrange(rel_logits_h, 'b x i y j -> b (y x) (j i)')return rel_logits_w + rel_logits_hclass FixedPosEmb(nn.Module):def __init__(self, window_size, overlap_window_size):super().__init__()self.window_size = window_sizeself.overlap_window_size = overlap_window_sizeattention_mask_table = torch.zeros((window_size + overlap_window_size - 1),(window_size + overlap_window_size - 1))attention_mask_table[0::2, :] = float('-inf')attention_mask_table[:, 0::2] = float('-inf')attention_mask_table = attention_mask_table.view((window_size + overlap_window_size - 1) * (window_size + overlap_window_size - 1))# get pair-wise relative position index for each token inside the windowcoords_h = torch.arange(self.window_size)coords_w = torch.arange(self.window_size)coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Wwcoords_flatten_1 = torch.flatten(coords, 1)  # 2, Wh*Wwcoords_h = torch.arange(self.overlap_window_size)coords_w = torch.arange(self.overlap_window_size)coords = torch.stack(torch.meshgrid([coords_h, coords_w]))coords_flatten_2 = torch.flatten(coords, 1)relative_coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :]  # 2, Wh*Ww, Wh*Wwrelative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2relative_coords[:, :, 0] += self.overlap_window_size - 1  # shift to start from 0relative_coords[:, :, 1] += self.overlap_window_size - 1relative_coords[:, :, 0] *= self.window_size + self.overlap_window_size - 1relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Wwself.attention_mask = nn.Parameter(attention_mask_table[relative_position_index.view(-1)].view(1, self.window_size ** 2, self.overlap_window_size ** 2), requires_grad=False)def forward(self):return self.attention_maskclass DilatedOCA(nn.Module):def __init__(self, dim, window_size, overlap_ratio, num_heads, dim_head, bias):super(DilatedOCA, self).__init__()self.num_spatial_heads = num_headsself.dim = dimself.window_size = window_sizeself.overlap_win_size = int(window_size * overlap_ratio) + window_sizeself.dim_head = dim_headself.inner_dim = self.dim_head * self.num_spatial_headsself.scale = self.dim_head ** -0.5self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size,padding=(self.overlap_win_size - window_size) // 2)self.qkv = nn.Conv2d(self.dim, self.inner_dim * 3, kernel_size=1, bias=bias)self.project_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, bias=bias)self.rel_pos_emb = RelPosEmb(block_size=window_size,rel_size=window_size + (self.overlap_win_size - window_size),dim_head=self.dim_head)self.fixed_pos_emb = FixedPosEmb(window_size, self.overlap_win_size)def forward(self, x):b, c, h, w = x.shapeqkv = self.qkv(x)qs, ks, vs = qkv.chunk(3, dim=1)# spatial attentionqs = rearrange(qs, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1=self.window_size, p2=self.window_size)ks, vs = map(lambda t: self.unfold(t), (ks, vs))ks, vs = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c=self.inner_dim), (ks, vs))# print(f'qs.shape:{qs.shape}, ks.shape:{ks.shape}, vs.shape:{vs.shape}')# split headsqs, ks, vs = map(lambda t: rearrange(t, 'b n (head c) -> (b head) n c', head=self.num_spatial_heads),(qs, ks, vs))# attentionqs = qs * self.scalespatial_attn = (qs @ ks.transpose(-2, -1))spatial_attn += self.rel_pos_emb(qs)spatial_attn += self.fixed_pos_emb()spatial_attn = spatial_attn.softmax(dim=-1)out = (spatial_attn @ vs)out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', head=self.num_spatial_heads,h=h // self.window_size, w=w // self.window_size, p1=self.window_size, p2=self.window_size)# merge spatial and channelout = self.project_out(out)return outif __name__ == "__main__":dim = 64window_size = 8overlap_ratio = 0.5num_heads = 2dim_head = 16# 初始化 DilatedOCA 模块oca_attention = DilatedOCA(dim=dim,window_size=window_size,overlap_ratio=overlap_ratio,num_heads=num_heads,dim_head=dim_head,bias=True)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")oca_attention = oca_attention.to(device)print(oca_attention)x = torch.randn(1, 32, 640, 480).to(device)# 前向传播output = oca_attention(x)print("input张量形状:", x.shape)print("output张量形状:", output.shape)

DilatedOCA模块详解

代码结构

import torch
import torch.nn as nn
from einops import rearrange
  • 导入库:首先导入 PyTorch 和 einops 库。einops 用于简化张量的重排操作。

模块定义

class DilatedOCA(nn.Module):def __init__(self, dim, window_size, overlap_ratio, num_heads, dim_head, bias):super(DilatedOCA, self).__init__()self.num_spatial_heads = num_headsself.dim = dimself.window_size = window_sizeself.overlap_win_size = int(window_size * overlap_ratio) + window_sizeself.dim_head = dim_headself.inner_dim = self.dim_head * self.num_spatial_headsself.scale = self.dim_head ** -0.5self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size,padding=(self.overlap_win_size - window_size) // 2)self.qkv = nn.Conv2d(self.dim, self.inner_dim * 3, kernel_size=1, bias=bias)self.project_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, bias=bias)self.rel_pos_emb = RelPosEmb(block_size=window_size,rel_size=window_size + (self.overlap_win_size - window_size),dim_head=self.dim_head)self.fixed_pos_emb = FixedPosEmb(window_size, self.overlap_win_size)
  • 初始化方法__init__ 方法定义了模块的结构。

    • dim:输入特征的通道数。

    • window_size:窗口的大小,用于空间注意力计算。

    • overlap_ratio:重叠窗口的比例,决定了窗口之间的重叠程度。

    • num_heads:空间注意力的头数。

    • dim_head:每个头的维度。

  • 层的定义

    • self.unfold:用于将输入张量展开为重叠窗口的操作。

    • self.qkv:一个 1x1 的卷积层,用于生成查询(Q)、键(K)和值(V)三个特征图。

    • self.project_out:一个 1x1 的卷积层,用于将输出特征映射回原始通道数。

    • self.rel_pos_embself.fixed_pos_emb:用于位置编码的模块,增强模型对空间位置的感知。

前向传播

def forward(self, x):b, c, h, w = x.shapeqkv = self.qkv(x)qs, ks, vs = qkv.chunk(3, dim=1)# spatial attentionqs = rearrange(qs, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1=self.window_size, p2=self.window_size)ks, vs = map(lambda t: self.unfold(t), (ks, vs))ks, vs = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c=self.inner_dim), (ks, vs))# split headsqs, ks, vs = map(lambda t: rearrange(t, 'b n (head c) -> (b head) n c', head=self.num_spatial_heads),(qs, ks, vs))# attentionqs = qs * self.scalespatial_attn = (qs @ ks.transpose(-2, -1))spatial_attn += self.rel_pos_emb(qs)spatial_attn += self.fixed_pos_emb()spatial_attn = spatial_attn.softmax(dim=-1)out = (spatial_attn @ vs)out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', head=self.num_spatial_heads,h=h // self.window_size, w=w // self.window_size, p1=self.window_size, p2=self.window_size)# merge spatial and channelout = self.project_out(out)return out
  • 输入形状x 的形状为 (batch_size, channels, height, width),其中 b 是批量大小,c 是通道数,hw 是图像的高度和宽度。

  • 特征提取

    • qkv = self.qkv(x):通过 qkv 层生成 Q、K、V 特征图。

    • qs, ks, vs = qkv.chunk(3, dim=1):将 Q、K、V 特征图沿通道维度分离。

  • 空间注意力计算

    • qs 被重排为适合空间注意力计算的格式。

    • ksvs 通过 unfold 操作展开为重叠窗口。

  • 分头处理

    • 使用 einops.rearrange 将 Q、K、V 的形状调整为适合多头自注意力计算的格式。
  • 计算注意力

    • qs = qs * self.scale:对 Q 进行缩放以提高稳定性。

    • spatial_attn = (qs @ ks.transpose(-2, -1)):计算注意力分数。

    • spatial_attn += self.rel_pos_emb(qs)spatial_attn += self.fixed_pos_emb():添加位置编码以增强空间感知。

    • spatial_attn = spatial_attn.softmax(dim=-1):对注意力分数进行 softmax 归一化。

  • 输出计算

    • out = (spatial_attn @ vs):使用注意力权重对 V 进行加权求和,得到最终输出。
  • 重排输出

    • out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', ...):将输出重排回原始形状。
  • 最终投影

    • out = self.project_out(out):通过投影层将输出映射回原始通道数。

总结

DilatedOCA 模块结合了扩张卷积和空间注意力机制,通过重叠窗口的设计增强了对图像局部特征的捕捉能力。该模块在图像处理任务中具有广泛的应用潜力,尤其是在需要精细特征提取的场景中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/65996.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】统信UOS服务器安装MySQL8.0(RPM)

目录 一、下载安装包 二、安装MySQL 2.1hive适配 2.2ranger适配 3.2DolphinScheduler适配 一、下载安装包 官网下载安装包:MySQL :: MySQL Downloads 选择社区版本下载 点击MySQL Community Server 选择对应系统的MySQL版本号 统信1060a 操作系统对应 redhat8…

Jenkins简单的安装运行

一、下载 官网下载:https://www.jenkins.io/download/ 清华大学开源软件镜像站:https://mirrors.tuna.tsinghua.edu.cn/jenkins/ 官网资料丰富,介绍了各种平台安装以及下载。安装简单,按照说明来就行。下面我介绍一个非常简单的…

【CSS】HTML页面定位CSS - position 属性 relative 、absolute、fixed 、sticky

目录 relative 相对定位 absolute 绝对定位 fixed 固定定位 sticky 粘性定位 position:relative 、absolute、fixed 、sticky (四选一) top:距离上面的像素 bottom:距离底部的像素 left:距离左边的像素…

Ubuntu中双击自动运行shell脚本

方法1: 修改文件双击反应 参考: https://blog.csdn.net/miffywm/article/details/103382405 chmod x test.sh鼠标选中待执行文件,在窗口左上角edit菜单中选择preference设计双击执行快捷键,如下图: 方法2: 设置一个应用 参考: https://blo…

从0开始学习搭网站的第一天

前言,以下内容学习自mdn社区,感兴趣的朋友可以直接去看原文章web技术 目录 web机制互联网是怎么运作的网站服务器是什么什么是URL?什么是web服务器?什么是域名什么是超链接什么是网页DOMgoole浏览器开发者工具 web机制 互联网是怎…

黑马linux笔记(03)在Linux上部署各类软件 MySQL5.7/8.0 Tomcat(JDK) Nginx RabbitMQ

文章目录 实战章节:在Linux上部署各类软件tar -zxvf各个选项的含义 为什么学习各类软件在Linux上的部署 一 MySQL数据库管理系统安装部署【简单】MySQL5.7版本在CentOS系统安装MySQL8.0版本在CentOS系统安装MySQL5.7版本在Ubuntu(WSL环境)系统…

[Transformer] The Structure of GPT, Generative Pretrained Transformer

The Structure of Generative Pretrained Transformer Reference: The Transformer architecture of GPT models How GPT Models Work

浅谈云计算04 | 云基础设施机制

探秘云基础设施机制:云计算的基石 一、云基础设施 —— 云计算的根基![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/1fb7ff493d3c4a1a87f539742a4f57a5.png)二、核心机制之网络:连接云的桥梁(一)虚拟网络边界&#xff…

解锁 JMeter 的 ForEach Controller 高效测试秘籍

各位小伙伴们,今天咱就来唠唠 JMeter 里超厉害的 “宝藏工具”——ForEach Controller,它可是能帮咱们在性能测试的江湖里 “大杀四方” 哦! 一、ForEach Controller 是啥 “神器” 想象一下,你手头有一串神秘钥匙,每…

sosadmin相关命令

sosadmin命令 以下是本人翻译的官方文档,如有不对,还请指出,引用请标明出处。 原本有个对应表可以跳转的,但是CSDN的这个[](#)跳转好像不太一样,必须得用html标签,就懒得改了。 sosadmin help 用法 sosadm…

【WPS】【WORDEXCEL】【VB】实现微软WORD自动更正的效果

1. 代码规范方面 添加 Option Explicit:强制要求显式声明所有变量,这样可以避免因变量名拼写错误等情况而出现难以排查的逻辑错误,提高代码的健壮性。使用 On Error GoTo 进行错误处理:通过设置错误处理机制,当代码执行…

Kafka 分区管理

分区是主题的子集,每个主题可以被分割成多个分区,一个分区有一个主副本(Leader)及一个或多个从(Follower)副本。分区允许将数据分布在多个broker上,这样可以提高数据的处理能力、并行性及可靠性…

【论文阅读+复现】High-fidelity Person-centric Subject-to-Image Synthesis

以人物为中心的主体到图像的高保真合成,CVPR2024 code:CodeGoat24/Face-diffuser: [CVPR2024] Official implementation of High-fidelity Person-centric Subject-to-Image Synthesis. paper:2311.10329 背景 研究问题:这篇文…

详解如何自定义 Android Dex VMP 保护壳

版权归作者所有,如有转发,请注明文章出处:https://cyrus-studio.github.io/blog/ 前言 Android Dex VMP(Virtual Machine Protection,虚拟机保护)壳是一种常见的应用保护技术,主要用于保护 And…

基于华为atlas的重车(满载)空车(空载)识别

该教程主要是想摸索出华为atlas的基于ACL的推理模式。最终实现通过煤矿磅道上方的摄像头,识别出车辆的重车(满载)、空车(空载)情况。本质上是一个简单的检测问题。 但是整体探索过程比较坎坷,Tianxiaomo的…

《零基础Go语言算法实战》【题目 2-25】goroutine 的执行权问题

《零基础Go语言算法实战》 【题目 2-25】goroutine 的执行权问题 请说明以下这段代码为什么会卡死。 package main import ( "fmt" "runtime" ) func main() { var i byte go func() { for i 0; i < 255; i { } }() fmt.Println("start&quo…

IntelliJ IDEA中Maven项目的配置、创建与导入全攻略

大家好&#xff0c;我是袁庭新。 IntelliJ IDEA是当前最流行的Java IDE&#xff08;集成开发环境&#xff09;之一&#xff0c;也是业界公认最好用的Java开发工具之一。IntelliJ IDEA支持Maven的全部功能&#xff0c;通过它我们可以很轻松地实现创建Maven项目、导入Maven项目、…

【Rust】函数

目录 思维导图 1. 函数的基本概念 1.1 函数的定义 2. 参数的使用 2.1 单个参数的示例 2.2 多个参数的示例 3. 语句与表达式 3.1 语句与表达式的区别 3.2 示例 4. 带返回值的函数 4.1 返回值的示例 4.2 返回值与表达式 5. 错误处理 5.1 错误示例 思维导图 1. 函数…

Cython全教程2 多种定义方式

—— 本篇文章&#xff0c;主要讲述Cython中的四种定义关键字 全教程2 多种定义方式&#xff1a; 在Cython中&#xff0c;关于定义的关键字有四个&#xff0c;分别是&#xff1a; cdef、def、cpdef、DEF 一、cdef定义关键字 顾名思义&#xff0c;cdef关键字定义的是一个C函数…

Web开发(一)HTML5

Web开发&#xff08;一&#xff09;HTML5 写在前面 参考黑马程序员前端Web教程做的笔记&#xff0c;主要是想后面自己搭建网页玩。 这部分是前端HTML5CSS3移动web视频教程的HTML5部分。主要涉及到HTML的基础语法。 HTML基础 标签定义 HTML定义 HTML(HyperText Markup Lan…