写在最前面之如何只用nn.Linear实现nn.Conv2d的功能
很多人说,Swin-Transformer就是另一种Convolution,但是解释得真就是一坨shit,这里我郑重解释一下,这是为什么?
首先,Convolution是什么?
Convolution是一种矩形区域内参数共享的Linear
这么说可能不好理解,那么我们上代码
import torch
import torch.nn as nn
import torch.nn.functional as Fclass Conv2D(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride):"""为了简单且便于理解,我们设定图片的Size是Kernel_size的整数倍,且Kernel_size等于Stride"""super(LinearConv2d, self).__init__()self.in_channels = in_channelsself.out_channels = out_channelsself.kernel_size = kernel_sizeself.stride = stride# 计算权重矩阵的维度weight_size = in_channels * kernel_size * kernel_sizeself.linear = nn.Linear(weight_size, out_channels, bias=False)def forward(self, x):# 计算输出特征图的尺寸B, C, H, W = x.size()output_height = H // self.strideoutput_width = W // self.stride# 展开输入特征,沿着kernel_size的窗口展开x_flatten = x.view(B, H // self.kernel_size, self.kernel_size, W // self.kernel_size, self.kernel_size, C)x_flatten = x_flatten.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.kernel_size, self.kernel_size, C)# 应用线性变换output_flatten = self.linear(x_flatten)# 重塑输出形状output = output_flatten.view(B, self.out_channels, output_height, output_width)return output# 使用nn.Linear实现nn.Conv2d(256, 256, k=7, s=7)
conv2d_manual = Conv2D(256, 256, 7, 7)# 创建一个随机初始化的输入张量,确保尺寸是7的整数倍
input_tensor = torch.randn(1, 256, 56, 56) # 假设输入图像大小为56x56,56是7的倍数# 应用卷积操作
output = conv2d_manual(input_tensor)
# 输出形状应为[1, 256, 8, 8]
print(output.shape)
上述代码通过了使用输入数据的维度变换,实现了利用nn.Linear来进行nn.Conv2d的过程,当然,nn.Conv1d甚至nn.Conv3d等也是同样操作。这里我们先记住,后面我们详细解释
Swin-Transformer为什么这么叫
首先,需要理解为什么叫Swin!
作者依然使用了Vision Transformer的主题架构,核心区别是对数据处理的区别!
在Vision Transformer中,数据根据spatial维度进行拉伸,并成为[Batch, HW, C]的样子,如图所示,具体参考Transformer之Vision Transformer结构解读
而在Swin-Transformer中,额外增加了一步,就是把维度为 [ B a t c h , H × W , C ] [Batch, H\times W, C] [Batch,H×W,C]的patch_embedding,进行二次分割,变成 [ B a t c h × n u m _ w i n d o w 2 , w i n d o w _ s i z e , w i n d o w _ s i z e , C ] [Batch \times num\_window^2, window\_size, window\_size, C] [Batch×num_window2,window_size,window_size,C],如图所示,
- 第一张图片就是经过patch_embed的patch_embedding
- 第二张图片就是经过window_partrition分割后的图片
- 第三张图片就是处理成 [ B a t c h × n u m _ w i n d o w 2 , w i n d o w _ s i z e , w i n d o w _ s i z e , C ] [Batch \times num\_window^2, window\_size, window\_size, C] [Batch×num_window2,window_size,window_size,C]的图片
这里还有一个操作,就是在第偶数个Attention-Block中,把输入的patch_embedding进行torch.roll操作,这个操作就是循环位移
这时候就可以解释为什么说Swin-Transformer就是另一种形式的CNN了
从上面的图片中可以看到如下过程: - 一张图片,经过nn.Conv2d(k=patch_size, stride=patch_size),将其分割成 N 2 N^2 N2个patch_embedding
- patch_embedding经过维度重整,从 [ B , H × W , C ] [B, H\times W, C] [B,H×W,C]变成 [ B a t c h × n u m _ w i n d o w 2 , w i n d o w _ s i z e , w i n d o w _ s i z e , C ] [Batch \times num\_window^2, window\_size, window\_size, C] [Batch×num_window2,window_size,window_size,C],然后送入nn.Linear()。这里的维度重整加上nn.Linear(),等于nn.Conv2d,可以通过写在最前面的"如何只用nn.Linear()实现nn.Conv2d的功能"看出
- 上一步可以总结为:经过nn.Conv2d的patch_embedding继续经过若干nn.Conv2d
Swin-Transformer的位置编码
绝对位置编码
详情参考Transformer之位置编码的通俗理解
在patch_embedding过程中,依然将Token和PE相加,如上图二所示。
但是既然有了相对位置编码,为什么还要加上绝对位置编码呢?
- 数学解释如下:
Q E + P E × K E + P E T = X E + P E × W q × [ X E + P E × W k ] T = X E + P E × W q × W k T × X E + P E T = ( X q + P E q ) × W q × W k T × ( X k + P E k ) T = X q × W q ⏞ Q u e r y × W k T × X k T ⏞ K e y ⏟ 第一项 + P E q × W q ⏞ a × W k T × X k T ⏞ K e y ⏟ 第二项 + X q × W q ⏞ Q u e r y × W k T × P E k T ⏞ b ⏟ 第三项 + P E q × W q ⏞ a × W k T × P E k T ⏞ b ⏟ 第四项 \begin{array}{ccl} Q_{E+PE} \times K_{E+PE}^T &= & X_{E + PE} \times W_q \times \Big[X_{E + PE} \times W_k \Big]^T \\ && \\ &= & X_{E + PE} \times W_q \times W_k^T \times X^T_{E + PE} \\ && \\ & = &(X_q+PE_q) \times W_q \times W_k^T \times (X_k+PE_k)^T \\ &&\\ &= &\underbrace{\overbrace{X_q \times W_q}^{Query} \times \overbrace{W_k^T \times X_k^T}^{Key}}_{第一项}+ \underbrace{ \overbrace{PE_q \times W_q}^{a} \times \overbrace{W_k^T \times X_k^T}^{Key}}_{第二项} + \underbrace{\overbrace{X_q \times W_q}^{Query} \times \overbrace{W_k^T \times PE^T_k}^{b}}_{第三项} + \underbrace{\overbrace{PE_q \times W_q}^{a} \times \overbrace{W_k^T \times PE^T_k}^{b}}_{第四项} \end{array} QE+PE×KE+PET====XE+PE×Wq×[XE+PE×Wk]TXE+PE×Wq×WkT×XE+PET(Xq+PEq)×Wq×WkT×(Xk+PEk)T第一项 Xq×Wq Query×WkT×XkT Key+第二项 PEq×Wq a×WkT×XkT Key+第三项 Xq×Wq Query×WkT×PEkT b+第四项 PEq×Wq a×WkT×PEkT b
绝对位置编码只能消去第三项和第四项中的d项,依然需要第二项中的a项,才能具有完整的偏置
- 直觉解释如下
如果只有相对位置编码,也就是相当于只有相对位置偏置,这个过程和只有绝对位置偏置的意义是相同的,所以只有同时具有相对位置编码和绝对位置编码,才能避免两者是等效的
相对位置编码
详情参考Transformer之位置编码的通俗理解
相对位置编码,实际上是Attention机制的偏置的位置编码:
A t t = s o f t m a x ( Q × K T D i m + r e l a t i v e _ p o s i t i o n _ b i a s ) × V Att = softmax\Big( \frac{Q \times K^T}{\sqrt{Dim}} + relative\_position\_bias\Big) \times V Att=softmax(DimQ×KT+relative_position_bias)×V
这里受到CSDN图片尺寸的限制,只能发这种清晰度的,点击这里下载无损svg