DeepViT:字节提出深层ViT的训练策略 | 2021 arxiv

作者发现深层ViT出现的注意力崩溃问题,提出了新颖的Re-attention机制来解决,计算量和内存开销都很少,在增加ViT深度时能够保持性能不断提高

来源:晓飞的算法工程笔记 公众号

论文: DeepViT: Towards Deeper Vision Transformer

  • 论文地址:https://arxiv.org/abs/2103.11886
  • 论文代码:https://github.com/zhoudaquan/dvit_repo

Introduction


  作者在ViT上效仿CNN堆叠更多层来提升性能的做法,但如图1所示,ViT的性能随着层数的增加会快速饱和。经过深入研究,作者发现这种缩放困难可能是由注意力崩溃问题引起。随着网络的深入,各层计算的注意力图逐渐变得相似,甚至在某些层之后几乎相同。这一事实表明,在ViT更深层中,self-attention机制无法有效地学习特征提取规律,阻碍了模型获得预期的性能提升。

  为了解决注意力崩溃问题并有效地扩展ViT的深度,作者提出了简单而有效的Re-attention方法。通过可学习的方式,该方法能够在多头自注意力(MHSA)的多个Head间进行信息交换,重新生成注意力图。重新生成的注意力图能够增加层的多样性,而且额外增加的计算和内存成本可以忽略不计。

  在没有任何额外的数据增强和正则化策略的情况下,只需用Re-attention替换ViTs中的MHSA模块,就可以训练非常深的ViT模型并得到相应的性能提升,如图2所示。

  总体而言,论文的贡献如下:

  • 深入研究ViT的行为,观察到ViT不能像CNN那样堆叠更多层中持续来提升性能,并且进一步确定了这种反直觉现象背后的根本原因为注意力崩溃。
  • 提出了Re-attention,一种简单而有效的注意机制,通过在不同注意头之间的进行信息交换来生成新的注意力图。
  • 第一个在ImageNet-1k上成功从零开始训练32层ViT并获得相应的性能提升,达到SOTA。

Revisiting Vision Transformer


  ViT模型如图2(a) 所示,由三个主要组件组成:用于Patch Embedding的线性层(即将高分辨率输入图像映射到低分辨率特征图),用于特征编码的多个包含MHSA和MLP的Transformer Block,用于分类分数预测的线性层。

  其中,最关键的MHSA层如公式1所示,也是Re-attention替换的目标。

Attention Collapse

  作者对ViT随深度增加而变化的性能进行了系统研究。首先根据DeiT的设置将中间层维度和MHSA的Head数量分别固定为384和12,然后堆叠不同数量的transformer blocks(从12到32不等)来构建不同深度的ViT模型。如前面所说的,作者惊讶地发现分类准确率会随着模型的深入而缓慢提高并快速饱和,在使用24个transformer blocks后提升就停止了。这一现象表明,现有的ViT难以从更深层次的架构中获益。

  这样的问题非常违反直觉,也值得探索。在CNN的早期开发阶段也观察到了类似的问题(即如何有效地训练深层模型),但后来被ResNet妥善解决了。通过更深入地研究transfromer的架构,作者认为自注意机制在ViT中起着关键作用,这使得它与CNN有显着不同。因此,作者首先研究自注意机制,观察其生成的注意力图如何随着模型的深入而变化

  为了测量各层注意力图的变化,需计算不同层注意力图之间的相似度:

  其中, M p , q M^{p,q} Mp,q是层pq的注意力图之间的余弦相似度矩阵,每个元素 M h , t p , q M^{p,q}_{h,t} Mh,tp,q衡量headh和tokent对应的层间注意力图的相似度。 A h , : , t ∗ A^{∗}_{h,:,t} Ah,:,t 是一个T维向量,表示输入token序列tT个输出标记中的每一个的贡献程度。因此, M h , t p , q M^{p,q}_{h,t} Mh,tp,q提供了关于token的权重如何从p层变化到q层的度量手段。当 M h , t p , q M^{p,q}_{h,t} Mh,tp,q等于1时,这意味着token序列t在层pq中对self-attention的作用完全相同。

  基于公式2,将ImageNet-1k上预训练32层ViT模型的所有注意力图之间的相似性进行可视化。如图3a所示,在第17层之后,相邻 k k k层的注意力图的相似度大于90%,这表明后面学习的注意力图都是相似的,即注意力崩溃问题。

  为了进一步验证不同深度的ViT是否存在这种现象,我们分别对12、16、24和32层的ViT进行了相同的实验,并计算了具有相似注意力图的块的数量。结果如图3b所示,当添加更多层时,相似注意力图的层数量与总层数的比率增加。

  为了解注意力崩溃如何影响ViT模型的性能,作者基于32层ViT模型,比较最终输出特征与每个中间层输出余弦相似度。结果如图4所示,学习到的特征在第20层之后停止变化,而且注意力图相似度的增加与特征相似度之间存在密切的相关性。这一观察表明,注意力崩溃是造成ViT不可扩展问题的根本原因。

Re-attention for Deep ViT


  将ViT扩展到更深的一个主要障碍是注意力崩溃问题,作者提出了两种解决方法,一种是增加自注意计算的中间维度,另一种是Re-attention机制。

Self-Attention in Higher Dimension Space

  克服注意力崩溃的一种直接解决方案是增加每个token的embedding维度。增加维度能够增强每个token embedding的表达能力,从而编码更多信息,生成更加多样化的注意力图以及减少相似性。

  作者基于12层ViT进行了不同中间维度的快速实验,维度范围从256到768。如图5和表1所示,增加embedding维度能够减少具有相似注意力图的层数以及缓解注意力崩溃,模型性能也得到相应的提高。这验证了作者的核心假设,注意力崩溃是ViT扩展的主要瓶颈。尽管这个方法有效,但持续增加embedding维度会显著增加计算成本,而且带来的性能提升往往也会减弱。此外,更大的模型通常需要更多的数据进行训练,存在过拟合风险以及降低训练效率。

Re-attention

  虽然不同transformer block之间的注意力图的相似性很高,但作者发现来自同一个Transformer block的不同Head的注意力图的相似性非常小,如图3c所示。实际上,同一自注意力层的不同Head主要关注输入token的不同方面。于是作者打算建立Head间交互来重新生成注意力图,使得训练的深层ViT的性能更优。

  Re-attention使用Head的注意力图作为基础,通过动态聚合生成一组新的注意力图。为了实现这一点,首先定义一个可学习的变换矩阵 Θ ∈ R H × H \Theta\in\mathbb{R}^{H\times H} ΘRH×H,在乘以V之前,使用该矩阵混合多个Head的注意力图重新生成新的注意力图。具体来说,Re-attention可定义为以下公式:

  其中变换矩阵 Θ \Theta Θ沿Head
维度乘以自注意力图ANorm是归一化函数,用于减少每层的方差, Θ \Theta Θ是可端到端学习的。

  Re-attention 的优点有两个:

  • 与其他注意力增强方法相比(随机丢弃注意力图元素或调节SoftMax温度),Re-attention利用Head之间的交互来收集互补信息,可以更好地提高注意力图的多样性。
  • Re-attention高效且易于实现,与原始的自注意力相比,只需要几行代码和可忽略不计的计算开销,比增加嵌入维度的方法更高效。

Experiments


  实验的基础模型配置,输入图片大小都是224x224

More Analysis on Attention Collapse

  • Attention reuse

  作者在24层和32层ViT模型上进行注意力复用的实验,将一个block的的注意力图直接共享给之后的所有块,block的选择为最后一个注意力图与相邻层的相似度小于90%的block。更多实现细节可以在补充材料中找到。

  结果如表3所示,共享注意力图的性能下降并不明显,这意味着注意力崩溃问题确实存在。当模型很深时,添加更多层的效率低下。

  • Visualization

  原始MHSA和Re-attention的注意力图可视化如图6所示。原始的MHSA学在较早层中主要关注相邻token之间的局部关系,并且随着层的深入逐渐覆盖更多token,最后在深层中具有高度相似性全局平均注意力图。在添加Re-attention后,深层的注意力图保持了多样性,并且与相邻层具有较小的相似性

Analysis on Re-attention

  • Re-attention v.s. Self-attention

  不同层数ViT上替换Re-attention对比。

  • Comparison to adding temperature in self-attention

  对比不同的缓解注意力图平滑问题的策略。

  • Comparison to dropping attentions

  对比注意力图dropout以及温度调节对相似性的影响。

Comparison with other SOTA models

  对比SOTA方法。

Conclusion


  作者发现深层ViT出现的注意力崩溃问题,提出了新颖的Re-attention机制来解决,计算量和内存开销都很少,在增加ViT深度时能够保持性能不断提高。



如果本文对你有帮助,麻烦点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】

work-life balance.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/868988.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

提升爬虫OCR识别率:解决嘈杂验证码问题

引言 在数据抓取和网络爬虫技术中,验证码是常见的防爬措施,特别是嘈杂文本验证码。处理嘈杂验证码是一个复杂的问题,因为这些验证码故意设计成难以自动识别。本文将介绍如何使用OCR技术提高爬虫识别嘈杂验证码的准确率,并结合实际…

面向对象的程序设计设计思想(解决问题所需要的类),面向过程的程序设计思想(解决问题的步骤)

一、引言 面向对象思想是现代编程语言的主流编程思想,除了C语言外,其他的主流编程语言,无论是脚本的还是非脚本的,基本上都引入了面向对象这一设计思想,面向对象设计思想是怎样的?为什么现在的编程语言大都…

模型驱动开发(Model-Driven Development,MDD):提高软件开发效率与一致性的利器

目录 前言1. 模型驱动开发的原理1.1 什么是模型驱动开发1.2 MDD的核心思想 2. 模型驱动开发的优势2.1 提高开发效率2.2 确保代码一致性2.3 促进沟通和协作2.4 方便维护和扩展 3. 实现模型驱动开发的方法3.1 选择合适的建模工具3.1.1 UML3.1.2 BPMN3.1.3 SysML 3.2 建模方法3.2.…

大学生竞赛管理系统-计算机毕业设计源码37276

大学生竞赛管理系统的设计与实现 摘 要 随着教育信息化的不断发展,大学生竞赛已成为高校教育的重要组成部分。传统的竞赛组织和管理方式存在着诸多问题,如信息不透明、效率低下、管理不便等。为了解决这些问题,提高竞赛组织和管理效率&#x…

K8S 上部署大数据相关组件

文章目录 一、前言二、Redis 一、前言 Artifact Hub 是一个专注于云原生应用的集中式搜索和发布平台。它旨在简化开发者在 CNCF(Cloud Native Computing Foundation)项目中寻找、安装和分享包与配置的过程。用户可以通过这个平台方便地发现、安装各类云原…

用SurfaceView实现落花动画效果

上篇文章 Android子线程真的不能刷新UI吗?(一)复现异常 中可以看出子线程更新main线程创建的View,会抛出异常。SurfaceView不依赖main线程,可以直接使用自己的线程控制绘制逻辑。具体代码怎么实现了? 这篇文章用Surfa…

vscode启用项目后,没有触发debugger

启动项目后在debugger时,一直不走断点,重启vscode和电脑,打开其他vscode项目,都不行 1.F12点击设置 2.然后取消忽略列表的勾选即可。

【力扣高频题】042.接雨水问题

上一篇我们通过采用 双指针 的方法解决了 经典 容器盛水 问题 ,本文我们接着来学习一道在面试中极大概率会被考到的经典题目:接雨水 问题 。 42. 接雨水 给定 n 个非负整数,表示每个宽度为 1 的柱子的高度图,计算按此排列的柱子…

Java-Redis-Clickhouse-Jenkins-MybatisPlus-Zookeeper-vscode-Docker-jdbc

文章目录 Clickhouse基础实操windows docker desktop 下载clickhousespringboot项目配置clickhouse Redis谈下你对Redis的了解?Redis一般都有哪些使用的场景?Redis有哪些常见的功能?Redis支持的数据类型有哪些?Redis为什么这么快…

第一个ffmpeg程序

在进行使用ffmpeg进行编写程序时,首先要记得进行注册设备(avdevice_register_all ),程序运行时,只需要注册一次就可以 avdevice_register_all 是 FFmpeg 多媒体处理库中的一个函数,其作用是注册所有可用的音…

【AI前沿】人工智能的历史演进

文章目录 📑引言一、人工智能的起源与早期发展1.1 古代与早期的智能机器设想1.2 20世纪初期的机械计算机1.3 图灵测试与计算智能1.4 达特茅斯会议与人工智能的正式诞生 二、早期AI研究与第一次冬天2.1 早期的探索与挑战2.2 早期的专家系统2.3 第一次AI冬天 三、专家…

SpringBoot日常:@Scheduled实现服务启动时执行一次

文章目录 一、Scheduled详解二、逻辑实现1、创建定时任务逻辑方法2、新建一个启动执行类 三、测试结果 说到定时任务,我们应该会想起Scheduled,Quartz以及XXL-JOB,但是有的单体服务或者小项目,为了方便快捷,可能会直接…

【昇思25天学习打卡营第1天】

前言 例如:随着大模型的爆火,这门技术也越来越重要,很多人都开启了关于大模型知识的学习,但大模型需要一定的资源且涉及的模块很多,如果个人想要系统的学习会有些难度,好在有昇思大模型平台,能…

WebRTC群发消息API接口选型指南!怎么用?

WebRTC群发消息API接口安全性如何?API接口怎么优化? WebRTC技术在现代实时通信中占据了重要地位。对于需要实现群发消息功能的应用程序来说,选择合适的WebRTC群发消息API接口是至关重要的。AokSend将详细介绍WebRTC群发消息API接口的选型指南…

本地部署 SenseVoice - 阿里开源语音大模型

本地部署 SenseVoice - 阿里开源语音大模型 1. 创建虚拟环境2. 克隆代码3. 安装依赖模块4. 启动 WebUI5. 访问 WebUI 1. 创建虚拟环境 conda create -n sensevoice python3.11 -y conda activate sensevoice 2. 克隆代码 git clone https://github.com/FunAudioLLM/SenseVoic…

本地部署 Llama3 – 8B/70B 大模型!

Llama3,作为Meta公司新发布的大型语言模型,在人工智能领域引起了广泛的关注。特别是其8B(80亿参数)版本,在性能上已经超越了GPT-3.5,而且由于是开源的,用户可以在自己的电脑上进行部署。 本文和…

太多项会毁了回归

「AI秘籍」系列课程: 人工智能应用数学基础 人工智能Python基础 人工智能基础核心知识 人工智能BI核心知识 人工智能CV核心知识 多项式回归的过度拟合及其避免方法 通过添加现有特征的幂,多项式回归可以帮助你充分利用数据集。它允许我们甚至使用简…

【智能算法改进】多策略改进的蜣螂优化算法

目录 1.算法原理2.改进点3.结果展示4.参考文献5.代码获取 1.算法原理 【智能算法】蜣螂优化算法(DBO)原理及实现 2.改进点 混沌反向学习初始化 采用 Pwlcm 分段混沌映射,由于 Pwlcm 在其定义区间上具有均匀的密度函数,在特定的…

User parameters 用户参数与Web监控

目录 一. 自定义键介绍 二. 制作步骤 1. 添加无可变部分参数 2. 添加有可变参数 3. 使用用户参数监控php-fpm 服务的状态 三. Web页面导入应用监控 四. Web监控 主要功能和操作: 开启方式 官方预定义监控项文档https://www.zabbix.com/documentation/6…

华三m-lag三层转发+VRRP配置案例

目录 一、相关理论介绍 1.1 华三M-LAG介绍 1.2 DRCP协议 1.3 keepalive机制 1.4 MAD机制 1.5 一致性检查功能 二、M-LAG系统建立及工作过程 三、实验组网案例 3.1 组网需求 3.2 组网拓扑 3.3 设备接口及地址规划 四、具体配置命令 4.1 S6850-1的配置 4.2 S6850-2…