Swin transformer 论文阅读记录 代码分析

该篇文章,是我解析 Swin transformer 论文原理(结合pytorch版本代码)所记,图片来源于源paper或其他相应博客。




Figure 1.
(a) The proposed Swin Transformer builds hierarchical feature maps by merging image patches (shown in gray) in deeper layers ,
and has linear computation complexity to input image size due to computation of self-attention only within each local window (shown in red).
It can thus serve as a general-purpose backbone for both image classification and dense recognition tasks.
(b) In contrast, previous vision Transformers produce feature maps of a single low resolution and have quadratic computation complexity to input image size due to computation of self attention globally.


Figure 3.
(a) The architecture of a Swin Transformer (Swin-T);
(b) two successive Swin Transformer Blocks (notation presented with Eq. (3)).
W-MSA and SW-MSA are multi-head self attention modules with regular and shifted windowing configurations, respectively.

Stage 1 – Patch Embedding

It first splits an input RGB image into non-overlapping patches by a patch splitting module, like ViT.

Each patch is treated as a “token” and its feature is set as a concatenation of the raw pixel RGB values.

In our implementation, we use a patch size of 4×4 and thus the feature dimension of each patch is 4×4×3 = 48.(channel–3)

A linear embedding layer is applied on this raw-valued feature to project it to an arbitrary dimension (denoted as C).
这个表述,linear embedding layer,我感觉不太准确,但是,后半部分比较准确,哈哈,将channel–3变成了96.

Several Transformer blocks with modified self-attention computation (Swin Transformer blocks) are applied on these patch tokens.

The Transformer blocks maintain the number of tokens (H/4 × W/4), and together with the linear embedding are referred to as “Stage 1”.



class PatchEmbed(nn.Module):"""2D Image to Patch Embedding"""
@ time : 2024/12/17
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as Fclass PatchEmbed(nn.Module):"""2D Image to Patch Embedding"""def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):super().__init__()patch_size = (patch_size, patch_size)self.patch_size = patch_sizeself.in_chans = in_cself.embed_dim = embed_dimself.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()def forward(self, x):_, _, H, W = x.shape# padding# 如果输入图片的H,W不是patch_size的整数倍,需要进行paddingpad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)if pad_input:# to pad the last 3 dimensions,# (W_left,W_right, H_top,H_bottom, C_front,C_back)x = F.pad(x,(0, self.patch_size[1] - W % self.patch_size[1],0, self.patch_size[0] - H % self.patch_size[0],0, 0))# 下采样patch_size倍x = self.proj(x)_, _, H, W = x.shape# flatten: [B, C, H, W] -> [B, C, HW]# transpose: [B, C, HW] -> [B, HW, C]x = x.flatten(2).transpose(1, 2)x = self.norm(x)print(x.shape)# torch.Size([1, 3136, 96])# 224/4 * 224/4 = 3136return x, H, Wif __name__ == '__main__':img_path = "tulips.jpg"img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]print(img.size)# (500,375)#img_size = 224data_transform = transforms.Compose([transforms.Resize(int(img_size * 1.14)),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])img = data_transform(img)print(img.shape)# torch.Size([3, 224, 224])# expand batch dimensionimg = torch.unsqueeze(img, dim=0)print(img.shape)# torch.Size([1, 3, 224, 224])# split image into non-overlapping patchespatch_embed = PatchEmbed(norm_layer=nn.LayerNorm)patch_embed(img)

Stage 2 – 3.2. Shifted Window based Self-Attention

Shifted window partitioning in successive blocks

The window-based self-attention module lacks connections across windows, which limits its modeling power.

To introduce cross-window connections while maintaining the efficient computation of non-overlapping windows,
we propose a shifted window partitioning approach which alternates between two partitioning configurations in consecutive Swin Transformer blocks.
为了在保持非重叠窗口高效计算的同时引入跨窗口连接,我们提出了一种移位窗口划分方法,该方法在连续的Swin Transformer块中交替使用两种不同的划分配置。

Figure 2.
In layer l (left), a regular window partitioning scheme is adopted, and self-attention is computed within each window.
In the next layer l + 1 (right), the window partitioning is shifted, resulting in new windows.
The self-attention computation in the new windows crosses the boundaries of the previous windows in layer l, providing connections among them.

Efficient batch computation for shifted configuration

An issue with shifted window partitioning is that it will result in more windows, and some of the windows will be smaller than M×M.

Here, we propose a more efficient batch computation approach by cyclic-shifting toward the top-left direction(向左上方向循环移动), as illustrated in Figure 4.

这里的 more efficient,是说相对于直观方法 padding—mask来说:

A naive solution is to pad the smaller windows to a size of M×M and mask out the padded values when computing attention.

Figure 4. Illustration of an efficient batch computation approach for self-attention in shifted window partitioning.

After this shift, a batched window may be composed of several sub-windows that are not adjacent in the feature map, so a masking mechanism is employed to limit self-attention computation to within each sub-window.

With the cyclic-shift, the number of batched windows remains the same as that of regular window partitioning, and thus is also efficient.




但,Swin-Transformer为了防止这个问题,在代码中使用了masked MSA,这样就能够通过设置蒙板来隔绝不同区域的信息了。





def window_partition(x, window_size: int):"""将feature map按照window_size划分成一个个没有重叠的window主要思路是将feature转成 (num_windows*B, window_size*window_size, C)的shape,把需要self-attn计算的window排列到第0维,一次并行的qkv就可以了Args:x: (B, H, W, C)window_size (int): window size(M)Returns:windows: (num_windows*B, window_size, window_size, C)"""B, H, W, C = x.shape# B,224,224,C# B,56,56,Cx = x.view(B, H // window_size, window_size, W // window_size, window_size, C)# B,32,7,32,7,C# B,8,7,8,7,C# permute:# [B, H//Mh, Mh,    W//Mw, Mw, C] -># [B, H//Mh, W//Mh, Mw,    Mw, C]# B,32,32,7,7,C# B,8,8,7,7,C# view:# [B, H//Mh, W//Mw, Mh, Mw, C] -># [B*num_windows,   Mh, Mw, C]# B*1024,7,7,C# B*64,7,7,C# 32*32 = 1024# 224 / 7 = 32windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windows

分析:将 [B, C, 56, 56] 最后变成了[64B, C, 7, 7],原先的 B*C 张 56*56 的特征图,最后变成了 B*64*C张7*7的特征;



class BasicLayer(nn.Module):# A basic Swin Transformer layer for one stage.def __init__(self, dim, depth, num_heads, window_size,mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):super().__init__()self.dim = dimself.depth = depthself.window_size = window_sizeself.use_checkpoint = use_checkpointself.shift_size = window_size // 2# 7//2 = 3# build blocksself.blocks = nn.ModuleList([SwinTransformerBlock(dim=dim,num_heads=num_heads,window_size=window_size,shift_size=0 if (i % 2 == 0) else self.shift_size,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,drop=drop,attn_drop=attn_drop,drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,norm_layer=norm_layer)for i in range(depth)])...# depth: 2, 2, 6, 2# 即,第一层,depth=2, 有两个SwinTransformerBlock,shift_size分别为:0,3# 即,第二层,depth=2, 有两个SwinTransformerBlock,shift_size分别为:0,3# 即,第三层,depth=6, 有两个SwinTransformerBlock,shift_size分别为:#	0,3,0,3,0,3# 即,第四层,depth=2, 有两个SwinTransformerBlock,shift_size分别为:0,3def create_mask(self, x, H, W):# calculate attention mask for SW-MSA
import numpy as np
import torchH = 7
W = 7
window_size = 7
shift_size = 3Hp = int(np.ceil(H / window_size)) * window_size
Wp = int(np.ceil(W / window_size)) * window_size# 拥有和feature map一样的通道排列顺序,方便后续window_partition
img_mask = torch.zeros((1, Hp, Wp, 1))
# [1, Hp, Wp, 1]
print(img_mask, '\n')h_slices = (slice(0, -window_size),slice(-window_size, -shift_size),slice(-shift_size, None)
print(h_slices, '\n')
# (slice(0, -7, None), slice(-7, -3, None), slice(-3, None, None))w_slices = (slice(0, -window_size),slice(-window_size, -shift_size),slice(-shift_size, None)
print(w_slices, '\n')
# (slice(0, -7, None), slice(-7, -3, None), slice(-3, None, None))cnt = 0
for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1print(img_mask)


import torchimg_mask = torch.rand((2, 3))
tensor([[0.7410, 0.6020, 0.5195],[0.9214, 0.2777, 0.8418]])
attn_mask = img_mask.unsqueeze(1) - img_mask.unsqueeze(2)
tensor([[[ 0.0000, -0.1390, -0.2215],[ 0.1390,  0.0000, -0.0825],[ 0.2215,  0.0825,  0.0000]],[[ 0.0000, -0.6437, -0.0796],[ 0.6437,  0.0000,  0.5642],[ 0.0796, -0.5642,  0.0000]]])
tensor([[[0.7410, 0.6020, 0.5195]],[[0.9214, 0.2777, 0.8418]]])


class SwinTransformerBlock(nn.Module):# Swin Transformer Block....def forward(self, x, attn_mask):H, W = self.H, self.WB, L, C = x.shapeassert L == H * W, "input feature has wrong size"shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)# pad feature maps to multiples of window size# 把feature map给pad到window size的整数倍pad_l = pad_t = 0pad_r = (self.window_size - W % self.window_size) % self.window_sizepad_b = (self.window_size - H % self.window_size) % self.window_size# 注意F.pad的顺序,刚好是反着来的, 例如:# x.shape = (b, h, w, c)# x = F.pad(x, (1, 1, 2, 2, 3, 3))# x.shape = (b, h+6, w+4, c+2)# 源码可能有误,修改成下面的# x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))x = F.pad(x, (0, 0, pad_t, pad_b, pad_l, pad_r))_, Hp, Wp, _ = x.shape# cyclic shiftif self.shift_size > 0:# paper中,滑动的size是窗口大小的/2(向下取整)# torch.roll以H,W的维度为例子,负值往左上移动,正值往右下移动。# 溢出的值在对角方向出现。即循环移动。shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))else:shifted_x = xattn_mask = None# partition windowsx_windows = window_partition(shifted_x, self.window_size)  # [nW*B, Mh, Mw, C]x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # [nW*B, Mh*Mw, C]...


import torchx = torch.randn(1, 4, 4, 3)
print(x, '\n')shifted_x = torch.roll(x, shifts=(-3, -3), dims=(1, 2))
print(shifted_x, '\n')


import torchx = torch.randn(1, 3, 7, 7)
print(x, '\n')shifted_x = torch.roll(x, shifts=(-3, -3), dims=(2, 3))
print(shifted_x, '\n')


Stage 3 – patch merging layers

To produce a hierarchical representation, the number of tokens is reduced by patch merging layers as the network gets deeper.

The first patch merging layer concatenates the features of each group of 2×2 neighboring patches, and applies a linear layer on the 4C-dimensional concatenated features.

This reduces the number of tokens by a multiple of 2×2=4(2 ×downsampling of resolution), and the output dimension is set to 2C.

Swin Transformer blocks are applied afterwards for feature transformation, with the resolution kept at H/8 × W/8.



Related Work

Self-attention based backbone architectures

Instead of using sliding windows, we propose to shift windows between consecutive layers, which allows for a more efficient implementation in general hardware.


Cited link or paper name

  1. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.
  2. https://blog.csdn.net/weixin_42392454/article/details/141395092




