论文信息
标题: FFT-based Dynamic Token Mixer for Vision
论文链接: https://arxiv.org/pdf/2303.03932
关键词: 深度学习、计算机视觉、对象检测、分割
GitHub链接: https://github.com/okojoalg/dfformer
创新点
本论文提出了一种新的标记混合器(token mixer),称为动态滤波器(Dynamic Filter),旨在解决多头自注意力(MHSA)模型在处理高分辨率图像时的计算复杂度问题。传统的MHSA模型在输入特征图中像素数量的平方上具有计算复杂度,导致处理速度缓慢。通过引入基于快速傅里叶变换(FFT)的动态滤波器,论文展示了在保持性能的同时显著降低计算复杂度的可能性。
方法
论文中提出的动态滤波器结合了全局操作的优点,类似于MHSA,但在计算效率上更具优势。具体方法包括:
- FFT-based Token Mixer: 通过FFT实现全局操作,降低计算复杂度。
- DFFormer和CDFFormer模型: 这两种新型图像识别模型利用动态滤波器进行图像分类和其他下游任务。
动态滤波器如何具体降低MHSA模型的计算复杂度?
动态滤波器通过引入基于快速傅里叶变换(FFT)的机制,显著降低了多头自注意力(MHSA)模型的计算复杂度。以下是其具体工作原理和优势:
计算复杂度问题
传统的MHSA模型在处理输入特征图时,其计算复杂度与特征图中像素数量的平方成正比。这意味着,当输入图像的分辨率增加时,计算需求会急剧上升,导致处理速度变慢,尤其是在高分辨率图像的情况下。
动态滤波器的工作原理
-
频域转换: 动态滤波器首先利用FFT将输入特征图转换到频域。FFT是一种高效的算法,可以将计算复杂度降低到 O ( N log N ) O(N \log N) O(NlogN),其中 N N N是数据的长度。这一转换使得后续的操作可以在频域中进行,从而减少了计算量。
-
动态生成滤波器: 在频域中,动态滤波器通过一个多层感知机(MLP)动态生成每个特征通道的滤波器。这些滤波器是根据输入特征图的内容进行调整的,能够更好地捕捉到图像中的重要信息。
-
频域操作: 生成的滤波器在频域中应用于特征图,进行全局信息的捕捉。通过这种方式,动态滤波器能够有效地进行全局操作,同时避免了MHSA中计算复杂度的急剧增加。
-
逆FFT转换: 最后,经过滤波的频域特征图通过逆FFT转换回空间域,得到最终的输出结果。
优势
-
降低计算复杂度: 通过在频域中进行操作,动态滤波器显著降低了MHSA模型的计算复杂度,使得处理高分辨率图像时的速度得以提升。
-
提高内存效率: 动态滤波器的设计使得模型在处理时占用更少的内存,适合在资源有限的环境中运行。
-
保持性能: 尽管计算复杂度降低,动态滤波器仍然能够保持与MHSA相似的性能,尤其是在图像分类和其他视觉任务中表现出色。
效果
实验结果表明,DFFormer和CDFFormer在高分辨率图像识别任务中表现出色,具有显著的吞吐量和内存效率。具体而言,这些模型在处理高分辨率图像时的性能优于传统的MHSA模型,显示出动态滤波器在实际应用中的潜力。
实验结果
论文通过一系列实验验证了提出模型的有效性,包括:
- 图像分类: DFFormer和CDFFormer在标准数据集上的表现接近或超过了现有的最先进模型。
- 下游任务分析: 通过对比实验,展示了动态滤波器在不同视觉任务中的适用性和优势。
总结
本论文的研究表明,基于FFT的动态滤波器是一种值得认真考虑的标记混合器选项,尤其是在处理高分辨率图像时。通过降低计算复杂度,动态滤波器不仅提高了模型的处理速度,还保持了良好的性能,推动了计算机视觉领域的进一步发展。研究结果为未来的视觉模型设计提供了新的思路和方向。
代码
import torch
import torch.nn as nn
from timm.models.layers import to_2tupleclass StarReLU(nn.Module):"""StarReLU: s * relu(x) ** 2 + b"""def __init__(self, scale_value=1.0, bias_value=0.0,scale_learnable=True, bias_learnable=True,mode=None, inplace=False):super().__init__()self.inplace = inplaceself.relu = nn.ReLU(inplace=inplace)self.scale = nn.Parameter(scale_value * torch.ones(1),requires_grad=scale_learnable)self.bias = nn.Parameter(bias_value * torch.ones(1),requires_grad=bias_learnable)def forward(self, x):return self.scale * self.relu(x) ** 2 + self.biasclass Mlp(nn.Module):""" MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks.Mostly copied from timm."""def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0.,bias=False, **kwargs):super().__init__()in_features = dimout_features = out_features or in_featureshidden_features = int(mlp_ratio * in_features)drop_probs = to_2tuple(drop)self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)self.act = act_layer()self.drop1 = nn.Dropout(drop_probs[0])self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)self.drop2 = nn.Dropout(drop_probs[1])def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop1(x)x = self.fc2(x)x = self.drop2(x)return xclass DynamicFilter(nn.Module):def __init__(self, dim, expansion_ratio=2, reweight_expansion_ratio=.25,act1_layer=StarReLU, act2_layer=nn.Identity,bias=False, num_filters=4, size=14, weight_resize=False,**kwargs):super().__init__()size = to_2tuple(size)self.size = size[0]self.filter_size = size[1] // 2 + 1self.num_filters = num_filtersself.dim = dimself.med_channels = int(expansion_ratio * dim)self.weight_resize = weight_resizeself.pwconv1 = nn.Linear(dim, self.med_channels, bias=bias)self.act1 = act1_layer()self.reweight = Mlp(dim, reweight_expansion_ratio, num_filters * self.med_channels)self.complex_weights = nn.Parameter(torch.randn(self.size, self.filter_size, num_filters, 2,dtype=torch.float32) * 0.02)self.act2 = act2_layer()self.pwconv2 = nn.Linear(self.med_channels, dim, bias=bias)def forward(self, x):B, H, W, _ = x.shaperouteing = self.reweight(x.mean(dim=(1, 2))).view(B, self.num_filters,-1).softmax(dim=1)x = self.pwconv1(x)x = self.act1(x)x = x.to(torch.float32)x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')if self.weight_resize:complex_weights = resize_complex_weight(self.complex_weights, x.shape[1],x.shape[2])complex_weights = torch.view_as_complex(complex_weights.contiguous())else:complex_weights = torch.view_as_complex(self.complex_weights)routeing = routeing.to(torch.complex64)weight = torch.einsum('bfc,hwf->bhwc', routeing, complex_weights)if self.weight_resize:weight = weight.view(-1, x.shape[1], x.shape[2], self.med_channels)else:weight = weight.view(-1, self.size, self.filter_size, self.med_channels)x = x * weightx = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')x = self.act2(x)x = self.pwconv2(x)return x
def resize_complex_weight(origin_weight, new_h, new_w):h, w, num_heads = origin_weight.shape[0:3] # size, w, c, 2origin_weight = origin_weight.reshape(1, h, w, num_heads * 2).permute(0, 3, 1, 2)new_weight = torch.nn.functional.interpolate(origin_weight,size=(new_h, new_w),mode='bicubic',align_corners=True).permute(0, 2, 3, 1).reshape(new_h, new_w, num_heads, 2)return new_weightif __name__ == "__main__":# 如果GPU可用,将模块移动到 GPUinput_size=20device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 输入张量 (batch_size, height, width,channels)x = torch.randn(1, input_size , input_size, 32).to(device)# 初始化 pconv 模块dim = 32block = DynamicFilter(dim=dim,size=input_size)print(block)block = block.to(device)# 前向传播output = block(x)print("输入:", x.shape)print("输出:", output.shape)