解决Vision Transformer在任意尺寸图像上微调的问题:使用timm库

解决Vision Transformer在任意尺寸图像上微调的问题:使用timm库

文章目录

          • 一、ViT的微调问题的本质
          • 二、Positional Embedding如何处理
            • 1,绝对位置编码
            • 2,相对位置编码
            • 3,对位置编码进行插值
          • 三、Patch Embedding Layer如何处理
          • 四、使用timm库来对任意尺寸进行微调

一、ViT的微调问题的本质

自从ViT被提出以来,在CV领域引起了新的研究热潮。理论上来说,Transformer的输入是一个序列,并且其参数主要来自于Transformer Block中的Linear层,因此Transformer可以处理任意长度的输入序列。但是在Vision Transformer中,由于需要将二维的图像通过Patch Embedding Layer映射为一个一维的序列,并且需要添加pos_embedding来保留位置信息。因此当patch_size和img_size发生改变时,会造成pos_embbeding的长度和Patch Embedding Layer的参数发生改变,从而导致预训练权重无法直接加载。更多有关ViT的实现细节和原理,可以参考Vision Transformer , 通用 Vision Backbone 超详细解读 (原理分析+代码解读)。

二、Positional Embedding如何处理

在Vision Transformer中有两种主流的编码方式:相对位置编码和绝对位置编码。

1,绝对位置编码

绝对位置编码依据token每个的绝对位置分配一个固定的值,其本质上是一组一维向量,有两种实现方式:

# 可学习的位置编码,ViT中使用, +1是因为有cls_tokenself.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))# 根据正余弦获取位置编码, Transformer中使用
def get_positional_embeddings(sequence_length,dim):result = torch.ones(sequence_length,dim)for i in range(sequence_length):for j in range(dim):result[i][j] = np.sin(i/(10000**(j/dim))) if j %2==0 else np.cos(i/(10000**((j-1)/dim)))return result

在forward过程中,绝对位置编码会在最开始直接和token相加:

	tokens += self.pos_embedding[:, :(n + 1)]
2,相对位置编码

相对位置编码,依据每个token的query相对于key的位置来分配位置编码,典型例子就是swin transformer,其本质是构建一个可学习的二维table,然后依据相对位置索引(x,y)来从table中取值,具体可以参考:有关swin transformer相对位置编码的理解

不过,在swin transformer中,query和key都是来自于同一个window,因此query和key的数量相同,构建位置编码的方式相对来说比较简单。如果query和key的数量不同,例如Focal Transformer中多层次的self-attention,其位置编码的方式可以参考:Focal Transformer。

对于相对位置编码的构造,还有一种方式是CrossFormer中提出的Dynamic Position Bias。其核心思想为构建一个MLP,其输入是二维的相对位置索引,输出是指定dim的位置偏置。这个和根据正余弦获取位置编码有点类似,只不过一个是依据一维的绝对坐标来生成位置编码,一个是依据二维的相对坐标来生成位置编码。

image-20231122172434141

在forward过程中,相对位置编码不会在一开始与token相加,而是在Attention Layer中以Bias的形式参与self-attention计算,核心代码如下:

        attn = (q @ k.transpose(-2, -1))relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nHrelative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Wwattn = attn + relative_position_bias.unsqueeze(0)
3,对位置编码进行插值

综上,我们可以依据实现方式将位置编码分为两大类:可学习的位置编码(例如,ViT、Swin Transformer、Focal Transformer等)和生成式的位置编码(例如,正余弦位置编码和CrossFormer中的DPB)。更多有关位置编码的内容,可以参考论文:Rethinking and Improving Relative Position Encoding for Vision Transformer。对于生成式的位置编码而言,其编码方式与序列长度无关,因此当patch_size和img_size改变而造成num_patches改变时,仍然可以加载与位置编码有关的预训练权重。

但是,对于可学习的位置编码而言,num_patches改变时,无法直接加载与位置编码的预训练权重。以ViT为例,其参数一般是一个shape为[N+1, C]的tensor。与cls_token有关的位置编码不用改变,我们只需要关心与img patch相关的位置编码即可,其shape为[N, C]。当num_patches变为n时,所需要位置编码shape为[n, C]。这显然无法直接加载预训练权重。

Pytorch官方提供了一种思路,通过插值算法,来获取新的权重。我们不妨将原始的位置编码想象为一个shape为[ N , N , C \sqrt{N}, \sqrt{N}, C N ,N ,C]的tensor,将所需要的位置编码想象为一个shape为[ n , n , C \sqrt{n}, \sqrt{n}, C n ,n ,C]。这样我们就可以通过插值算法,将原始的权重映射到所需要的权重上。核心代码如下:

# (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d)new_seq_length_1d = image_size // patch_size# Perform interpolation.# (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)new_pos_embedding_img = nn.functional.interpolate(pos_embedding_img,size=new_seq_length_1d,mode=interpolation_mode,align_corners=True,)# (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length)# (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1)

不过,Pytorch官方的这个代码,只能适配当num_patches是一个完全平方数的情况,因为需要开根号操作。实际上,num_patches一般是通过如下方式计算获得,理论上来说通过插值算法是可以适配到任意尺寸的num_patches的。

n u m _ p a t c h e s = i m g _ s i z e h p a t c h _ s i z e h i m g _ s i z e w p a t c h _ s i z e w (1) num\_patches=\frac{img\_size_h}{patch\_size_h}\frac{img\_size_w}{patch\_size_w} \tag{1} num_patches=patch_sizehimg_sizehpatch_sizewimg_sizew(1)

从上式可以看出,pos_embedding主要与img_size/patch_size有关,因此当把img_size和patch_size等比例缩放时,是不需要调整pos_embedding的。

在timm库中,提供了resample_abs_pos_embed函数,并将其集成到了VisionTransformer类中,所以我们在使用时无需自己考虑对位置编码进行插值处理。

三、Patch Embedding Layer如何处理

Patch Embedding Layer用于将二维的图像转为一维的输入序列,其实现方式通常有两种,如下所示:

### 基于MLP的实现方式patch_dim = in_channels * patch_height * patch_widthself.patch_embedding = nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), # 使用einops库nn.LayerNorm(patch_dim),nn.Linear(patch_dim, dim),nn.LayerNorm(dim),)### 基于Conv2d的实现方式self.patch_embedding = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)

从这两种实现可以看出,Patch Embedding Layer的参数主要与patch_size和in_channels有关,而与img_size无关。Pytorch官方和Timm库都采用基于Conv2d的方式来实现,当patch_size和in_channels改变时,无法直接加载预训练权重。

Pytorch官方并未给出解决方案,timm库通过resample_patch_embed来解决这一问题,并且也集成到了VisionTransformer类中。在使用时,我们也不需要考虑手动对Patch Embedding Layer的权重进行调整。

四、使用timm库来对任意尺寸进行微调

首先需要安装timm库

pip install timm
# 如果安装的Pytorch2.0及以上版本,无需考虑一下步骤
# 如果是其他版本的Pytorch,需要下载functorch库
pip install functorch==版本号
# 具体版本号,需要依据自己环境中的pytorch版本来
# 例如:0.20.0对应Pytorch1.12.0,0.2.1对应Pytorch1.12.1
# 对应关系可以去github上查看:https://github.com/pytorch/functorch/releases

代码示例如下:

import timm
from timm.models.registry import register_model@register_model # 注册模型
def vit_tiny_patch4_64(pretrained: bool = False, **kwargs) -> VisionTransformer:""" ViT-Tiny (Vit-Ti/16)"""# 在model_args中对需要部分参数进行修改,此处调整了img_size, patch_size和in_chansmodel_args = dict(img_size = 64, patch_size=4, in_chans=1, embed_dim=192, depth=12, num_heads=3) # vit_tiny_patch16_224是想要加载的预训练权重对应的模型model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model# 注册模型之后,就可以通过create_model来创建模型了
vit = timm.create_model('vit_tiny_patch4_64', pretrained = True) 

不过,由于预训练权重在线下载一般比较慢,可以通过pretrained_cfg来实现加载本地模型,代码如下:

    vit = timm.create_model('vit_tiny_patch4_64')cfg = vit.default_cfgprint(cfg['url']) # 查看下载的url来手动下载cfg['file'] = 'vit-tiny.npz' # 这里修改为你下载的模型vit = timm.create_model('vit_tiny_patch4_64', pretrained=True, pretrained_cfg=cfg).cuda()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/159720.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

气膜体育馆:低碳环保体育新潮流

在追求健康生活的今天,体育运动的重要性无法忽视。为了满足人民日益增长的体育需求,气膜体育馆应运而生,成为体育场馆领域的一次革命性创新。这种新型体育馆解决了传统体育场馆建设中面临的审批难、周期长、门槛高等问题,为我们的…

马蹄集oj赛(双周赛第十五次)

目录 小码哥的开心数字 淘金者 捡麦子 小码哥玩游戏 手机测试 自动浇花机 买月饼 未来战争 双人成行 魔法水晶球 ​编辑自驾游 文章压缩 银河贸易市场 小码哥的开心数字 子难度:青铜 0时间限制:1秒 巴占用内存:64M 小码哥有超能…

深入浅出 Linux 中的 ARM IOMMU SMMU I

Linux 系统下的 SMMU 介绍 在计算机系统架构中,与传统的用于 CPU 访问内存的管理的 MMU 类似,IOMMU (Input Output Memory Management Unit) 将来自系统 I/O 设备的 DMA 请求传递到系统互连之前,它会先转换请求的地址,并对系统 I…

海外IP代理:数据中心代理IP是什么?好用吗?

数据中心代理是代理IP中最常见的类型,也被称为机房IP。这些代理服务器为用户分配不属于 ISP(互联网服务提供商)而来自第三方云服务提供商的 IP 地址。数据中心代理的最大优势——它们允许在访问网络时完全匿名。 如果你正在寻找海外代理IP&am…

【JavaSE】-4-单层循环结构

回顾 运算符: 算术 --、逻辑 && & || |、比较 、三元 、赋值 int i 1; i; j i; //j2 i3 syso(--j"-----"i) //1 3 选择结构 if(){} if(){}else{} if(){}else if(){}else if(){}else{}//支持byte、short、int //支持char //支持枚举…

力扣labuladong——一刷day47

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、力扣993. 二叉树的堂兄弟节点二、力扣1315. 祖父节点值为偶数的节点和三、力扣1448. 统计二叉树中好节点的数目四、力扣1469. 寻找所有的独生节点 前言 二叉…

动态规划:2304. 网格中的最小路径代价

2304. 网格中的最小路径代价 给你一个下标从 0 开始的整数矩阵 grid ,矩阵大小为 m x n ,由从 0 到 m * n - 1 的不同整数组成。你可以在此矩阵中,从一个单元格移动到 下一行 的任何其他单元格。如果你位于单元格 (x, y) ,且满足…

网络安全之渗透测试入门准备

渗透测试入门所需知识 操作系统基础:Windows,Linux 网络基础:基础协议与简单原理 编程语言:PHP,python web安全基础 渗透测试入门 渗透测试学习: 1.工具环境准备:①VMware安装及使用&#xff1b…

JAVA SQL

-- /* */ -- 简单查询: -- 查询所有字段: select * from 表名 -- *:通配符,代表所有 select * from employees -- 查询部分字段: select 列名1,列名2,.. from 表名 -- 查询员工ID,员工姓名,员工的工资 select employee_id,salary,first_name from employees -- 查…

花木兰从军

《花木兰从军》 作家/罗光记 在遥远的古代,有一位英勇的女子,名叫花木兰。她以一介女子的柔弱之躯,披挂上阵,驰骋沙场,成为了那个时代最为传奇的英雄。 花木兰出生在一个普通的农家,从小便被…

云原生Docker系列 | docker-compose自动编排使用

云原生Docker系列 | docker-compose自动编排使用 1. yaml文件格式要求2. 使用docker compose搭建博客3. docker compose 命令总结1. yaml文件格式要求 yaml是以空格缩进来控制层级关系,不能使用Tab键,而且大小写敏感的。参数要遵循驼峰写法如:imagePullPolicy    yaml缩进…

BUUCTF--[ACTF2020 新生赛]Include

目录 1、本题详解 2、延伸拓展 1、本题详解 访问题目链接 有一个tips的链接,我们点击 请求了file,内容是flag.php的内容:Can you find out the flag? 尝试请求一下index.php 并没有发现什么信息 flag.php也没发现什么 尝试爆破一下它的…

java游戏制作-飞翔的鸟游戏

一.准备工作 首先创建一个新的Java项目命名为“飞翔的鸟”,并在src中创建一个包命名为“com.qiku.bird",在这个包内分别创建4个类命名为“Bird”、“BirdGame”、“Column”、“Ground”,并向需要的图片素材导入到包内。 二.代码呈现 …

Android线程优化——整体思路与方法

**在日常开发APP的过程中,难免需要使用第二方库和第三方库来帮助开发者快速实现一些功能,提高开发效率。但是,这些库也可能会给线程带来一定的压力,主要表现在以下几个方面: 线程数量增多:一些库可能会在后…

AIGC 是通向 AGI 的那条路吗?

AIGC 是通向 AGI 的那条路吗? 目录 一、背景知识 1.1、AGI(人工通用智能) 1.1.1、概念定义 1.1.2、通用人工智能特质 1.1.3、通用人工智能需要掌握能力 1.2、AIGC 二、AIGC 是通向 AGI 的那条路吗? 三、当前实现真正的 A…

window下查看端口是否被占用

window下查看端口是否被占用 1、开始---->运行---->cmd,或者是windowR组合键,调出命令窗口; 2、输入命令:netstat -ano,列出所有端口的情况。在列表中我们观察被占用的端口,比如是49157&#xff0c…

10.26~10.27论文,ALP: AdaptiveLossless floating-Point Compression

ALP使用自适应编码,加强版的伪小数去编码双进度小数作为整数,如果它们产生于小数,否则就用矢量化压缩小数的前位。它的高速来源于在标量编码的实现,自动矢量化,使用他们FastLanes的库,一个高效的二阶段压缩…

【云原生】初识 Service Mesh

目录 一、什么是Service Mesh 二、微服务发展历程 2.1 微服务架构演进历史 2.1.1 单体架构 2.1.2 SOA阶段 2.1.3 微服务阶段 2.2 微服务治理中的问题 2.2.1 技术栈庞杂 2.2.2 版本升级碎片化 2.2.3 侵入性强 2.2.4 中间件多,学习成本高 2.2.5 服务治理功…

年报类型页面动画

一. 动画分析 年报类型页面的动画都是第一行字体先触发动画 之后第二行字体再出现动画 以此类推 二. 动画写法 1.动画封装 keyframes identifier {0% {opacity: 0;transform: translateY(40px);}100% {opacity: 1;transform: translateY(0);} } 2. 添加动画 a. 第一行字…

导购APP与淘客查券返利机器人的对比:深刻揭露行业本质

导购APP与淘客查券返利机器人的对比:深刻揭露行业本质 随着互联网的快速发展,购物方式也发生了翻天覆地的变化。导购APP和淘客机器人成为了消费者购物时的两种常见工具。然而,在这两者之间,究竟该如何选择呢?本文将通…