该篇文章,是我解析 Swin transformer 论文原理(结合pytorch版本代码)所记,图片来源于源paper或其他相应博客。
代码也非原始代码,而是从代码里摘出来的片段,配上简单数据,以便理解。
当然,也可能因为设置数据不当,造成误解,请多指教。
刚写了一部分。先发布。希望多多指正。
Figure 1.
(a) The proposed Swin Transformer builds hierarchical feature maps by merging image patches (shown in gray) in deeper layers ,
and has linear computation complexity to input image size due to computation of self-attention only within each local window (shown in red).
It can thus serve as a general-purpose backbone for both image classification and dense recognition tasks.
(b) In contrast, previous vision Transformers produce feature maps of a single low resolution and have quadratic computation complexity to input image size due to computation of self attention globally.
模型结构图
Figure 3.
(a) The architecture of a Swin Transformer (Swin-T);
(b) two successive Swin Transformer Blocks (notation presented with Eq. (3)).
W-MSA and SW-MSA are multi-head self attention modules with regular and shifted windowing configurations, respectively.
Stage 1 – Patch Embedding
It first splits an input RGB image into non-overlapping patches by a patch splitting module, like ViT.
Each patch is treated as a “token” and its feature is set as a concatenation of the raw pixel RGB values.
In our implementation, we use a patch size of 4×4 and thus the feature dimension of each patch is 4×4×3 = 48.(channel–3)
A linear embedding layer is applied on this raw-valued feature to project it to an arbitrary dimension (denoted as C).
这个表述,linear embedding layer,我感觉不太准确,但是,后半部分比较准确,哈哈,将channel–3变成了96.
Several Transformer blocks with modified self-attention computation (Swin Transformer blocks) are applied on these patch tokens.
The Transformer blocks maintain the number of tokens (H/4 × W/4), and together with the linear embedding are referred to as “Stage 1”.
代码
以下代码来自于model.py:
class PatchEmbed(nn.Module):"""2D Image to Patch Embedding"""
"""
@ time : 2024/12/17
"""
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as Fclass PatchEmbed(nn.Module):"""2D Image to Patch Embedding"""def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):super().__init__()patch_size = (patch_size, patch_size)self.patch_size = patch_sizeself.in_chans = in_cself.embed_dim = embed_dimself.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()def forward(self, x):_, _, H, W = x.shape# padding# 如果输入图片的H,W不是patch_size的整数倍,需要进行paddingpad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)if pad_input:# to pad the last 3 dimensions,# (W_left,W_right, H_top,H_bottom, C_front,C_back)x = F.pad(x,(0, self.patch_size[1] - W % self.patch_size[1],0, self.patch_size[0] - H % self.patch_size[0],0, 0))# 下采样patch_size倍x = self.proj(x)_, _, H, W = x.shape# flatten: [B, C, H, W] -> [B, C, HW]# transpose: [B, C, HW] -> [B, HW, C]x = x.flatten(2).transpose(1, 2)x = self.norm(x)print(x.shape)# torch.Size([1, 3136, 96])# 224/4 * 224/4 = 3136return x, H, Wif __name__ == '__main__':img_path = "tulips.jpg"img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]print(img.size)# (500,375)#img_size = 224data_transform = transforms.Compose([transforms.Resize(int(img_size * 1.14)),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])img = data_transform(img)print(img.shape)# torch.Size([3, 224, 224])# expand batch dimensionimg = torch.unsqueeze(img, dim=0)print(img.shape)# torch.Size([1, 3, 224, 224])# split image into non-overlapping patchespatch_embed = PatchEmbed(norm_layer=nn.LayerNorm)patch_embed(img)
Stage 2 – 3.2. Shifted Window based Self-Attention
Shifted window partitioning in successive blocks
The window-based self-attention module lacks connections across windows, which limits its modeling power.
To introduce cross-window connections while maintaining the efficient computation of non-overlapping windows,
we propose a shifted window partitioning approach which alternates between two partitioning configurations in consecutive Swin Transformer blocks.
为了在保持非重叠窗口高效计算的同时引入跨窗口连接,我们提出了一种移位窗口划分方法,该方法在连续的Swin Transformer块中交替使用两种不同的划分配置。
Figure 2.
In layer l (left), a regular window partitioning scheme is adopted, and self-attention is computed within each window.
In the next layer l + 1 (right), the window partitioning is shifted, resulting in new windows.
The self-attention computation in the new windows crosses the boundaries of the previous windows in layer l, providing connections among them.
在新窗口中进行的自注意力计算跨越了第l层中先前窗口的边界,从而在它们之间建立了连接。
Efficient batch computation for shifted configuration
An issue with shifted window partitioning is that it will result in more windows, and some of the windows will be smaller than M×M.
Here, we propose a more efficient batch computation approach by cyclic-shifting toward the top-left direction(向左上方向循环移动), as illustrated in Figure 4.
这里的 more efficient,是说相对于直观方法 padding—mask来说:
A naive solution is to pad the smaller windows to a size of M×M and mask out the padded values when computing attention.
Figure 4. Illustration of an efficient batch computation approach for self-attention in shifted window partitioning.
After this shift, a batched window may be composed of several sub-windows that are not adjacent in the feature map, so a masking mechanism is employed to limit self-attention computation to within each sub-window.
在此转换之后,批处理窗口可能由特征图中不相邻的几个子窗口组成,因此采用掩蔽机制将自注意力计算限制在每个子窗口内。
With the cyclic-shift, the number of batched windows remains the same as that of regular window partitioning, and thus is also efficient.
通过循环移位,批处理窗口的数量与常规窗口分区的数量保持不变,因此也是高效的。
上图和叙述,并不太直观,找了相关资料,一起分析:
移动完成之后,4是一个单独区域,5、3为一组,7、1为一组,8、6、2、0为一组。
但,5、3本身是两个图像的边缘,混在一起计算不是乱了吗?一起计算也没问题,ViT也是全局计算的。
但,Swin-Transformer为了防止这个问题,在代码中使用了masked MSA,这样就能够通过设置蒙板来隔绝不同区域的信息了。
源码中具体的方法就是将不计算的位置元素减去100。
这里需要注意的是,在窗口数据进行滑动完之后,需要将数据还原回去,即挪回到原来的位置上。
代码
以下代码来自于model.py:
def window_partition(x, window_size: int):"""将feature map按照window_size划分成一个个没有重叠的window主要思路是将feature转成 (num_windows*B, window_size*window_size, C)的shape,把需要self-attn计算的window排列到第0维,一次并行的qkv就可以了Args:x: (B, H, W, C)window_size (int): window size(M)Returns:windows: (num_windows*B, window_size, window_size, C)"""B, H, W, C = x.shape# B,224,224,C# B,56,56,Cx = x.view(B, H // window_size, window_size, W // window_size, window_size, C)# B,32,7,32,7,C# B,8,7,8,7,C# permute:# [B, H//Mh, Mh, W//Mw, Mw, C] -># [B, H//Mh, W//Mh, Mw, Mw, C]# B,32,32,7,7,C# B,8,8,7,7,C# view:# [B, H//Mh, W//Mw, Mh, Mw, C] -># [B*num_windows, Mh, Mw, C]# B*1024,7,7,C# B*64,7,7,C# 32*32 = 1024# 224 / 7 = 32windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windows
分析:将 [B, C, 56, 56] 最后变成了[64B, C, 7, 7],原先的 B*C 张 56*56 的特征图,最后变成了 B*64*C张7*7的特征;
即,我们有64B个样本,每个样本包含C个7x7的通道。
注意,window_size–M–7,是每个window的大小,7*7,不是7*7个window,我刚开始混淆了这一点。
class BasicLayer(nn.Module):# A basic Swin Transformer layer for one stage.def __init__(self, dim, depth, num_heads, window_size,mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):super().__init__()self.dim = dimself.depth = depthself.window_size = window_sizeself.use_checkpoint = use_checkpointself.shift_size = window_size // 2# 7//2 = 3# build blocksself.blocks = nn.ModuleList([SwinTransformerBlock(dim=dim,num_heads=num_heads,window_size=window_size,shift_size=0 if (i % 2 == 0) else self.shift_size,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,drop=drop,attn_drop=attn_drop,drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,norm_layer=norm_layer)for i in range(depth)])...# depth: 2, 2, 6, 2# 即,第一层,depth=2, 有两个SwinTransformerBlock,shift_size分别为:0,3# 即,第二层,depth=2, 有两个SwinTransformerBlock,shift_size分别为:0,3# 即,第三层,depth=6, 有两个SwinTransformerBlock,shift_size分别为:# 0,3,0,3,0,3# 即,第四层,depth=2, 有两个SwinTransformerBlock,shift_size分别为:0,3def create_mask(self, x, H, W):# calculate attention mask for SW-MSA
import numpy as np
import torchH = 7
W = 7
window_size = 7
shift_size = 3Hp = int(np.ceil(H / window_size)) * window_size
Wp = int(np.ceil(W / window_size)) * window_size# 拥有和feature map一样的通道排列顺序,方便后续window_partition
img_mask = torch.zeros((1, Hp, Wp, 1))
# [1, Hp, Wp, 1]
print(img_mask, '\n')h_slices = (slice(0, -window_size),slice(-window_size, -shift_size),slice(-shift_size, None)
)
print(h_slices, '\n')
# (slice(0, -7, None), slice(-7, -3, None), slice(-3, None, None))w_slices = (slice(0, -window_size),slice(-window_size, -shift_size),slice(-shift_size, None)
)
print(w_slices, '\n')
# (slice(0, -7, None), slice(-7, -3, None), slice(-3, None, None))cnt = 0
for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1print(img_mask)
import torchimg_mask = torch.rand((2, 3))
print(img_mask)
'''
tensor([[0.7410, 0.6020, 0.5195],[0.9214, 0.2777, 0.8418]])
'''
attn_mask = img_mask.unsqueeze(1) - img_mask.unsqueeze(2)
print(attn_mask)
'''
tensor([[[ 0.0000, -0.1390, -0.2215],[ 0.1390, 0.0000, -0.0825],[ 0.2215, 0.0825, 0.0000]],[[ 0.0000, -0.6437, -0.0796],[ 0.6437, 0.0000, 0.5642],[ 0.0796, -0.5642, 0.0000]]])
'''print(img_mask.unsqueeze(1))
'''
tensor([[[0.7410, 0.6020, 0.5195]],[[0.9214, 0.2777, 0.8418]]])
'''
print(img_mask.unsqueeze(2))
'''
tensor([[[0.7410],[0.6020],[0.5195]],[[0.9214],[0.2777],[0.8418]]])
'''
上面那个代码,需要根据下面这个代码对应着走,shift_size–torch.roll()
class SwinTransformerBlock(nn.Module):# Swin Transformer Block....def forward(self, x, attn_mask):H, W = self.H, self.WB, L, C = x.shapeassert L == H * W, "input feature has wrong size"shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)# pad feature maps to multiples of window size# 把feature map给pad到window size的整数倍pad_l = pad_t = 0pad_r = (self.window_size - W % self.window_size) % self.window_sizepad_b = (self.window_size - H % self.window_size) % self.window_size# 注意F.pad的顺序,刚好是反着来的, 例如:# x.shape = (b, h, w, c)# x = F.pad(x, (1, 1, 2, 2, 3, 3))# x.shape = (b, h+6, w+4, c+2)# 源码可能有误,修改成下面的# x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))x = F.pad(x, (0, 0, pad_t, pad_b, pad_l, pad_r))_, Hp, Wp, _ = x.shape# cyclic shiftif self.shift_size > 0:# paper中,滑动的size是窗口大小的/2(向下取整)# torch.roll以H,W的维度为例子,负值往左上移动,正值往右下移动。# 溢出的值在对角方向出现。即循环移动。shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))else:shifted_x = xattn_mask = None# partition windowsx_windows = window_partition(shifted_x, self.window_size) # [nW*B, Mh, Mw, C]x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # [nW*B, Mh*Mw, C]...
其中,torch.roll()方法简易示例如下:
import torchx = torch.randn(1, 4, 4, 3)
print(x, '\n')shifted_x = torch.roll(x, shifts=(-3, -3), dims=(1, 2))
print(shifted_x, '\n')
为了方便理解,我更换了维度:
import torchx = torch.randn(1, 3, 7, 7)
print(x, '\n')shifted_x = torch.roll(x, shifts=(-3, -3), dims=(2, 3))
print(shifted_x, '\n')
Stage 3 – patch merging layers
To produce a hierarchical representation, the number of tokens is reduced by patch merging layers as the network gets deeper.
The first patch merging layer concatenates the features of each group of 2×2 neighboring patches, and applies a linear layer on the 4C-dimensional concatenated features.
首个补丁合并层将每组2×2相邻补丁的特征进行拼接,并在拼接后的4C维特征上应用一个线性层。
This reduces the number of tokens by a multiple of 2×2=4(2 ×downsampling of resolution), and the output dimension is set to 2C.
Swin Transformer blocks are applied afterwards for feature transformation, with the resolution kept at H/8 × W/8.
同样,结合其他大神分析,图展示如下:
Related Work
Self-attention based backbone architectures
Instead of using sliding windows, we propose to shift windows between consecutive layers, which allows for a more efficient implementation in general hardware.
。。。。。
Cited link or paper name
- Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.
- https://blog.csdn.net/weixin_42392454/article/details/141395092