Vision Transformer (ViT)原理

Vision Transformer (ViT)原理

flyfish
请添加图片描述
Transformer缺乏卷积神经网络(CNNs)的归纳偏差(inductive biases),比如平移不变性和局部受限的感受野。不变性意味着即使实体entity(即对象)的外观或位置发生变化,仍然可以在图像中识别出它。在计算机视觉中,平移意味着每个图像像素都以固定量朝特定方向移动。
卷积是一个线性局部操作符(linear local operator)。看到由卷积核kernel指示的邻近值。另一方面,Transformer的设计是permutation invariant 的。坏消息是它不能处理网格结构的数据。需要序列!将把空间非序列信号转换为序列!看看怎么做。
Vision Transformer(简称ViT)的工作原理概述如下:

  1. 将图像分割成小块(patches)
  2. 将小块展平 Flatten the patches
  3. 从展平的小块中产生低维线性嵌入 Produce lower-dimensional linear embeddings from the flattened patches
  4. 添加位置嵌入 Add positional embeddings
  5. 将序列作为输入输入到标准Transformer编码器 Feed the sequence as an input to a standard transformer encoder
  6. 使用图像标签对模型进行预训练(在大型数据集上进行全面监督)Pretrain the model with image labels (fully supervised on a huge dataset)
  7. 在下游数据集上进行微调以进行图像分类 Finetune on the downstream dataset for image classification

名词解释

“inductive”的中文意思是“归纳的;归纳法的;电感的”。在给定的上下文中,“inductive biases”可以理解为“归纳偏向”,指的是通过归纳推理而产生的某种倾向或偏好。例如在机器学习中,不同的模型可能具有不同的归纳偏向,这会影响它们对数据的处理方式和学习结果。

“permutation invariant”的中文意思是“排列不变性”。
在数学和计算机科学等领域,排列不变性指的是某个对象或函数在输入的元素排列顺序发生变化时,其值或性质保持不变。例如,对于一个集合的某种特征描述,如果无论集合中元素的排列顺序如何变化,这个特征描述的值都不变,那么就说这个特征具有排列不变性。
例如,计算一组数字的总和,无论这组数字以何种顺序排列,总和始终不变,这就体现了某种程度的排列不变性。在机器学习中,某些模型可能要求具有排列不变性,以确保对输入数据的不同排列方式具有相同的输出结果。
简化下就是4步

“Image Patching and Embedding”图像分块与嵌入,即将图像分割成小块并进行特定的编码嵌入操作。
“Positional Encoding”即位置编码,用于为模型提供序列中元素的位置信息。
“Transformer Encoder”编码器
“Classification Head (MLP Head)”是分类头(多层感知机头部),通常在模型的最后用于对输入进行分类任务。

详解点

Vision Transformer (ViT) 是一种将 Transformer 架构应用于计算机视觉任务的模型。它借鉴了自然语言处理 (NLP) 中 Transformer 的成功经验,旨在用 Transformer 替代传统卷积神经网络 (CNN),用于图像分类等任务。


1. 从图像到序列:图像分块与嵌入 (Image Patching and Embedding)

由于 Transformer 的输入需要是序列,而图像是二维网格数据,因此需要先将图像转换为序列形式。

  1. 图像分块

    • 将输入图像 X ∈ R H × W × C X \in \mathbb{R}^{H \times W \times C} XRH×W×C 划分为固定大小的非重叠小块(patches)。
      假设每个小块的大小为 P × P P \times P P×P,图像就被划分成 N = H × W P 2 N = \frac{H \times W}{P^2} N=P2H×W 个小块,每个小块展平为一个向量。
      每个向量的大小为 P × P × C P \times P \times C P×P×C,其中 C C C 是图像的通道数(如 RGB 图像中 C = 3 C=3 C=3)。
  2. 线性嵌入

    • 使用一个可学习的线性投影矩阵 E ∈ R ( P 2 ⋅ C ) × D E \in \mathbb{R}^{(P^2 \cdot C) \times D} ER(P2C)×D 将每个展平的小块嵌入到 D D D 维空间中。
    • 结果是一个大小为 N × D N \times D N×D 的嵌入序列,类似于 NLP 中的词嵌入(word embedding)。

2. 位置编码 (Positional Encoding)

由于 Transformer 对输入序列是排列不变的(Permutation Invariant),它无法直接利用图像中的空间关系。因此,需要为嵌入序列添加位置信息。

  • 使用可学习的位置嵌入 P ∈ R N × D P \in \mathbb{R}^{N \times D} PRN×D 或固定的位置编码,将每个小块的嵌入与其位置对应起来。
  • 最终输入变为 Z 0 = [ x 1 + p 1 ; x 2 + p 2 ; … ; x N + p N ] Z_0 = [x_1 + p_1; x_2 + p_2; \dots; x_N + p_N] Z0=[x1+p1;x2+p2;;xN+pN],其中 x i x_i xi 是第 i i i 个小块的嵌入, p i p_i pi 是其位置嵌入。

3. Transformer 编码器 (Transformer Encoder)

ViT 的核心部分是标准 Transformer 编码器,它包含以下模块:

  1. 多头自注意力机制 (Multi-Head Self-Attention, MHSA)

    • 通过自注意力机制捕获小块之间的全局关系,无需限制在局部感受野内操作(如 CNN)。
    • 每个小块与其他小块计算注意力权重,以关注哪些部分对当前任务最重要。
  2. 前馈神经网络 (Feed-Forward Neural Network, FFN)

    • 对每个小块的表示单独进行非线性变换,提升模型的表达能力。
  3. 残差连接与层归一化 (Residual Connections and Layer Normalization)

    • 使用残差连接缓解梯度消失问题,同时通过层归一化稳定训练。

多个 Transformer 编码器堆叠后,输出序列保持大小不变,为 N × D N \times D N×D


4. 分类头 (Classification Head)
  1. 类别标记 (Class Token)

    • 在输入序列中引入一个可学习的 [CLS] 标记,其嵌入表示整个图像的信息。
    • 类别标记通过编码器与其他小块交互,最终用于分类任务。
  2. 多层感知机 (MLP Head)

    • 最终 [CLS] 标记的输出经过一个多层感知机(通常是全连接层)进行分类,得到最终的预测结果。

5. 预训练与微调 (Pretraining and Finetuning)
  1. 预训练

    • ViT 通常在大规模数据集(如 ImageNet-21k 或 JFT-300M)上进行监督学习预训练。
    • 通过大量的标注数据,模型学习到丰富的视觉特征。
  2. 微调

    • 在特定任务的小型数据集上进行微调,例如 CIFAR-10、ImageNet 等。

ViT 的优点与局限性

优点
  1. 全局感受野

    • 自注意力机制可以直接建模全局依赖关系,而 CNN 的局部感受野需要通过多层叠加才能实现。
  2. 更少的归纳偏差

    • ViT 依赖数据学习到特征,而非依赖 CNN 的平移不变性等归纳偏差,更适合大规模数据。
  3. 灵活性

    • 不受限于卷积核的大小,可适配不同的输入尺寸和任务。
局限性
  1. 数据需求大

    • 由于缺乏 CNN 的归纳偏差,ViT 在小数据集上表现较差,需要大量预训练数据。
  2. 计算成本高

    • 多头自注意力的计算复杂度为 O ( N 2 ⋅ D ) O(N^2 \cdot D) O(N2D),对高分辨率图像或较多的分块数量计算开销较大。
import math
from collections import OrderedDict
from typing import Callable, List, Optionalimport torch
import torch.nn as nn
from torch import Tensor# 假设 Encoder 和 Conv2dNormActivation 已经被定义,并且 ConvStemConfig 是一个数据类或命名元组。
# 这些通常来自其他文件或库,例如 torchvision 或自定义实现。class VisionTransformer(nn.Module):"""Vision Transformer 如 https://arxiv.org/abs/2010.11929 所述."""def __init__(self,image_size: int,  # 输入图像的大小(假设为正方形)patch_size: int,  # 每个patch的大小(也假设为正方形)num_layers: int,  # 编码器中的层数num_heads: int,   # 注意力机制中的头数hidden_dim: int,  # 隐藏层维度mlp_dim: int,     # MLP 层的维度dropout: float = 0.0,  # Dropout 概率attention_dropout: float = 0.0,  # 注意力机制中的dropout概率num_classes: int = 1000,  # 分类的数量representation_size: Optional[int] = None,  # 表示层的大小,可选norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),  # 归一化层conv_stem_configs: Optional[List[ConvStemConfig]] = None,  # 卷积干配置,可选):super().__init__()_log_api_usage_once(self)  # 记录API使用情况torch._assert(image_size % patch_size == 0, "输入形状不能被patch大小整除!")  # 确保图像尺寸能被patch大小整除self.image_size = image_sizeself.patch_size = patch_sizeself.hidden_dim = hidden_dimself.mlp_dim = mlp_dimself.attention_dropout = attention_dropoutself.dropout = dropoutself.num_classes = num_classesself.representation_size = representation_sizeself.norm_layer = norm_layerif conv_stem_configs is not None:# 根据论文 https://arxiv.org/abs/2106.14881 使用卷积干seq_proj = nn.Sequential()  # 创建一个序列容器来保存卷积干的层prev_channels = 3  # 初始通道数为3(RGB图像)for i, conv_stem_layer_config in enumerate(conv_stem_configs):# 对于每个卷积干配置,添加一个卷积、归一化和激活层seq_proj.add_module(f"conv_bn_relu_{i}",Conv2dNormActivation(in_channels=prev_channels,out_channels=conv_stem_layer_config.out_channels,kernel_size=conv_stem_layer_config.kernel_size,stride=conv_stem_layer_config.stride,norm_layer=conv_stem_layer_config.norm_layer,activation_layer=conv_stem_layer_config.activation_layer,),)prev_channels = conv_stem_layer_config.out_channels# 添加最后一个1x1卷积层,将通道数转换为隐藏维度seq_proj.add_module("conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1))self.conv_proj: nn.Module = seq_proj  # 将卷积干设置为self.conv_projelse:# 如果没有提供卷积干配置,则使用简单的卷积投影self.conv_proj = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size)# 计算序列长度(包括类别标记)seq_length = (image_size // patch_size) ** 2# 添加一个类别标记self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))seq_length += 1  # 类别标记增加序列长度# 初始化编码器self.encoder = Encoder(seq_length,num_layers,num_heads,hidden_dim,mlp_dim,dropout,attention_dropout,norm_layer,)self.seq_length = seq_length# 定义分类头heads_layers: OrderedDict[str, nn.Module] = OrderedDict()if representation_size is None:heads_layers["head"] = nn.Linear(hidden_dim, num_classes)else:heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)heads_layers["act"] = nn.Tanh()heads_layers["head"] = nn.Linear(representation_size, num_classes)self.heads = nn.Sequential(heads_layers)# 初始化权重if isinstance(self.conv_proj, nn.Conv2d):# 初始化patchify stem的权重fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))if self.conv_proj.bias is not None:nn.init.zeros_(self.conv_proj.bias)elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):# 初始化卷积干中最后的1x1卷积层nn.init.normal_(self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels))if self.conv_proj.conv_last.bias is not None:nn.init.zeros_(self.conv_proj.conv_last.bias)if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):fan_in = self.heads.pre_logits.in_featuresnn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))nn.init.zeros_(self.heads.pre_logits.bias)if isinstance(self.heads.head, nn.Linear):nn.init.zeros_(self.heads.head.weight)nn.init.zeros_(self.heads.head.bias)def _process_input(self, x: Tensor) -> Tensor:# 获取输入张量的形状信息n, c, h, w = x.shapep = self.patch_sizetorch._assert(h == self.image_size, f"错误的图像高度!预期 {self.image_size} 但得到 {h}!")torch._assert(w == self.image_size, f"错误的图像宽度!预期 {self.image_size} 但得到 {w}!")n_h = h // pn_w = w // p# 使用卷积投影将图像划分为patch,并调整形状# (n, c, h, w) -> (n, hidden_dim, n_h, n_w)x = self.conv_proj(x)# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))x = x.reshape(n, self.hidden_dim, n_h * n_w)# 转换到 (N, S, E) 的格式,其中 N 是批次大小,S 是源序列长度,E 是嵌入维度# (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)x = x.permute(0, 2, 1)return xdef forward(self, x: Tensor):# 对输入张量进行重塑和转置x = self._process_input(x)n = x.shape[0]# 将类别标记扩展到整个批次,并与patch序列连接batch_class_token = self.class_token.expand(n, -1, -1)x = torch.cat([batch_class_token, x], dim=1)# 通过编码器处理x = self.encoder(x)# 只取每个样本的第一个token(即类别标记)作为表示x = x[:, 0]# 通过分类头进行最终预测x = self.heads(x)return x

参考
https://ai.googleblog.com/2020/12/transformers-for-image-recognition-at.html
论文 Worth 16x16 Words: Transformers for Image Recognition at Scale
https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/66972.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何让QPS提升20倍

一、什么是QPS QPS,全称Queries Per Second,即每秒查询率,是用于衡量信息检索系统(例如搜索引擎或数据库)或请求-响应系统(如Web服务器)每秒能够处理的请求数或查询次数的一个性能指标。以下是…

Domain Adaptation(李宏毅)机器学习 2023 Spring HW11 (Boss Baseline)

1. 领域适配简介 领域适配是一种迁移学习方法,适用于源领域和目标领域数据分布不同但学习任务相同的情况。具体而言,我们在源领域(通常有大量标注数据)训练一个模型,并希望将其应用于目标领域(通常只有少量或没有标注数据)。然而,由于这两个领域的数据分布不同,模型在…

SQL从入门到实战-1

目录 学前须知 sqlzoo数据介绍 world nobel covid ge game、goal、eteam teacher、dept movie、casting、actor 基础语句 select&from 基础查询select单列&多列&所有列&别名应用 例题一 例题二 例题三 select使用distinct去重 例题四 例题五…

Python在Excel工作表中创建数据透视表

在数据处理和分析工作中,Excel作为一个广泛使用的工具,提供了强大的功能来管理和解析数据。当面对大量复杂的数据集时,为了更高效地总结、分析和展示数据,创建数据透视表成为一种不可或缺的方法。通过使用Python这样的编程语言与E…

springboot整合h2

在 Spring Boot 中整合 H2 数据库非常简单。H2 是一个轻量级的嵌入式数据库&#xff0c;非常适合开发和测试环境。以下是整合 H2 数据库的步骤&#xff1a; 1. 添加依赖 首先&#xff0c;在你的 pom.xml 文件中添加 H2 数据库的依赖&#xff1a; <dependency><grou…

Web前端界面开发

前沿&#xff1a;介绍自适应和响应式布局 自适应布局&#xff1a;-----针对页面1个像素的变换而变化 就是我们上一个练习的效果 我们的页面效果&#xff0c;随着我们的屏幕大小而发生适配的效果&#xff08;类似等比例&#xff09; 如&#xff1a;rem适配 和 vw/vh适配 …

【01】AE特效开发制作特技-Adobe After Effects-AE特效制作快速入门-制作飞机,子弹,爆炸特效以及导出png序列图-优雅草央千澈

【01】AE特效开发制作特技-Adobe After Effects-AE特效制作快速入门-制作飞机&#xff0c;子弹&#xff0c;爆炸特效以及导出png序列图-优雅草央千澈 开发背景 优雅草央千澈所有的合集&#xff0c;系列文章可能是不太适合完全初学者的&#xff0c;因为课程不会非常细致的系统…

java项目之在线文档管理系统源码(springboot+mysql+vue+文档)

大家好我是风歌&#xff0c;曾担任某大厂java架构师&#xff0c;如今专注java毕设领域。今天要和大家聊的是一款基于springboot的在线文档管理系统。项目源码以及部署相关请联系风歌&#xff0c;文末附上联系信息 。 项目简介&#xff1a; 在线文档管理系统的主要使用者分为管…

可靠的人形探测,未完待续(III)

一不小心&#xff0c;此去经年啊。问大家新年快乐&#xff01; 那&#xff0c;最近在研究毫米波雷达模块嘛&#xff0c;期望用在后续的产品中&#xff0c;正好看到瑞萨的活动送板子&#xff0c;手一下没忍住。 拿了板子就得干活咯&#xff0c;我一路火花带闪电&#xff0c;开整…

【灵码助力安全3】——利用通义灵码辅助智能合约漏洞检测的尝试

前言 随着区块链技术的快速发展&#xff0c;智能合约作为去中心化应用&#xff08;DApps&#xff09;的核心组件&#xff0c;其重要性日益凸显。然而&#xff0c;智能合约的安全问题一直是制约区块链技术广泛应用的关键因素之一。由于智能合约代码一旦部署就难以更改&#xf…

腾讯云下架印度云服务器节点,印度云服务器租用何去何从

近日&#xff0c;腾讯云下架印度云服务器节点的消息引起了业界的广泛关注。这一变动让许多依赖印度云服务器的用户开始担忧&#xff0c;印度云服务器租用的未来究竟在何方&#xff1f; 从印度市场本身来看&#xff0c;其云服务市场的潜力不容小觑。据 IDC 报告&#xff0c;到 2…

【RTSP】使用webrtc播放rtsp视频流

一、简介 rtsp流一般是监控、摄像机的实时视频流,现在的主流浏览器是不支持播放rtsp流文件的,所以需要借助其他方案来播放实时视频,下面介绍下我采用的webrtc方案,实测可行。 二、webrtc-streamer是什么? webrtc-streamer是一个使用简单机制通过 WebRTC 流式传输视频捕获…

多并发发短信处理(头条项目-07)

1 pipeline操作 Redis数据库 Redis 的 C/S 架构&#xff1a; 基于客户端-服务端模型以及请求/响应协议的 TCP服务。客户端向服务端发送⼀个查询请求&#xff0c;并监听Socket返回。通常是以 阻塞模式&#xff0c;等待服务端响应。服务端处理命令&#xff0c;并将结果返回给客…

【网络协议】动态路由协议

前言 本文将概述动态路由协议&#xff0c;定义其概念&#xff0c;并了解其与静态路由的区别。同时将讨论动态路由协议相较于静态路由的优势&#xff0c;学习动态路由协议的不同类别以及无类别&#xff08;classless&#xff09;和有类别&#xff08;classful&#xff09;的特性…

c#集成npoi根据excel模板导出excel

NuGet中安装npoi 创建excel模板&#xff0c;替换其中的内容生成新的excel文件。 例子中主要写了这四种情况&#xff1a; 1、替换单个单元格内容&#xff1b; 2、替换横向多个单元格&#xff1b; 3、替换表格&#xff1b; 4、单元格中插入图片&#xff1b; using System.IO; …

人工智能知识分享第十天-机器学习_聚类算法

聚类算法 1 聚类算法简介 1.1 聚类算法介绍 一种典型的无监督学习算法&#xff0c;主要用于将相似的样本自动归到一个类别中。 目的是将数据集中的对象分成多个簇&#xff08;Cluster&#xff09;&#xff0c;使得同一簇内的对象相似度较高&#xff0c;而不同簇之间的对象相…

B树及其Java实现详解

文章目录 B树及其Java实现详解一、引言二、B树的结构与性质1、节点结构2、性质 三、B树的操作1、插入操作1.1、插入过程 2、删除操作2.1、删除过程 3、搜索操作 四、B树的Java实现1、节点类实现2、B树类实现 五、使用示例六、总结 B树及其Java实现详解 一、引言 B树是一种多路…

本地缓存:Guava Cache

这里写目录标题 一、范例二、应用场景三、加载1、CacheLoader2、Callable3、显式插入 四、过期策略1、基于容量的过期策略2、基于时间的过期策略3、基于引用的过期策略 五、显示清除六、移除监听器六、清理什么时候发生七、刷新八、支持更新锁定能力 一、范例 LoadingCache<…

【高录用 | 快见刊 | 快检索】第十届社会科学与经济发展国际学术会议 (ICSSED 2025)

第十届社会科学与经济发展国际学术会议(ICSSED 2025)定于2025年2月28日-3月2日在中国上海隆重举行。会议主要围绕社会科学与经济发展等研究领域展开讨论。会议旨在为从事社会科学与经济发展研究的专家学者提供一个共享科研成果和前沿技术&#xff0c;了解学术发展趋势&#xff…

[ComfyUI]接入Google的Whisk,巨物融合玩法介绍

一、介紹​ 前段时间&#xff0c;谷歌推出了一个图像生成工具whisk&#xff0c;有一个很好玩的图片融合玩法&#xff0c;分别提供三张图片,就可以任何组合来生成图片。​ ​ 最近我发现有人开发了对应的ComfyUI插件&#xff0c;对whisk做了支持&#xff0c;就来体验了下&#…