Swin transformer 论文阅读记录 代码分析


该篇文章,是我解析 Swin transformer 论文原理(结合pytorch版本代码)所记,图片来源于源paper或其他相应博客。

代码也非原始代码,而是从代码里摘出来的片段,配上简单数据,以便理解。

当然,也可能因为设置数据不当,造成误解,请多指教。

刚写了一部分。先发布。希望多多指正。


在这里插入图片描述
Figure 1.
(a) The proposed Swin Transformer builds hierarchical feature maps by merging image patches (shown in gray) in deeper layers ,
and has linear computation complexity to input image size due to computation of self-attention only within each local window (shown in red).
It can thus serve as a general-purpose backbone for both image classification and dense recognition tasks.
(b) In contrast, previous vision Transformers produce feature maps of a single low resolution and have quadratic computation complexity to input image size due to computation of self attention globally.

模型结构图

在这里插入图片描述
Figure 3.
(a) The architecture of a Swin Transformer (Swin-T);
(b) two successive Swin Transformer Blocks (notation presented with Eq. (3)).
W-MSA and SW-MSA are multi-head self attention modules with regular and shifted windowing configurations, respectively.

Stage 1 – Patch Embedding

It first splits an input RGB image into non-overlapping patches by a patch splitting module, like ViT.

Each patch is treated as a “token” and its feature is set as a concatenation of the raw pixel RGB values.

In our implementation, we use a patch size of 4×4 and thus the feature dimension of each patch is 4×4×3 = 48.(channel–3)

A linear embedding layer is applied on this raw-valued feature to project it to an arbitrary dimension (denoted as C).
这个表述,linear embedding layer,我感觉不太准确,但是,后半部分比较准确,哈哈,将channel–3变成了96.

Several Transformer blocks with modified self-attention computation (Swin Transformer blocks) are applied on these patch tokens.

The Transformer blocks maintain the number of tokens (H/4 × W/4), and together with the linear embedding are referred to as “Stage 1”.

代码

以下代码来自于model.py:

class PatchEmbed(nn.Module):"""2D Image to Patch Embedding"""
"""
@ time : 2024/12/17
"""
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as Fclass PatchEmbed(nn.Module):"""2D Image to Patch Embedding"""def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):super().__init__()patch_size = (patch_size, patch_size)self.patch_size = patch_sizeself.in_chans = in_cself.embed_dim = embed_dimself.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()def forward(self, x):_, _, H, W = x.shape# padding# 如果输入图片的H,W不是patch_size的整数倍,需要进行paddingpad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)if pad_input:# to pad the last 3 dimensions,# (W_left,W_right, H_top,H_bottom, C_front,C_back)x = F.pad(x,(0, self.patch_size[1] - W % self.patch_size[1],0, self.patch_size[0] - H % self.patch_size[0],0, 0))# 下采样patch_size倍x = self.proj(x)_, _, H, W = x.shape# flatten: [B, C, H, W] -> [B, C, HW]# transpose: [B, C, HW] -> [B, HW, C]x = x.flatten(2).transpose(1, 2)x = self.norm(x)print(x.shape)# torch.Size([1, 3136, 96])# 224/4 * 224/4 = 3136return x, H, Wif __name__ == '__main__':img_path = "tulips.jpg"img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]print(img.size)# (500,375)#img_size = 224data_transform = transforms.Compose([transforms.Resize(int(img_size * 1.14)),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])img = data_transform(img)print(img.shape)# torch.Size([3, 224, 224])# expand batch dimensionimg = torch.unsqueeze(img, dim=0)print(img.shape)# torch.Size([1, 3, 224, 224])# split image into non-overlapping patchespatch_embed = PatchEmbed(norm_layer=nn.LayerNorm)patch_embed(img)

Stage 2 – 3.2. Shifted Window based Self-Attention

Shifted window partitioning in successive blocks

The window-based self-attention module lacks connections across windows, which limits its modeling power.

To introduce cross-window connections while maintaining the efficient computation of non-overlapping windows,
we propose a shifted window partitioning approach which alternates between two partitioning configurations in consecutive Swin Transformer blocks.
为了在保持非重叠窗口高效计算的同时引入跨窗口连接,我们提出了一种移位窗口划分方法,该方法在连续的Swin Transformer块中交替使用两种不同的划分配置。

在这里插入图片描述
Figure 2.
In layer l (left), a regular window partitioning scheme is adopted, and self-attention is computed within each window.
In the next layer l + 1 (right), the window partitioning is shifted, resulting in new windows.
The self-attention computation in the new windows crosses the boundaries of the previous windows in layer l, providing connections among them.
在新窗口中进行的自注意力计算跨越了第l层中先前窗口的边界,从而在它们之间建立了连接。

Efficient batch computation for shifted configuration

An issue with shifted window partitioning is that it will result in more windows, and some of the windows will be smaller than M×M.

Here, we propose a more efficient batch computation approach by cyclic-shifting toward the top-left direction(向左上方向循环移动), as illustrated in Figure 4.

这里的 more efficient,是说相对于直观方法 padding—mask来说:

A naive solution is to pad the smaller windows to a size of M×M and mask out the padded values when computing attention.


在这里插入图片描述
Figure 4. Illustration of an efficient batch computation approach for self-attention in shifted window partitioning.


After this shift, a batched window may be composed of several sub-windows that are not adjacent in the feature map, so a masking mechanism is employed to limit self-attention computation to within each sub-window.
在此转换之后,批处理窗口可能由特征图中不相邻的几个子窗口组成,因此采用掩蔽机制将自注意力计算限制在每个子窗口内。

With the cyclic-shift, the number of batched windows remains the same as that of regular window partitioning, and thus is also efficient.
通过循环移位,批处理窗口的数量与常规窗口分区的数量保持不变,因此也是高效的。


上图和叙述,并不太直观,找了相关资料,一起分析:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
移动完成之后,4是一个单独区域,5、3为一组,7、1为一组,8、6、2、0为一组。

但,5、3本身是两个图像的边缘,混在一起计算不是乱了吗?一起计算也没问题,ViT也是全局计算的。

但,Swin-Transformer为了防止这个问题,在代码中使用了masked MSA,这样就能够通过设置蒙板来隔绝不同区域的信息了。

源码中具体的方法就是将不计算的位置元素减去100。

这里需要注意的是,在窗口数据进行滑动完之后,需要将数据还原回去,即挪回到原来的位置上。

代码

以下代码来自于model.py:

def window_partition(x, window_size: int):"""将feature map按照window_size划分成一个个没有重叠的window主要思路是将feature转成 (num_windows*B, window_size*window_size, C)的shape,把需要self-attn计算的window排列到第0维,一次并行的qkv就可以了Args:x: (B, H, W, C)window_size (int): window size(M)Returns:windows: (num_windows*B, window_size, window_size, C)"""B, H, W, C = x.shape# B,224,224,C# B,56,56,Cx = x.view(B, H // window_size, window_size, W // window_size, window_size, C)# B,32,7,32,7,C# B,8,7,8,7,C# permute:# [B, H//Mh, Mh,    W//Mw, Mw, C] -># [B, H//Mh, W//Mh, Mw,    Mw, C]# B,32,32,7,7,C# B,8,8,7,7,C# view:# [B, H//Mh, W//Mw, Mh, Mw, C] -># [B*num_windows,   Mh, Mw, C]# B*1024,7,7,C# B*64,7,7,C# 32*32 = 1024# 224 / 7 = 32windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windows

分析:将 [B, C, 56, 56] 最后变成了[64B, C, 7, 7],原先的 B*C 张 56*56 的特征图,最后变成了 B*64*C张7*7的特征;

即,我们有64B个样本,每个样本包含C个7x7的通道。

注意,window_size–M–7,是每个window的大小,7*7,不是7*7个window,我刚开始混淆了这一点。


class BasicLayer(nn.Module):# A basic Swin Transformer layer for one stage.def __init__(self, dim, depth, num_heads, window_size,mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):super().__init__()self.dim = dimself.depth = depthself.window_size = window_sizeself.use_checkpoint = use_checkpointself.shift_size = window_size // 2# 7//2 = 3# build blocksself.blocks = nn.ModuleList([SwinTransformerBlock(dim=dim,num_heads=num_heads,window_size=window_size,shift_size=0 if (i % 2 == 0) else self.shift_size,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,drop=drop,attn_drop=attn_drop,drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,norm_layer=norm_layer)for i in range(depth)])...# depth: 2, 2, 6, 2# 即,第一层,depth=2, 有两个SwinTransformerBlock,shift_size分别为:0,3# 即,第二层,depth=2, 有两个SwinTransformerBlock,shift_size分别为:0,3# 即,第三层,depth=6, 有两个SwinTransformerBlock,shift_size分别为:#	0,3,0,3,0,3# 即,第四层,depth=2, 有两个SwinTransformerBlock,shift_size分别为:0,3def create_mask(self, x, H, W):# calculate attention mask for SW-MSA
import numpy as np
import torchH = 7
W = 7
window_size = 7
shift_size = 3Hp = int(np.ceil(H / window_size)) * window_size
Wp = int(np.ceil(W / window_size)) * window_size# 拥有和feature map一样的通道排列顺序,方便后续window_partition
img_mask = torch.zeros((1, Hp, Wp, 1))
# [1, Hp, Wp, 1]
print(img_mask, '\n')h_slices = (slice(0, -window_size),slice(-window_size, -shift_size),slice(-shift_size, None)
)
print(h_slices, '\n')
# (slice(0, -7, None), slice(-7, -3, None), slice(-3, None, None))w_slices = (slice(0, -window_size),slice(-window_size, -shift_size),slice(-shift_size, None)
)
print(w_slices, '\n')
# (slice(0, -7, None), slice(-7, -3, None), slice(-3, None, None))cnt = 0
for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1print(img_mask)

在这里插入图片描述

import torchimg_mask = torch.rand((2, 3))
print(img_mask)
'''
tensor([[0.7410, 0.6020, 0.5195],[0.9214, 0.2777, 0.8418]])
'''
attn_mask = img_mask.unsqueeze(1) - img_mask.unsqueeze(2)
print(attn_mask)
'''
tensor([[[ 0.0000, -0.1390, -0.2215],[ 0.1390,  0.0000, -0.0825],[ 0.2215,  0.0825,  0.0000]],[[ 0.0000, -0.6437, -0.0796],[ 0.6437,  0.0000,  0.5642],[ 0.0796, -0.5642,  0.0000]]])
'''print(img_mask.unsqueeze(1))
'''
tensor([[[0.7410, 0.6020, 0.5195]],[[0.9214, 0.2777, 0.8418]]])
'''
print(img_mask.unsqueeze(2))
'''
tensor([[[0.7410],[0.6020],[0.5195]],[[0.9214],[0.2777],[0.8418]]])
'''

上面那个代码,需要根据下面这个代码对应着走,shift_size–torch.roll()

class SwinTransformerBlock(nn.Module):# Swin Transformer Block....def forward(self, x, attn_mask):H, W = self.H, self.WB, L, C = x.shapeassert L == H * W, "input feature has wrong size"shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)# pad feature maps to multiples of window size# 把feature map给pad到window size的整数倍pad_l = pad_t = 0pad_r = (self.window_size - W % self.window_size) % self.window_sizepad_b = (self.window_size - H % self.window_size) % self.window_size# 注意F.pad的顺序,刚好是反着来的, 例如:# x.shape = (b, h, w, c)# x = F.pad(x, (1, 1, 2, 2, 3, 3))# x.shape = (b, h+6, w+4, c+2)# 源码可能有误,修改成下面的# x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))x = F.pad(x, (0, 0, pad_t, pad_b, pad_l, pad_r))_, Hp, Wp, _ = x.shape# cyclic shiftif self.shift_size > 0:# paper中,滑动的size是窗口大小的/2(向下取整)# torch.roll以H,W的维度为例子,负值往左上移动,正值往右下移动。# 溢出的值在对角方向出现。即循环移动。shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))else:shifted_x = xattn_mask = None# partition windowsx_windows = window_partition(shifted_x, self.window_size)  # [nW*B, Mh, Mw, C]x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # [nW*B, Mh*Mw, C]...

其中,torch.roll()方法简易示例如下:

import torchx = torch.randn(1, 4, 4, 3)
print(x, '\n')shifted_x = torch.roll(x, shifts=(-3, -3), dims=(1, 2))
print(shifted_x, '\n')

为了方便理解,我更换了维度:

import torchx = torch.randn(1, 3, 7, 7)
print(x, '\n')shifted_x = torch.roll(x, shifts=(-3, -3), dims=(2, 3))
print(shifted_x, '\n')

在这里插入图片描述

Stage 3 – patch merging layers

To produce a hierarchical representation, the number of tokens is reduced by patch merging layers as the network gets deeper.

The first patch merging layer concatenates the features of each group of 2×2 neighboring patches, and applies a linear layer on the 4C-dimensional concatenated features.
首个补丁合并层将每组2×2相邻补丁的特征进行拼接,并在拼接后的4C维特征上应用一个线性层。

This reduces the number of tokens by a multiple of 2×2=4(2 ×downsampling of resolution), and the output dimension is set to 2C.

Swin Transformer blocks are applied afterwards for feature transformation, with the resolution kept at H/8 × W/8.

同样,结合其他大神分析,图展示如下:

在这里插入图片描述

Related Work

Self-attention based backbone architectures

Instead of using sliding windows, we propose to shift windows between consecutive layers, which allows for a more efficient implementation in general hardware.

。。。。。

Cited link or paper name

  1. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.
  2. https://blog.csdn.net/weixin_42392454/article/details/141395092

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/63430.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

GPT-Omni 与 Mini-Omni2:创新与性能的结合

近年来,随着人工智能技术的飞速发展,各种模型和平台应运而生,以满足从个人用户到企业级应用的多样化需求。在这一领域,GPT-Omni 和 Mini-Omni2 是两款备受瞩目的技术产品,它们凭借独特的设计和强大的功能,在…

龙迅#LT7911E适用于EDP/DP/TPYE-C转MIPIDSI应用,支持图像处理功能,内置I2C,主应用副屏显示,投屏领域!

1. 描述 LT7911E 是一款高性能 eDP 转 MIPI D-PHY 转换器,旨在将 eDP 源连接到 MIPI 显示面板。 LT7911E 集成了一个符合 eDP1.4 标准的接收器,支持 1.62Gbps 至 5.67Gbps 的输入数据,以 270Mbps 的递增步长,以及一个 2 端口 D…

C语言——实现求出最大值

问题描述&#xff1a;利用C语言自定义函数求出一维数组里边最大的数字 //利用函数找最大数#include<stdio.h>int search(int s[9]) //查找函数 {int i , max s[0] , max_xia 0;for(i0;i<9;i){if(s[i] > max){max_xia i;max s[max_xia];}}return max; } in…

解锁 draw.io 流程图制作工具Docker私有化部署(2/2)

一、draw.io 流程图制作工具简介 &#xff08;一&#xff09;基础介绍 draw.io 是一款备受青睐的开源流程图软件&#xff0c;它有着诸多优点。首先&#xff0c;其界面十分整洁有序&#xff0c;完全没有广告的干扰&#xff0c;并且所有功能都是免费向用户开放的&#xff0c;这一…

[HNCTF 2022 Week1]baby_rsa

源代码&#xff1a; from Crypto.Util.number import bytes_to_long, getPrime from gmpy2 import * from secret import flag m bytes_to_long(flag) p getPrime(128) q getPrime(128) n p * q e 65537 c pow(m,e,n) print(n,c) # 62193160459999883112594854240161159…

docker run命令大全

docker run命令大全 基本语法常用选项基础选项资源限制网络配置存储卷和挂载环境变量重启策略其他高级选项示例总结docker run 命令是 Docker 中最常用和强大的命令之一,用于创建并启动一个新的容器。该命令支持多种选项和参数,可以满足各种使用场景的需求。以下是 docker ru…

Flink2.0未来趋势中需要注意的一些问题

手机打字&#xff0c;篇幅不长&#xff0c;主要讲一下FFA中关于Flink2.0的未来趋势&#xff0c;直接看重点。 Flink Forward Asia 2024主会场有一场关于Flink2.0的演讲&#xff0c;很精彩&#xff0c;官方也发布了一些关于Flink2.0的展望和要解决的问题。 1.0时代和2.0时代避免…

智能座舱进阶-应用框架层-Jetpack主要组件

Jetpack的分类 1. DataBinding&#xff1a;以声明方式将可观察数据绑定到界面元素&#xff0c;通常和ViewModel配合使用。 2. Lifecycle&#xff1a;用于管理Activity和Fragment的生命周期&#xff0c;可帮助开发者生成更易于维护的轻量级代码。 3. LiveData: 在底层数据库更…

PowerMILL 客制化宏 - 变量

从PowerMILL2012起&#xff0c;命令起始支持变量。支持变量将使宏命令更加灵活和功能强大。可以对变量做一些运算而不依赖其它语言。 当前支持有变量类型为&#xff1a; INT&#xff1b; REAL&#xff1b; STRING&#xff1b; ENTITY&#xff1b; ARRAY LIST; OBJECT; 以下就…

arcgis for js实现地图截图、地图打印

地图截图 效果 实现 复制运行即可 要实现复杂的截图保存可以参考 官网案例 <!DOCTYPE html> <html lang"zn"><head><meta charset"UTF-8" /><meta http-equiv"X-UA-Compatible" content"IEedge" />…

【BUG】记一次context canceled的报错

文章目录 案例分析gorm源码解读gin context 生命周期context什么时候cancel的什么时候context会被动cancel掉呢&#xff1f; 野生协程如何处理 案例分析 报错信息 {"L":"ERROR","T":"2024-12-17T11:11:33.0050800","file"…

信号槽【QT】

文章目录 对象树字符集信号槽QT坐标系信号与槽connect自定义槽自定义信号disconnect 对象树 #ifndef MYLABEL_H #define MYLABEL_H#include<QLabel> class MyLabel : public QLabel { public:// 构造函数使用带 QWidget* 版本的.// 确保对象能够加到对象树上MyLabel(QWi…

写SQL太麻烦?免费搭建 Text2SQL 应用,智能写 SQL | OceanBase AI 实践

自OceanBase 4.3.3版本推出以来&#xff0c;向量检索的能力受到了很多客户的关注&#xff0c;也纷纷表达希望OB能拓展更多 多模数据库大模型 的AI应用实践。 在上篇文章 &#x1f449; OceanBase LLM&#xff0c;免费构建你的专属 AI 助手 &#xff0c;我们介绍了如何去搭建一…

Dalsa线阵CCD相机使用开发手册

要使用Dalsa工业相机进行二次开发&#xff0c;看用户开发手册顺便做下笔记&#xff1a;&#xff08;欢迎加QQ讨论&#xff1a;77248031&#xff0c; 或QQ群&#xff1a;585068192&#xff09; 由于“本公主”用的.NET开发&#xff0c;软件支持只翻译了手册中.NET部分&#xff0…

C++特殊类设计(单例模式等)

目录 引言 1.请设计一个类&#xff0c;不能被拷贝 2. 请设计一个类&#xff0c;只能在堆上创建对象 为什么设置实例的方法为静态成员呢 3. 请设计一个类&#xff0c;只能在栈上创建对象 4. 请设计一个类&#xff0c;不能被继承 5. 请设计一个类&#xff0c;只能创建一个对…

分布式系统架构:服务容错

1.为什么需要容错 分布式系统的本质是不可靠的&#xff0c;一个大的服务集群中&#xff0c;程序可能崩溃、节点可能宕机、网络可能中断&#xff0c;这些“意外情况”其实全部都在“意料之中”。故障的发生是必然的&#xff0c;所以需要设计一套健壮的容错机制来应对这些问题。 …

[IT项目管理]九.项目质量管理

九&#xff0e;项目质量管理 9.1项目质量管理的重要性 对于很多IT项目的差劲&#xff0c;大多数人只可以忍受。项目质量管理是IT项目管理的重要组成部分&#xff0c;对于提高项目成功率、降低项目成本、提升客户满意度至关重要。尽管很多人对IT项目的质量问题感到无奈&#x…

【Threejs】从零开始(六)--GUI调试开发3D效果

请先完成前置步骤再进行下面操作&#xff1a;【Threejs】从零开始&#xff08;一&#xff09;--创建threejs应用-CSDN博客 一.GUI界面概述 GUI&#xff08;Graphical User Interface&#xff09;指的是图形化用户界面&#xff0c;广泛用在各种程序的上位机&#xff0c;能够通过…

ffmpeg-SDL显示BMP

效果图如下 本文主要将我们通过创建窗口、渲染上下文工具、纹理工具、矩形框工具&#xff1b;其需要主要的是&#xff1a;首先我们在显示BMP时&#xff0c;需要先创建好窗口&#xff0c;再使用渲染工具对窗口进行格式刷&#xff0c;使用纹理工具和渲染工具配合进行BMP图片显示…

多音轨视频使用FFmpeg删除不要音轨方法

近期给孩子找宫崎骏动画&#xff0c;但是有很多是多音轨视频但是默认的都是日语&#xff0c;电视上看没办法所以只能下载后删除音轨文件只保留中文。 方法分两步&#xff0c;先安装FFmpeg在转文件即可。 第一步FFmpeg安装 FFmpeg是一个开源项目&#xff0c;包含了处理视频的…