【Block总结】DynamicFilter,动态滤波器降低计算复杂度,替换传统的MHSA|即插即用

论文信息

标题: FFT-based Dynamic Token Mixer for Vision

论文链接: https://arxiv.org/pdf/2303.03932

关键词: 深度学习、计算机视觉、对象检测、分割

GitHub链接: https://github.com/okojoalg/dfformer

在这里插入图片描述

创新点

本论文提出了一种新的标记混合器(token mixer),称为动态滤波器(Dynamic Filter),旨在解决多头自注意力(MHSA)模型在处理高分辨率图像时的计算复杂度问题。传统的MHSA模型在输入特征图中像素数量的平方上具有计算复杂度,导致处理速度缓慢。通过引入基于快速傅里叶变换(FFT)的动态滤波器,论文展示了在保持性能的同时显著降低计算复杂度的可能性。

方法

论文中提出的动态滤波器结合了全局操作的优点,类似于MHSA,但在计算效率上更具优势。具体方法包括:

  • FFT-based Token Mixer: 通过FFT实现全局操作,降低计算复杂度。
  • DFFormer和CDFFormer模型: 这两种新型图像识别模型利用动态滤波器进行图像分类和其他下游任务。
    在这里插入图片描述

动态滤波器如何具体降低MHSA模型的计算复杂度?

动态滤波器通过引入基于快速傅里叶变换(FFT)的机制,显著降低了多头自注意力(MHSA)模型的计算复杂度。以下是其具体工作原理和优势:

计算复杂度问题

传统的MHSA模型在处理输入特征图时,其计算复杂度与特征图中像素数量的平方成正比。这意味着,当输入图像的分辨率增加时,计算需求会急剧上升,导致处理速度变慢,尤其是在高分辨率图像的情况下。

动态滤波器的工作原理

  1. 频域转换: 动态滤波器首先利用FFT将输入特征图转换到频域。FFT是一种高效的算法,可以将计算复杂度降低到 O ( N log ⁡ N ) O(N \log N) O(NlogN),其中 N N N是数据的长度。这一转换使得后续的操作可以在频域中进行,从而减少了计算量。

  2. 动态生成滤波器: 在频域中,动态滤波器通过一个多层感知机(MLP)动态生成每个特征通道的滤波器。这些滤波器是根据输入特征图的内容进行调整的,能够更好地捕捉到图像中的重要信息。

  3. 频域操作: 生成的滤波器在频域中应用于特征图,进行全局信息的捕捉。通过这种方式,动态滤波器能够有效地进行全局操作,同时避免了MHSA中计算复杂度的急剧增加。

  4. 逆FFT转换: 最后,经过滤波的频域特征图通过逆FFT转换回空间域,得到最终的输出结果。

优势

  • 降低计算复杂度: 通过在频域中进行操作,动态滤波器显著降低了MHSA模型的计算复杂度,使得处理高分辨率图像时的速度得以提升。

  • 提高内存效率: 动态滤波器的设计使得模型在处理时占用更少的内存,适合在资源有限的环境中运行。

  • 保持性能: 尽管计算复杂度降低,动态滤波器仍然能够保持与MHSA相似的性能,尤其是在图像分类和其他视觉任务中表现出色。

效果

实验结果表明,DFFormer和CDFFormer在高分辨率图像识别任务中表现出色,具有显著的吞吐量和内存效率。具体而言,这些模型在处理高分辨率图像时的性能优于传统的MHSA模型,显示出动态滤波器在实际应用中的潜力。

实验结果

论文通过一系列实验验证了提出模型的有效性,包括:

  • 图像分类: DFFormer和CDFFormer在标准数据集上的表现接近或超过了现有的最先进模型。
  • 下游任务分析: 通过对比实验,展示了动态滤波器在不同视觉任务中的适用性和优势。

总结

本论文的研究表明,基于FFT的动态滤波器是一种值得认真考虑的标记混合器选项,尤其是在处理高分辨率图像时。通过降低计算复杂度,动态滤波器不仅提高了模型的处理速度,还保持了良好的性能,推动了计算机视觉领域的进一步发展。研究结果为未来的视觉模型设计提供了新的思路和方向。

代码

import torch
import torch.nn as nn
from timm.models.layers import to_2tupleclass StarReLU(nn.Module):"""StarReLU: s * relu(x) ** 2 + b"""def __init__(self, scale_value=1.0, bias_value=0.0,scale_learnable=True, bias_learnable=True,mode=None, inplace=False):super().__init__()self.inplace = inplaceself.relu = nn.ReLU(inplace=inplace)self.scale = nn.Parameter(scale_value * torch.ones(1),requires_grad=scale_learnable)self.bias = nn.Parameter(bias_value * torch.ones(1),requires_grad=bias_learnable)def forward(self, x):return self.scale * self.relu(x) ** 2 + self.biasclass Mlp(nn.Module):""" MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks.Mostly copied from timm."""def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0.,bias=False, **kwargs):super().__init__()in_features = dimout_features = out_features or in_featureshidden_features = int(mlp_ratio * in_features)drop_probs = to_2tuple(drop)self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)self.act = act_layer()self.drop1 = nn.Dropout(drop_probs[0])self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)self.drop2 = nn.Dropout(drop_probs[1])def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop1(x)x = self.fc2(x)x = self.drop2(x)return xclass DynamicFilter(nn.Module):def __init__(self, dim, expansion_ratio=2, reweight_expansion_ratio=.25,act1_layer=StarReLU, act2_layer=nn.Identity,bias=False, num_filters=4, size=14, weight_resize=False,**kwargs):super().__init__()size = to_2tuple(size)self.size = size[0]self.filter_size = size[1] // 2 + 1self.num_filters = num_filtersself.dim = dimself.med_channels = int(expansion_ratio * dim)self.weight_resize = weight_resizeself.pwconv1 = nn.Linear(dim, self.med_channels, bias=bias)self.act1 = act1_layer()self.reweight = Mlp(dim, reweight_expansion_ratio, num_filters * self.med_channels)self.complex_weights = nn.Parameter(torch.randn(self.size, self.filter_size, num_filters, 2,dtype=torch.float32) * 0.02)self.act2 = act2_layer()self.pwconv2 = nn.Linear(self.med_channels, dim, bias=bias)def forward(self, x):B, H, W, _ = x.shaperouteing = self.reweight(x.mean(dim=(1, 2))).view(B, self.num_filters,-1).softmax(dim=1)x = self.pwconv1(x)x = self.act1(x)x = x.to(torch.float32)x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')if self.weight_resize:complex_weights = resize_complex_weight(self.complex_weights, x.shape[1],x.shape[2])complex_weights = torch.view_as_complex(complex_weights.contiguous())else:complex_weights = torch.view_as_complex(self.complex_weights)routeing = routeing.to(torch.complex64)weight = torch.einsum('bfc,hwf->bhwc', routeing, complex_weights)if self.weight_resize:weight = weight.view(-1, x.shape[1], x.shape[2], self.med_channels)else:weight = weight.view(-1, self.size, self.filter_size, self.med_channels)x = x * weightx = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')x = self.act2(x)x = self.pwconv2(x)return x
def resize_complex_weight(origin_weight, new_h, new_w):h, w, num_heads = origin_weight.shape[0:3]  # size, w, c, 2origin_weight = origin_weight.reshape(1, h, w, num_heads * 2).permute(0, 3, 1, 2)new_weight = torch.nn.functional.interpolate(origin_weight,size=(new_h, new_w),mode='bicubic',align_corners=True).permute(0, 2, 3, 1).reshape(new_h, new_w, num_heads, 2)return new_weightif __name__ == "__main__":# 如果GPU可用,将模块移动到 GPUinput_size=20device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 输入张量 (batch_size, height, width,channels)x = torch.randn(1, input_size , input_size, 32).to(device)# 初始化 pconv 模块dim = 32block = DynamicFilter(dim=dim,size=input_size)print(block)block = block.to(device)# 前向传播output = block(x)print("输入:", x.shape)print("输出:", output.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/894117.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

(done) MIT6.S081 2023 学习笔记 (Day6: LAB5 COW Fork)

网页:https://pdos.csail.mit.edu/6.S081/2023/labs/cow.html 任务1:Implement copy-on-write fork(hard) (完成) 现实中的问题如下: xv6中的fork()系统调用会将父进程的用户空间内存全部复制到子进程中。如果父进程很大,复制过程…

鸢尾花书01---基本介绍和Jupyterlab的上手

文章目录 1.致谢和推荐2.py和.ipynb区别3.Jupyterlab的上手3.1入口3.2页面展示3.3相关键介绍3.4代码的运行3.5重命名3.6latex和markdown说明 1.致谢和推荐 这个系列是关于一套书籍,结合了python和数学,机器学习等等相关的理论,总结的7本书籍…

【愚公系列】《循序渐进Vue.js 3.x前端开发实践》033-响应式编程的原理及在Vue中的应用

标题详情作者简介愚公搬代码头衔华为云特约编辑,华为云云享专家,华为开发者专家,华为产品云测专家,CSDN博客专家,CSDN商业化专家,阿里云专家博主,阿里云签约作者,腾讯云优秀博主&…

【javaweb项目idea版】蛋糕商城(可复用成其他商城项目)

该项目虽然是蛋糕商城项目,但是可以复用成其他商城项目或者购物车项目 想要源码的uu可点赞后私聊 技术栈 主要为:javawebservletmvcc3p0idea运行 功能模块 主要分为用户模块和后台管理员模块 具有商城购物的完整功能 基础模块 登录注册个人信息编辑…

为什么LabVIEW适合软硬件结合的项目?

LabVIEW是一种基于图形化编程的开发平台,广泛应用于软硬件结合的项目中。其强大的硬件接口支持、实时数据采集能力、并行处理能力和直观的用户界面,使得它成为工业控制、仪器仪表、自动化测试等领域中软硬件系统集成的理想选择。LabVIEW的设计哲学强调模…

Fort Firewall:全方位守护网络安全

Fort Firewall是一款专为 Windows 操作系统设计的开源防火墙工具,旨在为用户提供全面的网络安全保护。它基于 Windows 过滤平台(WFP),能够与系统无缝集成,确保高效的网络流量管理和安全防护。该软件支持实时监控网络流…

【PyTorch】6.张量形状操作:在深度学习的 “魔方” 里,玩转张量形状

目录 1. reshape 函数的用法 2. transpose 和 permute 函数的使用 4. squeeze 和 unsqueeze 函数的用法 5. 小节 个人主页:Icomi 专栏地址:PyTorch入门 在深度学习蓬勃发展的当下,PyTorch 是不可或缺的工具。它作为强大的深度学习框架&am…

[STM32 - 野火] - - - 固件库学习笔记 - - -十三.高级定时器

一、高级定时器简介 高级定时器的简介在前面一章已经介绍过,可以点击下面链接了解,在这里进行一些补充。 [STM32 - 野火] - - - 固件库学习笔记 - - -十二.基本定时器 1.1 功能简介 1、高级定时器可以向上/向下/两边计数,还独有一个重复计…

Cyber Security 101-Build Your Cyber Security Career-Security Principles(安全原则)

了解安全三元组以及常见的安全模型和原则。 任务1:介绍 安全已成为一个流行词;每家公司都想声称其产品或服务是安全的。但事实真的如此吗? 在我们开始讨论不同的安全原则之前,了解我们正在保护资产的对手至关重要。您是否试图阻止蹒跚学步…

python:斐索实验(Fizeau experiment)

斐索实验(Fizeau experiment)是在1851年由法国物理学家阿曼德斐索(Armand Fizeau)进行的一项重要实验,旨在测量光在移动介质中的传播速度。这项实验的结果对当时的物理理论产生了深远的影响,并且在后来的相…

青少年CTF练习平台 贪吃蛇

题目 CtrlU快捷键查看页面源代码 源码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>贪吃蛇游戏</title><style>#gameCanvas {border: 1px solid black;}</style> </head>…

芯片AI深度实战:基础篇之Ollama

有这么多大模型&#xff0c;怎么本地用&#xff1f; Ollama可以解决这一问题。不依赖GPU&#xff0c;也不需要编程。就可以在CPU上运行自己的大模型。 软件甚至不用安装&#xff0c;直接在ollama官网下载可执行文件即可。 现在最流行的deepseek-r1也可以使用。当然还有我认为最…

本地部署deepseek模型步骤

文章目录 0.deepseek简介1.安装ollama软件2.配置合适的deepseek模型3.安装chatbox可视化 0.deepseek简介 DeepSeek 是一家专注于人工智能技术研发的公司&#xff0c;致力于打造高性能、低成本的 AI 模型&#xff0c;其目标是让 AI 技术更加普惠&#xff0c;让更多人能够用上强…

DeepSeek R1中提到“知识蒸馏”到底是什么

在 DeepSeek-R1 中&#xff0c;知识蒸馏&#xff08;Knowledge Distillation&#xff09;是实现模型高效压缩与性能优化的核心技术之一。在DeepSeek的论文中&#xff0c;使用 DeepSeek-R1&#xff08;教师模型&#xff09;生成 800K 高质量训练样本&#xff0c;涵盖数学、编程、…

关联传播和 Python 和 Scikit-learn 实现

文章目录 一、说明二、什么是 Affinity Propagation。2.1 先说Affinity 传播的工作原理2.2 更多细节2.3 传播两种类型的消息2.4 计算责任和可用性的分数2.4.1 责任2.4.2 可用性分解2.4.3 更新分数&#xff1a;集群是如何形成的2.4.4 估计集群本身的数量。 三、亲和力传播的一些…

通过配置代理解决跨域问题(Vue+SpringBoot项目为例)

跨域问题&#xff1a; 是由浏览器的同源策略引起的&#xff0c;同源策略是一种安全策略&#xff0c;用于防止一个网站访问其他网站的数据。 同源是指协议、域名和端口号都相同。 跨域问题常常出现在前端项目中&#xff0c;当浏览器中的前端代码尝试从不同的域名、端口或协议…

(1)Linux高级命令简介

Linux高级命令简介 在安装好linux环境以后第一件事情就是去学习一些linux的基本指令&#xff0c;我在这里用的是CentOS7作演示。 首先在VirtualBox上装好Linux以后&#xff0c;启动我们的linux&#xff0c;输入账号密码以后学习第一个指令 简介 Linux高级命令简介ip addrtou…

TOGAF之架构标准规范-信息系统架构 | 数据架构

TOGAF是工业级的企业架构标准规范&#xff0c;信息系统架构阶段是由数据架构阶段以及应用架构阶段构成&#xff0c;本文主要描述信息系统架构阶段中的数据架构阶段。 如上所示&#xff0c;信息系统架构&#xff08;Information Systems Architectures&#xff09;在TOGAF标准规…

Windows 程序设计7:文件的创建、打开与关闭

文章目录 前言一、文件的创建与打开CreateFile1. 创建新的空白文件2. 打开已存在文件3. 打开一个文件时&#xff0c;如果文件存在则打开&#xff0c;如果文件不存在则新创建文件4.打开一个文件&#xff0c;如果文件存在则打开文件并清空内容&#xff0c;文件不存在则 新创建文件…

FastReport.NET控件篇之富文本控件

简介 FastReport.NET 提供了 RichText 控件&#xff0c;用于在报表中显示富文本内容。富文本控件支持多种文本格式&#xff08;如字体、颜色、段落、表格、图片等&#xff09;&#xff0c;非常适合需要复杂排版和格式化的场景。 富文本控件(RichText)使用场景不多&#xff0c…