Vision Transformer详解-CSDN博客
视频:11.1 Vision Transformer(vit)网络详解_哔哩哔哩_bilibili
Vision Transformer学习笔记_linear projection of flattened patches-CSDN博客
一、embedding 层
对于标准的Transformer模块,要求输入的是token (向量)序列,即二维矩阵[num_token,token_dim];
在代码实现中,直接通过一个卷积层来实现以ViT一 B/16为例,使用卷积核大小为16x16,stride为16, 卷积核个数为768;
- [224, 224, 3] -> [14, 14, 768] -> [196, 768]
在输入Transformer Encoder之前需要加上[class]token 以及Position Embedding,都是可训练参数
- 拼接[class]token: Cat([1,768],[196,768])->[197,768]
- 叠加Position Embedding: [197,768]->[197,768]
在这里我画了一个图来解释一下整体过程:
二、Encoder层
主要完成机制就是多头注意力机制。
三、 MLP Head层
把class token从最终结果[197,768]中切片拿出来,对其进行linear全连接(简单理解),如果需要类别概率的话,可以再接一个softmax
借用我导的图片来总结一下