Pyramid Vision Transformer, PVT(ICCV 2021)原理与代码解读

paper:Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions

official implementation:GitHub - whai362/PVT: Official implementation of PVT series

存在的问题

现有的 Vision Transformer (ViT) 主要设计用于图像分类任务,难以直接用于像素级密集预测任务,如目标检测和分割。这是因为存在以下问题

  1. 低分辨率输出:传统的Vision Transformer(ViT)在处理密集预测任务(如目标检测和语义分割)时,输出分辨率较低,难以获得高质量的像素级别预测。
  2. 高计算和内存开销:ViT在处理大尺寸输入图像时,计算和内存开销较高,限制了其在实际应用中的效率。

本文的创新点

为了解决上述问题,作者提出了 Pyramid Vision Transformer (PVT), PVT结合了卷积神经网络的金字塔结构和Transformer的全局感受野,旨在克服传统Transformer在处理密集预测任务时遇到的分辨率低、计算和内存开销大的问题。它可以作为 CNN 骨干网络的替代品,用于多种下游任务,包括图像级预测和像素级密集预测。具体包括:

  1. 金字塔结构:PVT引入了金字塔结构,可以生成多尺度的特征图,这对于密集预测任务是有益的。
  2. 空间缩减注意力层(SRA):为了处理高分辨率特征图并减少计算/内存成本,作者设计了 SRA 层来替代传统的多头注意力 (MHA) 层。
  3. 纯Transformer骨干:PVT 是一个没有卷积的纯 Transformer 骨干网络,可以用于各种像素级密集预测任务,并与 DETR 结合构建了一个完全无需卷积的目标检测系统。

实际效果

  • PVT 在多个下游任务上进行了广泛的实验验证,包括图像分类、目标检测、实例和语义分割等,并与流行的 ResNets 和 ResNeXts 进行了比较。
  • 实验结果表明,在参数数量相当的情况下,PVT 在 COCO 数据集上使用 RetinaNet 作为检测器时,PVT-Small 模型达到了 40.4 的 AP(平均精度),超过了 ResNet50+RetinaNet(36.3 AP)4.1 个百分点。
  • PVT-Large 模型达到了 42.6 的 AP,比 ResNeXt101-64x4d 高出 1.6 个百分点,同时参数数量减少了 30%。
  • 这些结果表明 PVT 可以作为 CNN 骨干网络的一个有效的替代,用于像素级预测,并推动未来的研究。

方法介绍

Overall Architecture

PVT的整体结构如图3所示

和CNN backbone类似,PVT也有四个stage来生成不同尺度的特征图。所有stage都有一个相似的架构,包括一个patch embedding层和 \(L_i\) 个Transformer encoder层。

在第一个stage,给定大小为 \(H\times W\times 3\) 的输入图片,我们首先将其划分为 \(\frac{HW}{4^2}\) 个patch,每个大小为4x4x3。然后将展平的patch送入一个线性映射层得到大小为 \(\frac{HW}{4^2}\times C_1\) 的输出。然后将输出和位置编码一起进入有 \(L_1\) 层的Transformer encoder,得到的输出reshape成大小为 \(\frac{H}{4}\times \frac{W}{4}\times C_1\) 的特征图 \(F_1\)。同样的方式,以前一个stage的输出特征图作为输入,我们得到特征图 \(F_2,F_3,F_4\),相对于原始输入图片的步长分别为8,16,32。用了特征图金字塔 \(\{F_1,F_2,F_3,F_4\}\),我们的方法可以很容易地应用于大多数下游任务,包括图像分类、目标检测和语义分割。

Feature Pyramid for Transformer

和CNN backbone用不同stride的卷积来得到不同尺度特征图不同,PVT使用一个渐进式shrinking策略,通过patch embedding层来控制特征图的尺度。 

我们用 \(P_i\) 来表示第 \(i\) 个stage的patch size,在stage \(i\) 的开始,我们首先将输入特征图 \(F_{i-1}\in \mathbb{R}^{H_{i-1}\times W_{i-1}\times C_{i-1}}\) 均匀地划分成 \(\frac{H_{i-1}W_{i-1}}{P_i^2}\) 个patch,然后将每个patch展平并映射得到一个 \(C_i\) 维的embedding。在线性映射后,embedded patch的大小为 \(\frac{H_{i-1}}{P_i}\times \frac{W_{i-1}}{P_i}\times C_i\),其中宽高比输入小了 \(P_i\) 倍。

这样,我们就可以在每个stage灵活地调整特征图的尺度,从而将Transformer构建成金字塔结构。

Transforme Encoder

由于PVT需要处理高分辨率(stride-4)的特征图,我们提出了一种spatial-reduction attention(SRA)来替换encoder中传统的multi-head attention(MHA)。

和MHA类似,SRA的输入包括一个query \(Q\),一个key \(K\),一个value \(V\)。不同的是SRA在attention operation之前减小了 \(K\) 和 \(V\) 的大小,如图4所示,这大大减少了计算和内存的开销。

stage \(i\) 的SRA如下

其中 \(Concat(\cdot)\) 是拼接操作。\(W^{Q}_j\in \mathbb{R}^{C_i\times d_{head}},W^{K}_j\in \mathbb{R}^{C_i\times d_{head}},W^{V}_j\in \mathbb{R}^{C_i\times d_{head}},W^O\in \mathbb{R}^{C_i\times C_i}\) 是线性映射参数。\(N_i\) 是stage \(i\) 中attention层的head数量,所以每个head的维度(即\(d_{head}\))等于 \(\frac{C_i}{N_i}\)。\(SR(\cdot)\) 是降低输入序列(即 \(K\) 或 \(V\))空间维度的操作,如下:

其中 \(\mathbf{x}\in\mathbb{R}^{(H_iW_i)\times C_i}\) 表示一个输入序列,\(R_i\) 表示stage \(i\) 中attention层的reduction ratio。\(Reshape(\mathbf{x},R_i)\) 是将输入序列 \(\mathbf{x}\) reshape成大小为 \(\frac{H_iW_i}{R^2_i}\times (R^2_iC_i)\) 的序列的操作。\(W_S\in \mathbb{R}^{(R^2_iC_i)\times C_i}\) 是一个linear projection,它将输入序列的维度降低到 \(C_i\)。\(Norm(\cdot)\) 是layer normalization。和原始的Transformer一样,attention operation按下式计算

通过上述公式我们可以发现,MSA的计算/内存开销是MHA的 \(\frac{1}{R^2}\),因此MSA可以在有限的资源下处理更大的输入特征图或序列。

代码解析

见PVT v2的代码解析 PVT v2 原理与代码解析-CSDN博客

实验结果 

模型涉及到的一些超参总结如下:

  • \(P_i\):stage \(i\) 的patch size
  • \(C_i\):stage \(i\) 的输出通道数
  • \(L_i\):stage \(i\) 中的encoder层数
  • \(R_i\):stage \(i\) 中SRA的reduction ratio
  • \(N_i\):stage \(i\) 中SRA的head数量
  • \(E_i\):stage \(i\) 中FFN层的expansion ratio

作者设计了一系列的PVT模型,具体配置如表1

和其它SOTA模型在ImageNet的结果对比如表2所示

用RetinaNet上和其它backbone的结果对比如表3所示,可以看到PVT不同大小的模型与ResNet系列相比,参数更少精度更高。

在语义分割模型Semantic FPN上PVT也超越了对应的ResNet

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/24027.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++结合ffmpeg获取声音的分贝值

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、分贝是什么?1.功率量2.场量 二、实际操作1.分析wav文件2.读取麦克风 总结 前言 最近面对一个需求,就是需要传递声音文件到模型里推…

链表的回文结构OJ

链表的回文结构_牛客题霸_牛客网对于一个链表,请设计一个时间复杂度为O(n),额外空间复杂度为O(1)的算法,判断其是否为。题目来自【牛客题霸】https://www.nowcoder.com/practice/d281619e4b3e4a60a2cc66ea32855bfa?tpId49&&tqId29370&rp1&a…

CodeMeter助力Hilscher,推动实现全球智能制造连接解决方案

Hilscher的旗舰店为开放工业4.0联盟(OI4)社区提供了应用商店的便捷和开放性,将这一概念引入工业领域。该商店依托CodeMeter的许可证管理和加密保护,为工业用户提供了丰富的应用和解决方案库,满足他们在车间自动化和连接…

WPF中读取Excel文件的内容

演示效果 实现方案 1.首先导入需要的Dll(这部分可能需要你自己搜一下) Epplus.dll Excel.dll ICSharpCode.SharpZipLib.dll 2.在你的解决方案的的依赖项->添加引用->浏览->选择1中的这几个Dll点击确定。(添加依赖) 3.然后看代码内容 附上源码 using Excel; usi…

计网复习资料

一、选择题(每题2分,共40分) 1. Internet 网络本质上属于( )网络。 A.电路交换 B.报文交换 C.分组交换 D.虚电路 2.在 OSI 参考模型中,自下而上第一个提供端到端服务的是( )。 A.数据链路层 B.传输…

Thinkphp使用Elasticsearch查询

在Thinkphp中调用ES,如果自己手写json格式的query肯定是很麻烦的。我这里使用的是ONGR ElasticsearchDSL 构建 ES 查询。ongr ElasticsearchDSL 的开源项目地址:GitHub - ongr-io/ElasticsearchDSL: Query DSL library for Elasticsearch。ONGR Elastics…

100V 15A TO-252 N沟道MOS管 HC070N10L 惠海

MOS管的工作原理是基于在P型半导体与N型半导体之间形成的PN结,通过改变栅极电压来调整沟道内载流子的数量,从而改变沟道电阻和源极与漏极之间的电流大小。由于MOS管具有输入电阻高、噪声小、功耗低等优点,它们在大规模和超大规模集成电路中得…

package.json中resolutions的使用场景

文章目录 用途配置示例使用方法注意事项和peerDependencies有什么不同peerDependenciesresolutions 总结 ✍创作者:全栈弄潮儿 🏡 个人主页: 全栈弄潮儿的个人主页 🏙️ 个人社区,欢迎你的加入:全栈弄潮儿的…

git【工具软件】分布式版本控制工具软件

一、Git 的介绍 git软件的作用:管理软件开发项目中的源代码文件。 常用功能: 仓库管理、文件管理、分支管理、标签管理、远程操作 功能指令: add,commit,log,branch,tag,remote…

华为端云一体化开发 (起步1.0)(HarmonyOS学习第七课)

官方文献: 为丰富HarmonyOS对云端开发的支持、实现端云联动,DevEco Studio推出了云开发功能,开发者在创建工程时选择云开发模板,即可在DevEco Studio内同时完成HarmonyOS应用/元服务的端侧与云侧开发,体验端云一体化协…

论文代码解读STPGNN

1.前言 本次代码文章来自于《2024-AAAI-Spatio-Temporal Pivotal Graph Neural Networks for Traffic Flow Forecasting》,基本模型结构如下图所示: 文章讲解视频链接 代码开源链接 接下来就开始代码解读了。 2.代码解读 class nconv(nn.Module):de…

NDIS Filter开发-网络数据的传输

和NIC小端口驱动不同的是,无需考虑网络数据具体是如何传输的,只需要针对NBL进行处理即可。Filter驱动程序可以启动发送请求和接收指示,或“过滤”其他驱动程序的请求和指示。Filter模块堆叠在微型端口适配器上。 驱动程序堆栈中的Filter模块…

谷粒商城实战(033 业务-秒杀功能4-高并发问题解决方案sentinel 1)

Java项目《谷粒商城》架构师级Java项目实战,对标阿里P6-P7,全网最强 总时长 104:45:00 共408P 此文章包含第326p-第p331的内容 关注的问题 sentinel(哨兵) sentinel来实现熔断、降级、限流等操作 腾讯开源的tendis&#xff0c…

ctfshow web

【nl】难了 <?php show_source(__FILE__); error_reporting(0); if(strlen($_GET[1])<4){echo shell_exec($_GET[1]); } else{echo "hack!!!"; } ?> //by Firebasky //by Firebasky ?1>nl //先写个文件 ?1*>b //这样子会把所有文件名写在b里…

JSON 无法序列化

JSON 无法序列化通常出现在尝试将某些类型的数据转换为 JSON 字符串时&#xff0c;这些数据类型可能包含不可序列化的内容。 JSON 序列化器通常无法处理特定类型的数据&#xff0c;例如日期时间对象、自定义类实例等。在将数据转换为 JSON 字符串之前&#xff0c;确保所有数据都…

「动态规划」如何求地下城游戏中,最低初始健康点数是多少?

174. 地下城游戏https://leetcode.cn/problems/dungeon-game/description/ 恶魔们抓住了公主并将她关在了地下城dungeon的右下角。地下城是由m x n个房间组成的二维网格。我们英勇的骑士最初被安置在左上角的房间里&#xff0c;他必须穿过地下城并通过对抗恶魔来拯救公主。骑士…

【Text2SQL 论文】C3:使用 ChatGPT 实现 zero-shot Text2SQL

论文&#xff1a;C3: Zero-shot Text-to-SQL with ChatGPT ⭐⭐⭐⭐ arXiv:2307.07306&#xff0c;浙大 Code&#xff1a;C3SQL | GitHub 一、论文速读 使用 ChatGPT 来解决 Text2SQL 任务时&#xff0c;few-shots ICL 的 setting 需要输入大量的 tokens&#xff0c;这有点昂贵…

MacOS M系列芯片一键配置多个不同版本的JDK

第一步&#xff1a;下载JDK。 官网下载地址&#xff1a;Java Archive | Oracle 选择自己想要下载的版本&#xff0c;一般来说下载一个jdk8和一个jdk11就够用了。 M系列芯片选择这两个&#xff0c;第一个是压缩包&#xff0c;第二个是dmg可以安装的。 第二步&#xff1a;编辑…

eclipse插件开发(二)RCP第三方库的引入方式

RCP第三方库的引入 最近在RCP开发过程中遇到JSON串与对象互转的问题&#xff0c;如何像spring开发模式一样引入第三方库呢&#xff1f;eclipse插件开发中用到p2库&#xff0c;但也支持maven库的引入。关键在于.target这个关键文件。 .target 文件用于定义一个目标平台&#x…

民主测评要做些什么?

民主测评&#xff0c;作为一种重要的民主管理工具&#xff0c;旨在通过广泛征求群众意见&#xff0c;对特定对象或事项进行客观、公正的评价。它不仅是推动民主参与、民主监督的重要手段&#xff0c;也是提升治理效能、促进社会和谐的有效途径。以下将详细介绍民主测评的主要过…