T2T VIT 学习笔记(附代码)

论文地址:https://arxiv.org/abs/2101.11986

代码地址:https://github.com/PaddlePaddle/PASSL/tree/main/configs/t2t_vit

1.是什么?

T2T-ViT是一种基于Transformer的视觉模型,用于图像分类任务。它通过将图像分割成小的图块,并使用Transformer模型对这些图块进行编码和处理,从而实现对图像的分类。

T2T-ViT模型的复杂度可以根据不同的配置进行调整。作者设计了几种不同复杂度的T2T-ViT模型,以与常见的手动设计的CNN模型进行对比。这些模型包括T2T-ViT-14、T2T-ViT-19、T2T-ViT-24、T2T-ViT-7、T2T-ViT-10和T2T-ViT-12,分别对应于ResNet50、ResNet101、ResNet152、MobileNetV1、MobileNetV2等模型的复杂度。

在实验中,作者对这些T2T-ViT模型进行了评估和比较。他们使用了常见的图像分类数据集,并通过计算模型在测试集上的准确率来评估模型的性能。实验结果表明,T2T-ViT模型在图像分类任务上取得了与手动设计的CNN模型相媲美甚至更好的性能。

2.为什么?

通过分析发现:

(1) 输入图像的简单token化难以很好的建模近邻像素间的重要局部结构(比如边缘、线条等),这就导致了少量样本时的低效性;

(2) 在固定计算负载与有限训练样本约束下,ViT中的冗余注意力骨干设计限制了特征的丰富性。可参考下图,ResNet50可以逐步捕获到期望的局部结构(比如边缘、线条、纹理等),然而ViT的结构建模信息较差,反而全局相关性被更好的获取。与此同时,可以看到:ViT的特征中存在零值,这导致其不如ResNet高效,限制了其特征丰富性

为克服上述局限性,我们提出了一种新的Tokens-to-Token Vision Transformer,T2T-ViT,它引入了(1) 层级Tokens-to-Token变换通过递归的集成近邻Tokens为Token将渐进的将图像结构化为tokens,因此局部结构可以更好的建模且tokens长度可以进一步降低;(2) 受CNN架构设计启发,设计一种高效的deep-narrow的骨干结构用于ViT。

相比标准ViT,所提T2T-ViT的参数量与MACs可以减低200%,并取得2.5%的性能提升(注仅用ImageNet从头开始训练);所提T2T-ViT同样取得了优于ResNet的性能,比如,在ImageNet数据集上,T2T-ViT-ResNet50取得了80.7%的Top1精度。可参考下图。

 

总而言之,本文的主要贡献包含以下几个方面:

  • 首次通过精心设计Transformer结构在标准ImageNet数据集上取得了全面超越CNN的性能,而无需在JFT-300M数据进行预训练;
  • 提出一种新颖的渐进式Token化机制用于ViT,并证实了其优越性,所提T2T模块可以更好的协助每个token建模局部重要结构信息;
  • CNN的架构设计思想有助于ViT的骨干结构设计并提升其特征丰富性、减少信息冗余。通过实验发现:deep-narrow结构设计非常适合于ViT。

3.怎么样?

3.1网络结构

3.1.1Tokens-to-Token

Tokens-to-Token(T2T)模块旨在克服ViT中简单Token化机制的局限性,它采用渐进式方式将图像结构化为token并建模局部结构信息;而Tokens的长度可以通过渐进式迭代降低,每个T2T过程包含两个步骤:Restructurization与SoftSplit,见下图。

Re-structurization 

 如上图所示,给定Tokens序列T,将通过自注意力模块对齐进行变换处理,可以描述为:

注:MSA表示多头自注意力模块,MLP表示多层感知器。经过上述变换后,Tokens将在空间维度上reshape为图像形式,描述如下: 

 Soft Split

正如上图所示,在得到重结构化图像I后,作者对其进行软拆分操作以建模局部结构信息,降低Tokens长度。为避免图像到tokens过程中的信息损失,将图像拆分为重叠块,也就是说:每个块将与近邻块之间构建相关性。每个拆分块内的Token将通过Concat方式变换为一个Token(即Tokens-to-Token),因此可以更好的建模局部信息。作者假设每个块大小为k*k ,重叠尺寸为s,padding为p,对于重建图像I\epsilon R^{h*w*c} ,其对应的输出Token 可以表示为如下尺寸:

注:每个拆分块尺寸为k*k*c 。最后将所有块在空域维度上flatten为Token T_{o\epsilon}\epsilon R^{l_{o}*ck^{2}}。这里所得到的输出Token将被送入到下一个T2T处理过程。 

3.1.2 T2T module

通过交替执行上述Re-structurization与Soft Split操作,T2T模块可以逐渐的减少Token的长度、变换图像的空间结构。T2T模块可以表示为如下形式:

对于输入图像T_{o} ,作者采用SoftSplit操作将其拆分为Token:T_{1}=SS(I_{o}) 。在完成最后的迭代后,输出TokenT_{f} 具有固定IG长度,因此T2T-ViT可以在 T_{f}上建模全局相关性。

另外,由于T2T模块中的Token长度要比常规形式的更大,故而MAC与内存占用会非常大。为克服该局限性,在T2T模块中,作者设置通道维度为32(或者64)以降低MACs。

例子:

为了方便理解流程,这里不考虑 batch size,定义输入尺寸为 3, 224, 224,首先经过一次 soft split,将图像转为二维张量,这个过程利用 nn.Unfold 函数,如下所示

然后我们通过 T2T Transformer 来提取 Token 的全局信息,T2T Transformer 和 ViT 的 MSA 极其相似,都采用了 Self-Attention 结构 

接下来就是 T2T process 部分,论文给出的流程图如下所示,其中我们目前得到的数据(size=3136, 64)对应下图的 Ti’,我们可以看见 T2T process 操作流程为Reshape+Unfold

为什么要这样做呢?怎么理解上图的 Tokens to Token?

实际上 soft split 操作是用来模拟局部结构信息,作者希望建立一个先验,即周围的 Token 之间应该有更强的关联性,如上图所示,token1、token2、token4、token5 这4个 token 局部相邻,将它们通过 Unfold(soft split 操作)来融合成新的 token,如上图 Ti+1[0] 所示

T2T process 详细输出变化如下所示

 其中 T2T Transformer 用于提取全局表征,T2T process 用于引入局部先验,这两个模块交替进行特征提取,得到输出为196, 576,然后通过 Linear 得到 final token(196,576 --Linear--> 196,786)。

3.1.3 T2T-ViT Backbone

 由于ViT骨干中的不少通道是无效的,故而作者计划设计一种高效骨干以降低冗余提升特征丰富性。T2T-ViT将CNN架构设计思想引入到ViT骨干设计以提升骨干的高效性、增强所学习特征的丰富性。由于每个Transformer具有类似ResNet的跳过连接,一个最直接的想法是采用类似DenseNet的稠密连接提升特征丰富性;或者采用Wide-ResNet、ResNeXt结构改变通道维度。本文从以下五个角度进行了系统性的比较:

  • Dense Connection,类似于DenseNet;

  • Deep-narrow vs shallow-wide结构,类似于Wide-ResNet一文的讨论;

  • Channel Attention,类似SENet;

  • More Split Head,类似ResNeXt;

  • Ghost操作,类似GhostNet。

结合上述五种结构设计,作者通过实验发现:(1) Deep-Narrow结构可以在通道层面通过减少通道维度减少冗余,可以通过提升深度提升特征丰富性,可以减少模型大小与MACs并提升性能;(2) 通道注意力可以提升ViT的性能,但不如Deep-Narrow结构高效。

基于上述结构上的探索与发现,作者为T2T-ViT设计了Deep-Narrow形式的骨干结构,也就是说:更少的通道数、更深的层数。对于定长Token ,将类Token预期Concat融合并添加正弦位置嵌入(Sinusoidal Position Embedding, SPE),类似于ViT进行最后的分类:

3.2代码实现

3.2.1. T2T ViT

class T2T_ViT(nn.Module):""":paramimg_size: 图像尺寸tokens_type: tokens类型,分为经典的vit的token和谷歌performer的tokenin_chans: 输入的通道数num_classes: 分类类别数embed_dim: 每个patch的编码长度depth: token to token处理的深度drop_path_rate: 防过拟合的比例mlp_ratio: 多层感知器中隐藏层结点数目与输入结点数目的比例qkv_bias: query、key、value向量是否偏置,因此qkv是通过Linear层生成的参数矩阵,可以设定其偏置qk_scale: query * key 的缩放比例drop_rate:"""def __init__(self, img_size=224, tokens_type='performer', in_chans=3, num_classes=1000, embed_dim=768, depth=12,num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,drop_path_rate=0., norm_layer=nn.LayerNorm, token_dim=64):super().__init__()self.num_classes = num_classesself.num_features = self.embed_dim = embed_dim  # num_features for consistency with other modelsself.tokens_to_token = T2T_module(img_size=img_size, tokens_type=tokens_type, in_chans=in_chans, embed_dim=embed_dim, token_dim=token_dim)num_patches = self.tokens_to_token.num_patchesself.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))self.pos_embed = nn.Parameter(data=get_sinusoid_encoding(n_position=num_patches + 1, d_hid=embed_dim),requires_grad=False)self.pos_drop = nn.Dropout(p=drop_rate)dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay ruleself.blocks = nn.ModuleList([Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)for i in range(depth)])self.norm = norm_layer(embed_dim)self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()trunc_normal_(self.cls_token, std=.02)self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)def no_weight_decay(self):return {'cls_token'}def get_classifier(self):return self.headdef reset_classifier(self, num_classes, global_pool=''):self.num_classes = num_classesself.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()def forward_features(self, x):B = x.shape[0]x = self.tokens_to_token(x)cls_tokens = self.cls_token.expand(B, -1, -1)x = torch.cat((cls_tokens, x), dim=1)x = x + self.pos_embedx = self.pos_drop(x)for blk in self.blocks:x = blk(x)x = self.norm(x)return x[:, 0]def forward(self, x):x = self.forward_features(x)x = self.head(x)return x
3.2.2 T2T Module:通过这样的一个模块,让模型更好的感知图像中的纹理特征信息,从而达到更好的分类效果
class T2T_module(nn.Module):"""Tokens-to-Token encoding module"""def __init__(self, img_size=224, tokens_type='performer', in_chans=3, embed_dim=768, token_dim=64):super().__init__()if tokens_type == 'transformer':print('adopt transformer encoder for tokens-to-token')# unfold将卷积核相同位置的元素放置在一起组成新的向量,这样可以类似卷积网络那样提取出领域特征信息self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))self.attention1 = Token_transformer(dim=in_chans * 7 * 7, in_dim=token_dim, num_heads=1, mlp_ratio=1.0)self.attention2 = Token_transformer(dim=token_dim * 3 * 3, in_dim=token_dim, num_heads=1, mlp_ratio=1.0)self.project = nn.Linear(token_dim * 3 * 3, embed_dim)self.num_patches = (img_size // (4 * 2 * 2)) * (img_size // (4 * 2 * 2))  # there are 3 sfot split, stride are 4,2,2 seperatelydef forward(self, x):# 第一次重构数据x = self.soft_split0(x).transpose(1, 2)# 第一次Attention机制x = self.attention1(x)B, new_HW, C = x.shapex = x.transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))# 第二次重构数据x = self.soft_split1(x).transpose(1, 2)# 第二次Attention机制x = self.attention2(x)B, new_HW, C = x.shapex = x.transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))# 第三次重构数据x = self.soft_split2(x).transpose(1, 2)# 产生新的tokenx = self.project(x)return x

        2.Embedding:(cls token adding+ position embedding)

        3.Block:(和Vit中的Norm-Attention、Norm-MLP组成的encoder一致)

                NormLayer -> Attention -> ResLayer in DropPath

                -> NormLayer -> MLP -> ResLayer in DropPath

        4.Linear 类别映射
————————————————
版权声明:本文为CSDN博主「尼卡尼卡尼」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_42418728/article/details/120550784

参考:ResNet被全面超越了,是Transformer干的:依图科技开源“可大可小”T2T-ViT

各类Transformer都得稍逊一筹,LV-ViT:探索多个用于提升ViT性能的高效Trick 

【视觉 Transformer】超详细解读 T2T-ViT 模型

Vision Transformer(2):T2T ViT源码阅读以及Drop解释

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/632581.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在uni-app中使用sku插件,实现商品详情页规格展示和交互。

商品详情 - SKU 模块 学会使用插件市场,下载并使用 SKU 组件,实现商品详情页规格展示和交互。 存货单位(SKU) SKU 概念 存货单位(Stock Keeping Unit),库存管理的最小可用单元,通…

【MyBatis-Plus】逻辑删除

对于一些比较重要的数据,我们通常采用逻辑删除。(即用一个字段表示是否删除,实际上始终在数据库没有被删除) 当逻辑删除字段为 true,业务处理的时候会自动把该数据当做一个“不存在”的数据处理。(即不处理…

计算机网络课程设计-Tracert 与 Ping 程序设计与实现

目录 前言 1 实验题目 2 实验目的 3 实验内容 3.1 步骤 3.2 关键代码 3.2.1 发送ICMP数据报 3.2.2 解析收到的数据报 4 实验结果与分析 5 代码 5.1 ping代码 5.2 Tracert代码 前言 本实验为计算机网络课程设计内容,基本上所有代码都是根据指导书给的附录…

BigeMap在Unity3d中的应用,助力数字孪生

1. 首先需要用到3个软件,unity,gis office 和 bigemap离线服务器 Unity下载地址:点击前往下载页面(Unity需要 Unity 2021.3.2f1之后的版本) Gis office下载地址:点击前往下载页面 Bigemap离线服务器 下载地址: 点击前往下载页面 Unity用于数字孪生项…

ilqr 算法说明

1 Introduction 希望能用比较简单的方式将ilqr算法进行整理和总结。 2 HJB方程 假定我们现在需要完成一个从A点到B点的任务,执行这段任务的时候,每一步都需要消耗能量,可以用下面这个图表示。 我们在执行这个A点到B点的任务的时候&#xf…

什么是区块链?

区块链 区块链 (英语:blockchain)是借由 密码学 与 共识机制 等技术建立,存储数据的 保证不可篡改和不可伪造的 分布式技术。 什么是区块 区块 就是将一批数据打包在一起,并且给打包出来的区块编号。第一个区块的编…

vue3-事件处理

事件监听 DOM 事件监听指令 v-on 简写 v-on:click"handler" 或者 click"handler"事件处理器 (handler) 的值可以是: 内联事件处理器:比如 click 方法事件处理器:一个指向组件上定义的方法的属性名或是路径。 在内联…

【干货】网络安全之URL过滤

热门IT课程【视频教程】-华为/思科/红帽/oraclehttps://xmws-it.blog.csdn.net/article/details/134398330?spm1001.2014.3001.5502 URL过滤是一种针对用户的URL请求进行上网控制的技术,通过允许或禁止用户访问某些网页资源,达到规范上网行为和降低安全…

matlab行操作快?还是列操作快?

在MATLAB中,通常情况下,对矩阵的列进行操作比对行进行操作更有效率。这是因为MATLAB中内存是按列存储的,因此按列访问数据会更加连续,从而提高访问速度。 一、实例代码 以下是一个简单的测试代码, % 测试矩阵大小 ma…

Python实战 -- PySide6 制作天气查询软件

一、环境准备 开发环境:Python 3.9.2 pycharm PySide6 申请天气情况 API :https://console.amap.com/dev/key/app designer 设计 ui 目录下 Weather.ui 转换为 Weather.py 结果显示 二、完整代码 import sysfrom PySide6 import QtWidgetsimport…

Mysql深度分页优化的一个实践

问题简述: 最近在工作中遇到了大数据量的查询场景, 日产100w左右明细, 会查询近90天内的数据, 总数据量约1亿, 业务要求支持分页查询与导出. 无论是分页或导出都涉及到深度分页查询, mysql通过limit/offset实现的深度分页查询会存在全表扫描的问题, 比如offset1000w, limit10…

7. UE5 RPG修改GAS的Attribute的值

前面几节文章介绍了如何在角色身上添加AbilitySystemComponent和AttributeSet。并且还实现了给AttributeSet添加自定义属性。接下来,实现一下如何去修改角色身上的Attribute的值。 实现拾取药瓶回血功能 首先创建一个继承于Actor的c类,actor是可以放置到…

WAF攻防相关知识点总结2-代码免杀绕过

WAF的检测除了有对于非正常的流量检测外还对于非正常的数据包特征进行检测 以宝塔为例 在宝塔的后台可以放置一句话木马的文件 宝塔不会对于这个文件进行拦截,但是一旦我们使用菜刀蚁剑等webshell工具去进行连接的时候,数据报中有流量特征就会被拦截 …

JRT和springboot比较测试

想要战胜他,必先理解他。这两天系统的学习Maven和跑springboot工程,从以前只是看着复杂,现在到亲手体验一下,亲自实践的才是更可靠的了解。 第一就是首先Maven侵入代码结构,代码一般要按约定搞src/main/java。如果是能…

C#判断输入的数字是否符合货币格式

目录 一、用正则表达式判断输入是否符合货币格式 二、用double.TryParse()判断输入是否符合货币格式 一、用正则表达式判断输入是否符合货币格式 // 判断输入是否货币合格 using System.Text.RegularExpressions; namespace IsCurrency_Format {partial class Program{stati…

虚幻UE 特效-Niagara特效实战-雨天

回顾Niagara特效基础知识:虚幻UE 特效-Niagara特效初识 其他两篇实战:虚幻UE 特效-Niagara特效实战-火焰、烛火、虚幻UE 特效-Niagara特效实战-烟雾、喷泉 本篇笔记我们再来实战雨天,雨天主要用到了特效中的事件。 文章目录 一、雨天1、创建雨…

九、K8S-label和label Selector

label和label selector 标签和标签选择器 1、label 标签: 一个label就是一个key/value对 label 特性: label可以被附加到各种资源对象上一个资源对象可以定义任意数量的label同一个label可以被添加到任意数量的资源上 2、label selector 标签选择器 L…

k8s的对外服务ingress

1、service的作用体现在两个方面 (1)集群内部:不断跟踪pod的变化,更新deployment中的pod对象,基于pod的ip地址不断变化的一种服务发现机制 (2)集群外部:类似于负载均衡器&#xff…

【计算机组成与体系结构Ⅱ】Tomasulo 算法模拟和分析(实验)

实验5:Tomasulo 算法模拟和分析 一、实验目的 1:加深对指令级并行性及开发的理解。 2:加深对 Tomasulo 算法的理解。 3:掌握 Tomasulo 算法在指令流出、执行、写结果各阶段对浮点操作指令以及 load 和 store 指令进行了什么处…

MATLAB R2023a安装教程

鼠标右击软件压缩包,选择“解压到MATLAB.R2023a”。 打开解压后的文件夹,鼠标右击“R2023a_Windows_iso”选择“装载”。 鼠标右击“setup.exe”选择“以管理员身份运行”。 点击“高级选项”选择“我有文件安装密钥”。 点击“是”,然后点击…