CVPR | CNN融合注意力机制,芜湖起飞!

**标题:**On the Integration of Self-Attention and Convolution
**论文链接:**https://arxiv.org/pdf/2111.14556
**代码链接:**https://github.com/LeapLabTHU/ACmix

创新点

1. 揭示卷积和自注意力的内在联系

文章通过重新分解卷积和自注意力模块的操作,发现它们在第一阶段(特征投影)都依赖于 1×1 卷积操作,并且这一阶段占据了大部分的计算复杂度(与通道数的平方成正比)。这一发现为整合两种模块提供了理论基础。

2. 提出 ACmix 模型

基于上述发现,作者提出了 ACmix 模型,它通过共享 1×1 卷积操作来同时实现卷积和自注意力的功能。具体来说:
**第一阶段:**输入特征通过 1×1 卷积投影,生成中间特征。
**第二阶段:**这些中间特征分别用于卷积路径(通过移位和聚合操作)和自注意力路径(计算注意力权重并聚合值)。最终,两条路径的输出通过可学习的权重加权求和,得到最终输出。

3. 改进的移位和聚合操作

文章还提出了一种改进的移位操作,通过使用 固定卷积核的分组卷积 来替代传统的张量移位操作。这种方法不仅提高了计算效率,还允许卷积核的可学习性,进一步增强了模型的灵活性。

4. 适应性路径权重

ACmix 引入了两个可学习的标量参数(α 和 β),用于动态调整卷积路径和自注意力路径的权重。这种设计不仅提高了模型的灵活性,还允许模型在不同深度上自适应地选择更适合的特征提取方式。实验表明,这种设计在模型的不同阶段表现出不同的偏好,例如在早期阶段更倾向于卷积,在后期阶段更倾向于自注意力。

整体结构

第一阶段:特征投影

在第一阶段,输入特征通过三个1×1卷积进行投影,分别生成查询(query)、键(key)和值(value)特征映射。这些特征映射随后被重塑为N块,形成一个包含3×N特征映射的中间特征集。

第二阶段:特征聚合

在第二阶段,中间特征集被分为两个路径进行处理:

  • **自注意力路径:**将中间特征集分为N组,每组包含三个特征映射(分别对应查询、键和值)。这些特征映射按照传统的多头自注意力机制进行处理,计算注意力权重并聚合值。
  • **卷积路径:**通过轻量级的全连接层生成k²个特征映射(k为卷积核大小)。这些特征映射通过移位和聚合操作,以类似传统卷积的方式处理输入特征,从局部感受野收集信息。

输出整合

最后,自注意力路径和卷积路径的输出通过两个可学习的标量参数(α和β)加权求和,得到最终的输出。

改进的移位和聚合操作

为了提高计算效率,ACmix模型采用了改进的移位操作,通过固定卷积核的分组卷积来替代传统的张量移位操作。这种方法不仅提高了计算效率,还允许卷积核的可学习性,进一步增强了模型的灵活性。

模型的灵活性和泛化能力

ACmix模型不仅适用于标准的自注意力机制,还可以与各种变体(如Patchwise Attention、Window Attention和Global Attention)结合使用。这种设计使得ACmix能够适应不同的任务需求,具有广泛的适用性。

消融实验

1. 结合两个路径的输出

消融实验探索了卷积和自注意力输出的不同组合方式对模型性能的影响。实验结果表明:

  • **卷积和自注意力的组合优于单一路径:**使用卷积和自注意力模块的组合始终优于仅使用单一路径(如仅卷积或仅自注意力)的模型。
  • **可学习参数的灵活性:**通过引入可学习的参数(如α和β)来动态调整卷积和自注意力路径的权重,ACmix能够根据网络中不同位置的需求自适应地调整路径强度,从而获得更高的灵活性和性能。

2. 组卷积核的选择

实验还对组卷积核的设计进行了验证,结果表明:

  • **用组卷积替代张量位移:**通过使用组卷积替代传统的张量位移操作,显著提高了模型的推理速度。
  • **可学习卷积核和初始化:**使用可学习的卷积核并结合精心设计的初始化方法,进一步增强了模型的灵活性,并有助于提升最终性能。

3. 不同路径的偏好

ACmix模型引入了两个可学习标量α和β,用于动态调整卷积和自注意力路径的权重。通过平行实验,观察到以下趋势:

  • **早期阶段偏好卷积:**在Transformer模型的早期阶段,卷积作为特征提取器表现更好。
  • **中间阶段混合使用:**在网络的中间阶段,模型倾向于混合使用两种路径,并逐渐增加对卷积的偏好。
  • **后期阶段偏好自注意力:**在网络的最后阶段,自注意力表现优于卷积。

4. 对模型性能的影响

这些消融实验结果表明,ACmix模型通过合理结合卷积和自注意力的优势,并优化计算路径,不仅在多个视觉任务上取得了显著的性能提升,还保持了较高的计算效率

ACmix模块的作用

1. 融合卷积和自注意力的优势

ACmix模块通过结合卷积的局部特征提取能力和自注意力的全局感知能力,实现了一种高效的特征融合策略。这种设计使得模型能够同时利用卷积的局部感受野特性和自注意力的灵活性。

2. 优化计算路径

ACmix通过优化计算路径和减少重复计算,提高了整体模块的计算效率。具体来说,它通过1×1卷积对输入特征图进行投影,生成中间特征,然后根据不同的范式(卷积和自注意力)分别重用和聚合这些中间特征。这种设计不仅减少了计算开销,还提升了模型性能。

3. 改进的位移与求和操作

在卷积路径中,ACmix采用深度可分离卷积(depthwise convolution)来替代低效的张量位移操作,从而提高了实际推理效率。

4. 动态调整路径权重

ACmix引入了两个可学习的标量参数(α和β),用于动态调整卷积和自注意力路径的权重。这种设计使得模型能够根据网络中不同位置的需求自适应地调整路径强度,从而获得更高的灵活性。

5. 广泛的应用潜力

ACmix在多个视觉任务(如图像分类、语义分割和目标检测)上均显示出优于单一机制(仅卷积或仅自注意力)的性能,展示了其广泛的应用潜力。

6. 实验验证

实验结果表明,ACmix在保持较低计算开销的同时,能够显著提升模型的性能。例如,在ImageNet分类任务中,ACmix模型在相同的FLOPs或参数数量下表现出色,并且在与竞争对手的基准比较中取得了持续的改进。此外,ACmix在ADE20K语义分割任务和COCO目标检测任务中也显示出明显的改进

代码实现

import torch
import torch.nn as nndef position(H, W, is_cuda=True):if is_cuda:loc_w = torch.linspace(-1.0, 1.0, W).cuda().unsqueeze(0).repeat(H, 1)loc_h = torch.linspace(-1.0, 1.0, H).cuda().unsqueeze(1).repeat(1, W)else:loc_w = torch.linspace(-1.0, 1.0, W).unsqueeze(0).repeat(H, 1)loc_h = torch.linspace(-1.0, 1.0, H).unsqueeze(1).repeat(1, W)loc = torch.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0).unsqueeze(0)return locdef stride(x, stride):b, c, h, w = x.shapereturn x[:, :, ::stride, ::stride]def init_rate_half(tensor):if tensor is not None:tensor.data.fill_(0.5)def init_rate_0(tensor):if tensor is not None:tensor.data.fill_(0.)class ACmix(nn.Module):def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, stride=1, dilation=1):super(ACmix, self).__init__()self.in_planes = in_planesself.out_planes = out_planesself.head = headself.kernel_att = kernel_attself.kernel_conv = kernel_convself.stride = strideself.dilation = dilationself.rate1 = torch.nn.Parameter(torch.Tensor(1))self.rate2 = torch.nn.Parameter(torch.Tensor(1))self.head_dim = self.out_planes // self.headself.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1)self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2self.pad_att = torch.nn.ReflectionPad2d(self.padding_att)self.unfold = nn.Unfold(kernel_size=self.kernel_att, padding=0, stride=self.stride)self.softmax = torch.nn.Softmax(dim=1)self.fc = nn.Conv2d(3 * self.head, self.kernel_conv * self.kernel_conv, kernel_size=1, bias=False)self.dep_conv = nn.Conv2d(self.kernel_conv * self.kernel_conv * self.head_dim, out_planes,kernel_size=self.kernel_conv, bias=True, groups=self.head_dim, padding=1,stride=stride)self.reset_parameters()def reset_parameters(self):init_rate_half(self.rate1)init_rate_half(self.rate2)kernel = torch.zeros(self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv)for i in range(self.kernel_conv * self.kernel_conv):kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1.kernel = kernel.squeeze(0).repeat(self.out_planes, 1, 1, 1)self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True)self.dep_conv.bias = init_rate_0(self.dep_conv.bias)def forward(self, x):q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)scaling = float(self.head_dim) ** -0.5b, c, h, w = q.shapeh_out, w_out = h // self.stride, w // self.stride# ### att# ## positional encodingpe = self.conv_p(position(h, w, x.is_cuda))q_att = q.view(b * self.head, self.head_dim, h, w) * scalingk_att = k.view(b * self.head, self.head_dim, h, w)v_att = v.view(b * self.head, self.head_dim, h, w)if self.stride > 1:q_att = stride(q_att, self.stride)q_pe = stride(pe, self.stride)else:q_pe = peunfold_k = self.unfold(self.pad_att(k_att)).view(b * self.head, self.head_dim,self.kernel_att * self.kernel_att, h_out,w_out) # b*head, head_dim, k_att^2, h_out, w_outunfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att * self.kernel_att, h_out,w_out) # 1, head_dim, k_att^2, h_out, w_outatt = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(1) # (b*head, head_dim, 1, h_out, w_out) * (b*head, head_dim, k_att^2, h_out, w_out) -> (b*head, k_att^2, h_out, w_out)att = self.softmax(att)out_att = self.unfold(self.pad_att(v_att)).view(b * self.head, self.head_dim, self.kernel_att * self.kernel_att,h_out, w_out)out_att = (att.unsqueeze(1) * out_att).sum(2).view(b, self.out_planes, h_out, w_out)## convf_all = self.fc(torch.cat([q.view(b, self.head, self.head_dim, h * w), k.view(b, self.head, self.head_dim, h * w),v.view(b, self.head, self.head_dim, h * w)], 1))f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])out_conv = self.dep_conv(f_conv)return self.rate1 * out_att + self.rate2 * out_conv#输入 B C H W, 输出 B C H W
if __name__ == '__main__':block = ACmix(in_planes=64, out_planes=64)input = torch.rand(3, 64, 32, 32)output = block(input)print(input.size(), output.size())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/67997.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LOCAL_PREBUILT_JNI_LIBS使用说明

LOCAL_PREBUILT_JNI_LIBS使用说明 使用LOCAL_PREBUILT_JNI_LIBS,可用于控制APK集成时,其相关so的集成方式。 比如,用于将APK中的so,抽取出来。 LOCAL_PREBUILT_JNI_LIBS : \lib/arm64-v8a/libNativeCore.so \lib/arm64-v8a/liba…

Java中的object类

1.Object类是什么? 🟪Object 是 Java 类库中的一个特殊类,也是所有类的父类(超类),位于类继承层次结构的顶端。也就是说,Java 允许把任何类型的对象赋给 Object 类型的变量。 🟦Java里面除了Object类,所有的…

uniapp小程序自定义中间凸起样式底部tabbar

我自己写的自定义的tabbar效果图 废话少说咱们直接上代码,一步一步来 第一步: 找到根目录下的 pages.json 文件,在 tabBar 中把 custom 设置为 true,默认值是 false。list 中设置自定义的相关信息, pagePath&#x…

四、GPIO中断实现按键功能

4.1 GPIO简介 输入输出(I/O)是一个非常重要的概念。I/O泛指所有类型的输入输出端口,包括单向的端口如逻辑门电路的输入输出管脚和双向的GPIO端口。而GPIO(General-Purpose Input/Output)则是一个常见的术语&#xff0c…

【Elasticsearch】post_filter

post_filter是 Elasticsearch 中的一种后置过滤机制,用于在查询执行完成后对结果进行过滤。以下是关于post_filter的详细介绍: 工作原理 • 查询后过滤:post_filter在查询执行完毕后对返回的文档集进行过滤。这意味着所有与查询匹配的文档都…

从零开始:用Qt开发一个功能强大的文本编辑器——WPS项目全解析

文章目录 引言项目功能介绍1. **文件操作**2. **文本编辑功能**3. **撤销与重做**4. **剪切、复制与粘贴**5. **文本查找与替换**6. **打印功能**7. **打印预览**8. **设置字体颜色**9. **设置字号**10. **设置字体**11. **左对齐**12. **右对齐**13. **居中对齐**14. **两侧对…

【IoCDI】_Spring的基本扫描机制

目录 1. 创建测试项目 2. 改变启动类所属包 3. 使用ComponentScan 4. Spring基本扫描机制 程序通过注解告诉Spring希望哪些bean被管理,但在仅使用Bean时已经发现,Spring需要根据五大类注解才能进一步扫描方法注解。 由此可见,Spring对注…

通向AGI之路:人工通用智能的技术演进与人类未来

文章目录 引言:当机器开始思考一、AGI的本质定义与技术演进1.1 从专用到通用:智能形态的范式转移1.2 AGI发展路线图二、突破AGI的五大技术路径2.1 神经符号整合(Neuro-Symbolic AI)2.2 世界模型架构(World Models)2.3 具身认知理论(Embodied Cognition)三、AGI安全:价…

【工具变量】中国省级八批自由贸易试验区设立及自贸区设立数据(2024-2009年)

一、测算方式:参考C刊《中国软科学》任晓怡老师(2022)的做法,使用自由贸易试验区(Treat Post) 表征,Treat为个体不随时间变化的虚拟变量,如果该城市设立自由贸易试验区则赋值为1,反之赋值为0&am…

Java进阶总结——集合

Java进阶总结——集合 说明:对于以上的框架图有如下几点说明 1.所有集合类都位于java.util包下。Java的集合类主要由两个接口派生而出:Collection和Map,Collection和Map是Java集合框架的根接口,这两个接口又包含了一些子接口或实…

计算机视觉和图像处理

计算机视觉与图像处理的最新进展 随着人工智能技术的飞速发展,计算机视觉和图像处理作为其中的重要分支,正逐步成为推动科技进步和产业升级的关键力量。 一、计算机视觉的最新进展 计算机视觉,作为人工智能的重要分支,主要研究如…

3.PPT:华老师-计算机基础课程【3】

目录 NO12​ NO34​ NO56​ NO789​ NO12 根据考生文件夹下的Word文档“PPT素材.docx”中提供的内容在PPT.pptx中生成初始的6张幻灯片 新建幻灯片6张→ctrlc复制→ctrlv粘贴开始→新建幻灯片→幻灯片(从大纲)→Word文档注❗前提是:Word文档必须应用标题1、标题2…

(三)QT——信号与槽机制——计数器程序

目录 前言 信号(Signal)与槽(Slot)的定义 一、系统自带的信号和槽 二、自定义信号和槽 三、信号和槽的扩展 四、Lambda 表达式 总结 前言 信号与槽机制是 Qt 中的一种重要的通信机制,用于不同对象之间的事件响…

基于多智能体强化学习的医疗AI中RAG系统程序架构优化研究

一、引言 1.1 研究背景与意义 在数智化医疗飞速发展的当下,医疗人工智能(AI)已成为提升医疗服务质量、优化医疗流程以及推动医学研究进步的关键力量。医疗 AI 借助机器学习、深度学习等先进技术,能够处理和分析海量的医疗数据,从而辅助医生进行疾病诊断、制定治疗方案以…

Redis --- 秒杀优化方案(阻塞队列+基于Stream流的消息队列)

下面是我们的秒杀流程: 对于正常的秒杀处理,我们需要多次查询数据库,会给数据库造成相当大的压力,这个时候我们需要加入缓存,进而缓解数据库压力。 在上面的图示中,我们可以将一条流水线的任务拆成两条流水…

使用 Ollama 和 Kibana 在本地为 RAG 测试 DeepSeek R1

作者:来自 Elastic Dave Erickson 及 Jakob Reiter 每个人都在谈论 DeepSeek R1,这是中国对冲基金 High-Flyer 的新大型语言模型。现在他们推出了一款功能强大、具有开放权重的思想链推理 LLM,这则新闻充满了对行业意味着什么的猜测。对于那些…

2025年大年初一篇,C#调用GPU并行计算推荐

C#调用GPU库的主要目的是利用GPU的并行计算能力,加速计算密集型任务,提高程序性能,支持大规模数据处理,优化资源利用,满足特定应用场景的需求,并提升用户体验。在需要处理大量并行数据或进行复杂计算的场景…

Unity 2D实战小游戏开发跳跳鸟 - 计分逻辑开发

上文对障碍物的碰撞逻辑进行了开发,接下来就是进行跳跳鸟成功穿越过障碍物进行计分的逻辑开发,同时将对应的分数以UI的形式显示告诉玩家。 计分逻辑 在跳跳鸟通过障碍物的一瞬间就进行一次计分,计分后会同步更新分数的UI显示来告知玩家当前获得的分数。 首先我们创建一个用…

langchain基础(二)

一、输出解析器(Output Parser) 作用:(1)让模型按照指定的格式输出; (2)解析模型输出,提取所需的信息 1、逗号分隔列表 CommaSeparatedListOutputParser:…

游戏AI,让AI 玩游戏有什么作用?

让 AI 玩游戏这件事远比我们想象的要早得多。追溯到 1948 年,图灵和同事钱伯恩共同设计了国际象棋程序 Turochamp。之所以设计这么个程序,图灵是想说明,机器理论上能模拟人脑能做的任何事情,包括下棋这样复杂的智力活动。 可惜的是…