DVT:华为提出动态级联Vision Transformer,性能杠杠的 | NeurIPS 2021

论文主要处理Vision Transformer中的性能问题,采用推理速度不同的级联模型进行速度优化,搭配层级间的特征复用和自注意力关系复用来提升准确率。从实验结果来看,性能提升不错

来源:晓飞的算法工程笔记 公众号

论文: Not All Images are Worth 16x16 Words: Dynamic Transformers for Efficient Image Recognition

  • 论文地址:https://arxiv.org/abs/2105.15075
  • 论文代码:https://github.com/blackfeather-wang/Dynamic-Vision-Transformer

Introduction


  Transformers是自然语言处理 (NLP) 中占主导地位的自注意的模型,最近很多研究将其成功适配到图像识别任务。这类模型不仅在ImageNet上取得了SOTA,而且性能还能随着数据集规模的增长而不断增长。这类模型一般都先将图像拆分为固定数量的图像块,然后转换为1D token作为输入,拆分更多的token有助于提高预测的准确性,但也会带来巨额的计算成本(与token数成二次增长)。为了权衡性能和准确率,现有的这类模型都采用14x14或16x16的token数量。

  论文认为不同图片之间存在相当大的差异,使用相同数量的token处理所有图片并不是最优的。最理想的做法应为每个输入专门配置token数量,这也是模型计算效率的关键。以T2T-ViT-12为例,官方推荐的14x14 token数仅比4x4 token数增加了15.9%(76.7% 对 60.8%)的准确率,却增加了8.5倍的计算成本(1.78G 对 0.21G)。也就是说,对“简单”图片使用14x14 token数配置浪费了大量计算资源,使用4x4 token数配置就足够了。

  受此启发,论文提出了一种动态Vision Transformer(DVT)框架,能够根据每个图片自动配置合适的token数,实现高效计算。训练时使用逐渐增多的token数训练级联Transformer,测试时从较少的token数开始依次推理,得到置信度足够的预测即终止推理过程。通过自动调整token数,“简单”样本和“困难”样本的计算消耗将会不一样,从而显着提高效率。

  另外,论文还设计了基于特征和基于关系的两种复用机制,减少冗余的计算。前者允许下游模型在先前提取的深度特征上进行训练,而后者允许利用上游模型中的自注意力关系来学习更准确的注意力图。

  DVT是一个通用框架,可集成到大多数图像识别的Transformer模型中。而且可以通过简单地调整提前终止标准,在线调整整体计算成本,适用于计算资源动态波动或需要以最小功耗来实现特定性能的情况。从ImageNet和CIFAR的实验结果来看,在精度相同的情况下,DVT能将T2T-ViT的计算成本降低1.6-3.6倍,而在NVIDIA 2080Ti上的真实推理速度也与理论结果一致。

Dynamic Vision Transformer


Overview

  • Inference

  DVT的推理过程如图2所示。对于每张测试图片,先使用少量1D token序列对其进行粗略表示,可通过直接使用分割图像块或利用如tokens-to-token模块之类的技术来实现,然后通过Vision Transformer对这些token进行快速预测。由于Transformer的计算消耗与token数量成二次增长,所以这个过程很快。最后基于预设的终止标准对预测结果进行快速评估,确定是否足够可靠。

  如果预测未能满足终止标准,原始输入图像将被拆分为更多token,再进行更准确、计算成本更高的推理。每个token embedding的维度保持不变,只增加token数量,从而实现更细粒度的表示。此时推理使用的Vision Transformer与上一级具有相同架构,但参数是不同的。根据设计,此阶段在某些“困难”测试图片上权衡计算量以获得更高的准确性。为了提高效率,新模型可以复用之前学习的特征和关系。在获得新的预测结果后,同样根据终止标准进行判断,不符合则继续上述过程,直到结果符合标准或已使用最终的Vision Transformer。

  • Training

  训练时,需保证DVT中所有级联Vision Transformer输出正确的预测结果,其优化目标为:

  其中, ( x , y ) (x, y) (x,y)为训练集 D t r a i n D_{train} Dtrain中的一个样本及其对应的标签,采用标准的交叉熵损失函数 L C E ( ⋅ ) L_{CE}(·) LCE(),而 p i p_i pi表示第 i i i个模型输出的softmax预测概率。

  • Transformer backbone

  DVT是一个通用且灵活的框架,可以嵌入到大多数现有的Vision Transformer模型(如ViT、DeiT和T2T-ViT)之中,提高其性能。

Feature and Relationship Reuse

  DVT的一个重要挑战是如何进行计算的复用。在使用的具有更多token的下游Vision Transformer时,直接忽略之前模型中的计算结果显然是低效的。虽然上游模型的token数量较少,但也提取了对预测有价值的信息。因此,论文提出了两种机制来复用学习到的深度特征和自注意力关系,仅增加少量的额外计算成本就能显着提高准确率。

  • Background

  介绍前,先重温一下Vision Transformer的基本公式。Transformer encoder由交替堆叠的多头自注意力(MSA)和多层感知器 (MLP)块组成,每个块的之前和之后分别添加了层归一化(LN)和残差连接。定义 z l ∈ R N × D z_l\in R^{N\times D} zlRN×D表示第 l l l层的输出,其中 N N N是样本的token数, D D D是token的维度。需要注意的是, N = H W + 1 N=HW+1 N=HW+1,对应 H × W H\times W H×W图像块和可学习的分类token。假设Transformer共 L L L层,则整个模型的计算可表示为:

  得到最终的结果 z L z_L zL后,取其中的分类token通过LN层+全连接层进行最终预测。这里省略了position embedding的细节,论文没有对其进行修改。

  • Feature reuse

  DVT中的所有Transformer都具有相同的目标,即提取关键特征进行准确识别。 因此,下游模型应该在上游模型计算的深度特征的基础上学习才是最高效的,而不是从头开始提取特征。为此,论文提出了图3的特征复用机制,利用上游Transformer最后输出的结果 z L u p z^{up}_L zLup来生成下游模型每层的辅助embedding输入 E l E_l El

f l : R N × D → R N × D ′ f_l:\mathbb{R}^{N\times D}\to \mathbb{R}^{N\times D^{'}} fl:RN×DRN×D 由LN+MLP( R D → R D ′ \mathbb{R}^{D}\to \mathbb{R}^{D^{'}} RDRD)开头,对上游模型输出进行非线性转换。转换后将结果reshape到原始图像中的相应位置,然后上采样并展平来匹配下游模型的token数量。一般情况下,使用较小的 D ′ D^{'} D以便快速生成 f l f_l fl

  之后将 E l E_l El拼接到下游模型对应层的中间特征作为预测的先验知识,也就是将公式3替换为:

E l E_l El与中间特征 z l ′ z^{'}_l zl拼接,LN 的维度和MLP的第一层从 D D D增加到 D + D ′ D+D^{'} D+D。 由于 E l E_l El是基于上游输出 z L u p z^{up}_L zLup生成的,token数少于 z l ′ z^{'}_l zl,它实际上为 z l ′ z^{'}_l zl中的每个token总结了输入图像的上下文信息。 因此,将 E l E_l El命名为上下文embedding。此外,论文发现不复用分类token对性能有提升,因此在公式5中将其填充零。

  公式4和5允许下游模型在每层灵活地利用 z L u p z^{up}_L zLup内的信息,从而最小化最终识别损失,这种特征重用方式也可以认为隐式地扩大了模型深度。

  • Relationship reuse

  Vision Transformer的关键在于自注意力模块能够整合整个图像的信息,从而有效地模拟图像中的长距离关系。通常情况下,模型需要在每一层学习一组注意力图来描述token之间的关系。除了上面提到的特征复用,论文认为下游模型还可以复用上游模型产生的自注意力图来进行优化。

  定义输入特征 z l z_l zl,自注意力模块先通过线性变换得到query矩阵 Q l Q_l Ql、key矩阵 K l K_l Kl和value矩阵 V l V_l Vl

  其中, W l Q W^Q_l WlQ W l K W^K_l WlK W l V W^V_l WlV为权重矩阵。然后通过一个带有softmax的缩放点乘矩阵运算得到注意力图,最后根据注意力图来计算所有token的值:

  其中, d d d Q Q Q K K K的点积结果维度, A l ∈ R N × N A_l\in \mathbb{R}^{N\times N} AlRN×N为注意力图。为了清楚起见,这省略了多头注意力机制的细节,多头情况下 A l A_l Al包含多个注意力图。

  对于关系复用,先将上游模型所有层产生的注意力图(即 A l u p , l ∈ { 1 , ⋯ , L } A^{up}_l, l\in \{1,\cdots , L\} Alup,l{1,,L})拼接起来:

  其中, N u p N^{up} Nup N u p A t t N^{Att}_{up} NupAtt 分别为上游模型中的toekn数和注意力图数,通常 N u p A t t = N H L N^{Att}_{up} = N^H L NupAtt=NHL N H N^H NH是多头注意力的head数, L L L是层数。

  下游的模型同时利用自己的token和 A u p A^{up} Aup来构成注意力图,也就是将公式7替换为:

  其中 r l ( ⋅ ) r_l(\cdot) rl()是一个转换网络,整合 A u p A^{up} Aup提供的信息来细化下游注意力图 A l A_l Al r l ( ⋅ ) r_l(\cdot) rl()的架构如图5所示,先进行非线性MLP转换,然后上采样匹配下游模型的注意力图大小。

  公式9虽然很简单,但很灵活。有两个可以魔改的地方:

  • 由于下游模型中的每个自注意力模块可以访问上游模型的所有浅层和深层的注意力头,可以尝试通过可学习的方式来对多层的注意力信息进行加权整合。
  • 新生成的注意力图和复用注意力图直接相加,可以尝试通过可学习的方式来对两者加权。

  还需要注意的是, r l ( ⋅ ) r_l(\cdot) rl()不能直接使用常规上采样操作。如图5所示,假设需要将 H W × H W HW\times HW HW×HW( H = W = 2 H =W = 2 H=W=2)的注意力图映射上采样到 H ′ W ′ × H ′ W ′ H^{'}W^{'}\times H^{'}W^{'} HW×HW( H ′ = W ′ = 3 H^{'} =W^{'} = 3 H=W=3)的大小。由于每一行对应单个token与其他 H × W H\times W H×W个token的关系,直接对注意力图上采样会引入混乱的数据。因此,需要先将行reshape为 H × W H\times W H×W,然后再缩放到 H ′ W ′ × H ′ W ′ H^{'}W^{'}\times H^{'}W^{'} HW×HW,最后再展平为 H ′ W ′ H^{'}W^{'} HW向量。

  • Adaptive Infernece

  如前面所述,DVT框架逐渐增加测试样本的token数量并执行提前终止,“简单”和“困难”图像可以使用不同的token数来处理,从而提高了整体效率。对于第 i i i个模型产生的softmax预测 p i p_i pi,将 p i p_i pi的最大项 m a x j p i j max_j p_{ij} maxjpij与阈值 μ i {\mu}_{i} μi进行比较。如果 m a x j p i j ≥ μ i max_j p_{ij}\ge {\mu}_{i} maxjpijμi,则停止并采用 p i p_i pi作为输出。否则,将使用更多token数更多的下游模型继续预测直到最后一个模型。

  阈值 { μ 1 , μ 2 , ⋯ } \{\mu_1, \mu_2, \cdots\} {μ1,μ2,}需要在验证集上求解。假设一个计算资源有限的批量数据分类场景,DVT需要在给定的计算预算 B > 0 B > 0 B>0内识别一组样本 D v a l D_{val} Dval。定义 A c c ( D v a l , { μ 1 , μ 2 , ⋯ } ) Acc(D_{val}, \{\mu_1, \mu_2, \cdots\}) Acc(Dval,{μ1,μ2,}) F L O P s ( D v a l , { μ 1 , μ 2 , ⋯ } ) FLOPs(D_{val}, \{\mu_1, \mu_2, \cdots\}) FLOPs(Dval,{μ1,μ2,})为数据集 D v a l D_{val} Dval上使用阈值 { μ 1 , μ 2 , ⋯ } \{\mu_1, \mu_2, \cdots\} {μ1,μ2,}时的准确度和计算成本,最优阈值可以通过求解以下优化问题得到:

  由于公式10是不可微的,论文使用遗传算法解决了这个问题。

Experiment


  ImageNet上的性能对比。

  推理性能对比。

  CIFAR上对比DVT在不同模型规模的性能。

  在ImageNet上与SOTA vision transformer提升方法的性能对比。

  基于DeiT的DVT性能对比。

  复用机制的对比实验。

  与类似的提前退出方法的性能对比。

  复用机制提升的性能与计算量。

  复用机制实现细节的对比实验。

  难易样本的例子以及数量分布。

  不同终止标准的性能对比。

  与自适应深度方法进行性能对比,自适应方法是在模型的不同位置插入分类器。

Conclusion


  论文主要处理Vision Transformer中的性能问题,采用推理速度不同的级联模型进行速度优化,搭配层级间的特征复用和自注意力关系复用来提升准确率。从实验结果来看,性能提升不错。



如果本文对你有帮助,麻烦点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】

work-life balance.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/39868.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

智慧应急管理平台:数字孪生,让防汛救灾更科学高效

近期全国各地暴雨频发,城市排水系统面临着前所未有的挑战,应急防涝已成为城市管理中不可或缺的一环。在这个信息化、智能化的时代,数字孪生技术以其独特的优势,为应急领域带来了革命性的变革。数字孪生,作为现实世界在…

揭秘:学校教室采用数码管同步时钟的原因-讯鹏电子钟

在学校的教室里,我们常常会看到数码管同步时钟的身影。究竟是什么原因让它成为学校教室的宠儿呢?让我们一同来探究其中的奥秘。 数码管同步时钟具有极高的准确性。对于学校这样一个对时间管理要求严格的场所,准确的时间是保障教学秩序的基石。…

SwinIR: Image Restoration Using Swin Transformer(ICCV 2021)含代码复现

目录 一、Introduction 1 Motivation 2 Contribution 二、原理分析 1 Network Architecture 1)Shallow feature extraction 2) deep feature extraction 3) image reconsruction modules 4) loss function 2 Residual Swin Transformer Block 三、实验结果…

没有调用memcpy却报了undefined reference to memcpy错误

现象 在第5行出现了,undefined reference to memcpy’ 1 static void printf_x(unsigned int val) 2{ 3 char buffer[32]; 4 const char lut[]{0,1,2,3,4,5,6,7,8,9,A,B,C,D,E,F}; 5 char *p buffer; 6 while (val || p buffer) { 7 *(p) …

基于循环神经网络的一维信号降噪方法(简单版本,Python)

代码非常简单。 import torch import torch.nn as nn from torch.autograd import Variable from scipy.io.wavfile import write #need install pydub module #pip install pydub import numpy as np import pydub from scipy import signal import IPython import matplot…

C语言学习记录(十二)——指针与数组及字符串

文章目录 前言一、指针和数组二、指针和二维数组**行指针(数组指针)** 三、 字符指针和字符串四、指针数组 前言 一个学习嵌入式的小白~ 有问题评论区或私信指出~ 提示:以下是本篇文章正文内容,下面案例可供参考 一、指针和数组 在C语言中 &#xff0…

AI降重,不再难:降AI率的实用技巧大揭秘

如何有效降低AIGC论文的重复率,也就是我们说的aigc如何降重?AIGC疑似度过高确实是个比较愁人的问题。如果你用AI帮忙写了论文,就一定要在交稿之前做一下AIGC降重的检查。一般来说,如果论文的AIGC超过30%,很可能会被判定…

CAS操作

CAS 全称:Compare and swap,能够比较和交换某个寄存器中的值和内存中的值,看是否相等,如果相等,则把另外一个寄存器中的值和内存进行交换. (这是一个伪代码,所以这里的&address实际上是想要表示取出address中的值) 那么我们可以看到,CAS就是这样一个简单的交换操作,那么…

基于SpringBoot房屋租赁管理系统设计和实现(源码+LW+调试文档+讲解等)

💗博主介绍:✌全网粉丝10W,CSDN作者、博客专家、全栈领域优质创作者,博客之星、平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌💗 Java精品实战案例《1000套》 2025-2026年最值得选择的Java毕业设计选题大全&#xff…

新火种AI|国产大模型展开决战,是资本游戏还是技术革命?

作者:一号 编辑:美美 资本角逐与技术革新,国产大模型的双线战场已然开启。 随着人工智能技术的不断进步,国产大模型正迅速成为行业关注的焦点。在这个由数据驱动的时代,资本的注入和技术创新的加速,让国…

Python28-6 随机森林

随机森林算法详细介绍 1. 理论背景 随机森林(Random Forest)是一种由Leo Breiman和Adele Cutler在2001年提出的集成学习方法。它结合了多个决策树的预测结果,以提高模型的准确性和鲁棒性。 2. 算法细节 随机森林的构建过程可以分为以下几…

Qt——升级系列(Level Eight):界面优化

目录 QSS 背景介绍 基本语法 QSS设置方式 指定控件样式设置 全局样式设置 从文件加载样式表 使用Qt Designer 编辑样式 选择器 选择器概况 子控件选择器 伪类选择器 样式属性 盒模型 控件样式示例 按钮 复选框、单选框 输入框 列表 菜单栏 登录界面 绘图 基本概念 绘制各种形…

[Go 微服务] Kratos 使用的简单总结

文章目录 1.Kratos 简介2.传输协议3.日志4.错误处理5.配置管理6.wire 1.Kratos 简介 Kratos并不绑定于特定的基础设施,不限定于某种注册中心,或数据库ORM等,所以您可以十分轻松地将任意库集成进项目里,与Kratos共同运作。 API -&…

Linux内网端口转公网端口映射

由于服务商做安全演练,把原先服务器内网的端口映射到外网端口全都关闭了,每次维护服务器特别麻烦,像数据库查询如果用原生的mysql 去连接,查询返回的结果乱了,非常不方便。 查了服务还是可以正常访问部分外网的&#x…

抖音外卖服务商入驻流程及费用分别是什么?入驻官方平台的难度大吗?

随着抖音关于新增《【到家外卖】内容服务商开放准入公告》的意见征集通知(以下简称“通知”)的发布,抖音外卖服务商入驻流程及费用逐渐成为众多创业者所关注和热议的话题。不过,就当前的讨论情况来看,这个话题似乎没有…

软件测试中安全测试包含内容及安全测试怎么测

一、软件测试安全测试包含哪些 1. 漏洞扫描 漏洞扫描是软件测试安全测试的基础,它用于检测应用程序和系统中存在的已知漏洞。安全测试工具如AppScan、OWASP ZAP和Nessus等可以对应用程序进行自动化扫描,发现可能存在的漏洞,如跨站点脚本&am…

7.2、指针变量的定义和使用

代码 #include <iostream> using namespace std; #include <string>int main() {//定义指针int a 10;//指针定义语法&#xff1a;数据类型 * 指针变量名int * p;//让指针记录变量a的地址p &a;cout << "a的地址为&#xff1a;" << &am…

MySQL之应用层优化(二)

应用层优化 Web服务器问题 寻找最优并发度 每个Web服务器都有一个最佳并发度——就是说&#xff0c;让进程处理请求尽可能快&#xff0c;并且不超过系统负载的最优的并发连接数。这就是前面说的最大系统容量。进行一个简单的测量和建模&#xff0c;或者只是反复试验&#xf…

nginx SSI(Server Side Include)服务端包含 合并拼装静态内容

一、什么是SSI 在被传送给浏览器之前&#xff0c;服务器会对 HTML 文档进行完全地读取、分析以及修改&#xff0c;使用SSI指令将文本、图片或代码信息包含到网页中。对于整个页面可以拆分成多个模块&#xff0c;通过SSI指令将几个模块拼接成一个完整的页面&#xff0c;当有内容…

【数据库原理】课程笔记

数据库原理 一、数据库系统基础 数据模型的类型 概念数据模型&#xff1a; 概念数据模型也称概念模型或信息模型,是对现实世界中问题域内事务(特性)的描述,是以用户观点实现世界的模型(图形表示)。主要用于描述事物的概念化结构,使数据库的设计人员在设计初期,避开计算机系统及…