混合注意力 ACmix | On the Integration of Self-Attention and Convolution

在这里插入图片描述

论文名称:《On the Integration of Self-Attention and Convolution》

论文地址:2111.14556 (arxiv.org)


卷积和自注意力是两种强大的表示学习技术,通常被认为是两种截然不同的并列方法。在本文中,我们展示了它们之间存在一种强烈的潜在关系,从计算角度来看,这两个范式的大部分计算实际上是使用相同的操作完成的。具体而言,我们首先表明,传统的卷积操作,具有 k × k k×k k×k 的核大小,可以被分解为 个独立的 1 × 1 1×1 1×1 卷积,然后是移位和求和操作。然后,我们将自注意力模块中查询、键和值的投影解释为多个 1 × 1 1×1 1×1 卷积,然后计算注意力权重并聚合这些值。因此,这两个模块的第一阶段包含了相似的操作。更重要的是,相较于第二阶段,第一阶段在计算复杂性上占据主要地位(频道大小的平方)。这一发现自然引出了这两个看似截然不同的范式的巧妙结合,即一个混合模型,它既享受自注意力和卷积的好处,同时比纯卷积或纯自注意力模型具有最小的计算开销。大量实验表明,我们的模型在图像识别和下游任务上持续取得比竞争基准更好的结果。


问题背景

论文讨论了深度学习中的两大主要方法:卷积和自注意力。在计算机视觉领域,卷积神经网络(CNN)通常是默认的选择,而自注意力则在自然语言处理(NLP)中更常见。该论文的背景在于这些两种技术的不同设计范式,以及它们在计算机视觉任务中的应用。这引发了一个问题:这两者是否可以通过某种方式结合,以实现更好的性能和效率。


核心概念

该论文提出了一种名为ACmix的新方法,将卷积和自注意力结合在一起。核心概念在于发现这两种技术在计算上有共同点。传统的卷积可以分解为一系列的 1 × 1 1×1 1×1 卷积,而自注意力的投影过程也类似。这种相似性为将这两种方法结合在一起提供了机会,从而形成一种混合模型,既具有卷积的局部特性,又具备自注意力的灵活性。


模块的操作步骤


在这里插入图片描述

展示了 ACmix 的概念草图。我们探索了卷积和自注意力之间的密切关系,这种关系在于共享相同的计算负载(1×1 卷积),并结合其余的轻量级聚合操作。我们展示了每个模块相对于特征通道的计算复杂性。


ACmix中,操作步骤分为两个阶段:

  • 阶段一:将输入特征映射通过 1 × 1 1×1 1×1 卷积进行投影,形成中间特征。
  • 阶段二:根据不同的范式,分别应用卷积和自注意力的操作。在卷积路径上,利用ShiftSummation操作实现卷积过程。在自注意力路径上,计算查询、键和值,然后应用传统的自注意力方法。

文章贡献

  1. 揭示了卷积和自注意力之间的强关联性,提供了一种理解这两者之间关系的新方式。
  2. 提出了ACmix,一种将卷积和自注意力优雅地结合在一起的模型。它能够在没有额外计算负担的情况下,享受这两种方法的优势。

实验结果与应用:

ACmixImageNet分类、语义分割和目标检测等任务中进行了验证。实验结果表明,与传统的卷积或自注意力模型相比,ACmix在准确性、计算开销和参数数量上具有优势。在ImageNet分类任务中,ACmix的模型在相同的FLOPs或参数数量下表现出色,并且在与竞争对手的基准比较中取得了持续的改进。此外,ACmixADE20K语义分割任务和COCO目标检测任务中也显示出明显的改进,进一步验证了该模型的有效性。


对未来工作的启示:

该论文提出了ACmix这一混合模型,为计算机视觉领域提供了一种新的方向。它揭示了卷积和自注意力之间的关系,为设计新的学习范式提供了灵感。未来的工作可以在以下几个方面继续探索:

  • 在其他自注意力模块中应用ACmix,并进一步验证其有效性。
  • 研究如何在更大的模型中更好地整合卷积和自注意力,从而在更复杂的任务中实现性能提升。
  • 考虑在实际应用中进一步优化ACmix的计算效率,以确保其在生产环境中的可用性。

代码

import torch
import torch.nn as nn
import torch.nn.functional as Fdef position(H, W, is_cuda=True):if is_cuda:loc_w = torch.linspace(-1.0, 1.0, W).cuda().unsqueeze(0).repeat(H, 1)loc_h = torch.linspace(-1.0, 1.0, H).cuda().unsqueeze(1).repeat(1, W)else:loc_w = torch.linspace(-1.0, 1.0, W).unsqueeze(0).repeat(H, 1)loc_h = torch.linspace(-1.0, 1.0, H).unsqueeze(1).repeat(1, W)loc = torch.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0).unsqueeze(0)return locdef stride(x, stride):b, c, h, w = x.shapereturn x[:, :, ::stride, ::stride]def init_rate_half(tensor):if tensor is not None:tensor.data.fill_(0.5)def init_rate_0(tensor):if tensor is not None:tensor.data.fill_(0.0)class ACmix(nn.Module):def __init__(self,in_planes,out_planes,kernel_att=7,head=4,kernel_conv=3,stride=1,dilation=1,):super(ACmix, self).__init__()self.in_planes = in_planesself.out_planes = out_planesself.head = headself.kernel_att = kernel_attself.kernel_conv = kernel_convself.stride = strideself.dilation = dilationself.rate1 = torch.nn.Parameter(torch.Tensor(1))self.rate2 = torch.nn.Parameter(torch.Tensor(1))self.head_dim = self.out_planes // self.headself.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1)self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2self.pad_att = torch.nn.ReflectionPad2d(self.padding_att)self.unfold = nn.Unfold(kernel_size=self.kernel_att, padding=0, stride=self.stride)self.softmax = torch.nn.Softmax(dim=1)self.fc = nn.Conv2d(3 * self.head,self.kernel_conv * self.kernel_conv,kernel_size=1,bias=False,)self.dep_conv = nn.Conv2d(self.kernel_conv * self.kernel_conv * self.head_dim,out_planes,kernel_size=self.kernel_conv,bias=True,groups=self.head_dim,padding=1,stride=stride,)self.reset_parameters()def reset_parameters(self):init_rate_half(self.rate1)init_rate_half(self.rate2)kernel = torch.zeros(self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv)for i in range(self.kernel_conv * self.kernel_conv):kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1.0kernel = kernel.squeeze(0).repeat(self.out_planes, 1, 1, 1)self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True)self.dep_conv.bias = init_rate_0(self.dep_conv.bias)def forward(self, x):q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)scaling = float(self.head_dim) ** -0.5b, c, h, w = q.shapeh_out, w_out = h // self.stride, w // self.stridepe = self.conv_p(position(h, w, x.is_cuda))q_att = q.view(b * self.head, self.head_dim, h, w) * scalingk_att = k.view(b * self.head, self.head_dim, h, w)v_att = v.view(b * self.head, self.head_dim, h, w)if self.stride > 1:q_att = stride(q_att, self.stride)q_pe = stride(pe, self.stride)else:q_pe = peunfold_k = self.unfold(self.pad_att(k_att)).view(b * self.head,self.head_dim,self.kernel_att * self.kernel_att,h_out,w_out,)  # b*head, head_dim, k_att^2, h_out, w_outunfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att * self.kernel_att, h_out, w_out)  # 1, head_dim, k_att^2, h_out, w_outatt = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(1)  # (b*head, head_dim, 1, h_out, w_out) * (b*head, head_dim, k_att^2, h_out, w_out) -> (b*head, k_att^2, h_out, w_out)att = self.softmax(att)out_att = self.unfold(self.pad_att(v_att)).view(b * self.head,self.head_dim,self.kernel_att * self.kernel_att,h_out,w_out,)out_att = ((att.unsqueeze(1) * out_att).sum(2).view(b, self.out_planes, h_out, w_out))f_all = self.fc(torch.cat([q.view(b, self.head, self.head_dim, h * w),k.view(b, self.head, self.head_dim, h * w),v.view(b, self.head, self.head_dim, h * w),],1,))f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])out_conv = self.dep_conv(f_conv)return self.rate1 * out_att + self.rate2 * out_convif __name__ == "__main__":input = torch.randn(64, 256, 8, 8)model = ACmix(in_planes=256, out_planes=256)output = model(input)print(output.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/830191.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

history命令显示时间戳、IP地址、用户名

一、前置知识 history命令的功能是显示和管理用户所执行过的所有命令记录。这些记录默认被Linux系统保存。用户可以使用history命令查阅这些记录,也可以对其记录进行修改和删除操作。 history命令的常用参数如下: -a: 保存命令记录-c: 清空命令记录-d:…

JavaScript 如何理解柯里化函数结构及调用

文章目录 柯里化函数是什么逐步理解柯里化函数 柯里化函数是什么 柯里化(Currying)函数,又称部分求值,是一种函数转换技术。这种技术将一个接受多个参数的函数转换为一系列接受单一参数的函数。具体来说,一个柯里化的…

2024 java使用Graceful Response,告别自己去封装响应,可以接收数据异常,快看我这一篇,足够你用!

参考官网手册地址&#xff1a;快速入门 | Docs 一、导入依赖&#xff08;根据springboot查看对应依赖版本&#xff09; <!-- Graceful --><dependency><groupId>com.feiniaojin</groupId><artifactId>graceful-response</artifactId&g…

微信小程序与web-view网页进行通信的尝试

首先&#xff0c;微信小程序向web-view传递数据一般通过地址栏传参的形式&#xff08;给src赋值或者修改hash&#xff09;&#xff0c;这样一般就已经能够满足实际开发需求了&#xff0c;所以这里主要探讨web-view向微信小程序传参。下面&#xff0c;我们从官方文档入手&#x…

基于51单片机智能窗帘仿真设计( proteus仿真+程序+设计报告+讲解视频)

基于51单片机智能窗帘仿真设计( proteus仿真程序设计报告讲解视频&#xff09; 基于51单片机智能窗帘仿真设计 1. 主要功能&#xff1a;2. 讲解视频&#xff1a;3. 仿真设计4. 程序代码5. 设计报告6. 原理图7. 设计资料内容清单资料下载链接&#xff1a; 仿真图proteus8.9及以上…

【JAVA进阶篇教学】第七篇:Spring中常用注解

博主打算从0-1讲解下java进阶篇教学&#xff0c;今天教学第七篇&#xff1a;Spring中常用注解 在Java Spring框架中&#xff0c;注解&#xff08;Annotation&#xff09;是一种元数据&#xff0c;它提供了关于程序代码的额外信息&#xff0c;这些信息可以用于编译时检查、运行时…

【国信华源北斗型雨量站新品亮相第三届防汛抗旱抢险新技术新产品展示会】

4月24—25日&#xff0c;第三届防汛抗旱抢险新技术、新产品应用研讨与展示会暨中国水利企业协会防灾与抢险装备技术分会年会在河南郑州召开。由《中国防汛抗旱》杂志社、水利部防洪抗旱减灾工程技术研究中心主办&#xff0c;围绕我国防汛抗旱形势、防灾与抢险新技术新产品现状和…

vue2实现字节流byte[]数组的图片预览

项目使用vantui框架&#xff0c;后端返回图片的字节流byte[]数组&#xff0c;在移动端实现预览&#xff0c;实现代码如下&#xff1a; <template><!-- 附件预览 --><div class"file-preview-wrap"><van-overlay :show"show"><…

【Markdown笔记】——设置markdown中文字的颜色

【Markdown笔记】——设置markdown中文字的颜色 Markdownmarkdown中设置文字颜色常用颜色对照表【含RGB值对照】 &#x1f49d;&#x1f49d;&#x1f49d; 欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#x…

笔记:能量谱密度与功率谱密度(二)

目录 一、ESD与PSD的定义、单位、性质 二、对ESD与PSD的直观理解 三、总结&#xff1a; 某物理量的“分布”在离散系统中&#xff0c;各点(纵坐标含义&#xff09;的物理意义仍然是该物理量&#xff0c;而在连续系统中&#xff0c;各点&#xff08;纵坐标含义&#xff09;的物…

实战干货|Spark 在袋鼠云数栈的深度探索与实践

Spark 是一个快速、通用、可扩展的大数据计算引擎&#xff0c;具有高性能、易用、容错、可以与 Hadoop 生态无缝集成、社区活跃度高等优点。在实际使用中&#xff0c;具有广泛的应用场景&#xff1a; 数据清洗和预处理&#xff1a;在大数据分析场景下&#xff0c;数据通常需要…

后台架构总结

前言 疫情三年&#xff0c;全国各地的健康码成为了每个人的重要生活组成部分。虽然过去一年&#xff0c;但是回想起来任然历历在目。 今天我就通过当时基于小程序的健康码架构&#xff0c;来给大家讲一下如何基于java&#xff0c;springboot等技术来快速搭建一个后台业务系统…

Pixelmator Pro for Mac:简洁而强大的图像编辑软件

Pixelmator Pro for Mac是一款专为Mac用户设计的图像编辑软件&#xff0c;它集简洁的操作界面与强大的功能于一身&#xff0c;为用户提供了卓越的图像编辑体验。 Pixelmator Pro for Mac v3.5.9中文激活版下载 该软件支持多种文件格式&#xff0c;包括常见的JPEG、PNG、TIFF等&…

系统触发器

目录 数据库触发器 常见触发器&#xff0c;记录登录和退出数据库事件 模式触发器 创建一个模式触发器&#xff0c;记录各种 DDL 操作的日志 Oracle从入门到总裁:​​​​​​https://blog.csdn.net/weixin_67859959/article/details/135209645 前面已经介绍过&#xff0c;…

WEB攻防-PHP特性-函数缺陷对比

目录 和 MD5函数 intval ​strpos in_array preg_match str_replace 和 使用 时&#xff0c;如果两个比较的操作数类型不同&#xff0c;PHP 会尝试将它们转换为相同的类型&#xff0c;然后再进行比较。 使用 进行比较时&#xff0c;不仅比较值&#xff0c;还比较变量…

MATLAB非均匀网格梯度计算

在matlab中&#xff0c;gradient函数可以很方便的对均匀网格进行梯度计算&#xff0c;但是对于非均匀网格&#xff0c;但是gradient却无法求解非均匀网格的梯度&#xff0c;这一点我之前犯过错误。我之前以为在gradient函数中指定x&#xff0c;y等坐标&#xff0c;其求解的就是…

Metasploit 溢出 samba 提权漏洞

一、信息收集 1.1 右键单击桌面&#xff0c;选择 Open Terminal Here &#xff0c;打开终端。 1.2 输入命令 nmap -sS -p 139,445 -A 192.168.1.254 ,对目标主机进行扫描,发现 139、445 端口开放。 1.3 输入命令“msfconsole”&#xff0c;启动 MSF 终端。 1.4 输入命令“searc…

电脑录制视频快捷键,一键开启录屏新时代(干货)

“最近尝试录制一些电脑上的操作视频&#xff0c;用来制作教学教程。不过&#xff0c;每次录制都要通过菜单或搜索来打开录屏软件&#xff0c;实在是有些繁琐。有没有人知道哪些电脑录制视频的快捷键呀&#xff1f;或者有没有通用的快捷键设置方法&#xff1f;” 在当今数字时…

免费语音转文字:自建Whisper,贝锐花生壳3步远程访问

Whisper是OpenAI开发的自动语音识别系统&#xff08;语音转文字&#xff09;。 OpenAI称其英文语音辨识能力已达到人类水准&#xff0c;且支持其它98中语言的自动语音辨识&#xff0c;Whisper神经网络模型被训练来运行语音辨识与翻译任务。 此外&#xff0c;与其他需要联网运行…

MySQL中脏读与幻读

一般对于我们的业务系统去访问数据库而言&#xff0c;它往往是多个线程并发执行多个事务的&#xff0c;对于数据库而言&#xff0c;它会有多个事务同时执行&#xff0c;可能这多个事务还会同时更新和查询同一条数据&#xff0c;所以这里会有一些问题需要数据库来解决 我们来看…