1.1 简介
Swin Transformer 是一种用于计算机视觉任务的新型深度学习架构,由微软亚洲研究院于2021年提出。它结合了Transformer模型在序列数据处理上的强大能力与卷积神经网络(CNN)在图像识别中的高效局部特征提取优势,特别适用于图像分类、目标检测、语义分割等任务。Swin Transformer的主要创新点在于其“Shifted Window”机制,这一机制使得模型既能保持Transformer的全局注意力特性,又能有效利用图像数据的空间局部性,从而在性能和计算效率上取得了平衡。下面是Swin Transformer的几个核心特点和组成部分的详细介绍:
1. 分层Transformer结构
- 分层设计:与ViT(Vision Transformer)直接将整个图像分割为固定大小的patch不同,Swin Transformer采用了类似于CNN的分层设计。它首先将图像分割成小块(称为“patches”),然后通过多层逐步减小这些块的大小和增加通道数,形成了层次化的特征表示。这种设计有助于捕捉不同尺度的特征。
2. Shifted Windows机制
-
窗口划分:Swin Transformer的核心创新在于引入了“窗口”概念,即在每个阶段,模型只在局部窗口内计算自注意力,这大大减少了计算复杂度,使其在图像数据上变得可行。
-
位移窗口:在连续的Transformer层中,窗口会按照一定的模式(例如,上下左右平移一定步长)进行移动或“shift”,这样可以确保信息在不同窗口间流动,实现跨窗口的交互,同时保持计算效率。这一机制既保留了局部特征的有效提取,又引入了一定程度的全局上下文信息。
3. 轻量级自注意力计算
-
无重叠窗口内的自注意力:仅在窗口内部计算自注意力,避免了全局自注意力带来的高昂计算成本。
-
线性计算复杂度:由于窗口内的自注意力计算是独立的,Swin Transformer的自注意力计算复杂度与窗口大小相关,而非与图像尺寸相关,这使得模型在处理高分辨率图像时更加高效。
4. 连续的下采样策略
- 与传统CNN中的池化操作类似,Swin Transformer通过在某些层级减少窗口的数量(即窗口合并)来实现空间下采样,进一步提高模型的感受野,同时保持了窗口内的局部注意力机制。
5. 性能与应用
-
广泛适用性:Swin Transformer因其高效的结构设计,在多个视觉任务上展现出优越的性能,包括图像分类、目标检测、实例分割和语义分割等。
-
可训练性和可扩展性:该模型易于训练且具有良好的可扩展性,能够通过增加层数或调整窗口大小等手段进一步提升性能。
总结
Swin Transformer的成功在于它创造性地融合了Transformer架构与图像处理中的局部性原理,通过Shifted Windows机制实现了高效而强大的视觉特征表示能力。这种设计不仅在理论上具有创新意义,也在实际应用中展示了强大的竞争力,成为当前视觉变换器研究的一个重要里程碑。
下图中可以看到,Swin Transformer的模型效果非常好。
1.2 网络整体框架
和VIT做一个对比,我们可以发现,(1)Swin Transformer的下采样的倍率是不一样的,所以能构建出不同层次的特征图,而VIT全都是16倍下采样。(2)Swin Transformer将特征图用窗口的形式分割开了,窗口与窗口之间是没有重叠的。在VIT当中,特征图是一个整体并没有分割。
window与window之间不进行信息的传递。这样做的好处在于能够减少运算量,尤其是浅层网络。
网络的整体架构如下:
先来说一下图(a)中的patch partition是什么东西,看图中的上半部分。
对于输入的图像,patch partition会使用一个4x4大小的窗口对图像进行分割。分割之后,按照channel方向进行展平(沿深度方向进行拼接),因为每个图像都是三通道的,这里一个通道有16个,那么三通道就是48个。所以经过patch partition,原来图像的高宽缩减为原来的四分之一,通道数变为48。接下来通过linear embedding层来对输入特征矩阵的channel进行一个调整,调整之后通道数变为C,这个C具体为多少需要针对Swin transformer的不同版本(T\S\B\L)采用不同值。然后linear embedding后跟着一个layer norm层归一化。
patch partition和linear embedding都是通过卷积层实现的。
接着看图(a),在stage1还要堆叠swin transformer block 2次。注意一下,在不同stage中block的堆叠次数都是偶数次,原因是因为图b,这俩block我们都是成对去使用的。和VIT不同的是,这里的block将多头注意力机制替换成了W-msa和sw-msa。
1.3 Patch Merging
这是对于stage2\3\4,在block前都会进行一个patch emerging。它的作用就是下采样。
假设我们输入的特征矩阵的高和宽都是4x4的,以单通道为例,我们以2x2大小作为一个窗口,在这窗口中有四个像素,然后我们将每个窗口位置中相同的像素取出来,然后得到四个特征矩阵,按照channel进行concat,层归一化,再通过一个全连接层进行线性映射。
经过Patch Merging,高宽缩减为原来的一半,通道数变为原来的二倍。
1.4 W-MSA(windows-multi head self attention)
对于普通的MSA,我们会对每一个像素pixel去求它的QKV,然后对于每一个像素所求的Q和K进行匹配然后进行一系列操作。也就是说对于每一个pixel会对特征矩阵当中所有的像素进行一个沟通。
对于W-MSA,我们会对特征图分成一个个的window,然后对每个window内部进行MSA。这么做的目的就是为了减少计算量。缺点是窗口之间无法进行信息交互,也就导致感受野变小,无法看到全局的视野。
MSA和W-MSA模块计算量
这俩公式是怎么来的:https://blog.csdn.net/qq_37541097/article/details/121119988
(对于矩阵相乘,假设有矩阵A(形状axb)和B(bxc),他俩进行矩阵相乘,FLOPS应该为:axbxc。)
1.5 SW-MSA(shifted windows-multi head self attention)
对于SW-MSA是怎么划分窗口的,其实是将原来WMSA划分窗口的那个东西向右又向下移动了两个像素(或者说patch)。
这样划分之后,我们那第一行第二个窗口举例,在对划分后的窗口进行MSA计算的话就会融合上一层第一行那两个窗口的信息。其他窗口同理。
一开始WMSA划分窗口是这样的:
然后向右再向下移动两个像素:
这样划分来的就是SWMSA的窗口了。
现在有一个问题,原来是四个window,现在变成9个window了,如果我们要实现并行计算的话,除了中间那个,其他八个大小都是小于 4x4的,那么就要进行填充到 4x4大小,那么这相当于9个windows计算量,那么这样做我们的计算量又增加了。
作者为了解决这个问题提出了一个解决方法,见图右下角以及后面的图片。
通过SWMSA已经划分为9个window了,接下来我们将左上角的window标记为区域A,然后1,2标记为C,3和6标记为B。
接下来我们将A和C移动到下面来:
接下来我们再将A和B移动到右边去:
然后我们重新划分window,接下来我们把4不变,3和5合并,7和1合并,0268合并,这样我们对这合并后的四个window分别进行一个MSA,那么它的计算量和WMSA其实是一样的。
那么这又会引入一个新的问题。比如,对于5和3,它本来是两个分开的区域并不连续,我们强行把它俩划在一个windows里面去了,如果我们对它进行一个MSA计算的话,其实是有问题的。
那么,我们就希望在这个window内,单独计算5和3的MSA。
具体怎么做呢?论文中给出了方法:加上MSAK MSA,就是加上一个蒙版。
举一个例子,还是刚刚的5和3,我们有16个patch,计算5和3的MSA是分开计算的。根据之前讲的MSA计算公式,我们对0这个像素先算,算出来q0矩阵,然后与其他像素进行匹配(点乘计算相似度),得到α0,0到α0,15。但是我们又不希望引入区域3的信息,作者将区域三的α的值全部减去100,然后进行softmax的时候就会非常小趋近于0,也就是说,对于像素0,它与区域三内的像素的相似度全是0。
通过这个方法我们就将区域5和区域3的像素分开了。计算量和WMSA是一样的,只不过多了一个mask中的减法操作。
注意,我们在计算完之后,需要将数据挪回去(因为之间对区域进行了移动)。
现在举一个例子,假设我们经过WMSA得到的矩阵是9x9的特征矩阵,然后window尺寸是3x3的,那么现在就有一个问题,我们需要将那几行或者那几列数据进行挪动呢?
论文给的方法是:首先我们将窗口尺寸M除以2然后向下取整,那么3对应的就是1,也就是说我们只需要移动第一行和第一列。
粗的黑色分割线是原始的window。接下来我们再用3x3的window重新划分。
对于这四个橙色的而言我们可以直接进行MSA操作,因为内部它们的数据本来就是连续的。而且我们可以观察到,橙色的每一个windows都能融合上一层的四个windows的信息。
而对于紫色的而言并不是连续的,需要mask MSA。
1.6 Relative Position Bias (相对位置偏置)
下图中偏置就是公式中的B。
表4说明,如果我们不去使用任何位置编码的话,准确率是80.1%,如果使用绝对位置编码(即VIT中的位置编码),Imagenet准确率上升了一点点,但是在COCO和ADE上的目标检测性能反而降低了,这就说明使用绝对位置编码其实效果并不好。
如果我们使用相对位置偏置的话,那么它在下列任务中准确率均有了比较大的提升。
表的前两行使用了WMSA和SWMSA的对比可以发现是用SWMSA也是非常有必要的。
那么到底什么是相对位置偏置呢?
我们首先假设我们这里有一个feature map,我们先对这个特征图进行标注绝对位置索引。
接下来我们标注相对位置索引,相对位置索引的值,就是先选一个颜色所在的像素,然后剩余位置的相对索引就是用选中参考点的减去剩余的位置的绝对位置索引。比如选蓝色,然后橙色的相对位置索引就是蓝色的绝对位置索引减去橙色的绝对位置索引。以此类推。
我们将每一个相对位置索引的矩阵在行方向上展平,然后拼接在一起就得到了下图的矩阵。根据矩阵每个位置上的相对位置索引可以在relative position bias table中取到一个对应的参数。
注意相对位置索引和相对位置偏置是两个不同的概念,我们需要利用索引去取相应的参数。
在原作者使用的代码中并不是使用的这样的一个二元的位置坐标,它使用的是一个一元的位置坐标,那么就需要将二元坐标转为一元坐标。
首先,我们先使偏移从0开始,我们将位置坐标全部加上M-1(M是窗口尺寸,这里举的例子是2)
然后,我们在行标乘上2M-1。
接着我们将行标列标相加。
我们可以发现,原来相同的索引的位置在变换后还是相同的。
我们再来看这个bias table,元素个数为(2M-1)X(2M-1)。
拿取到表中的偏置值之后,矩阵变化成如下这样子。这个最终的矩阵才是我们最后要的B。
那么元素个数为什么是(2M-1)X(2M-1)呢?我们回到这张图上,这里我们取两个最极端的位置,即蓝色和绿色,那么我们可以的出来,相对位置索引的值最大是M-1,最小是-M+1,以M=2为例,那么范围内就有-1,0,1这三个数,对它进行排列组合一共有9种。
所以,对于行而言,一共有2M-1个数可以取,对于列而言,也有2M-1个数可以取。
1.7 模型详细配置参数
对于stage1中的,就是指这两个模块,作用就是对输入进行下采样,调整通道数,4x4意思就是将高和宽下采样四倍,96-d意思是通过Linear embedding后特征矩阵通道数变为96,LN就是layer normalization。
stage1中,win.sz指的是窗口的大小,dim86就是说经过stage1后输出的channel是96,也就是那个C。head3指的是多头注意力机制MSA是3头。
下面stage同理。