ViT(Vision Transformer)网络结构详解

本文在transformer的基础上对ViT进行讲解,transformer相关部分可以看我另一篇博客(transformer中对于QKV的个人理解-CSDN博客)。

一、网络结构概览

上图展示了Vision Transformer (ViT) 的基本架构,我按照运行顺序分为三个板块进行讲解。以下是 ViT 处理图像的基本步骤:

  1. Embedding(Linear Projection of Flattened Patches):首先,输入图像被分割成固定大小的patch。每个patch被展平并经过一个线性层,将其转换为特征向量。此外,每个特征向量都会附加一个表示其在原始图像中相对位置的位置嵌入。

  2. Transformer 编码器(Transformer Encoder):这些预处理过的特征向量作为输入馈送到一系列的 Transformer 层进行编码。每个 Transformer 层包含多头注意力机制(Multi-Head Attention)和前馈神经网络(Feed-Forward Network)。在每个 Transformer 层之间有层归一化(Layer Normalization),以保持特征分布的一致性。

  3. MLP Head 和分类(MLP Head):最后,来自 Transformer 编码器的最后一层输出通过一个多层感知机(MLP)头部,通常是一个简单的全连接层,用于执行最终的任务,如图像分类。在这个阶段,模型将学习到的特征映射到类概率空间。

上图是另外一个大佬做的细致一点的网络结构图。

1. Embedding

我将Embedding分为4个步骤:

  1. 分割patch
    首先输入为一张图片,vit中默认将图像缩放为224*224,通常利用16*16的patch进行划分得到(224*224) / (16*16)=14*14=196张子图,每张patch的大小为16*16,3通道。
  2. 转换获得token
    将每个二维的补丁转换为可以输入到 Transformer 的一维序列,将14*14个patch展平成一个长向量,再通过线性变换得到196个token,token的长度为768。


    然而,在vit中是利用768个卷积核大小为16*16,stride为16,padding为0的一个卷积层直接对224*224的输入图像进行卷积,从而得到14*14*768的输出,再对其展平得到[196,768]的token,相当于合并了第一步和第二步。
  3. 拼接上类别token

    之后在得到的196个 token 的前面加上加上一个新的Class Token(即图中0号紫色框右边带*的那个框,这不是通过某个patch产生的。其作用类似于BERT中的Class Token。),得到[197,768]的数据。
    在ViT中,Class Token(通常记为CLS)是一个可学习的参数。它在网络初始化时被随机初始化,类似于其他神经网络权重参数。在经过多层Transformer编码器后,Class Token会聚集来自所有图像块的信息,形成图像的全局表示。最终的Class Token表示(即最后一层Transformer编码器输出的Class Token)被输入到一个分类头(通常是一个全连接层)中,用于图像的分类任务。分类头的输出即为预测的类别概率分布。

  4. 加上位置编码

    patch得到的图像是没有位置信息的,需要用position embedding将位置信息加到模型中去。如上图所示,编号有0-9的紫色框表示各个位置的position embedding
    位置编码是
    可训练的编码,通过叠加加入到[197,768]的数据中,最终输出[197,768]的数据。


    加了位置编码,性能有明显提升,但是不同编码器的方式对性能提升差不多。所以源码中使用的是1-D位置编码。


    上图呈现了一个热力图,其中水平轴代表输入补丁的列数,垂直轴代表输入补丁的行数,颜色深浅表示相似程度。
    通过观察热力图,我们可以看到随着补丁距离变远,它们的位置嵌入变得越来越不相似。在真实世界中,离得近的物体往往比离得远的物体具有更强的空间相关性。这种特性有助于 ViT 更好地理解图像内容及其结构。

2. Transformer 编码器

Transformer encoder层如图所示,Transformer encoder重复堆叠 L 次,整个模型也就包括 L 个 Transformer。关于Layer Norm和多头注意力模块的具体解析可以看我主页其他的博客,这里不过多赘述。
MLP中先通过一个线性层将输入数据的通道数变为原来的4倍,之后通过GELU激活函数和Dropout,再通过一个线性层将4倍通道数变为原来的通道数。

3. MLP Head 和分类

MLP Head 层位于 Transformer 编码器之后,用于完成特定任务,如图像分类。该层通常是一个多层感知机(Multilayer Perceptron,简称 MLP),它接收来自 Transformer 编码器的输出,并对其进行进一步处理以生成最终的预测。在不同的场景下,MLP Head 的结构可能会有所不同。

在这张图片中,我们能看到两种情况下的 MLP Head 设计:

  1. 训练 ImageNet21K 时的 MLP Head

    当训练 ViT 时,特别是在大型数据集如 ImageNet21K 上,MLP Head 包含三个组成部分:线性层、tanh 激活函数以及另一个线性层。这样的设计允许 MLP Head 对编码器输出进行非线性的转换,以适应复杂的模式和特征。激活函数 tanh 提供了一种非线性变形,可以帮助模型学习更复杂的表示。
  2. 迁移至 ImageNet1K 或其他数据集时的 MLP Head

    当将 ViT 迁移到较小的数据集,如 ImageNet1K 或者自己的数据集时,通常只保留一个线性层。这是因为较小的数据集可能不需要那么多的复杂性,一个线性层就足以提供足够的泛化能力。减少层数也可以降低过拟合的风险。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/48869.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Rancher

文章目录 Rancher1. 安装和配置2. 服务部署和管理3. 容器自动化缩容和扩容 Rancher Rancher 是一个开源的企业级容器管理平台,旨在简化容器化应用的部署、管理和运维。它支持多种容器编排引擎,如 Kubernetes、Docker Swarm 等,并提供了统一的…

自动驾驶系统开发与调试:车路云一体化无人驾驶挑战赛参赛体验

点击蓝字 关注我们 在过去的几年里,自动驾驶技术在全球范围内吸引了大量关注。其潜力不仅在于提升行车安全,而且还可以改变我们的出行方式和城市规划,提高交通运输效率。国际汽车工程师学会(SAE)根据不同自动驾驶程度&…

JAVA在线文档

1.存在码 JDK21中文API 2.全栈行动派 JDK17中文API 3.mklab.cn JDK11中文API JDK8中文API JDK7-21英文API 4.docs.oracle.com JDK7-22英文文档

项目笔记| 基于Arduino和IR2101的无刷直流电机控制器

本文介绍如何使用 Arduino UNO 板构建无传感器无刷直流 (BLDC) 电机控制器或简单的 ESC(电子速度控制器)。 无刷直流电机有两种类型:有传感器和无传感器。有感无刷直流电机内置3个霍尔效应传感器,这些传感…

MLIR的TOY教程学习笔记

MLIR TOY Language 文章目录 MLIR TOY Language如何编译该项目ch1: MLIR 前端IR解析ch2: 定义方言和算子 (ODS)1. 定义方言2. 定义OP3. OP相关操作4. 定义OP ODS (Operation Definition Specification)1. 基本定义2. 添加文档3. 验证OP4. 新增构造函数5. 定义打印OP的格式 ch3:…

【机器学习】超参数选择:解锁机器学习模型潜力的关键

🌈个人主页: 鑫宝Code 🔥热门专栏: 闲话杂谈| 炫酷HTML | JavaScript基础 ​💫个人格言: "如无必要,勿增实体" 文章目录 超参数选择:解锁机器学习模型潜力的关键引言什么是超参数&#xff1…

计算机的错误计算(三十八)

摘要 计算机的错误计算(十九)指出:两个等价大数相减,差不是正确值0,而是一个大数。本节用 Python的 torch库中函数进行计算验证,进一步说明错误的一般性。 例1. 在Windows10,Python 3.12.4 下…

Android APP Camerax应用(02)预览流程

说明:camera子系统 系列文章针对Android12.0系统,主要针对 camerax API框架进行解读。 1 CameraX简介 1.1 CameraX 预览流程简要解读 CameraX 是 Android 上的一个 Jetpack 支持库,它提供了一套统一的 API 来处理相机功能,无论 …

【HarmonyOS NEXT】网络请求 - 分页加载

分页加载关键字:onReachEnd 一、申请网络权限 在 module.json5 文件中,添加网络权限: {"module": {..."requestPermissions": [{"name": "ohos.permission.INTERNET","usedScene": {&qu…

网络安全常用易混术语定义与解读(Top 20)

没有网络安全就没有国家安全,网络安全已成为每个人都重视的话题。随着技术的飞速发展,各种网络攻击手段层出不穷,保护个人和企业的信息安全显得尤为重要。然而,在这个复杂的领域中,许多专业术语往往让人感到困惑。为了…

portainer教程-docker可视化管理工具

很多朋友刚接触docker 学习,就想问 docker有图形化界面吗 ,答案是肯定的, 这里白眉大叔 给大家推荐 Docker可视化管理平台 -- Portainer 1- 运行Portainer: docker run -d -p 8000:8000 -p 9000:9000 --name portainer --restarta…

【保姆级讲解C语言中的运算符的优先级!】

🎥博主:程序员不想YY啊 💫CSDN优质创作者,CSDN实力新星,CSDN博客专家 🤗点赞🎈收藏⭐再看💫养成习惯 ✨希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共同学习、交流进步! 🪶C语言中的运算符的优先级 🪶C语言中的运算符的优先级决定了…

二、C#数据类型

本文是网页版《C# 12.0 本质论》第二章解读。欲完整跟踪本系列文章,请关注并订阅我的Essential C# 12.0解读专栏。 前言 数据类型(Data Type)是一个很恼人的话题。 似乎根本没必要对数据类型进行展开讲解,因为人人都懂。 但是…

grafana大坑,es找不到时间戳 | No date field named timestamp found

grafana大坑,es找不到时间戳。最近我这边的es重新装了一遍,结果发现grafana连不上elasticsearch了(以下简称es),排查问题查了好久一直以为是es没有装成功或者两边的版本不兼容,最后才发现是数值类型问题 一…

浅聊 Three.js 屏幕空间反射SSR-SSRShader

浅聊 Three.js 屏幕空间反射SSR(2)-SSRShader 前置基础 渲染管线中的相机和屏幕示意图 -Z (相机朝向的方向)||| -------------- <- 屏幕/投影平面| | || | || | (f) | <- 焦距| | ||…

【BUG】已解决:error: legacy - install - failure

error: legacy - install - failure 目录 error: legacy - install - failure 【常见模块错误】 【解决方案】 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页&#xff0c;我是博主英杰&#xff0c;211科班出身&#xff0c;就职于医疗科技公司&…

51单片机14(独立按键实验)

一、按键介绍 1、按键是一种电子开关&#xff0c;使用的时候&#xff0c;只要轻轻的按下我们的这个按钮&#xff0c;按钮就可以使这个开关导通。 2、当松开这个手的时候&#xff0c;我们的这个开关&#xff0c;就断开开发板上使用的这个按键&#xff0c;它的内部结构&#xff…

免费分享:2021年度全国城乡划分代码(附下载方法)

《关于统计上划分城乡的规定》指出&#xff1a;“本规定作为统计上划分城乡的依据&#xff0c;不改变现有的行政区划、隶属关系、管理权限和机构编制&#xff0c;以及土地规划、城乡规划等有关规定”。统计用区划代码和城乡划分代码用于统计工作&#xff0c;需要在其他工作中使…

回溯题目的套路总结

前言 昨天写完了LeeCode的7&#xff0c;8道回溯算法的题目&#xff0c;写一下总结&#xff0c;这类题目的共同特点就是暴力搜索问题&#xff0c;排列组合或者递归&#xff0c;枚举出所有可能的答案&#xff0c;思路很简单&#xff0c;实现起来的套路也很通用&#xff0c;一…

java题目之抽奖以及优化方式

public class Main9 {public static void main(String[] args) {int[]arr{ 2,588,888,1000,10000};int [] newArrnew int[arr.length];//3.抽奖Random rnew Random();//因为有5个奖项,所以这里循环五次for (int i 0; i <5 ; ) {//获取随机索引int randomIndexr.nextInt(arr…