详解SwinIR的论文和代码(SwinIR: Image Restoration Using Swin Transformer)

paper:https://arxiv.org/abs/2108.10257
code:https://github.com/JingyunLiang/SwinIR

目录

  • 1. Swin Transformer layers
    • 1.1 局部注意力
    • 1.2 移动窗口机制
    • 1.3 关键代码理解
  • 2. 整体网络结构
    • 2.1 浅层特征提取
    • 2.2 深层特征提取
    • 2.3 图像重建
  • 3.总结

SwinIR将Swin transformer1应用到low level领域的图像增强任务,结合卷积设计了网络结构,在以下三个任务上取得了很好的效果:图像超分辨率(包括classical、lightweight和real-world SR)、图像去噪(包括灰度图和彩色图像去噪)和 JPEG压缩失真去除。本文将结合代码对SwinIR进行详解。

SwinIR的网络结构并不复杂,关键部件就是Swin Transformer layers(STL)卷积层残差连接。卷积和残差连接大家都比较熟悉了,因此我首先结合代码介绍一下swin transformer层,然后自底向上的介绍SwinIR的全貌

1. Swin Transformer layers

SwinIR使用的Swin Transformer layers(STL)是在swin transformer中提出的,并未有改动。STL基于原始的多头注意力transformer层进行优化,主要的不同点在于:1. 局部注意力(local attention);2. 移动窗口机制(shifted window mechanism);

1.1 局部注意力

原始的全局注意力会将图像分成若干个patch,所有的patch之间做自注意力计算;所谓的局部注意力就是首先将图像划分成若干个window,每个window内在进行patch的划分,然后在window内部进行自注意力的计算,而不在一个window内的patch是没有交互的。也就是说,只考虑一个window内的patch,他们之间的计算和全局注意力操作是一样的。

理解局部注意力具体是怎么做的,很好的一个办法是看代码和分析tensor在不同层之间的shape整理出来。下面是我整理的tensor shape变化:

请添加图片描述
其中,b: batchsize, h: 输入高, w:输入宽, ws: 窗口大小, C: channel数, num_heads:attention的head数

1.2 移动窗口机制

由于基于窗口的多头注意力(W-MSA)没有考虑跨窗口的连接,模型建模长距离关联的能力受损。因此swin transformer提出了移动窗口多头注意力机制(SW-MSA),可在保证计算高效性的前提下,扩大感受野。

如下图所示,W-MSA的窗口大小为M*M(图中M=4),那么SW-MSA的窗口划分将向右下移动 ⌊ M / 2 ⌋ ∗ ⌊ M / 2 ⌋ \lfloor M/2 \rfloor *\lfloor M/2 \rfloor M/2M/2

请添加图片描述

但是经过位移之后,窗口数量会变多,由原来的 ⌊ h / M ⌋ ∗ ⌊ w / M ⌋ \lfloor h/M \rfloor *\lfloor w/M \rfloor h/Mw/M变成 ( ⌊ h / M ⌋ + 1 ) ∗ ( ⌊ w / M ⌋ + 1 ) (\lfloor h/M \rfloor + 1) *(\lfloor w/M \rfloor +1) (⌊h/M+1)(⌊w/M+1),而且窗口大小不一致。因此swin transformer提出了循环位移,减少窗口数量,同时可以获得相同大小的窗口进行并行计算。循环位移如下图所示。
请添加图片描述

在代码中,循环位移通过torch.roll实现,shifts为负,代表从下往上移动,从右往左移动,最上和最左循环移动到最下和最右。

# cyclic shift
if self.shift_size > 0:shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))

关于torch.roll可参考:https://blog.csdn.net/weixin_42899627/article/details/116095067
如上图所示,经过循环移位后,有三个窗口中有一些patch是本不相邻的,它们不应该做自注意力,所以swin transformer建立了mask机制来完成最终的注意力计算。

关于mask的理解可参考https://github.com/microsoft/Swin-Transformer/issues/38

1.3 关键代码理解

下面来看一下关键代码及注释,首先是WindowAttention的forward函数:

def forward(self, x, mask=None):"""Args:x: input features with shape of (num_windows*B, N, C)mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None"""B_, N, C = x.shape  # 此处的输入是经过window partition的# self.qkv(x): num_windows*B, window_size*window_size, 3*Cqkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # 通过一个全连接层获取所有头的qkv,(3, num_windows*B, num_heads, window_size*window_size, C // num_heads)q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.scaleattn = (q @ k.transpose(-2, -1)) # num_windows*B, num_heads, window_size*window_size, window_size*window_size# 可学习的相对位置biasrelative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Wwattn = attn + relative_position_bias.unsqueeze(0) # num_windows*B, num_heads, window_size*window_size, window_size*window_sizeif mask is not None:nW = mask.shape[0]# 将mask和attn相加,mask只有两种取值0和-100,因此为0时对attn无影响,为-100时,self.softmax(attn)将变为接近于0attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N) # num_windows*B, num_heads, window_size*window_size, window_size*window_sizeattn = self.softmax(attn) # num_windows*B, num_heads, window_size*window_size, window_size*window_sizeelse:attn = self.softmax(attn)attn = self.attn_drop(attn)# v:num_windows*B, num_heads, window_size*window_size, C // num_heads# attn:num_windows*B, num_heads, window_size*window_size, window_size*window_size# attn @ v: num_windows*B, num_heads, window_size*window_size, C // num_headsx = (attn @ v).transpose(1, 2).reshape(B_, N, C) # num_windows*B, window_size*window_size, Cx = self.proj(x) # 全连接层x = self.proj_drop(x)return x

接下来是SwinTransformerBlock的forward函数

    def forward(self, x, x_size):H, W = x_sizeB, L, C = x.shape# assert L == H * W, "input feature has wrong size"shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)# cyclic shiftif self.shift_size > 0:shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))else:shifted_x = x# partition windowsx_windows = window_partition(shifted_x, self.window_size)  # (num_windows*B, window_size, window_size, C)x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # num_windows*B, window_size*window_size, C# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window sizeif self.input_resolution == x_size:attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, Celse:attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))# merge windowsattn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C# reverse cyclic shiftif self.shift_size > 0:x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))else:x = shifted_xx = x.view(B, H * W, C)# FFNx = shortcut + self.drop_path(x)x = x + self.drop_path(self.mlp(self.norm2(x)))return x

可以看到每个SwinTransformerBlock内部完成的是:
X = M S A ( L N ( X ) ) + X X = MSA(LN(X)) + X X=MSA(LN(X))+X
X = M L P ( L N ( X ) ) + X X = MLP(LN(X)) + X X=MLP(LN(X))+X
其中MSA为W-MSA和SW-MSA交替。

2. 整体网络结构

请添加图片描述
如上图所示,SwinIR包括三个modules,浅层特征提取、深层特征提取和图像重建。其中特征提取模块对所有任务都是一样的,但是图像重建对于不同的任务是不同的。

2.1 浅层特征提取

一个3×3卷积层将特征图通道转成embed_dim:(b, embed_dim, h, w)

self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)

2.2 深层特征提取

深层特征提取的基本模块则是第一节中讲解的STL和卷积层和残差连接。STL和卷积组成RSTB,RSTB和卷积组成了深层特征提取。

2.3 图像重建

以下代码可以看到对于不同的任务,图像重建模块是不同的,有的采用最邻近插值+卷积,有的采用pixelshuffle+卷积,有的直接采用卷积。


if self.upsampler == 'pixelshuffle':# for classical SRx = self.conv_first(x)x = self.conv_after_body(self.forward_features(x)) + xx = self.conv_before_upsample(x)x = self.conv_last(self.upsample(x))
elif self.upsampler == 'pixelshuffledirect':# for lightweight SRx = self.conv_first(x)x = self.conv_after_body(self.forward_features(x)) + xx = self.upsample(x)
elif self.upsampler == 'nearest+conv':# for real-world SRx = self.conv_first(x) # (b, embed_dim, h, w)x = self.conv_after_body(self.forward_features(x)) + xx = self.conv_before_upsample(x)x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))x = self.conv_last(self.lrelu(self.conv_hr(x)))
else:# for image denoising and JPEG compression artifact reductionx_first = self.conv_first(x)res = self.conv_after_body(self.forward_features(x_first)) + x_firstx = x + self.conv_last(res)

SwinIR可以很灵活配置网络的复杂度。影响W-MSA计算复杂度: 4 h w C 2 + 2 M 2 h w C 4hwC^2 + 2M^2hwC 4hwC2+2M2hwC
请添加图片描述

3.总结

  1. 结构简单,性能全面超过cnn-based的方法,适用于多种任务,可做为Low-level的基线模型;
  2. 作者发现与以往基于transformer的方法不同,Swinir不需要比cnn更多的训练数据,收敛速度也更快;
  3. 结构模块化,可以方便调整出不同复杂度的模型;

  1. Liu Z, Lin Y, Cao Y, et al. Swin transformer: Hierarchical vision transformer using shifted windows[C]//Proceedings of the IEEE/CVF international conference on computer vision. 2021: 10012-10022. ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/149973.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

同城跑腿服务预约小程序的作用是什么

随着生活质量逐渐提升,围绕人们生活的行业或产品非常多,同时互联网赋能下,也出现了很多便捷人们日常消费的场景,如外卖服务、快递服务等。 跑腿仅依赖微信私聊及电话预约是很低效且容易出错及造成极大工作压力的,同时…

基于一致性算法的微电网分布式控制MATLAB仿真模型

微❤关注“电气仔推送”获得资料(专享优惠) 本模型主要是基于一致性理论的自适应虚拟阻抗、二次电压补偿以及二次频率补偿,实现功率均分,保证电压以及频率稳定性。 一致性算法 分布式一致性控制主要分为两类:协调同…

Linux入门攻坚——6、磁盘管理——分区及文件系统管理

磁盘管理主要涉及分区的管理,以及分区后的文件系统管理。 磁盘的使用大体要分两步: 文件系统也是一个软件,根是自引用的。 文件系统的全局结构:物理格式: 一个磁盘刚被生产出来的时候,它里边没有划分扇区…

【Go入门】Web工作方式

【Go入门】 Web工作方式 我们平时浏览网页的时候,会打开浏览器,输入网址后按下回车键,然后就会显示出你想要浏览的内容。在这个看似简单的用户行为背后,到底隐藏了些什么呢? 对于普通的上网过程,系统其实是这样做的&…

基于RK3588的8k多屏异显安卓智能网络机顶盒

采用RK3588芯片方案的8K网络机顶盒,搭载纯净的安卓12操作系统,支持Ubuntu和Debian系统容拓展。主要面向外贸市场。此款机顶盒自带两个HDMI输出接口,一个HDMI输入接口,内置双频WiFi6无线模块,支持千兆以太网和USB接口。…

【文末送书】十大排序算法及C++代码实现

欢迎关注博主 Mindtechnist 或加入【智能科技社区】一起学习和分享Linux、C、C、Python、Matlab,机器人运动控制、多机器人协作,智能优化算法,滤波估计、多传感器信息融合,机器学习,人工智能等相关领域的知识和技术。关…

微创机器人:CRM撬动售后服务数字化升级

一方面,我国医疗器械行业起步较晚,更注重产品的销售和业务的拓展,企业售后服务整体比较滞后。 另一方面,医疗器械售后服务环节数字化程度不足,一些企业仍通过传统的线下手段管理售后服务,进行数字化尝试的…

【快速解决】实验四 对话框 《Android程序设计》实验报告

目录 前言 实验要求 实验四 对话框 正文开始 第一步建立项目 第二步选择empty views activity点击next ​编辑 第三步起名字,点击finish 第四步对 activity _main.xml文件操作进行布局 第五步,建立两个新文件,建立方法如下 SecondA…

SLAM中提到的相机位姿到底指什么?

不小心又绕进去了,所以掰一下。 以我个人最直观的理解,假设无旋转,相机在世界坐标系的(5,0,0)^T的位置上,所谓“位姿”,应该反映相机的位置,所以相机位姿应该如下: Eigen::Matrix4d T Eigen::M…

亚马逊云科技AI创新应用下的托管在AWS上的数据可视化工具—— Amazon QuickSight

目录 Amazon QuickSight简介 Amazon QuickSight的独特之处 Amazon QuickSight注册 Amazon QuickSight使用 Redshift和Amazon QuickSightt平台构建数据可视化应用程序 构建数据仓库 数据可视化 Amazon QuickSight简介 亚马逊QuickSight是一项可用于交付的云级商业智能 (BI…

基于circle group的Reed-Solomon codes

1. 引言 Polygon团队Ulrich Habock等人2023年论文 Reed-Solomon codes over the circle group。 前序博客有: Plonky3 Mersenne素数域的Reed-Solomon codes设计 STARKs支持任意size的域,而不要求是椭圆曲线。STARKs中在选择域size时,越小…

Unity中 Start和Awake的区别

Awake和Start在Unity中都是MonoBehaviour脚本中的生命周期函数 Awake函数在游戏对象首次被加载时调用,在游戏对象初始化之前调用。 start函数在游戏对象初始化完成后调用,在update第一次执行前调用。 这两个函数在其生命周期内都只会调用一次&#xf…

SpringBoot的启动流程

一、SpringBoot是什么? springboot是依赖于spring的,比起spring,除了拥有spring的全部功能以外,springboot无需繁琐的xml配置,这取决于它自身强大的自动装配功能;并且自身已嵌入Tomcat、Jetty等web容器&am…

redis+python 建立免费http-ip代理池;验证+留接口

前言: 效果图: 对于网络上的一些免费代理ip,http的有效性还是不错的;但是,https的可谓是凤毛菱角; 正巧,有一个web可以用http访问,于是我就想到不如直接拿着免费的HTTP代理去做这个! 思路: 1.单页获取ipporttime (获取time主要是为了后面使用的时候,依照时效可以做文章) 2.整…

windows环境搭建Zblog博客并发布上线公网可访问

文章目录 1. 前言2. Z-blog网站搭建2.1 XAMPP环境设置2.2 Z-blog安装2.3 Z-blog网页测试2.4 Cpolar安装和注册 3. 本地网页发布3.1. Cpolar云端设置3.2 Cpolar本地设置 4. 公网访问测试5. 结语 1. 前言 想要成为一个合格的技术宅或程序员,自己搭建网站制作网页是绕…

总结 CNN 模型:将焦点转移到基于注意力的架构

一、说明 在计算机视觉时代,卷积神经网络(CNN)几十年来一直是主导范式。直到 2021 年 Vision Transformers (ViTs) 出现,这个领域才开始发生变化。现在,是时候采用受 Transformer 架构启发的基于注意力的模型了&#x…

Springboot+vue的机动车号牌管理系统(有报告)。Javaee项目,springboot vue前后端分离项目

演示视频: Springbootvue的机动车号牌管理系统(有报告)。Javaee项目,springboot vue前后端分离项目 项目介绍: 本文设计了一个基于Springbootvue的前后端分离的机动车号牌管理系统,采用M(model&#xff09…

项目九、无线组网

目录 1 配置AC使AP放出Wifi1.1 确保AP和AC三层互通且AP知道AC的IP1.1.1 配置管理SVI的IP1.1.2 该SVI配置DHCP下发IP给AP 1.2 AC为AP下发配置1.2.1 AC用哪个接口回复AP1.2.2 AC验证AP身份(可以不认证)1.2.3 配置ssid 文件确定Wifi名称1.2.4 配置security …

Apache Pulsar 技术系列 - 基于 Pulsar 的海量 DB 数据采集和分拣

导语 Apache Pulsar 是一个多租户、高性能的服务间消息传输解决方案,支持多租户、低延时、读写分离、跨地域复制、快速扩容、灵活容错等特性。本文是 Pulsar 技术系列中的一篇,主要介绍 Pulsar 在海量DB Binlog 增量数据采集、分拣场景下的应用。 前言…

程序员开发者神器:10个.Net开源项目

今天一起盘点下,8月份推荐的10个.Net开源项目(点击标题查看详情)。 1、基于C#开发的适合Windows开源文件管理器 该项目是一个基于C#开发、开源的文件管理器,适用于Windows,界面UI美观、方便轻松浏览文件。此外&#…