FcaNet:频率通道注意力,进阶版SE

paper:https://arxiv.org/abs/2012.11879

github:GitHub - cfzd/FcaNet: FcaNet: Frequency Channel Attention Networks

目录

1. 动机

2. 方法

2.1. 回顾通道注意力和离散余弦变换(DCT)

通道注意力:

离散余弦变换(DCT):

2.2. 多光谱通道注意力

通道注意力的理论分析:

多光谱注意力模块:

选择频率分量的标准:

2.3. 讨论

多光谱框架是如何嵌入更多的信息的:

 复杂度分析:

3. 实验结果

4. 总结

5. 代码


1. 动机

注意力机制,尤其是通道注意力,在CV领域取得了巨大成功。大部分研究都集中在如何设计更高效的通道注意力机制,却忽略了一个基本问题,也即:他们都是使用全局平均池化(GAP)来作为预处理方法。尽管GAP十分简单高效,但他的捕获的信息也确实不足。那么,有没有更好的预处理方式呢?

这篇文章提供了一个新的视角:从频域对注意力进行重新思考,并从数学上证明了GAP就是频域特征分解的一个特例。基于此,作者将通道注意力机制的预处理泛化到了频域,并基于多光谱通道注意力构建了FcaNet。这种方法将SENet-50在ImageNet上的top-1准确率提升了1.8%。

图1. ImageNet的分类精度比较。在相同的参数和计算代价下,我们的方法始终优于baseline SENet。我们使用 ResNet-50 主干的方法甚至可以在 ResNet-152 主干上优于 SENet。

2. 方法

2.1. 回顾通道注意力和离散余弦变换(DCT)

通道注意力:

在CNN中常用通道注意力来对不同通道的特征进行加权。对于输入X\in \mathbb{R}^{C*H*W},通道注意力模块可表示为:

其中att \in \mathbb{R}^C就是注意力向量,这种一维的向量和原始输入X对应通道相乘,即可产生注意力的效果:

离散余弦变换(DCT):

 对于一维信号x\in\mathbb{R}^L,其DCT公式表示为:

对于二维信号x\in\mathbb{R}^{H\times W},则其2d DCT就是:

相应的,逆2D DCT可表示为:

从通道注意力和DCT可以总结两点:

  • 1)通道注意力使用GAP进行数据预处理;
  • 2)DCT可以看做输入的加权和,上述DCT公式中的cos部分可以当做权重。

对于GAP,它是对输入的feature map每个通道求全局平均,其实可以看作一种最简单频谱。也正是因为简单,其所包含的信息也是不足的,因此本文探索了更复杂的频谱,来代替GAP。

2.2. 多光谱通道注意力

通道注意力的理论分析:

上面已经讨论过,DCT可以看作输入的加权和,而GAP是2D DCT的一个特例。因此提出以下定理:

定理一:GAP 是 2D DCT 的一个特例,其结果与 2D DCT 的最低频率分量成正比。

证明:

假设公式4中的h和w都为0,则有:

在公式6中,f_{0,0}^{2d}表示2D DCT的最低频率分量,且与GAP成正比。至此,定理一得证。

既然GAP是2D DCT的一个特例,那么是否可以考虑再加入其他分量呢?作者正是这样想的,所以接下来讨论了一下融合其他频率分量的可能性。

简单起见,使用B来表示2D DCT的基函数:

则公式5中的2D DCT可以改写为:

 从上述公式8可以看出,公式1所表示的通道注意力只是DCT中的一个分量,而其余的分量则被舍弃了:

其中HWB_{0,0}^{i,j}是一个恒定的比例因子,可以在注意力机制中被忽略。

自然地,可以想到把其他的频率分量也考虑进来,是不是就可以获取信息更丰富的通道注意力了呢?这就是下面要讲的多光谱注意力模块。

多光谱注意力模块:

首先,将输入X沿着通道分为多个部分,表示为X = [X^0, X^1, ... , X^{n-1}],其中X^i \in \mathbb R^{C'\times H \times W}, i \in {\{0,1,...,n-1\}}, C'=\frac{C}{n}, C 可以被n整除。对于每个部分,分配相应的 2D DCT 频率分量,2D DCT 结果可以用作通道注意力的预处理结果。如此则有:

其中[u, v]为xi对应的频率分量2D指数 ,而Freq^i \in \mathbb R^{C'}则是预处理后的C'维的向量。整体预处理的向量可以通过拼接得到:

其中的Freq则是获得的多光谱向量。整体的多光谱通道注意力就可以写为:

 从式12和13,可以看到提出的方法包括了 GAP 的原始方法,并从最低频率分量推广到具有多个频率源。如此,解决了原始方法不足的问题。总体框架如图2所示:

图 2. 现有通道注意力和多光谱通道注意力的图示。为简单起见,2D DCT 索引以一维格式表示。我们可以看到,我们的方法使用多个频率分量和选定的 DCT 基,而 SENet 在通道注意力中仅使用 GAP。

选择频率分量的标准:

如何为X的每个部分选择合适的频率分量[u,v]是一个重要问题。对于空间大小为 H × W 的每个通道,我们可以得到 2D DCT 之后的 HW 频率分量。在这种情况下,这些频率分量的组合总数为 CHW。例如,ResNet-50 主干的 C 可能等于 2048。测试所有组合是昂贵的。因此,我们提出了一种启发式两步标准来选择多光谱注意模块中的频率分量。

主要思想是首先确定每个频率分量的重要性,然后确定将不同数量的频率分量一起使用的效果。首先,我们分别检查通道注意中每个频率分量的结果。然后,我们根据结果选择 Topk 最高性能频率分量。这样,就可以满足多光谱通道注意力。

2.3. 讨论

多光谱框架是如何嵌入更多的信息的:

上面已经证明GAP只是2D DCT的一个特例,其只考虑最低的频率分量,而丢弃了其他频率分量;而进一步考虑利用起来其他频率分量能够丰富通道注意力的的信息量。那么,这些多光谱的频率分量为什么能够嵌入更多信息呢?

这里,作者给出了一个思想实验:

众所周知,深度网络是冗余的。对于两个冗余通道,使用GAP可以获取相同的信息,但在所提出的多光谱框架中,却可以获取更多不同的信息,因为不同的频率分量包含的信息不同。这样一来,所提出的多光谱框架可以在通道注意机制中嵌入更多的信息。

这部分我的个人理解是:两个通道里面的feature map中的每个像素值大概率是不同的,直接GAP是有可能获取相似的值的(也就是所谓的两个通道冗余了),但如果对这两个通道使用不同的频率分量进行信息提取,那它们的结果就很可能不一样了(也就是不冗余了);所以这种多光谱框架其实就是从原先看似冗余的通道之间建模了一丝细微的差别出来,从而丰富了或者说细化了通道注意力的信息提取能力。

 复杂度分析:

从参数数量和计算成本两个方面分析了此方法的复杂性。对于参数数量,与基线 SENet 相比,我们的方法没有额外的参数,因为 2D DCT 的权重是预先计算的常数。对于计算成本,我们的方法的额外成本可以忽略不计,可以被视为具有与SENet相同的计算成本。使用 ResNet-34、ResNet-50、ResNet101 和 ResNet-152 主干,与 SENet 相比,我们的方法的相对计算成本分别提高了 0.04%、0.13%、0.11% 和 0.11%。

只需一行代码改动:

所提出的多光谱框架的另一个重要特性是,它可以通过现有的通道注意实现轻松实现。如前面所述,2D DCT可以看作是输入的加权和。这样,我们的方法的实现可以简单地通过元素乘法和求和来实现。实现如图3所示:

图3. 我们的方法和SENet的实现: 在计算中,我们只需要更改一行代码来实现基于现有代码的方法。红色和绿色的线表示 SENet 和我们的工作之间的差异。得到的dct权重函数是实现Eq. 7,详细信息可以在附录中找到。

3. 实验结果

使用不同数量的频率分量,得到效果如下表,可以看出16个分量效果最好。

 在ImageNet上,FacNet明显优于其他方法:

在COCO上,也有优异表现:

实例分割同样毫不示弱:

4. 总结

本文针对SE通道注意力中使用GAP作为预处理存在的问题,即其获取的信息太过简单,提出了从频域分析获取更丰富信息的多光谱通道注意力方法,并提出了带有多光谱通道注意力的FcaNet网络,从频域泛化了已有的通道注意力。同时,在多光谱框架中探索了频率分量的不同组合,并提出了频率分量选择的两步标准。实验表明,在相同的参数量和相同计算量上,可以获取比SENet更优的效果。与其他通道注意方法相比,还在图像分类、目标检测和实例分割方面取得了最先进的性能。此外,FcaNet 简单而有效,可以基于现有的通道注意方法仅使用一行代码更改来实现。

5. 代码

虽然文章提供了SE和所提出多光谱通道注意力的对比,不过实际实现略有不同。从github上可以找到其代码(具体参考https://github.com/cfzd/FcaNet/blob/master/model/layer.py):

class MultiSpectralAttentionLayer(torch.nn.Module):def __init__(self, channel, dct_h, dct_w, reduction = 16, freq_sel_method = 'top16'):super(MultiSpectralAttentionLayer, self).__init__()self.reduction = reductionself.dct_h = dct_hself.dct_w = dct_wmapper_x, mapper_y = get_freq_indices(freq_sel_method)self.num_split = len(mapper_x)mapper_x = [temp_x * (dct_h // 7) for temp_x in mapper_x] mapper_y = [temp_y * (dct_w // 7) for temp_y in mapper_y]# make the frequencies in different sizes are identical to a 7x7 frequency space# eg, (2,2) in 14x14 is identical to (1,1) in 7x7self.dct_layer = MultiSpectralDCTLayer(dct_h, dct_w, mapper_x, mapper_y, channel)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):n,c,h,w = x.shapex_pooled = xif h != self.dct_h or w != self.dct_w:x_pooled = torch.nn.functional.adaptive_avg_pool2d(x, (self.dct_h, self.dct_w))# If you have concerns about one-line-change, don't worry.   :)# In the ImageNet models, this line will never be triggered. # This is for compatibility in instance segmentation and object detection.y = self.dct_layer(x_pooled)y = self.fc(y).view(n, c, 1, 1)return x * y.expand_as(x)

其中的self.dct_layer(也即是所谓的只需改动一行代码)定义如下:


class MultiSpectralDCTLayer(nn.Module):"""Generate dct filters"""def __init__(self, height, width, mapper_x, mapper_y, channel):super(MultiSpectralDCTLayer, self).__init__()assert len(mapper_x) == len(mapper_y)assert channel % len(mapper_x) == 0self.num_freq = len(mapper_x)# fixed DCT initself.register_buffer('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel))# fixed random init# self.register_buffer('weight', torch.rand(channel, height, width))# learnable DCT init# self.register_parameter('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel))# learnable random init# self.register_parameter('weight', torch.rand(channel, height, width))# num_freq, h, wdef forward(self, x):assert len(x.shape) == 4, 'x must been 4 dimensions, but got ' + str(len(x.shape))# n, c, h, w = x.shapex = x * self.weightresult = torch.sum(x, dim=[2,3])return resultdef build_filter(self, pos, freq, POS):result = math.cos(math.pi * freq * (pos + 0.5) / POS) / math.sqrt(POS) if freq == 0:return resultelse:return result * math.sqrt(2)def get_dct_filter(self, tile_size_x, tile_size_y, mapper_x, mapper_y, channel):dct_filter = torch.zeros(channel, tile_size_x, tile_size_y)c_part = channel // len(mapper_x)for i, (u_x, v_y) in enumerate(zip(mapper_x, mapper_y)):for t_x in range(tile_size_x):for t_y in range(tile_size_y):dct_filter[i * c_part: (i+1)*c_part, t_x, t_y] = self.build_filter(t_x, u_x, tile_size_x) * self.build_filter(t_y, v_y, tile_size_y)return dct_filter

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/762315.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何使用Android平板公网访问本地Linux code-server

文章目录 1.ubuntu本地安装code-server2. 安装cpolar内网穿透3. 创建隧道映射本地端口4. 安卓平板测试访问5.固定域名公网地址6.结语 1.ubuntu本地安装code-server 准备一台虚拟机,Ubuntu或者centos都可以,这里以VMwhere ubuntu系统为例 下载code server服务,浏览器…

Skywalking的Helm Chart方式部署

背景 之前介绍了AWS云上面的EKS的集中日志方案。这次主要介绍调用链监控了,这里我们用的是Skywalking。监控三王者(EFKPrometheusSkywalking)之一。之前AWS云上面使用fluent bit替代EFK方案,其实,AWS云在调用链方面&a…

Elasticsearch:ES|QL 入门 - Python Notebook

数据丰富在本笔记本中,你将学习 Elasticsearch 查询语言 (ES|QL) 的基础知识。 你将使用官方 Elasticsearch Python 客户端。 你将学习如何: 运行 ES|QL 查询使用处理命令对表格进行排序查询数据链式处理命令计算值计算统计数据访问列创建直方图丰富数…

UE4 Json事件设置Asset值(Asset如果都在同一目录下)

通过Json事件来设置,比如骨骼网格体(换皮)等等

docker可视化管理工具-DockerUI

系列文章目录 文章目录 系列文章目录前言 前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去分享给你的码吧。 一个可视化的管理工…

ABAP笔记:定义指针,动态指针分配:ASSIGN COMPONENT <N> OF STRUCTURE <结构> TO <指针>.

参考大佬文章学习,总结了下没有提到的点:SAP ABAP指针的6种用法。_abap 指针-CSDN博客 定义指针:其实指针这玩意,就是类似你给个地方,把东西临时放进去,然后指针就是这个东西的替身了,写代码的…

iPhone语音备忘录误删?掌握这几个技巧轻松恢复【详】

语音备忘录是一款强大的应用程序,它允许用户使用语音输入功能来快速记录想法、提醒、待办事项等。无论是在行进间、工作中还是日常生活中,语音备忘录都是一个非常实用的工具,可以帮助您随时随地记录重要信息,而无需打字或者手动输…

redis-黑马点评-商户查询缓存

缓存:cache public Result queryById(Long id) {//根据id在redis中查询数据String s redisTemplate.opsForValue().get(CACHE_SHOP_KEY id);//判断是否存在if (!StrUtil.isBlank(s)) {//将字符串转为bean//存在,直接返回Shop shop JSONUtil.toBean(s, …

专家解读!IMAP的要点助您在旅途中保持邮件无忧!

你是否经常因会议而出差,需要在各种设备上灵活地访问你的电子邮件?如果是的话,你可能会想了解你的电子邮件系统是如何通过使用互联网消息访问协议(IMAP)来工作的,这样当你不在办公桌前时,你可以…

_.debounce防抖函数 在vue中使用this问题,应该传匿名函数而不是箭头函数

简单理解:_.debounce内部做了apply操作,箭头函数由于没有this,无法绑定this,导致最终this是undefined, 而匿名函数,成功通过applay绑定了this,所以this指向了vue组件实例。 methods: {// 防抖动dSave1: _.debounce(() > {console.log(thi…

你知道弧幕影院如何制作吗?其应用领域竟如此广泛!

“沉浸式”作为如今备受热议的内容展示形式,其有着多种可实现的途径,其中弧幕影院作为一项有着独特视觉效果、沉浸式观影体验的技术类型,便是大多数影院、主题公园等娱乐场景的必备设计展项,这种弧幕影院通常使用大型的半圆形屏幕…

python 爬取杭州小区挂牌均价

下载chrome驱动 通过chrome浏览器的 设置-帮助-关于Google Chrome 查看你所使用的Chrome版本 驱动可以从这两个地方找: 【推荐】https://storage.googleapis.com/chrome-for-testing-publichttp://npm.taobao.org/mirrors/chromedriver import zipfile import os import r…

leetcode 232.用栈实现队列 JAVA

题目 思路 使用两个栈(输入栈和输出栈)来模拟一个队列。 队列的push操作实现:直接将元素push到输入栈中。 队列的pop操作实现:队列是先入先出,将输入栈的元素全部pop到输出栈中,然后再由输出栈pop&#…

PMP备考时间、出成绩时间有多久?从在威班培训到拿证我用了60天

尽管PMI官方没有对PMP考试通过分数进行具体规定,能否通过也是看成绩页显示的是“PASS”(通过)还是“FAIL”(未通过),没有成绩的数值体现,但有每个领域的等级可以进行查看,比如下图。…

Windows系统服务器宝塔面板打开提示Internal Server Error错误

1、cmd运行bt命令 2、尝试输入16修复程序 3、如果不行,输入17升级程序

STL —— string(1)

目录 1. 模板 1.1 泛型编程 1.2 函数模板 1.2.1 函数模板概念 1.2.2 函数模板格式 1.2.3 函数模板的原理 1.2.4 显式实例化 1.2.5 模板参数的匹配原则 1.3 类模板 1.3.1 类模板定义格式 1.3.2 类模板的实例化 2. STL —— string类 2.1 STL 简介 2.2 标准库中的s…

怎样隐藏查询和分组?

发布查询时,遇到信息量较大需要提前制作好,但不用马上发布的查询,该怎样隐藏查询和分组? 📌使用教程 01“开始”和“暂停”查询 如果想要隐藏查询,可以通过点击绿色开始按钮来暂停查询,暂停后的…

【软考高项】十五、信息系统工程之系统集成

1、集成基础 定义:通过硬件平台、网络通信平台、数据库平台、工具平台、应用软件平台将各类资源有机、高效地集成到一起,形成一个完整的工作台面 基本原则包括:开放性、结构化、先进性和主流化 2、网络集成 包括:传输子系统、交换子系统、…

调试西门子G120STO模式出现O.F1600等一系列报警

目录 一、现象描述 二、 解决经历 三、结果展示 四、总结 一、现象描述 在调试使用西门子G120的STO功能时,一直无法使用,变频器也一直在报警(RDY灯红灯快闪、SAFE灯黄灯快闪)。在博图上查询发现下面一系列的故障报警。 二、 解决经历 也查询了很多网…

Vue中的状态管理Vuex,基本使用

1.什么是Vuex? Vuex是专门为Vue.js设计的状态管理模式;特点:集中式存储和管理应用程序中所有组件状态,保证状态以一种可预测的方式发生变化。 1.1.什么是状态管理模式? 先看一个单向数据流的简单示意图 state:驱动应用的数据源 view:以声明方式将state映射到视图 actions:…