Positional Encodings in ViTs 近期各视觉Transformer中的位置编码方法总结及代码解析 1

Positional Encodings in ViTs 近期各视觉Transformer中的位置编码方法总结及代码解析

最近CV领域的Vision Transformer将在NLP领域的Transormer结果借鉴过来,屠杀了各大CV榜单。对其做各种改进的顶会论文也是层出不穷,本文将聚焦于各种最新的视觉transformer的位置编码PE(positional encoding)部分的设计思想及代码实现做一些总结。

ViT

[2021-ICLR] AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE

论文:https://arxiv.org/abs/2010.11929

代码:https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch

对于原始的ViT,笔者曾做过一份较为全面的代码解析及图解:Vision Transformer(ViT)PyTorch代码全解析(附图解),有兴趣的读者可以参考。

论文中的位置编码方法

PE的设计

在这里插入图片描述

在ViT中,并没有对位置编码做过多的设计,只是使用一组可学习的参数来学习位置编码,注意这样的位置编码如果在面对测试时的高分辨率图像时是无法处理的。

ViT原文是这么说的:

When feeding images of higher resolution, we keep the patch size the same, which results in a larger effective sequence length. The Vision Transformer can handle arbitrary sequence lengths (up to memory constraints), however, the pre-trained position embeddings may no longer be meaningful. We therefore perform 2D interpolation of the pre-trained position embeddings, according to their location in the original image. Note that this resolution adjustment and patch extraction are the only points at which an inductive bias about the 2D structure of the images is manually injected into the Vision Transformer.

大概意思就是:当输入高分图像时,会导致序列的长度变长,ViT是可以处理任意长度的,但此时训练得到的位置编码就不再有意义了,并且只能通过2D插值实现。

z=[xclass;xp1E,xp2E,…;xpNE]+Epos,E∈R(P2⋅C)×D,Epos∈R(N+1)×D(1)\mathbf{z}=[\mathbf{x}_{class};\mathbf{x}^1_p\mathbf{E},\mathbf{x}^2_p\mathbf{E},\dots;\mathbf{x}^N_p\mathbf{E}]+\mathbf{E}_{pos},\ \ \ \mathbf{E}\in\mathbb{R}^{(P^2\cdot C)\times D},\mathbf{E}_{pos}\in \mathbb{R}^{(N+1)\times D} \ \ \ \ \ \ \ \ \ \ \ \ \ (1) z=[xclass;xp1E,xp2E,;xpNE]+Epos,   ER(P2C)×D,EposR(N+1)×D             (1)
根据原文公式(即上式),ViT中位置编码的维度应该为 (N+1)×D(N+1)\times D(N+1)×D ,这里 NNN 是图块的个数,+1是加上class token, DDD 是映射后的每个token的维度,因为要直接相加,所以要保持一致。下面会用代码来验证查看。

关于PE的消融实验

原文附录中的实验也显示肯定是有位置编码比没有效果要好,但是看起来比较有设计的二维位置编码和相对位置编码相较于简单的一维位置编码性能反而更差。

在这里插入图片描述

第一行是完全没有位置编码,即没有提供位置信息,相当于将一堆patch直接输入进去;第二行是一维位置编码,即将输入patch看作是序列;第三行是二维位置编码,将输入看作是二维的patch网格;第四行是相对位置编码,考虑到patch之间的相对距离,将空间信息编码为而不是其绝对位置。

注意:如果要使用相对位置编码,一定要考虑好自己的任务需不需要绝对位置信息,如目标检测,由于要输出预测的边界框的坐标,因此绝对位置信息是必须的,这时使用相对位置编码就不合适了。

关于PE的可视化实验

ViT原文对位置编码做的可视化实验如下图所示,热力图的含义是某个位置的图块的位置编码与全图其他位置图块的位置编码的余弦相似度。我们可以看到,当然与自己相似度最高,然后就是同行同列也比较高,其他的位置就低一些,这也基本符合我们对位置编码的基本期望,因为所谓的位置编码要的就是图像块在原图中的位置信息,更通俗点说就是行列信息,即某个图像块是在原图中的哪行哪列。
在这里插入图片描述

代码分析

ViT代码中的位置编码:

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))	
# ...
x += self.pos_embedding[:, :(n+1)] 		

直接用可学习的参数torch.Parameter()作为位置编码直接加到token序列中,跟随整个训练过程一起学习。(关于torch.Parameter()的介绍可见博客:PyTorch中的torch.nn.Parameter() 详解)

另外,我们再用代码来检查一下ViT中的位置编码的维度形状,这里我们直接借用timm库中的实现:

import timm
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)
pos_embed = model.state_dict()['pos_embed']
print(pos_embed.shape)         

输出:

torch.Size([1, 197, 768])

我们是将224x224的图像分为14x14个图块,共196块,再加上class token 为197,而768则是我们指定的维度,符合我们的预期。

CPVT

Conditional Positional Encodings for Vision Transformers

论文:https://arxiv.org/abs/2102.10882

代码:https://github.com/Meituan-AutoML/Twins (原文中给的链接中没有实做代码,实做代码发布在这个仓库了)

论文中的位置编码方法

CPVT与ViT的位置编码的区别在下图中体现的很明显,ViT的位置编码PE没有过多的设计,直接加到patch token和cls token得到的embedding上,然后就送到后面的多个transformer block(图中encoder)中,注意ViT中的PE必须显示地指定好token序列的长度。而CPVT则是先不加PE,在第一个transformer block之后,仅过PEG(Postional Encoding Generator)来生成位置编码,在加到第一层的输出上,在进行后面的计算,这样长度就不需要显式指定,可以随输入变化而变化,因此被称为隐式的条件位置编码。

在这里插入图片描述

其中的PEG模块是用来产生条件位置编码的模块,其框架如下图所示:

在这里插入图片描述

在 PEG 中,将上一层 Encoder 的 1D 输出变形成 2D,再使用 F 学习其位置信息,最后重新变形到 1D 空间,与之前的 1D 输出相加之后作为下一个 Encoder 的输入。

具体来说,在上图中,为了根据局部领域,我们首先将DeiT flatten过的输入序列 X∈RB×N×CX\in \mathbb{R}^{B\times N\times C}XRB×N×C​ reshape回二维图像空间 X′∈RB×H×W×CX'\in\mathbb{R}^{B\times H\times W\times C}XRB×H×W×C​ 。然后某个函数 F\mathcal{F}F​ 会反复作用于 X′X'X​ 中的局部图块来生成条件位置编码 EB×H×W×CE^{B\times H\times W\times C}EB×H×W×C​ ,PEG可以由二维卷积高效地实现,其卷积核 k>=3k>=3k>=3​,并且有零填充 k−12\frac{k-1}{2}2k1​ 。注意这里的零填充是很重要的,它可以使模型感知到绝对位置, F\mathcal{F}F​ 可以是多种形式,比如可分离卷积。

代码分析

在CPVT的代码实现中,我们主要来看PEG部分:

class PosCNN(nn.Module):def __init__(self, in_chans, embed_dim=768, s=1):super(PosCNN, self).__init__()self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, s, 1, bias=True, groups=embed_dim), )self.s = sdef forward(self, x, H, W):B, N, C = x.shapefeat_token = xcnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)if self.s == 1:x = self.proj(cnn_feat) + cnn_featelse:x = self.proj(cnn_feat)x = x.flatten(2).transpose(1, 2)return xdef no_weight_decay(self):return ['proj.%d.weight' % i for i in range(4)]

可以看到,与原文中对PEG的介绍一致:将第一层Encoder 的1D 输出变形成 2D,再使用F学习其位置信息,最后重新变形到 1D 空间,与之前的 1D 输出相加之后作为下一个 Encoder 的输入。

这里的self.proj就是文中的转换函数 F​。

我们再来看PEG模块在整个CPVT中的使用:

class CPVTV2(PyramidVisionTransformer):def __init__(self, ...)# ...self.pos_block = nn.ModuleList(			# 实例化一个PEG模块[PosCNN(embed_dim, embed_dim) for embed_dim in embed_dims])# ...def forward_features(self, x):B = x.shape[0]for i in range(len(self.depths)):x, (H, W) = self.patch_embeds[i](x)x = self.pos_drops[i](x)for j, blk in enumerate(self.blocks[i]):x = blk(x, H, W)if j == 0:x = self.pos_block[i](x, H, W)  # PEG模块 在这里使用if i < len(self.depths) - 1:x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()x = self.norm(x)return x.mean(dim=1) 

可以看到,只有在第一个encoder之后(for循环中j=0时),使用PEG模块计算位置编码,后面正常进行其他的其他Encoder的计算,与论文原文一致。

本文将保持持续更新,读者如果遇到有趣的Vision Transformer的改进方法,也欢迎分享讨论。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532841.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mysql 分析查询语句,MySQL教程之SQL语句分析查询优化

怎么获取有功能问题的SQL1、经过用户反应获取存在功能问题的SQL2、经过慢查询日志获取功能问题的SQL3、实时获取存在功能问题的SQL运用慢查询日志获取有功能问题的SQL首要介绍下慢查询相关的参数1、slow_query_log 发动定制记载慢查询日志设置的办法&#xff0c;能够经过MySQL指…

树莓派摄像头基础配置及测试

树莓派摄像头基础配置 step 1 硬件连接 硬件连接&#xff0c;注意不要接反了&#xff0c;排线蓝色一段朝向网口的方向。&#xff08;笔者的设备是树莓派4B&#xff09; step 2 安装raspi-config 安装 raspi-config raspi-config在raspbian中是预装的&#xff0c;而在kali、…

使用百度云智能SDK和树莓派搭建简易的人脸识别系统 Python语言版

硬件 树莓派4B一个CSI摄像头一个 笔者使用的是树莓派4B和CSI摄像头&#xff0c;但是树莓派3和USB摄像头等相似设备均可。 百度云智能设置 Step 1 登录 百度云智能 网址https://cloud.baidu.com/ 首先登录百度账号&#xff0c;与百度云、百度贴吧等互通&#xff0c;可直接…

xp搭建 php环境,windows xp 下 LAMP环境搭建

1. apache安装步骤如下图在浏览器中输入&#xff1a;localhost&#xff0c;出现下面页面说明已成功安装apache。2. mysql安装如下图显示在运行里面输入cmd &#xff0c;然后连接测试mysql &#xff0c;如图所示&#xff1a;3. php安装(1)将php压缩包解压到安装路径中的php目录…

C++中的虚函数(表)实现机制以及用C语言对其进行的模拟实现

C中的虚函数(表)实现机制以及用C语言对其进行的模拟实现 声明&#xff1a;本文非博主原创&#xff0c;转自https://blog.twofei.com/496/&#xff0c;博主读后受益良多&#xff0c;特地转载&#xff0c;一是希望好文能有更多人看到&#xff0c;二是为了日后自己查阅。 前言 …

C++中数组和指针的关系(区别)详解

C中数组和指针的关系&#xff08;区别&#xff09;详解 本文转自&#xff1a;http://c.biancheng.net/view/1472.html 博主在阅读后将文中几个知识点提出来放在前面&#xff1a; 没有方括号和下标的数组名称实际上代表数组的起始地址&#xff0c;这意味着数组名称实际上就是…

安装php独立环境,0507-php独立环境的安装与配置 Web程序 - 贪吃蛇学院-专业IT技术平台...

1.在一个纯英文目录下新建三个文件夹2.安装apache(选择好版本)过程中该填的按格式填好&#xff0c;其余的只更改安装目录即可如果报错1901是安装版本的问题。检查&#xff1a;安装完成后localhost打开为It works!添加到电脑属性环境变量&#xff1a;3.将php文件解压文档放到AMP…

linux中PATH变量-详细介绍

转自&#xff1a;https://blog.csdn.net/haozhepeng/article/details/100584451 转载者勘误 原文最后提到的 echo 命令对于环境变量的修改无影响。这是肯定的&#xff0c;echo 命令相当于只是一个打印的函数&#xff08;比如 Python 中的 print&#xff09;。这里要修改环境变…

php assert eval,代码执行函数之一句话木马

前言大家好&#xff0c;我是阿里斯&#xff0c;一名IT行业小白。非常抱歉&#xff0c;昨天的内容出现瑕疵比较多&#xff0c;今天重新整理后再次发出&#xff0c;修改并添加了细节&#xff0c;另增加了常见的命令执行函数如果哪里不足&#xff0c;还请各位表哥指出。eval和asse…

显卡、显卡驱动、CUDA、CUDA Toolkit、cuDNN 梳理

显卡、显卡驱动、CUDA、CUDA Toolkit、cuDNN 梳理 转自&#xff1a;https://www.cnblogs.com/marsggbo/p/11838823.html#nvccnvidia-smi GPU型号含义 显卡&#xff1a; 简单理解这个就是我们前面说的GPU&#xff0c;尤其指NVIDIA公司生产的GPU系列&#xff0c;因为后面介绍的…

VS Code的Error: Running the contributed command: ‘_workbench.downloadResource‘ failed解决

VS Code的Error: Running the contributed command: _workbench.downloadResource failed解决 转自&#xff1a;https://blog.csdn.net/ibless/article/details/118610776 1 问题描述 此前&#xff0c;本人参考网上教程在VS Code中配置了“Remote SSH”插件&#xff08;比如这…

Oracle闪回报错,oracle 闪回区满了,ORA-19815

oracle 闪回区满了&#xff0c;查看日志报错&#xff1a;ORA-19815&#xff0c;命令行输入&#xff1a;sqlplus / as sysdbastartup mount //如果你的数据库出现了无法连接的情况时&#xff0c;可以加上这句select file_type, percent_space_used as used,percent_space_rec…

[2021-ICCV] MUSIQ Multi-scale Image Quality Transformer 论文简析

[2021-ICCV] MUSIQ: Multi-scale Image Quality Transformer 论文简析 论文&#xff1a;https://arxiv.org/abs/2108.05997 代码&#xff1a;https://github.com/google-research/google-research/tree/master/musiq 概述 当前SOTA的IQA&#xff08;图像质量评估&#xff0…

安装oracle不动了,windows2008安装ORACLE到2%不动的问题 | 信春哥,系统稳,闭眼上线不回滚!...

最近又有网友遇到在windows2008服务器上安装ORACLE软件时到2%就卡住不动的问题&#xff0c;下面是该网友的描述&#xff1a;oralce 11g r2 windows server 2008 R2安装到最后一步复制数据文件时卡到2% 不走了内存一直飙升求解决这个问题前段时间也有人遇到过&#xff0c;但是他…

手把手教你入门Git --- Git使用指南(Linux)

手把手教你入门Git — Git使用指南&#xff08;Linux&#xff09; 系统&#xff1a;ubuntu 18.04 LTS 本文所有git命令操作实验具有连续性&#xff0c;git小白完全可以从头到尾跟着本文所有给出的命令走一遍&#xff0c;就会对git有一个初步的了解&#xff0c;应当能做到会用并…

php数据关系图,如何利用navicat查看数据表的ER关系图

文章背景&#xff1a;(相关推荐&#xff1a;navicat)由于工作需要&#xff0c;现在要分析一个数据库&#xff0c;然后查看各个表之间的关系&#xff0c;所以需要查看表与表之间的关系图&#xff0c;专业术语叫做ER关系图。默认情况下&#xff0c;Navicat显示的界面是这样的&…

Linux中g++与gcc的区别

转自&#xff1a;https://blog.csdn.net/bit_clearoff/article/details/53965514 Windows中我们常用vs来编译编写好的C和C代码&#xff1b;vs把编辑器&#xff0c;编译器和调试器等工具都集成在这一款工具中&#xff0c;在Linux下我们能用什么工具来编译所编写好的代码呢&#…

从C源代码到可执行文件的四个过程:预处理、编译、汇编、链接

从C源代码到可执行文件的四个过程&#xff1a;预处理、编译、汇编、链接 总览 我们将在Linux操作系统中&#xff0c;以C语言的Hello World程序为例&#xff0c;用gcc编译器分步执行这四个步骤。 我们有再熟悉不过的HelloWorld程序&#xff0c;hello.c&#xff1a; #include …

linux内核中cent文件夹,Centos 中如何快速定制二进制的内核 RPM 包

1、rpm 制作前的环境准备&#xff1a;yum install -y ncurses-devel qt-devel rpm-build redhat-rpm-config asciidoc hmaccalc perl-ExtUtils-Embed xmlto audit-libs-devel binutils-devel elfutils-devel elfutils-libelf-devel newt-devel python-devel zlib-devel bc2、准…

TabError- inconsistent use of tabs and spaces in indentation 查验及解决方法

TabError: inconsistent use of tabs and spaces in indentation 查验及解决方法 报错代码 def eccv16(pretrainedTrue):model ECCVGenerator()if(pretrained):import torch.utils.model_zoo as model_zoomodel.load_state_dict(torch.load(/home/ps/.cache/torch/hub/check…