ViT: transformer在图像领域的应用

文章目录

  • 1. 概要
  • 2. 方法
  • 3. 实验
    • 3.1 Compare with SOTA
    • 3.2 PRE-TRAINING DATA REQUIREMENTS
    • 3.3 SCALING STUDY
    • 3.4 自监督学习
  • 4. 总结
  • 参考

论文: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
代码:https://github.com/google-research/vision_transformer
代码2:https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py

我们在Transformer详解(1)—原理部分详细介绍了transformer在NLP领域应用的原理,transformer架构自发布以来已经在自然语言处理任务上广泛应用,今天我们将介绍如何将transformer架构应用在图像领域。

1. 概要

基于self-attention的网络架构在NLP领域中取得了很大的成功,但是在CV领域卷积网络架构仍然占据主导地位。受到transformer在NLP中应用成功的启发,也有很多工作尝试将self-attention与CNN网络结合,甚至有些工作直接替换CNN网络,理论上这些模型是高效的,由于这些特殊的注意力机制未与硬件加速器有效适配,因此在大规模的图像检测中,经典的ResNet网络架构仍然是SOTA

受到Transformer网络在NLP领域中成功适配的启发,作者提出对transformer尽可能少的修改,直接在图片上应用标准的transformer。为了实现这个目标,首先需要将图片分割成多个patch,并将这些patch转换成embedding作为transformer的输入。图片的patch就相当于NLP中的token

最后作者得到结论:在数据量不足的情况下进行训练时,ViT不能很好地泛化,效果不如CNN,不过在训练大规模数据时,vit的效果会反超CNN

2. 方法

在模型设计方面,version transformer尽量与原始transformer结构保持一致,因为NLP中的transformer具有高效的实现方式,这样可以开箱即用。模型的整体结构如下所示:
在这里插入图片描述
标准的 transformer 输入是一维向量序列,为了处理二维图像,将输入图片 x ∈ R H × W × C \mathbf{x}\in\mathbb{R}^{H\times W\times C} xRH×W×C 分割成一系列的patch,并将这些patch平整成一维向量,最终得到 x p ∈ R N × ( P 2 ⋅ C ) \mathbf{x}_p\in\mathbb{R}^{N\times(P^2\cdot C)} xpRN×(P2C),其中 ( H , W ) (H,W) (H,W)是原始图片分辨率, C C C 是图片的通道数, ( P , P ) (P,P) (P,P)是每个patch的分辨率, N = H W P 2 N=\frac{HW}{P^2} N=P2HW 是patch的个数,也可以看作是输入序列的长度。由于transformer每一层的输入向量维度都是固定的 D D D,因此需要通过一个可训练的线性层将 flatten patch 的维度从 P 2 C P^2C P2C转换成 D D D,这个线性层的输出称为patch的embedding.

和BERT的 [class] token 类似,在path embedding序列的首位增加了一个可学习向量 z 0 0 = x c l a s s z_0^0=x_{class} z00=xclass,该向量在transformer encoder的输出部分看做是图片的表征,在预训练和微调阶段,该表征后都会接一个分类层。

为了保持位置信息,位置embedding会加到patch embedding上,这里作者使用了一个一维可学习的位置向量,因为通过实验发现使用二维位置向量并没有获得很大的性能提升,通过以上流程处理后的embedding就是transformer的输入embedding。从输入图片到transformer encoder输出可由以下式子表示:

z 0 = [ x class ; x p 1 E ; x p 2 E ; ⋯ ; x p N E ] + E p o s ; E ∈ R ( P 2 ⋅ C ) × D , E p o s ∈ R ( N + 1 ) × D z ′ ℓ = M S A ( L N ( z ℓ − 1 ) ) + z ℓ − 1 ; ℓ = 1 … L z ℓ = M L P ( L N ( z ′ ℓ ) ) + z ′ ℓ ; ℓ = 1 … L y = L N ( z L 0 ) \begin{align} z_0 =&[\mathbf{x}_\text{class};\mathbf{x}_p^1\mathbf{E};\mathbf{x}_p^2\mathbf{E};\cdots;\mathbf{x}_p^N\mathbf{E}]+\mathbf{E}_{pos}; \ \ \ \mathbf{E}\in\mathbb{R}^{(P^{2}\cdot C)\times D}, \mathbf{E}_{pos}\in\mathbb{R}^{(N+1)\times D}\\ \mathbf{z}^{\prime}{}_{\ell} =& \mathrm{MSA(LN(z_{\ell-1}))+z_{\ell-1}};\ \ \ell=1\ldots L \\ \mathbf{z}_{\ell} = &\mathrm{MLP}(\mathrm{LN}(\mathbf{z^{\prime}}_\ell))+\mathbf{z^{\prime}}_\ell; \ \ \ \ell=1\ldots L \\ y =& \mathrm{LN}(\mathbf{z}_{L}^{0}) \end{align} z0=z=z=y=[xclass;xp1E;xp2E;;xpNE]+Epos;   ER(P2C)×D,EposR(N+1)×DMSA(LN(z1))+z1;  =1LMLP(LN(z))+z;   =1LLN(zL0)

其中 E E E 是patch维度转换矩阵, M S A MSA MSA是多头注意力层(multi-head self attention), L N LN LN是layer normalization 层, M L P MLP MLP是transformer中前馈网络层

另外,也可以使用CNN网络的特征图作为输入序列,在这种混合模型中,patch embeding 投影层将被用于改变CNN特征图的形状。

在微调阶段,将移除预训练的prediction layer,并新增一个零初始化的预测层,一般来说,在更高分辨率图像上微调是非常有益的。在喂入更高分辨率图像时,保持patch的尺寸不变,这样会造成输入序列长度增加,虽然ViT模型可以处理任意长的输入序列(直到内存不够),但是预训练的位置编码将无效,因此作者根据当前位置在原始图片中的位置,对预训练的位置编码采用2D插值的方法获取最新的位置编码

3. 实验

下文中将用一些简写来代表模型的尺寸和输入patch的尺寸,如ViT-L/16 代表模型为ViT-Large,输入patch的尺寸为 16 × 16 16 \times 16 16×16,下表展示了不同尺寸模型的配置及参数量
在这里插入图片描述
这里需要注意,由于输入序列长度与patch的尺寸成反比,所以,patch 尺寸越小,反而计算量越大

3.1 Compare with SOTA

在这里插入图片描述
TPU v3-core-days:代表计算量,All models were trained on TPUv3 hardware, and we
report the number of TPUv3-core-days taken to pre-train each of them, that is, the number of TPU
v3 cores (2 per chip) used for training multiplied by the training time in days

在这里插入图片描述
不同模型简介:

  • Big Transfer (BiT), which performs supervised transfer learning with large ResNets
  • VIVI – a ResNet co-trained on ImageNet and Youtube
  • S4L – supervised plus semi-supervised learning on ImageNet

3.2 PRE-TRAINING DATA REQUIREMENTS

作者经过实验得到如下结论:

  • 在小数据集上预训练,ViT-Large比ViT-Base要差,在大数据集上训练对ViT-Large比较有益
  • 在小数据集上预训练,ViT的效果比CNN还要差,在大数据集上预训练ViT的效果超过CNN
  • CNN网络的归纳有偏性在小数据集上是有用的,但是在大数据集上,直接从数据中学习相关的模式更有效
    在这里插入图片描述

3.3 SCALING STUDY

如下图所示,作者得到如下结论:

  • ViT在效果和计算量平衡之间相比ResNet占绝对优势,ResNet需要使用约3倍的算力来获得与ViT相似的结果
  • 混合模型在小计算量上相比ViT具有一定的优势,但是这种优势在大模型(大计算量)上逐渐消失
  • ViT在当前实验中貌似并没有饱和,这激励着未来的研究
    在这里插入图片描述

3.4 自监督学习

作者模仿BERT通过mask patch prediction任务进行自监督预训练,ViT-B/16在ImageNet上获得了79.9%的准确率,相比从随机初始化开始训练提升了2%,但是相比于监督学习仍然落后4%。

4. 总结

作者将图片看作是patch序列,并使用标准的Transformer对patch序列进行处理,最终在大数据集上预训练取得了很不错的效果,在图片分类任务上超过了很多SOTA模型。但也还存在一些挑战等待后期处理:

  • 将ViT应用在其他计算机视觉任务中,如目标检测、语义分割等
  • 还需进一步探索自监督预训练方法
  • 进一步扩大ViT模型的规模,可能会取得更好的效果

参考

如何理解Inductive bias?
Translation Equivariance
CNN中的Translation Equivariance【理解】
2D插值(2D interpolation)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/688075.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python爬虫详解(一看就懂)

爬虫 爬虫是什么 爬虫简单的来说就是用程序获取网络上数据这个过程的一种名称。 爬虫的原理 如果要获取网络上数据,我们要给爬虫一个网址(程序中通常叫URL),爬虫发送一个HTTP请求给目标网页的服务器,服务器返回数据…

机器学习---规则学习(序贯覆盖、单条规则学习、剪枝优化)

1. 序贯覆盖 回归: 分类: 聚类: 逻辑规则: 读作:若(文字1且文字2且...),则目标概念成立 规则集:充分性与必要性;冲突消解:顺序规则、缺省规则…

nacos 2.3.1-SNAPSHOT 源码springboot方式启动(详细)附改造工程地址

文章时间是2024-2-18日,nacos默认develop分支,最新版是2.3.1-SNAPSHOT版本。 我们这里就以nacos最新版进行改造成springboot启动方式。 1. Clone 代码 nacos github地址:https://github.com/alibaba/nacos.git 根据上面git地址把源码克隆到…

[ai笔记10] 关于sora火爆的反思

欢迎来到文思源想的ai空间,这是技术老兵重学ai以及成长思考的第10篇分享! 最近sora还持续在技术圈、博客、抖音发酵,许多人都在纷纷发表对它的看法,这是一个既让人惊喜也感到焦虑的事件。openai从2023年开始,每隔几个…

c++中浮点类型比较的理解

为什么浮点类型存在误差 带有小数的表示: 25.3 整数通过除2取余法表示: 25/2…1 12/2…0 6/2…0 3/2…1 1/2…1 倒过来:25(十进制) 11001(二进制) 小数部分通过乘2取整法: 0.3 * 2 …

wps快速生成目录及页码设置(自备)

目录 第一步目录整理 标题格式设置 插入页码(罗马和数字) 目录生成(从罗马尾页开始) ​编辑目录格式修改 第一步目录整理 1罗马标题 2罗马标题1一级标题 1.1 二级标题 1.2二级标题2一级标题 2.1 二级标题 2.2二级标题3一级标…

VMWare ubuntu共享宿主机window11文件夹

宿主机window的设置 找到需要共享的文件夹,比如我需要share文件夹共享到虚拟机中 点击“共享”文件夹属性,如果找不到“共享”选项卡,需要在下面的“选项”中 注意勾选“使用共享向导(推荐)”,如果已经勾…

notepad++打开文本文件乱码的解决办法

目录 第一步 在编码菜单栏下选择GB2312中文。如果已经选了忽略这一步 第二步 点击编码,红框圈出来的一个个试。我切换到UTF-8编码就正常了。 乱码如图。下面分享我的解决办法 第一步 在编码菜单栏下选择GB2312中文。如果已经选了忽略这一步 第二步 点击编码&#…

生成式 AI - Diffusion 模型 (DDPM)原理解析(1)

来自 论文《 Denoising Diffusion Probabilistic Model》(DDPM) 论文链接:https://arxiv.org/abs/2006.11239 Hung-yi Lee 课件整理 文章目录 一、整体运作二、Denoise module三、Noise Predictor四、Text-to-Image 简单地介绍diffusion mode…

安装部署k8s集群

系统: CentOS Linux release 7.9.2009 (Core) 准备3台主机 192.168.44.148k8s-master92.168.44.154k8s-worker01192.168.44.155k8s-worker02 3台主机准备工作 关闭防火墙和selinux systemctl disable firewalld --nowsetenforce 0sed -i s/SELINUXenforcing/SELI…

dm_control 翻译: Software and Tasks for Continuous Control

dm_control: Software and Tasks for Continuous Control dm_control:连续控制软件及任务集 文章目录 dm_control: Software and Tasks for Continuous Controldm_control:连续控制软件及任务集Abstract1 Introduction1 引言1.1 Software for research1…

Java - SPI机制

本文参考:SPI机制 SPI(Service Provider Interface),是JDK内置的一种服务提供发现机制,可以用来启动框架扩展和替换组件,主要是被框架的开发人员使用,比如 java.sql.Driver接口,其他…

TensorRT转换onnx的Transpose算子遇到的奇怪问题

近来把一个模型导出为onnx并用onnx simplifier化简后转换为TensorRT engine遇到非常奇怪的问题,在我们的网络中有多个检测头时,转换出来的engine的推理效果是正常的,当网络中只有一个检测头时,转换出来的engine的推理效果奇差&…

动态代理IP如何选择?

IP地址是由IP协议所提供的一种统一的地址格式,通过为每一个网络和每一台主机分配逻辑地址的方式来屏蔽物理地址的差异。根据IP地址的分配方式,IP可以分为动态IP与静态IP两种。对于大部分用户而言,日常使用的IP地址均为动态IP地址。从代理IP的…

LeetCode 0429.N 叉树的层序遍历:广度优先搜索(BFS)

【LetMeFly】429.N 叉树的层序遍历:广度优先搜索(BFS) 力扣题目链接:https://leetcode.cn/problems/n-ary-tree-level-order-traversal/ 给定一个 N 叉树,返回其节点值的层序遍历。(即从左到右,逐层遍历)…

aiofiles:解锁异步文件操作的神器

aiofiles:解锁异步文件操作的神器 在Python的异步编程领域,文件操作一直是一个具有挑战性的任务。传统的文件操作函数在异步环境下无法发挥其最大的潜力,而aiofiles库应运而生。aiofiles是一个针对异步I/O操作的Python库,它简化了…

C#使用迭代器实现文字的动态效果

目录 一、涉及到的知识点 1.GDI 2.Thread类 3.使用IEnumerable()迭代器 二、实例 1.源码 2.生成效果: 一、涉及到的知识点 1.GDI GDI主要用于在窗体上绘制各种图形图像。 GDI的核心是Graphics类,该类表示GDI绘图表面,它提供将对象绘制…

不等式的证明之二

不等式的证明之二 证明下述不等式证法一证法二证法二的补充 证明下述不等式 设 a , b , c a,b,c a,b,c 是正实数,请证明下述不等式: 11 a 5 a 6 b 11 b 5 b 6 c 11 c 5 c 6 a ≤ 3 \begin{align} \sqrt{\frac{11a}{5a6b}}\sqrt{\frac{11b}{5b6c}…

leetcode hot100不同路径

本题可以采用动态规划来解决。还是按照五部曲来做 确定dp数组:dp[i][j]表示走到(i,j)有多少种路径 确定递推公式:我们这里,只有两个移动方向,比如说我移动到(i,j&#x…

STM32 寄存器操作 systick 滴答定时器 与中断

一、什么是 SysTick SysTick—系统定时器是属于CM3内核中的一个外设,内嵌在NVIC中。系统定时器是一个24bit的向下递减的计数器, 计数器每计数一次的时间为1/SYSCLK,一般我们设置系统时钟SYSCLK等于72M。当重装载数值寄存器的值递减到0的时候…