CVPR2022 ACmix 注意力模块 | On the Integration of Self-Attention and Convolution

在这里插入图片描述

论文名称:《On the Integration of Self-Attention and Convolution》

论文地址:2111.14556 (arxiv.org)


卷积和自注意力是两种强大的表示学习技术,通常被认为是两种截然不同的并列方法。在本文中,我们展示了它们之间存在一种强烈的潜在关系,从计算角度来看,这两个范式的大部分计算实际上是使用相同的操作完成的。具体而言,我们首先表明,传统的卷积操作,具有 k × k k×k k×k 的核大小,可以被分解为 个独立的 1 × 1 1×1 1×1 卷积,然后是移位和求和操作。然后,我们将自注意力模块中查询、键和值的投影解释为多个 1 × 1 1×1 1×1 卷积,然后计算注意力权重并聚合这些值。因此,这两个模块的第一阶段包含了相似的操作。更重要的是,相较于第二阶段,第一阶段在计算复杂性上占据主要地位(频道大小的平方)。这一发现自然引出了这两个看似截然不同的范式的巧妙结合,即一个混合模型,它既享受自注意力和卷积的好处,同时比纯卷积或纯自注意力模型具有最小的计算开销。大量实验表明,我们的模型在图像识别和下游任务上持续取得比竞争基准更好的结果。


问题背景

论文讨论了深度学习中的两大主要方法:卷积和自注意力。在计算机视觉领域,卷积神经网络(CNN)通常是默认的选择,而自注意力则在自然语言处理(NLP)中更常见。该论文的背景在于这些两种技术的不同设计范式,以及它们在计算机视觉任务中的应用。这引发了一个问题:这两者是否可以通过某种方式结合,以实现更好的性能和效率。


核心概念

该论文提出了一种名为ACmix的新方法,将卷积和自注意力结合在一起。核心概念在于发现这两种技术在计算上有共同点。传统的卷积可以分解为一系列的 1 × 1 1×1 1×1 卷积,而自注意力的投影过程也类似。这种相似性为将这两种方法结合在一起提供了机会,从而形成一种混合模型,既具有卷积的局部特性,又具备自注意力的灵活性。


模块的操作步骤


在这里插入图片描述

展示了 ACmix 的概念草图。我们探索了卷积和自注意力之间的密切关系,这种关系在于共享相同的计算负载(1×1 卷积),并结合其余的轻量级聚合操作。我们展示了每个模块相对于特征通道的计算复杂性。


ACmix中,操作步骤分为两个阶段:

  • 阶段一:将输入特征映射通过 1 × 1 1×1 1×1 卷积进行投影,形成中间特征。
  • 阶段二:根据不同的范式,分别应用卷积和自注意力的操作。在卷积路径上,利用ShiftSummation操作实现卷积过程。在自注意力路径上,计算查询、键和值,然后应用传统的自注意力方法。

文章贡献

  1. 揭示了卷积和自注意力之间的强关联性,提供了一种理解这两者之间关系的新方式。
  2. 提出了ACmix,一种将卷积和自注意力优雅地结合在一起的模型。它能够在没有额外计算负担的情况下,享受这两种方法的优势。

实验结果与应用:

ACmixImageNet分类、语义分割和目标检测等任务中进行了验证。实验结果表明,与传统的卷积或自注意力模型相比,ACmix在准确性、计算开销和参数数量上具有优势。在ImageNet分类任务中,ACmix的模型在相同的FLOPs或参数数量下表现出色,并且在与竞争对手的基准比较中取得了持续的改进。此外,ACmixADE20K语义分割任务和COCO目标检测任务中也显示出明显的改进,进一步验证了该模型的有效性。


对未来工作的启示:

该论文提出了ACmix这一混合模型,为计算机视觉领域提供了一种新的方向。它揭示了卷积和自注意力之间的关系,为设计新的学习范式提供了灵感。未来的工作可以在以下几个方面继续探索:

  • 在其他自注意力模块中应用ACmix,并进一步验证其有效性。
  • 研究如何在更大的模型中更好地整合卷积和自注意力,从而在更复杂的任务中实现性能提升。
  • 考虑在实际应用中进一步优化ACmix的计算效率,以确保其在生产环境中的可用性。

代码

import torch
import torch.nn as nn
import torch.nn.functional as Fdef position(H, W, is_cuda=True):if is_cuda:loc_w = torch.linspace(-1.0, 1.0, W).cuda().unsqueeze(0).repeat(H, 1)loc_h = torch.linspace(-1.0, 1.0, H).cuda().unsqueeze(1).repeat(1, W)else:loc_w = torch.linspace(-1.0, 1.0, W).unsqueeze(0).repeat(H, 1)loc_h = torch.linspace(-1.0, 1.0, H).unsqueeze(1).repeat(1, W)loc = torch.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0).unsqueeze(0)return locdef stride(x, stride):b, c, h, w = x.shapereturn x[:, :, ::stride, ::stride]def init_rate_half(tensor):if tensor is not None:tensor.data.fill_(0.5)def init_rate_0(tensor):if tensor is not None:tensor.data.fill_(0.0)class ACmix(nn.Module):def __init__(self,in_planes,out_planes,kernel_att=7,head=4,kernel_conv=3,stride=1,dilation=1,):super(ACmix, self).__init__()self.in_planes = in_planesself.out_planes = out_planesself.head = headself.kernel_att = kernel_attself.kernel_conv = kernel_convself.stride = strideself.dilation = dilationself.rate1 = torch.nn.Parameter(torch.Tensor(1))self.rate2 = torch.nn.Parameter(torch.Tensor(1))self.head_dim = self.out_planes // self.headself.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1)self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2self.pad_att = torch.nn.ReflectionPad2d(self.padding_att)self.unfold = nn.Unfold(kernel_size=self.kernel_att, padding=0, stride=self.stride)self.softmax = torch.nn.Softmax(dim=1)self.fc = nn.Conv2d(3 * self.head,self.kernel_conv * self.kernel_conv,kernel_size=1,bias=False,)self.dep_conv = nn.Conv2d(self.kernel_conv * self.kernel_conv * self.head_dim,out_planes,kernel_size=self.kernel_conv,bias=True,groups=self.head_dim,padding=1,stride=stride,)self.reset_parameters()def reset_parameters(self):init_rate_half(self.rate1)init_rate_half(self.rate2)kernel = torch.zeros(self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv)for i in range(self.kernel_conv * self.kernel_conv):kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1.0kernel = kernel.squeeze(0).repeat(self.out_planes, 1, 1, 1)self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True)self.dep_conv.bias = init_rate_0(self.dep_conv.bias)def forward(self, x):q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)scaling = float(self.head_dim) ** -0.5b, c, h, w = q.shapeh_out, w_out = h // self.stride, w // self.stridepe = self.conv_p(position(h, w, x.is_cuda))q_att = q.view(b * self.head, self.head_dim, h, w) * scalingk_att = k.view(b * self.head, self.head_dim, h, w)v_att = v.view(b * self.head, self.head_dim, h, w)if self.stride > 1:q_att = stride(q_att, self.stride)q_pe = stride(pe, self.stride)else:q_pe = peunfold_k = self.unfold(self.pad_att(k_att)).view(b * self.head,self.head_dim,self.kernel_att * self.kernel_att,h_out,w_out,)  # b*head, head_dim, k_att^2, h_out, w_outunfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att * self.kernel_att, h_out, w_out)  # 1, head_dim, k_att^2, h_out, w_outatt = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(1)  # (b*head, head_dim, 1, h_out, w_out) * (b*head, head_dim, k_att^2, h_out, w_out) -> (b*head, k_att^2, h_out, w_out)att = self.softmax(att)out_att = self.unfold(self.pad_att(v_att)).view(b * self.head,self.head_dim,self.kernel_att * self.kernel_att,h_out,w_out,)out_att = ((att.unsqueeze(1) * out_att).sum(2).view(b, self.out_planes, h_out, w_out))f_all = self.fc(torch.cat([q.view(b, self.head, self.head_dim, h * w),k.view(b, self.head, self.head_dim, h * w),v.view(b, self.head, self.head_dim, h * w),],1,))f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])out_conv = self.dep_conv(f_conv)return self.rate1 * out_att + self.rate2 * out_convif __name__ == "__main__":input = torch.randn(64, 256, 8, 8)model = ACmix(in_planes=256, out_planes=256)output = model(input)print(output.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/4313.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

排序试题解析(二)

8.4.3 01.在以下排序算法中,每次从未排序的记录中选取最小关键字的记录,加入已排序记录的 末尾,该排序算法是( A ). A.简单选择排序 B.冒泡排序 C.堆排序 D.直接插入排序 02.简单选择排序算法的比较次数和移动次数分别为( C )。…

【小沐学Java】VSCode搭建Java开发环境

文章目录 1、简介2、安装VSCode2.1 简介2.2 安装 3、安装Java SDK3.1 简介3.2 安装3.3 配置 4、安装插件Java Extension Pack4.1 简介4.2 安装4.3 配置 结语 1、简介 2、安装VSCode 2.1 简介 Visual Studio Code 是一个轻量级但功能强大的源代码编辑器,可在桌面上…

如何使用小浪助手快速下载学浪中的视频?

今天给大家准备好了一个工具,小浪助手,它可以帮你们快速下载学浪中的视频 小浪助手我已经打包好了,有需要自己取一下 学浪下载工具链接:https://pan.baidu.com/s/1_Sg-EGGXKc4bMW-NPqUqvg?pwd1234 提取码:1234 --…

【语音识别】搭建本地的语音转文字系统:FunASR(离线不联网即可使用)

参考自: 参考配置:FunASR/runtime/docs/SDK_advanced_guide_offline_zh.md at main alibaba-damo-academy/FunASR (github.com)参考配置:FunASR/runtime/quick_start_zh.md at 861147c7308b91068ffa02724fdf74ee623a909e alibaba-damo-aca…

电脑教程1

一、介绍几个桌面上面的软件 1、火绒:主要用于电脑的安全防护和广告拦截 1.1 广告拦截 1.打开火绒软件点击安全工具 点击弹窗拦截 点击截图拦截 拦截具体的小广告 2、向日葵远程控制:可以通过这个软件进行远程协助 可以自己去了解下 这个软件不要…

模块四:一维前缀和模板——DP34 【模板】前缀和

文章目录 题目描述算法原理解法一:暴力解法(时间复杂度为O(n*q))解法二:前缀和(时间复杂度为O(n)O(q))细节问题 代码实现CJava 题目描述 题目链接:DP34 【模板】前缀和 根据描述第一句可得数组长度应设为n 1 算法原理 解法一…

编写一个函数fun,它的功能是:实现两个字符串的连接(不使用库函数strcat),即把p2所指的字符串连接到p1所指的字符串后。

本文收录于专栏:算法之翼 https://blog.csdn.net/weixin_52908342/category_10943144.html 订阅后本专栏全部文章可见。 本文含有题目的题干、解题思路、解题思路、解题代码、代码解析。本文分别包含C语言、C++、Java、Python四种语言的解法完整代码和详细的解析。 题干 编写…

个人学习-前端相关(2):ECMAScript 6-箭头函数、rest、spread

ES6的箭头函数 ES6允许使用箭头函数,语法类似java中的lambda表达式 let fun1 function(){} //普通的函数声明 let fun2 ()>{} //箭头函数声明 let fun3 (x) >{return x1} let fun4 x >{return x1} //参数列表中有且只有一个参数,()可…

kubebuilder(3)实现operator

在前面的文章我们已经了解了operator项目的基本结构。现在我们来写一点简单的代码,然后把我们的crd和operator部署到k8s集群中。 需求 这是一个真实的需求,只不过做了简化。 在开发公司自己的paas平台,有一个需求是,用户在发版…

236基于matlab的三维比例导引法仿真

基于matlab的三维比例导引法仿真,可以攻击静止/机动目标。1.三维空间内的比例导引程序,采用龙哥库塔积分法;2.文件名为bili3dnew的.m文件是主函数,执行时需调用目标机动子函数、导引律子函数、数值积分法子函数;3.文件…

统计建模——模型——python为例

统计建模涵盖了众多数学模型和分析方法,这些模型和方法被广泛应用于数据分析、预测、推断、分类、聚类等任务中。下面列举了一些常见的统计建模方法及其具体应用方式: 目录 1.线性回归模型: ----python实现线性回归模型 -------使用NumPy…

【C++】---STL容器适配器之queue

【C】---STL容器适配器之queue 一、队列1、队列的性质 二、队列类1、队列的构造2、empty()3、push()4、pop()5、size()6、front()7、back() 三、队列的模拟实现1、头文件(底层:deque)2、测试文件3、底层:list 一、队列 1、队列的…

Java基础_集合类_List

List Collection、List接口1、继承结构2、方法 Collection实现类1、继承结构2、相关类(1)AbstractCollection(2)AbstractListAbstractSequentialList(子类) 其它接口RandomAccess【java.util】Cloneable【j…

Kafka学习笔记01【2024最新版】

一、Kafka-课程介绍 官网地址:Apache KafkaApache Kafka: A Distributed Streaming Platform.https://kafka.apache.org/ kafka 3.6.1版本,作为经典分布式订阅、发布的消息传输中间件,kafka在实时数据处理、消息队列、流处理等领域具有广泛…

容器安全-镜像扫描

前言 容器镜像安全是云原生应用交付安全的重要一环,对上传的容器镜像进行及时安全扫描,并基于扫描结果选择阻断应用部署,可有效降低生产环境漏洞风险。容器安全面临的风险有:镜像风险、镜像仓库风险、编排工具风险,小…

Python_AI库 Matplotlib的应用简例:绘制与保存折线图

本文默认读者已具备以下技能: 熟悉Python基础语法,以自行阅读python代码块熟悉Vscode或其它编辑工具的应用 在数据可视化领域,Matplotlib无疑是一个强大的工具。它允许我们创建各种静态、动态、交互式的可视化图形,帮助我们更好…

python中如何用matplotlib写雷达图

#代码 import numpy as np # import matplotlib as plt # from matplotlib import pyplot as plt import matplotlib.pyplot as pltplt.rcParams[font.sans-serif].insert(0, SimHei) plt.rcParams[axes.unicode_minus] Falselabels np.array([速度, 力量, 经验, 防守, 发球…

新科技辅助器具赋能视障生活:让盲人出行融入日常

随着科技日新月异的发展,一款名为蝙蝠避障专为改善盲人日常生活的盲人日常生活辅助器具应运而生,它通过巧妙整合实时避障与拍照识别功能,成功改变了盲人朋友们的生活格局,为他们提供了更为便捷、高效的生活体验。 这款非同…

注意力机制:SENet详解

SENet(Squeeze-and-Excitation Networks)是2017年提出的一种经典的通道注意力机制,这种注意力可以让网络更加专注于一些重要的featuremap,它通过对特征通道间的相关性进行建模,把重要的特征图进行强化来提升模型的性能…

【Redis 开发】Redisson

Redisson RedissonRedisson分布式锁Redisson可重入锁Redission解决超时释放的问题Redission解决锁的判断一次性问题Redission分布式锁主从一致性问题 Redisson Redisson是一个在Redis的基础上实现的java驻内存数据网格,就是提供了一系列的分布式的java对象 官方地址…