解决Vision Transformer在任意尺寸图像上微调的问题:使用timm库

解决Vision Transformer在任意尺寸图像上微调的问题:使用timm库

文章目录

          • 一、ViT的微调问题的本质
          • 二、Positional Embedding如何处理
            • 1,绝对位置编码
            • 2,相对位置编码
            • 3,对位置编码进行插值
          • 三、Patch Embedding Layer如何处理
          • 四、使用timm库来对任意尺寸进行微调

一、ViT的微调问题的本质

自从ViT被提出以来,在CV领域引起了新的研究热潮。理论上来说,Transformer的输入是一个序列,并且其参数主要来自于Transformer Block中的Linear层,因此Transformer可以处理任意长度的输入序列。但是在Vision Transformer中,由于需要将二维的图像通过Patch Embedding Layer映射为一个一维的序列,并且需要添加pos_embedding来保留位置信息。因此当patch_size和img_size发生改变时,会造成pos_embbeding的长度和Patch Embedding Layer的参数发生改变,从而导致预训练权重无法直接加载。更多有关ViT的实现细节和原理,可以参考Vision Transformer , 通用 Vision Backbone 超详细解读 (原理分析+代码解读)。

二、Positional Embedding如何处理

在Vision Transformer中有两种主流的编码方式:相对位置编码和绝对位置编码。

1,绝对位置编码

绝对位置编码依据token每个的绝对位置分配一个固定的值,其本质上是一组一维向量,有两种实现方式:

# 可学习的位置编码,ViT中使用, +1是因为有cls_tokenself.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))# 根据正余弦获取位置编码, Transformer中使用
def get_positional_embeddings(sequence_length,dim):result = torch.ones(sequence_length,dim)for i in range(sequence_length):for j in range(dim):result[i][j] = np.sin(i/(10000**(j/dim))) if j %2==0 else np.cos(i/(10000**((j-1)/dim)))return result

在forward过程中,绝对位置编码会在最开始直接和token相加:

	tokens += self.pos_embedding[:, :(n + 1)]
2,相对位置编码

相对位置编码,依据每个token的query相对于key的位置来分配位置编码,典型例子就是swin transformer,其本质是构建一个可学习的二维table,然后依据相对位置索引(x,y)来从table中取值,具体可以参考:有关swin transformer相对位置编码的理解

不过,在swin transformer中,query和key都是来自于同一个window,因此query和key的数量相同,构建位置编码的方式相对来说比较简单。如果query和key的数量不同,例如Focal Transformer中多层次的self-attention,其位置编码的方式可以参考:Focal Transformer。

对于相对位置编码的构造,还有一种方式是CrossFormer中提出的Dynamic Position Bias。其核心思想为构建一个MLP,其输入是二维的相对位置索引,输出是指定dim的位置偏置。这个和根据正余弦获取位置编码有点类似,只不过一个是依据一维的绝对坐标来生成位置编码,一个是依据二维的相对坐标来生成位置编码。

image-20231122172434141

在forward过程中,相对位置编码不会在一开始与token相加,而是在Attention Layer中以Bias的形式参与self-attention计算,核心代码如下:

        attn = (q @ k.transpose(-2, -1))relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nHrelative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Wwattn = attn + relative_position_bias.unsqueeze(0)
3,对位置编码进行插值

综上,我们可以依据实现方式将位置编码分为两大类:可学习的位置编码(例如,ViT、Swin Transformer、Focal Transformer等)和生成式的位置编码(例如,正余弦位置编码和CrossFormer中的DPB)。更多有关位置编码的内容,可以参考论文:Rethinking and Improving Relative Position Encoding for Vision Transformer。对于生成式的位置编码而言,其编码方式与序列长度无关,因此当patch_size和img_size改变而造成num_patches改变时,仍然可以加载与位置编码有关的预训练权重。

但是,对于可学习的位置编码而言,num_patches改变时,无法直接加载与位置编码的预训练权重。以ViT为例,其参数一般是一个shape为[N+1, C]的tensor。与cls_token有关的位置编码不用改变,我们只需要关心与img patch相关的位置编码即可,其shape为[N, C]。当num_patches变为n时,所需要位置编码shape为[n, C]。这显然无法直接加载预训练权重。

Pytorch官方提供了一种思路,通过插值算法,来获取新的权重。我们不妨将原始的位置编码想象为一个shape为[ N , N , C \sqrt{N}, \sqrt{N}, C N ,N ,C]的tensor,将所需要的位置编码想象为一个shape为[ n , n , C \sqrt{n}, \sqrt{n}, C n ,n ,C]。这样我们就可以通过插值算法,将原始的权重映射到所需要的权重上。核心代码如下:

# (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d)new_seq_length_1d = image_size // patch_size# Perform interpolation.# (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)new_pos_embedding_img = nn.functional.interpolate(pos_embedding_img,size=new_seq_length_1d,mode=interpolation_mode,align_corners=True,)# (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length)# (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1)

不过,Pytorch官方的这个代码,只能适配当num_patches是一个完全平方数的情况,因为需要开根号操作。实际上,num_patches一般是通过如下方式计算获得,理论上来说通过插值算法是可以适配到任意尺寸的num_patches的。

n u m _ p a t c h e s = i m g _ s i z e h p a t c h _ s i z e h i m g _ s i z e w p a t c h _ s i z e w (1) num\_patches=\frac{img\_size_h}{patch\_size_h}\frac{img\_size_w}{patch\_size_w} \tag{1} num_patches=patch_sizehimg_sizehpatch_sizewimg_sizew(1)

从上式可以看出,pos_embedding主要与img_size/patch_size有关,因此当把img_size和patch_size等比例缩放时,是不需要调整pos_embedding的。

在timm库中,提供了resample_abs_pos_embed函数,并将其集成到了VisionTransformer类中,所以我们在使用时无需自己考虑对位置编码进行插值处理。

三、Patch Embedding Layer如何处理

Patch Embedding Layer用于将二维的图像转为一维的输入序列,其实现方式通常有两种,如下所示:

### 基于MLP的实现方式patch_dim = in_channels * patch_height * patch_widthself.patch_embedding = nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), # 使用einops库nn.LayerNorm(patch_dim),nn.Linear(patch_dim, dim),nn.LayerNorm(dim),)### 基于Conv2d的实现方式self.patch_embedding = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)

从这两种实现可以看出,Patch Embedding Layer的参数主要与patch_size和in_channels有关,而与img_size无关。Pytorch官方和Timm库都采用基于Conv2d的方式来实现,当patch_size和in_channels改变时,无法直接加载预训练权重。

Pytorch官方并未给出解决方案,timm库通过resample_patch_embed来解决这一问题,并且也集成到了VisionTransformer类中。在使用时,我们也不需要考虑手动对Patch Embedding Layer的权重进行调整。

四、使用timm库来对任意尺寸进行微调

首先需要安装timm库

pip install timm
# 如果安装的Pytorch2.0及以上版本,无需考虑一下步骤
# 如果是其他版本的Pytorch,需要下载functorch库
pip install functorch==版本号
# 具体版本号,需要依据自己环境中的pytorch版本来
# 例如:0.20.0对应Pytorch1.12.0,0.2.1对应Pytorch1.12.1
# 对应关系可以去github上查看:https://github.com/pytorch/functorch/releases

代码示例如下:

import timm
from timm.models.registry import register_model@register_model # 注册模型
def vit_tiny_patch4_64(pretrained: bool = False, **kwargs) -> VisionTransformer:""" ViT-Tiny (Vit-Ti/16)"""# 在model_args中对需要部分参数进行修改,此处调整了img_size, patch_size和in_chansmodel_args = dict(img_size = 64, patch_size=4, in_chans=1, embed_dim=192, depth=12, num_heads=3) # vit_tiny_patch16_224是想要加载的预训练权重对应的模型model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model# 注册模型之后,就可以通过create_model来创建模型了
vit = timm.create_model('vit_tiny_patch4_64', pretrained = True) 

不过,由于预训练权重在线下载一般比较慢,可以通过pretrained_cfg来实现加载本地模型,代码如下:

    vit = timm.create_model('vit_tiny_patch4_64')cfg = vit.default_cfgprint(cfg['url']) # 查看下载的url来手动下载cfg['file'] = 'vit-tiny.npz' # 这里修改为你下载的模型vit = timm.create_model('vit_tiny_patch4_64', pretrained=True, pretrained_cfg=cfg).cuda()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/159720.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

气膜体育馆:低碳环保体育新潮流

在追求健康生活的今天,体育运动的重要性无法忽视。为了满足人民日益增长的体育需求,气膜体育馆应运而生,成为体育场馆领域的一次革命性创新。这种新型体育馆解决了传统体育场馆建设中面临的审批难、周期长、门槛高等问题,为我们的…

马蹄集oj赛(双周赛第十五次)

目录 小码哥的开心数字 淘金者 捡麦子 小码哥玩游戏 手机测试 自动浇花机 买月饼 未来战争 双人成行 魔法水晶球 ​编辑自驾游 文章压缩 银河贸易市场 小码哥的开心数字 子难度:青铜 0时间限制:1秒 巴占用内存:64M 小码哥有超能…

深入浅出 Linux 中的 ARM IOMMU SMMU I

Linux 系统下的 SMMU 介绍 在计算机系统架构中,与传统的用于 CPU 访问内存的管理的 MMU 类似,IOMMU (Input Output Memory Management Unit) 将来自系统 I/O 设备的 DMA 请求传递到系统互连之前,它会先转换请求的地址,并对系统 I…

海外IP代理:数据中心代理IP是什么?好用吗?

数据中心代理是代理IP中最常见的类型,也被称为机房IP。这些代理服务器为用户分配不属于 ISP(互联网服务提供商)而来自第三方云服务提供商的 IP 地址。数据中心代理的最大优势——它们允许在访问网络时完全匿名。 如果你正在寻找海外代理IP&am…

【JavaSE】-4-单层循环结构

回顾 运算符: 算术 --、逻辑 && & || |、比较 、三元 、赋值 int i 1; i; j i; //j2 i3 syso(--j"-----"i) //1 3 选择结构 if(){} if(){}else{} if(){}else if(){}else if(){}else{}//支持byte、short、int //支持char //支持枚举…

动态规划:2304. 网格中的最小路径代价

2304. 网格中的最小路径代价 给你一个下标从 0 开始的整数矩阵 grid ,矩阵大小为 m x n ,由从 0 到 m * n - 1 的不同整数组成。你可以在此矩阵中,从一个单元格移动到 下一行 的任何其他单元格。如果你位于单元格 (x, y) ,且满足…

网络安全之渗透测试入门准备

渗透测试入门所需知识 操作系统基础:Windows,Linux 网络基础:基础协议与简单原理 编程语言:PHP,python web安全基础 渗透测试入门 渗透测试学习: 1.工具环境准备:①VMware安装及使用&#xff1b…

BUUCTF--[ACTF2020 新生赛]Include

目录 1、本题详解 2、延伸拓展 1、本题详解 访问题目链接 有一个tips的链接,我们点击 请求了file,内容是flag.php的内容:Can you find out the flag? 尝试请求一下index.php 并没有发现什么信息 flag.php也没发现什么 尝试爆破一下它的…

java游戏制作-飞翔的鸟游戏

一.准备工作 首先创建一个新的Java项目命名为“飞翔的鸟”,并在src中创建一个包命名为“com.qiku.bird",在这个包内分别创建4个类命名为“Bird”、“BirdGame”、“Column”、“Ground”,并向需要的图片素材导入到包内。 二.代码呈现 …

Android线程优化——整体思路与方法

**在日常开发APP的过程中,难免需要使用第二方库和第三方库来帮助开发者快速实现一些功能,提高开发效率。但是,这些库也可能会给线程带来一定的压力,主要表现在以下几个方面: 线程数量增多:一些库可能会在后…

AIGC 是通向 AGI 的那条路吗?

AIGC 是通向 AGI 的那条路吗? 目录 一、背景知识 1.1、AGI(人工通用智能) 1.1.1、概念定义 1.1.2、通用人工智能特质 1.1.3、通用人工智能需要掌握能力 1.2、AIGC 二、AIGC 是通向 AGI 的那条路吗? 三、当前实现真正的 A…

【云原生】初识 Service Mesh

目录 一、什么是Service Mesh 二、微服务发展历程 2.1 微服务架构演进历史 2.1.1 单体架构 2.1.2 SOA阶段 2.1.3 微服务阶段 2.2 微服务治理中的问题 2.2.1 技术栈庞杂 2.2.2 版本升级碎片化 2.2.3 侵入性强 2.2.4 中间件多,学习成本高 2.2.5 服务治理功…

知虾数据软件:电商人必备知虾数据软件,轻松掌握市场趋势

在当今数字化时代,数据已经成为了企业决策的重要依据。对于电商行业来说,数据更是至关重要。如果你想在电商领域中脱颖而出,那么你需要一款强大的数据分析工具来帮助你更好地了解市场、分析竞争对手、优化运营策略。而知虾数据软件就是这样一…

【React-Router】导航传参

1. searchParams 传参 // /page/Login/index.js import { Link, useNavigate } from react-router-dom const Login () > {const navigate useNavigate()return <div>登录页<button onClick{() > navigate(/article?id91&namejk)}>searchParams 传参…

SpringBoot中使用注解的方式创建队列和交换机

SpringBoot中使用注解的方式创建队列和交换机 前言 最开始蘑菇博客在进行初始化配置的时候&#xff0c;需要手动的创建交换机&#xff0c;创建队列&#xff0c;然后绑定交换机&#xff0c;这个步骤是非常繁琐的&#xff0c;而且一不小心的话&#xff0c;还可能就出了错误&…

phpinfo中的重要信息

phpinfo中的重要信息 1.PHP/操作系统版本信息2.Configuration File(ini配置文件位置)3.Registered PHP Streams(支持的流)4.Registered Stream Filters(支持的流过滤器)5.allow_url_fopen&allow_url_include6.disable_functions7.display_errors8.include_path9.open_based…

【OpenCV实现图像:使用OpenCV进行物体轮廓排序】

文章目录 概要读取图像获取轮廓轮廓排序小结 概要 在图像处理中&#xff0c;经常需要进行与物体轮廓相关的操作&#xff0c;比如计算目标轮廓的周长、面积等。为了获取目标轮廓的信息&#xff0c;通常使用OpenCV的findContours函数。然而&#xff0c;一旦获得轮廓信息后&#…

Java8新特性 ----- Lambda表达式和方法引用/构造器引用详解

前言 在讲一下内容之前,我们需要引入函数式接口的概念 什么是函数式接口呢? 函数式接口&#xff1a;有且仅有一个抽象方法的接口 java中函数式编程的体现就是Lambda表达式,你可以认为函数式接口就是适用于Lambda表达式的接口. 也可以加上注解来在编译层次上限制函数式接口 Fun…

视频云存储EasyCVR平台国标接入获取通道设备未回复是什么原因?该如何解决?

安防视频监控/视频集中存储/云存储/磁盘阵列EasyCVR平台可拓展性强、视频能力灵活、部署轻快&#xff0c;可支持的主流标准协议有国标GB28181、RTSP/Onvif、RTMP等&#xff0c;以及支持厂家私有协议与SDK接入&#xff0c;包括海康Ehome、海大宇等设备的SDK等。平台既具备传统安…