CVPR2022 ACmix 注意力模块 | On the Integration of Self-Attention and Convolution

在这里插入图片描述

论文名称:《On the Integration of Self-Attention and Convolution》

论文地址:2111.14556 (arxiv.org)


卷积和自注意力是两种强大的表示学习技术,通常被认为是两种截然不同的并列方法。在本文中,我们展示了它们之间存在一种强烈的潜在关系,从计算角度来看,这两个范式的大部分计算实际上是使用相同的操作完成的。具体而言,我们首先表明,传统的卷积操作,具有 k × k k×k k×k 的核大小,可以被分解为 个独立的 1 × 1 1×1 1×1 卷积,然后是移位和求和操作。然后,我们将自注意力模块中查询、键和值的投影解释为多个 1 × 1 1×1 1×1 卷积,然后计算注意力权重并聚合这些值。因此,这两个模块的第一阶段包含了相似的操作。更重要的是,相较于第二阶段,第一阶段在计算复杂性上占据主要地位(频道大小的平方)。这一发现自然引出了这两个看似截然不同的范式的巧妙结合,即一个混合模型,它既享受自注意力和卷积的好处,同时比纯卷积或纯自注意力模型具有最小的计算开销。大量实验表明,我们的模型在图像识别和下游任务上持续取得比竞争基准更好的结果。


问题背景

论文讨论了深度学习中的两大主要方法:卷积和自注意力。在计算机视觉领域,卷积神经网络(CNN)通常是默认的选择,而自注意力则在自然语言处理(NLP)中更常见。该论文的背景在于这些两种技术的不同设计范式,以及它们在计算机视觉任务中的应用。这引发了一个问题:这两者是否可以通过某种方式结合,以实现更好的性能和效率。


核心概念

该论文提出了一种名为ACmix的新方法,将卷积和自注意力结合在一起。核心概念在于发现这两种技术在计算上有共同点。传统的卷积可以分解为一系列的 1 × 1 1×1 1×1 卷积,而自注意力的投影过程也类似。这种相似性为将这两种方法结合在一起提供了机会,从而形成一种混合模型,既具有卷积的局部特性,又具备自注意力的灵活性。


模块的操作步骤


在这里插入图片描述

展示了 ACmix 的概念草图。我们探索了卷积和自注意力之间的密切关系,这种关系在于共享相同的计算负载(1×1 卷积),并结合其余的轻量级聚合操作。我们展示了每个模块相对于特征通道的计算复杂性。


ACmix中,操作步骤分为两个阶段:

  • 阶段一:将输入特征映射通过 1 × 1 1×1 1×1 卷积进行投影,形成中间特征。
  • 阶段二:根据不同的范式,分别应用卷积和自注意力的操作。在卷积路径上,利用ShiftSummation操作实现卷积过程。在自注意力路径上,计算查询、键和值,然后应用传统的自注意力方法。

文章贡献

  1. 揭示了卷积和自注意力之间的强关联性,提供了一种理解这两者之间关系的新方式。
  2. 提出了ACmix,一种将卷积和自注意力优雅地结合在一起的模型。它能够在没有额外计算负担的情况下,享受这两种方法的优势。

实验结果与应用:

ACmixImageNet分类、语义分割和目标检测等任务中进行了验证。实验结果表明,与传统的卷积或自注意力模型相比,ACmix在准确性、计算开销和参数数量上具有优势。在ImageNet分类任务中,ACmix的模型在相同的FLOPs或参数数量下表现出色,并且在与竞争对手的基准比较中取得了持续的改进。此外,ACmixADE20K语义分割任务和COCO目标检测任务中也显示出明显的改进,进一步验证了该模型的有效性。


对未来工作的启示:

该论文提出了ACmix这一混合模型,为计算机视觉领域提供了一种新的方向。它揭示了卷积和自注意力之间的关系,为设计新的学习范式提供了灵感。未来的工作可以在以下几个方面继续探索:

  • 在其他自注意力模块中应用ACmix,并进一步验证其有效性。
  • 研究如何在更大的模型中更好地整合卷积和自注意力,从而在更复杂的任务中实现性能提升。
  • 考虑在实际应用中进一步优化ACmix的计算效率,以确保其在生产环境中的可用性。

代码

import torch
import torch.nn as nn
import torch.nn.functional as Fdef position(H, W, is_cuda=True):if is_cuda:loc_w = torch.linspace(-1.0, 1.0, W).cuda().unsqueeze(0).repeat(H, 1)loc_h = torch.linspace(-1.0, 1.0, H).cuda().unsqueeze(1).repeat(1, W)else:loc_w = torch.linspace(-1.0, 1.0, W).unsqueeze(0).repeat(H, 1)loc_h = torch.linspace(-1.0, 1.0, H).unsqueeze(1).repeat(1, W)loc = torch.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0).unsqueeze(0)return locdef stride(x, stride):b, c, h, w = x.shapereturn x[:, :, ::stride, ::stride]def init_rate_half(tensor):if tensor is not None:tensor.data.fill_(0.5)def init_rate_0(tensor):if tensor is not None:tensor.data.fill_(0.0)class ACmix(nn.Module):def __init__(self,in_planes,out_planes,kernel_att=7,head=4,kernel_conv=3,stride=1,dilation=1,):super(ACmix, self).__init__()self.in_planes = in_planesself.out_planes = out_planesself.head = headself.kernel_att = kernel_attself.kernel_conv = kernel_convself.stride = strideself.dilation = dilationself.rate1 = torch.nn.Parameter(torch.Tensor(1))self.rate2 = torch.nn.Parameter(torch.Tensor(1))self.head_dim = self.out_planes // self.headself.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1)self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2self.pad_att = torch.nn.ReflectionPad2d(self.padding_att)self.unfold = nn.Unfold(kernel_size=self.kernel_att, padding=0, stride=self.stride)self.softmax = torch.nn.Softmax(dim=1)self.fc = nn.Conv2d(3 * self.head,self.kernel_conv * self.kernel_conv,kernel_size=1,bias=False,)self.dep_conv = nn.Conv2d(self.kernel_conv * self.kernel_conv * self.head_dim,out_planes,kernel_size=self.kernel_conv,bias=True,groups=self.head_dim,padding=1,stride=stride,)self.reset_parameters()def reset_parameters(self):init_rate_half(self.rate1)init_rate_half(self.rate2)kernel = torch.zeros(self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv)for i in range(self.kernel_conv * self.kernel_conv):kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1.0kernel = kernel.squeeze(0).repeat(self.out_planes, 1, 1, 1)self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True)self.dep_conv.bias = init_rate_0(self.dep_conv.bias)def forward(self, x):q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)scaling = float(self.head_dim) ** -0.5b, c, h, w = q.shapeh_out, w_out = h // self.stride, w // self.stridepe = self.conv_p(position(h, w, x.is_cuda))q_att = q.view(b * self.head, self.head_dim, h, w) * scalingk_att = k.view(b * self.head, self.head_dim, h, w)v_att = v.view(b * self.head, self.head_dim, h, w)if self.stride > 1:q_att = stride(q_att, self.stride)q_pe = stride(pe, self.stride)else:q_pe = peunfold_k = self.unfold(self.pad_att(k_att)).view(b * self.head,self.head_dim,self.kernel_att * self.kernel_att,h_out,w_out,)  # b*head, head_dim, k_att^2, h_out, w_outunfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att * self.kernel_att, h_out, w_out)  # 1, head_dim, k_att^2, h_out, w_outatt = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(1)  # (b*head, head_dim, 1, h_out, w_out) * (b*head, head_dim, k_att^2, h_out, w_out) -> (b*head, k_att^2, h_out, w_out)att = self.softmax(att)out_att = self.unfold(self.pad_att(v_att)).view(b * self.head,self.head_dim,self.kernel_att * self.kernel_att,h_out,w_out,)out_att = ((att.unsqueeze(1) * out_att).sum(2).view(b, self.out_planes, h_out, w_out))f_all = self.fc(torch.cat([q.view(b, self.head, self.head_dim, h * w),k.view(b, self.head, self.head_dim, h * w),v.view(b, self.head, self.head_dim, h * w),],1,))f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])out_conv = self.dep_conv(f_conv)return self.rate1 * out_att + self.rate2 * out_convif __name__ == "__main__":input = torch.randn(64, 256, 8, 8)model = ACmix(in_planes=256, out_planes=256)output = model(input)print(output.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/4313.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

排序试题解析(二)

8.4.3 01.在以下排序算法中,每次从未排序的记录中选取最小关键字的记录,加入已排序记录的 末尾,该排序算法是( A ). A.简单选择排序 B.冒泡排序 C.堆排序 D.直接插入排序 02.简单选择排序算法的比较次数和移动次数分别为( C )。…

微信小程序手写文件解决日期少一天且格式无法切割问题

编译环境 微信开发者工具 问题 在小程序中无法实现对日期的切割,并且可能会出现日期少一天的问题,这个问题可以由后端进行解决,也可以前端,这里用了前端新建一个wxs转换文件进行解决。 比如数据库中的数据是2024-03-02… 但是返…

js动态设置css主题(Style-setProperty)

hex颜色转RGB hex2Rgb(str) {str str.replace("#", "");const hxs str.match(/../g);for (let index 0; index < 3; index) hxs[index] parseInt(hxs[index], 16);return hxs; } RGB转HXS rgb2hex(r,g,b){const hexs [r.toString(16), g.toString…

UE5蓝图 函数勾选线程安全的意义,我在动画蓝图状态机中调用了函数(gpt答复分享)

在Unreal Engine中&#xff0c;蓝图函数的“线程安全”选项通常用于确定该函数是否可以安全地在多线程环境下调用。线程安全意味着函数在执行时不会导致数据竞争&#xff0c;状态错误&#xff0c;或其他并发问题。如果一个函数是线程安全的&#xff0c;它就可以在不同的线程中同…

【小沐学Java】VSCode搭建Java开发环境

文章目录 1、简介2、安装VSCode2.1 简介2.2 安装 3、安装Java SDK3.1 简介3.2 安装3.3 配置 4、安装插件Java Extension Pack4.1 简介4.2 安装4.3 配置 结语 1、简介 2、安装VSCode 2.1 简介 Visual Studio Code 是一个轻量级但功能强大的源代码编辑器&#xff0c;可在桌面上…

如何使用小浪助手快速下载学浪中的视频?

今天给大家准备好了一个工具&#xff0c;小浪助手&#xff0c;它可以帮你们快速下载学浪中的视频 小浪助手我已经打包好了&#xff0c;有需要自己取一下 学浪下载工具链接&#xff1a;https://pan.baidu.com/s/1_Sg-EGGXKc4bMW-NPqUqvg?pwd1234 提取码&#xff1a;1234 --…

江苏宿迁服务器的优势有哪些?

江苏宿迁服务器是一款性能强大、稳定可靠的服务器&#xff0c;能够应用在各种应用场景当中&#xff0c;比如云计算、大数据分析等&#xff0c;接下来就让我们来了解一下江苏服务器的优势都有哪些吧&#xff01; 江苏宿迁服务器采用了优秀的散热技术&#xff0c;并且配置了多种安…

opencv动态识别人脸

import cv2 import os import numpy as npdef take_faces():while True:key input(请输入文件夹的名字&#xff0c;姓名拼音的缩写&#xff0c;如果输入Q&#xff0c;程序退出!)if key Q:break# 在faces_dynamic下面创建子文件夹os.makedirs(./faces_dymamic/%s % (key), exi…

【语音识别】搭建本地的语音转文字系统:FunASR(离线不联网即可使用)

参考自&#xff1a; 参考配置&#xff1a;FunASR/runtime/docs/SDK_advanced_guide_offline_zh.md at main alibaba-damo-academy/FunASR (github.com)参考配置&#xff1a;FunASR/runtime/quick_start_zh.md at 861147c7308b91068ffa02724fdf74ee623a909e alibaba-damo-aca…

电脑教程1

一、介绍几个桌面上面的软件 1、火绒&#xff1a;主要用于电脑的安全防护和广告拦截 1.1 广告拦截 1.打开火绒软件点击安全工具 点击弹窗拦截 点击截图拦截 拦截具体的小广告 2、向日葵远程控制&#xff1a;可以通过这个软件进行远程协助 可以自己去了解下 这个软件不要…

模块四:一维前缀和模板——DP34 【模板】前缀和

文章目录 题目描述算法原理解法一&#xff1a;暴力解法&#xff08;时间复杂度为O(n*q))解法二&#xff1a;前缀和(时间复杂度为O(n)O(q))细节问题 代码实现CJava 题目描述 题目链接&#xff1a;DP34 【模板】前缀和 根据描述第一句可得数组长度应设为n 1 算法原理 解法一…

编写一个函数fun,它的功能是:实现两个字符串的连接(不使用库函数strcat),即把p2所指的字符串连接到p1所指的字符串后。

本文收录于专栏:算法之翼 https://blog.csdn.net/weixin_52908342/category_10943144.html 订阅后本专栏全部文章可见。 本文含有题目的题干、解题思路、解题思路、解题代码、代码解析。本文分别包含C语言、C++、Java、Python四种语言的解法完整代码和详细的解析。 题干 编写…

个人学习-前端相关(2):ECMAScript 6-箭头函数、rest、spread

ES6的箭头函数 ES6允许使用箭头函数&#xff0c;语法类似java中的lambda表达式 let fun1 function(){} //普通的函数声明 let fun2 ()>{} //箭头函数声明 let fun3 (x) >{return x1} let fun4 x >{return x1} //参数列表中有且只有一个参数&#xff0c;()可…

支持向量机(SVM)详细介绍

一、SVM基本概念 支持向量机&#xff08;Support Vector Machine&#xff0c;简称SVM&#xff09;是一种二分类模型&#xff0c;它的基本模型是定义在特征空间上的间隔最大的线性分类器。SVM的核心思想是寻找一个超平面&#xff0c;将不同类别的样本点分开&#xff0c;并且使得…

LeetCode题目74:搜索二维矩阵

作者介绍&#xff1a;10年大厂数据\经营分析经验&#xff0c;现任大厂数据部门负责人。 会一些的技术&#xff1a;数据分析、算法、SQL、大数据相关、python 欢迎加入社区&#xff1a;码上找工作 作者专栏每日更新&#xff1a; LeetCode解锁1000题: 打怪升级之旅 python数据分析…

kubebuilder(3)实现operator

在前面的文章我们已经了解了operator项目的基本结构。现在我们来写一点简单的代码&#xff0c;然后把我们的crd和operator部署到k8s集群中。 需求 这是一个真实的需求&#xff0c;只不过做了简化。 在开发公司自己的paas平台&#xff0c;有一个需求是&#xff0c;用户在发版…

FFMpeg - macOS build 报错 : xcrun -sdk iphoneos clang ...

文章目录 报错1&#xff1a;xcrun -sdk iphoneos clang is unable to create an executable file报错 2 &#xff1a; error: unknown type name AudioDeviceID; 在 macOS 上使用 https://github.com/kewlbear/FFmpeg-iOS-build-script 脚本&#xff0c;运行 ./build-ffmpeg.sh…

236基于matlab的三维比例导引法仿真

基于matlab的三维比例导引法仿真&#xff0c;可以攻击静止/机动目标。1.三维空间内的比例导引程序&#xff0c;采用龙哥库塔积分法&#xff1b;2.文件名为bili3dnew的.m文件是主函数&#xff0c;执行时需调用目标机动子函数、导引律子函数、数值积分法子函数&#xff1b;3.文件…

模拟LinkedList实现的双向链表

1. 前言 前文我们用java语言实现了无哨兵的单向链表.稍作修改即可实现有哨兵的单向链表.有哨兵的单向链表相较与无哨兵的而言&#xff0c;其对链表的头结点的增删操作更为方便.而在此我们实现了带有头节点和尾节点的双向链表(该头节点和尾节点都不存储有效的数据). 2. 带有头…

统计建模——模型——python为例

统计建模涵盖了众多数学模型和分析方法&#xff0c;这些模型和方法被广泛应用于数据分析、预测、推断、分类、聚类等任务中。下面列举了一些常见的统计建模方法及其具体应用方式&#xff1a; 目录 1.线性回归模型&#xff1a; ----python实现线性回归模型 -------使用NumPy…