Vision Transformer (ViT)原理
flyfish
Transformer缺乏卷积神经网络(CNNs)的归纳偏差(inductive biases),比如平移不变性和局部受限的感受野。不变性意味着即使实体entity(即对象)的外观或位置发生变化,仍然可以在图像中识别出它。在计算机视觉中,平移意味着每个图像像素都以固定量朝特定方向移动。
卷积是一个线性局部操作符(linear local operator)。看到由卷积核kernel指示的邻近值。另一方面,Transformer的设计是permutation invariant 的。坏消息是它不能处理网格结构的数据。需要序列!将把空间非序列信号转换为序列!看看怎么做。
Vision Transformer(简称ViT)的工作原理概述如下:
- 将图像分割成小块(patches)
- 将小块展平 Flatten the patches
- 从展平的小块中产生低维线性嵌入 Produce lower-dimensional linear embeddings from the flattened patches
- 添加位置嵌入 Add positional embeddings
- 将序列作为输入输入到标准Transformer编码器 Feed the sequence as an input to a standard transformer encoder
- 使用图像标签对模型进行预训练(在大型数据集上进行全面监督)Pretrain the model with image labels (fully supervised on a huge dataset)
- 在下游数据集上进行微调以进行图像分类 Finetune on the downstream dataset for image classification
名词解释
“inductive”的中文意思是“归纳的;归纳法的;电感的”。在给定的上下文中,“inductive biases”可以理解为“归纳偏向”,指的是通过归纳推理而产生的某种倾向或偏好。例如在机器学习中,不同的模型可能具有不同的归纳偏向,这会影响它们对数据的处理方式和学习结果。
“permutation invariant”的中文意思是“排列不变性”。
在数学和计算机科学等领域,排列不变性指的是某个对象或函数在输入的元素排列顺序发生变化时,其值或性质保持不变。例如,对于一个集合的某种特征描述,如果无论集合中元素的排列顺序如何变化,这个特征描述的值都不变,那么就说这个特征具有排列不变性。
例如,计算一组数字的总和,无论这组数字以何种顺序排列,总和始终不变,这就体现了某种程度的排列不变性。在机器学习中,某些模型可能要求具有排列不变性,以确保对输入数据的不同排列方式具有相同的输出结果。
简化下就是4步
“Image Patching and Embedding”图像分块与嵌入,即将图像分割成小块并进行特定的编码嵌入操作。
“Positional Encoding”即位置编码,用于为模型提供序列中元素的位置信息。
“Transformer Encoder”编码器
“Classification Head (MLP Head)”是分类头(多层感知机头部),通常在模型的最后用于对输入进行分类任务。
详解点
Vision Transformer (ViT) 是一种将 Transformer 架构应用于计算机视觉任务的模型。它借鉴了自然语言处理 (NLP) 中 Transformer 的成功经验,旨在用 Transformer 替代传统卷积神经网络 (CNN),用于图像分类等任务。
1. 从图像到序列:图像分块与嵌入 (Image Patching and Embedding)
由于 Transformer 的输入需要是序列,而图像是二维网格数据,因此需要先将图像转换为序列形式。
-
图像分块
- 将输入图像 X ∈ R H × W × C X \in \mathbb{R}^{H \times W \times C} X∈RH×W×C 划分为固定大小的非重叠小块(patches)。
假设每个小块的大小为 P × P P \times P P×P,图像就被划分成 N = H × W P 2 N = \frac{H \times W}{P^2} N=P2H×W 个小块,每个小块展平为一个向量。
每个向量的大小为 P × P × C P \times P \times C P×P×C,其中 C C C 是图像的通道数(如 RGB 图像中 C = 3 C=3 C=3)。
- 将输入图像 X ∈ R H × W × C X \in \mathbb{R}^{H \times W \times C} X∈RH×W×C 划分为固定大小的非重叠小块(patches)。
-
线性嵌入
- 使用一个可学习的线性投影矩阵 E ∈ R ( P 2 ⋅ C ) × D E \in \mathbb{R}^{(P^2 \cdot C) \times D} E∈R(P2⋅C)×D 将每个展平的小块嵌入到 D D D 维空间中。
- 结果是一个大小为 N × D N \times D N×D 的嵌入序列,类似于 NLP 中的词嵌入(word embedding)。
2. 位置编码 (Positional Encoding)
由于 Transformer 对输入序列是排列不变的(Permutation Invariant),它无法直接利用图像中的空间关系。因此,需要为嵌入序列添加位置信息。
- 使用可学习的位置嵌入 P ∈ R N × D P \in \mathbb{R}^{N \times D} P∈RN×D 或固定的位置编码,将每个小块的嵌入与其位置对应起来。
- 最终输入变为 Z 0 = [ x 1 + p 1 ; x 2 + p 2 ; … ; x N + p N ] Z_0 = [x_1 + p_1; x_2 + p_2; \dots; x_N + p_N] Z0=[x1+p1;x2+p2;…;xN+pN],其中 x i x_i xi 是第 i i i 个小块的嵌入, p i p_i pi 是其位置嵌入。
3. Transformer 编码器 (Transformer Encoder)
ViT 的核心部分是标准 Transformer 编码器,它包含以下模块:
-
多头自注意力机制 (Multi-Head Self-Attention, MHSA)
- 通过自注意力机制捕获小块之间的全局关系,无需限制在局部感受野内操作(如 CNN)。
- 每个小块与其他小块计算注意力权重,以关注哪些部分对当前任务最重要。
-
前馈神经网络 (Feed-Forward Neural Network, FFN)
- 对每个小块的表示单独进行非线性变换,提升模型的表达能力。
-
残差连接与层归一化 (Residual Connections and Layer Normalization)
- 使用残差连接缓解梯度消失问题,同时通过层归一化稳定训练。
多个 Transformer 编码器堆叠后,输出序列保持大小不变,为 N × D N \times D N×D。
4. 分类头 (Classification Head)
-
类别标记 (Class Token)
- 在输入序列中引入一个可学习的 [CLS] 标记,其嵌入表示整个图像的信息。
- 类别标记通过编码器与其他小块交互,最终用于分类任务。
-
多层感知机 (MLP Head)
- 最终 [CLS] 标记的输出经过一个多层感知机(通常是全连接层)进行分类,得到最终的预测结果。
5. 预训练与微调 (Pretraining and Finetuning)
-
预训练
- ViT 通常在大规模数据集(如 ImageNet-21k 或 JFT-300M)上进行监督学习预训练。
- 通过大量的标注数据,模型学习到丰富的视觉特征。
-
微调
- 在特定任务的小型数据集上进行微调,例如 CIFAR-10、ImageNet 等。
ViT 的优点与局限性
优点
-
全局感受野
- 自注意力机制可以直接建模全局依赖关系,而 CNN 的局部感受野需要通过多层叠加才能实现。
-
更少的归纳偏差
- ViT 依赖数据学习到特征,而非依赖 CNN 的平移不变性等归纳偏差,更适合大规模数据。
-
灵活性
- 不受限于卷积核的大小,可适配不同的输入尺寸和任务。
局限性
-
数据需求大
- 由于缺乏 CNN 的归纳偏差,ViT 在小数据集上表现较差,需要大量预训练数据。
-
计算成本高
- 多头自注意力的计算复杂度为 O ( N 2 ⋅ D ) O(N^2 \cdot D) O(N2⋅D),对高分辨率图像或较多的分块数量计算开销较大。
import math
from collections import OrderedDict
from typing import Callable, List, Optionalimport torch
import torch.nn as nn
from torch import Tensor# 假设 Encoder 和 Conv2dNormActivation 已经被定义,并且 ConvStemConfig 是一个数据类或命名元组。
# 这些通常来自其他文件或库,例如 torchvision 或自定义实现。class VisionTransformer(nn.Module):"""Vision Transformer 如 https://arxiv.org/abs/2010.11929 所述."""def __init__(self,image_size: int, # 输入图像的大小(假设为正方形)patch_size: int, # 每个patch的大小(也假设为正方形)num_layers: int, # 编码器中的层数num_heads: int, # 注意力机制中的头数hidden_dim: int, # 隐藏层维度mlp_dim: int, # MLP 层的维度dropout: float = 0.0, # Dropout 概率attention_dropout: float = 0.0, # 注意力机制中的dropout概率num_classes: int = 1000, # 分类的数量representation_size: Optional[int] = None, # 表示层的大小,可选norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), # 归一化层conv_stem_configs: Optional[List[ConvStemConfig]] = None, # 卷积干配置,可选):super().__init__()_log_api_usage_once(self) # 记录API使用情况torch._assert(image_size % patch_size == 0, "输入形状不能被patch大小整除!") # 确保图像尺寸能被patch大小整除self.image_size = image_sizeself.patch_size = patch_sizeself.hidden_dim = hidden_dimself.mlp_dim = mlp_dimself.attention_dropout = attention_dropoutself.dropout = dropoutself.num_classes = num_classesself.representation_size = representation_sizeself.norm_layer = norm_layerif conv_stem_configs is not None:# 根据论文 https://arxiv.org/abs/2106.14881 使用卷积干seq_proj = nn.Sequential() # 创建一个序列容器来保存卷积干的层prev_channels = 3 # 初始通道数为3(RGB图像)for i, conv_stem_layer_config in enumerate(conv_stem_configs):# 对于每个卷积干配置,添加一个卷积、归一化和激活层seq_proj.add_module(f"conv_bn_relu_{i}",Conv2dNormActivation(in_channels=prev_channels,out_channels=conv_stem_layer_config.out_channels,kernel_size=conv_stem_layer_config.kernel_size,stride=conv_stem_layer_config.stride,norm_layer=conv_stem_layer_config.norm_layer,activation_layer=conv_stem_layer_config.activation_layer,),)prev_channels = conv_stem_layer_config.out_channels# 添加最后一个1x1卷积层,将通道数转换为隐藏维度seq_proj.add_module("conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1))self.conv_proj: nn.Module = seq_proj # 将卷积干设置为self.conv_projelse:# 如果没有提供卷积干配置,则使用简单的卷积投影self.conv_proj = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size)# 计算序列长度(包括类别标记)seq_length = (image_size // patch_size) ** 2# 添加一个类别标记self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))seq_length += 1 # 类别标记增加序列长度# 初始化编码器self.encoder = Encoder(seq_length,num_layers,num_heads,hidden_dim,mlp_dim,dropout,attention_dropout,norm_layer,)self.seq_length = seq_length# 定义分类头heads_layers: OrderedDict[str, nn.Module] = OrderedDict()if representation_size is None:heads_layers["head"] = nn.Linear(hidden_dim, num_classes)else:heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)heads_layers["act"] = nn.Tanh()heads_layers["head"] = nn.Linear(representation_size, num_classes)self.heads = nn.Sequential(heads_layers)# 初始化权重if isinstance(self.conv_proj, nn.Conv2d):# 初始化patchify stem的权重fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))if self.conv_proj.bias is not None:nn.init.zeros_(self.conv_proj.bias)elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):# 初始化卷积干中最后的1x1卷积层nn.init.normal_(self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels))if self.conv_proj.conv_last.bias is not None:nn.init.zeros_(self.conv_proj.conv_last.bias)if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):fan_in = self.heads.pre_logits.in_featuresnn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))nn.init.zeros_(self.heads.pre_logits.bias)if isinstance(self.heads.head, nn.Linear):nn.init.zeros_(self.heads.head.weight)nn.init.zeros_(self.heads.head.bias)def _process_input(self, x: Tensor) -> Tensor:# 获取输入张量的形状信息n, c, h, w = x.shapep = self.patch_sizetorch._assert(h == self.image_size, f"错误的图像高度!预期 {self.image_size} 但得到 {h}!")torch._assert(w == self.image_size, f"错误的图像宽度!预期 {self.image_size} 但得到 {w}!")n_h = h // pn_w = w // p# 使用卷积投影将图像划分为patch,并调整形状# (n, c, h, w) -> (n, hidden_dim, n_h, n_w)x = self.conv_proj(x)# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))x = x.reshape(n, self.hidden_dim, n_h * n_w)# 转换到 (N, S, E) 的格式,其中 N 是批次大小,S 是源序列长度,E 是嵌入维度# (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)x = x.permute(0, 2, 1)return xdef forward(self, x: Tensor):# 对输入张量进行重塑和转置x = self._process_input(x)n = x.shape[0]# 将类别标记扩展到整个批次,并与patch序列连接batch_class_token = self.class_token.expand(n, -1, -1)x = torch.cat([batch_class_token, x], dim=1)# 通过编码器处理x = self.encoder(x)# 只取每个样本的第一个token(即类别标记)作为表示x = x[:, 0]# 通过分类头进行最终预测x = self.heads(x)return x
参考
https://ai.googleblog.com/2020/12/transformers-for-image-recognition-at.html
论文 Worth 16x16 Words: Transformers for Image Recognition at Scale
https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py