ViT模型复现项目实战

项目源码获取方式见文章末尾! 600多个深度学习项目资料,快来加入社群一起学习吧。

《------往期经典推荐------》

项目名称
1.【基于CNN-RNN的影像报告生成】
2.【卫星图像道路检测DeepLabV3Plus模型】
3.【GAN模型实现二次元头像生成】
4.【CNN模型实现mnist手写数字识别】
5.【fasterRCNN模型实现飞机类目标检测】
6.【CNN-LSTM住宅用电量预测】
7.【VGG16模型实现新冠肺炎图片多分类】
8.【AlexNet模型实现鸟类识别】
9.【DIN模型实现推荐算法】
10.【FiBiNET模型实现推荐算法】
11.【钢板表面缺陷检测基于HRNET模型】

1. 项目简介

本项目的目标是复现Vision Transformer(ViT)模型,通过深入理解其核心架构和应用,探索其在图像分类任务中的性能表现。ViT模型是近年来在视觉任务中取得突破性进展的深度学习模型之一,核心思想是将Transformer这种原本应用于自然语言处理的模型引入到计算机视觉领域,解决了传统卷积神经网络(CNN)在处理全局信息时的局限性。ViT模型通过将输入图像划分为若干固定大小的图块(patches),再将这些图块展平并转换为序列形式输入Transformer模型,从而捕捉图像中的长距离依赖关系。这种方法克服了CNN的局部感受野限制,在大规模数据集上取得了比CNN更优的效果。本项目通过复现ViT模型,帮助用户深入理解其在计算机视觉中的应用及优势,尤其是在图像分类、目标检测等领域的实际表现。

2.技术创新点摘要

将Transformer引入计算机视觉领域:ViT模型将Transformer这种原本用于自然语言处理的架构引入到计算机视觉领域,摒弃了传统的卷积神经网络(CNN)。这种创新性设计突破了CNN局部感受野的限制,能够更好地捕捉图像中的长距离依赖关系。

图像分割为Patch的处理方法:ViT模型通过将输入图像划分为固定大小的图块(Patch),然后将这些图块展平并作为序列输入到Transformer中进行处理。这种处理方式与传统CNN处理整幅图像的方式不同,允许ViT模型能够对全局信息进行更加灵活的建模。

使用Multi-Head Attention机制捕捉全局依赖:模型中采用了多头自注意力机制(Multi-Head Attention),能够并行处理不同位置的图像信息,有效捕捉全局依赖关系。这使得模型可以在多个头上关注图像的不同部分,增强了模型对复杂场景的理解能力。

更高效的训练和推理:ViT模型相比CNN在大规模数据集上训练时效率更高,尤其是在处理高分辨率图像和复杂任务时展现出了显著的优势。这得益于其Transformer架构的优势,使得模型在图像分类任务中的表现优于传统的卷积神经网络。

权重初始化和自监督预训练:代码中还展示了对ViT模型权重初始化的优化方案,通过自监督预训练(self-supervised pretraining)技术,进一步提升了模型的泛化能力。

3. 数据集与预处理

在本项目中使用的数据集为经典的图像分类数据集,主要用于评估ViT模型在图像分类任务中的表现。常见的数据集包括ImageNet等大规模数据集,这些数据集具有类别丰富、样本数量大、图像分辨率高等特点。项目中选用的数据集包含多种类别的图片,每个类别的样本数较为均衡,能够为模型提供丰富的特征信息,帮助模型学习更具泛化能力的特征表示。

数据预处理流程是模型训练中至关重要的一环,确保输入数据的质量和一致性。首先,对于每张输入图片,进行了统一的图像尺寸调整,确保所有图像都能适配模型的输入要求。具体来说,ViT模型通常将图片划分为固定大小的图块(例如16×16像素),因此在预处理阶段,首先需要将图像缩放到指定大小。

接下来,应用了常见的归一化操作,将像素值缩放到[0, 1]或[-1, 1]区间。这有助于加快模型的收敛速度,并防止梯度消失或爆炸。此外,归一化还可以减少各特征间的量纲差异,提高模型的鲁棒性。

为了增强模型的泛化能力,数据增强技术也在预处理阶段被广泛应用。常见的数据增强方法包括随机裁剪、水平翻转、色彩抖动和旋转等操作。这些增强方法通过生成不同的图像变体,扩大了训练数据的多样性,减少了模型过拟合的风险。

在这里插入图片描述

4. 模型架构

该项目采用了Vision Transformer(ViT)模型,其模型结构由多个Transformer块组成,具体如下:

  • 输入层:输入图像被划分为固定大小的图块(Patch)。假设输入图像大小为 H×W×C,其中 H 为高度,W 为宽度,C为通道数。每个图像被分割为 N=HP×WP个图块, P 是每个Patch的大小。
  • Patch Embedding Layer:图块被展平,并通过一个线性投影层映射到固定维度的嵌入空间中。假设线性投影的输出维度为 D,则Patch的表示为:

Z 0 = [ x 1 E ; x 2 E ; … ; x N E ] + E p o s Z_0 = [x_1E; x_2E; \dots; x_NE] + E_{pos} Z0=[x1E;x2E;;xNE]+Epos

  • 其中,xi 是第 iii 个Patch, E是可学习的嵌入矩阵, Epos是位置编码,确保模型能够捕捉Patch的相对位置。

  • Transformer Block:每个Transformer块包含以下部分:

    • Layer Normalization:对输入进行标准化处理。
    • 多头自注意力机制(Multi-Head Attention) Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V 其中 Q, K, V 分别是查询矩阵、键矩阵和值矩阵, dk是键的维度。
    • 前馈神经网络:包含两个线性层,中间有一个激活函数(通常为GELU)。公式如下: FFN ( x ) = GELU ( x W 1 + b 1 ) W 2 + b 2 \text{FFN}(x) = \text{GELU}(xW_1 + b_1)W_2 + b_2 FFN(x)=GELU(xW1+b1)W2+b2
  • 分类层:输出层是一个线性分类器,输入的是Transformer最后一层的输出(即第一个Token的表示),用于图像分类任务。

  1. 模型的整体训练流程

模型的训练过程分为以下几步:

  • 前向传播:将图像输入模型,通过各层的处理,输出分类结果。
  • 损失函数:使用交叉熵损失(CrossEntropy Loss)计算模型预测结果与真实标签之间的误差。公式为:

L = − ∑ i = 1 N y i log ⁡ ( y i ^ ) L = - \sum_{i=1}^{N} y_i \log(\hat{y_i}) L=i=1Nyilog(yi^)

  • 其中 yi是实际标签,yi^ 是模型预测的概率分布。
  • 反向传播:通过计算梯度来更新模型的参数,优化目标是最小化损失函数。
  • 评估指标:训练过程中主要采用准确率(Accuracy)作为评估指标,计算模型正确预测样本的比例:

Accuracy = 正确分类的样本数 总样本数 \text{Accuracy} = \frac{\text{正确分类的样本数}}{\text{总样本数}} Accuracy=总样本数正确分类的样本数

通过该架构,模型在处理图像分类任务时展现了较强的全局建模能力,有效捕捉图像中的长距离依赖关系

5. 核心代码详细讲解

1. 数据预处理

data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
  • transforms.Resize(256):将输入图像的最小边缩放到256像素。
  • transforms.CenterCrop(224):从缩放后的图像中取224×224像素的中心部分。这是标准的图像分类输入尺寸。
  • transforms.ToTensor():将图像从PIL格式转换为PyTorch的张量格式,并且像素值被归一化到 [0, 1]。
  • transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]):对每个通道进行归一化,将图像的像素值缩放到[-1, 1]之间,方便模型训练。

2. Patch Embedding Layer

class PatchEmbed(nn.Module):def init(self, img_size=224, patch_size=16, in_c=3, embed_dim=768):super().
__init__
()self.img_size = img_sizeself.patch_size = patch_sizeself.num_patches = (img_size // patch_size) ** 2self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
  • img_size=224:输入图像的大小为224×224像素。
  • patch_size=16:图像被分割为16×16的Patch。
  • nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size):使用卷积层将图像分割为不重叠的Patch,并将其投影到嵌入维度(embed_dim),这里通过卷积的步幅等于Kernel Size实现图块划分。

3. Transformer Block

class Block(nn.Module):def init(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_ratio=0., attn_drop_ratio=0., drop_path_ratio=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):super(Block, self).
__init__
()self.norm1 = norm_layer(dim)self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
  • self.norm1 = norm_layer(dim):Layer Normalization层,用于对输入进行标准化。
  • self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias):多头自注意力机制(Attention),通过多头机制捕捉图像中不同区域的相互关系。
  • self.mlp = Mlp(...):多层感知机(MLP),包含激活函数GELU及两个线性层。

4. 前向传播

def forward(self, x):x = x + self.drop_path(self.attn(self.norm1(x)))x = x + self.drop_path(self.mlp(self.norm2(x)))return x
  • x = x + self.drop_path(self.attn(self.norm1(x))):首先对输入进行Layer Normalization,然后通过Attention层计算注意力得分,最后使用残差连接(Residual Connection)保留输入信息。
  • x = x + self.drop_path(self.mlp(self.norm2(x))):第二步是对输出进行标准化,并通过MLP层处理,同样使用残差连接。

5. 模型训练流程

for step, data in enumerate(data_loader):images, labels = datapred = model(images.to(device))loss = loss_function(pred, labels.to(device))loss.backward()optimizer.step()optimizer.zero_grad()
  • images, labels = data:从数据加载器中获取一批图像和对应的标签。
  • pred = model(images.to(device)):将图像输入模型,得到预测结果。
  • loss = loss_function(pred, labels.to(device)):计算预测结果和实际标签之间的损失,这里使用的是交叉熵损失。
  • loss.backward():通过反向传播算法计算损失的梯度。
  • optimizer.step():更新模型的参数,使损失最小化。
  • optimizer.zero_grad():清除上一步的梯度,以防止累积。

6. 模型评估

@torch.no_grad()def evaluate(model, data_loader, device):model.eval()for step, data in enumerate(data_loader):images, labels = datapred = model(images.to(device))pred_classes = torch.max(pred, dim=1)[1]accu_num += torch.eq(pred_classes, labels.to(device)).sum()return accu_num.item() / sample_num
  • model.eval():将模型置于评估模式,关闭Dropout等随机操作。
  • pred_classes = torch.max(pred, dim=1)[1]:通过取预测值的最大索引,得到模型的预测类别。
  • accu_num += torch.eq(pred_classes, labels.to(device)).sum():计算模型在当前批次中的预测正确数。

6. 模型优缺点评价

优点

  1. 全局信息捕捉能力强:Vision Transformer(ViT)通过自注意力机制能够有效捕捉图像中不同区域之间的长距离依赖关系,与传统的卷积神经网络(CNN)相比,ViT能够更好地理解全局特征。
  2. 较少依赖卷积操作:ViT抛弃了CNN中的卷积操作,减少了对局部感受野的依赖,适合处理高分辨率图像和大规模数据。
  3. 扩展性强:ViT架构灵活,可以通过增加Transformer的深度和宽度来提高模型的能力,在大数据集上表现突出,尤其在大规模预训练后迁移到下游任务时有较好的表现。

缺点

  1. 数据需求高:与传统CNN相比,ViT对大规模数据的依赖更强。如果数据量不足,ViT容易出现过拟合,难以学习到有效的特征。
  2. 训练成本高:Transformer模型计算复杂度高,训练过程中占用大量的计算资源,尤其是在较深的网络结构下,显著增加了计算时间和内存消耗。
  3. 不适合小型数据集:在小型数据集上,ViT的表现不如CNN,因为缺乏丰富的卷积特征提取能力。

改进方向

  1. 模型结构优化:可以在ViT中加入混合架构,如结合卷积层与Transformer层,使模型既具备局部特征提取能力,又能有效捕捉全局信息。
  2. 超参数调整:可以通过调节模型的深度、宽度、注意力头的数量以及学习率等超参数,找到适合特定任务的最佳模型配置。
  3. 更多数据增强方法:为减少对大规模数据的依赖,可以引入更多的数据增强技术,如CutMix、MixUp等,提高模型的泛化能力。
  4. 预训练技术:通过更大规模的自监督学习进行预训练,有助于提升ViT在下游任务中的表现,尤其是小数据集的迁移能力。

↓↓↓更多热门推荐:
Densenet模型花卉图像分类
ResNet18模型扑克牌图片预测
transformer模型写诗词

查看全部项目数据集、代码、教程点击下方名片

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/59550.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

16通道AD采集方案,基于复旦微ARM + FPGA国产SoC处理器平台

测试数据汇总 表 1 本文带来的是基于复旦微FMQL20S400M四核ARM Cortex-A7(PS端) + FPGA可编程逻辑资源(PL端)异构多核SoC处理器设计的全国产工业评估板的AD采集案例。本次案例演示的开发环境如下: Windows开发环境:Windows 7 64bit、Windows 10 64bit PL端开发环境:P…

【Python爬虫实战】DrissionPage 与 ChromiumPage:高效网页自动化与数据抓取的双利器

🌈个人主页:易辰君-CSDN博客 🔥 系列专栏:https://blog.csdn.net/2401_86688088/category_12797772.html ​ 目录 前言 一、DrissionPage简介 (一)特点 (二)安装 (三…

R7:糖尿病预测模型优化探索

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 一、实验目的: 探索本案例是否还有进一步优化的空间 二、实验环境: 语言环境:python 3.8编译器:Jupyter notebo…

HANDLINK ISS-7000v2 网关 login_handler.cgi 未授权RCE漏洞复现

0x01 产品简介 瀚霖科技股份有限公司ISS-7000 v2网络网关服务器是台高性能的网关,提供各类酒店网络认证计费的完整解决方案。由于智慧手机与平板电脑日渐普及,人们工作之时开始使用随身携带的设备,因此无线网络也成为网络使用者基本服务的项目。ISS-7000 v2可登录300至1000…

RK3576 LINUX RKNN SDK 测试

安装Conda工具 安装 Miniforge Conda wget -c https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh chmod 777 Miniforge3-Linux-x86_64.sh bash Miniforge3-Linux-x86_64.shsource ~/miniforge3/bin/activate # Miniforge 安装的…

深入学习指针(5)!!!!!!!!!!!!!!!

文章目录 1.回调函数是什么?2.qsort使用举例2.1使用qsort函数排序整形数据2.2使用sqort排序结构数据 3.qsort函数的模拟实现 1.回调函数是什么? 回调函数就是⼀个通过函数指针调⽤的函数。 如果你把函数的指针(地址)作为参数传递…

天锐绿盾加密软件与Ping32数据安全防护对比,为企业提供坚实的保障

在当今信息化时代,数据安全已成为企业不可忽视的重要议题。天锐绿盾加密软件与Ping32作为两款备受关注的数据安全解决方案,各自以其卓越的功能和优势,为企业数据安全提供了坚实的保障。 Ping32,同样以其出色的数据加密和防泄密功能…

支持向量机相关证明 解的稀疏性

主要涉及拉格朗日乘子法,对偶问题求解

求职经验分享

更多详情:爱米的前端小笔记,更多前端内容,等你来看!这些都是利用下班时间整理的,整理不易,大家多多👍💛➕🤔哦!你们的支持才是我不断更新的动力!找…

基于Dpabi和spm12的脑脊液(csf)分割和提取笔记

一、前言 脑脊液(csf)一直被认为与新陈代谢有重要关联,其为许多神经科学研究提供重要价值,从fMRI图像中提取脑脊液信号可用于多种神经系统疾病的诊断。特别是自2019年Science上那篇著名的csf-BOLD文章发表后,大家都试图…

力扣:94--中序遍历二叉树

树 – 二叉树 完全二叉树: 完全二叉树可以用数组完美匹配位置(先序存储:根左右), 推论一 : 位置为k的节点,左孩子:2*k 1 ,右孩子 : 2 * (k 1&…

SQL 常用语句

目录 我的测试环境 学习文档 进入数据库 基础通关测验 语句-- 查 展示数据库; 进入某个数据库; 展示表: 展示某个表 desc 查询整个表: 查询特定列: 范围查询 等于特定值 不等于 介于 特定字符查询 Li…

MySQL utf8mb3 和 utf8mb4引发的问题

问题描述 Cause: java.sql.SQLException: Incorrect string value: \xF4\x8F\xBB\xBF-b... for column sddd_aaa_ark at row 1 sddd_aaa_ark 存储中文字符时,出现上述问题 原因分析 sddd_aaa_ark在数据库中结构是 utf8字符的最大字节数是3 byte,但是某些…

ONLYOFFICE 文档8.2更新评测:PDF 协作编辑、性能优化及更多新功能体验

文章目录 🍀引言🍀ONLYOFFICE 产品简介🍀功能与特点🍀体验与测评ONLYOFFICE 8.2🍀邀请用户使用🍀 ONLYOFFICE 项目介绍🍀总结 🍀引言 在日常办公软件的选择中,WPS 和微软…

SAP-ABAP开发-ONLINE 程序、DIALOG屏幕开发

目录 一、Online 程序概览 1、程序类型 2、Online程序的主要对象 二、界面 1、SAP的屏幕开发 2、屏幕功能实现 3、界面中的事件块(Event Block) 4、界面的创建 三、简单界面元素 1、文本/输入框控件 2、数据检查 3、一些常用的关键字 四、复…

java、excel表格合并、指定单元格查找、合并文件夹

#创作灵感# 公司需求 记录工作内容 后端:JAVA、Solon、easyExcel、FastJson2 前端:vue2.js、js、HTML 模式1:合并文件夹 * 现有很多文件夹 想合并全部全部的文件夹的文件到一个文件夹内 * 每个部门发布的表格 合并全部的表格为方便操作 模…

平替谷歌翻译--沉浸式翻译

这款插件真特么的猛啊!!! 谷歌插件或者油猴插件都有。 沉浸式翻译 - 免费双语对照网页翻译插件

印尼市场潜力无限!用友司库直联助力中企印尼“掘金”

在经济全球化的浪潮下,东南亚市场正焕发出勃勃生机。而其中印度尼西亚作为东盟 大的经济体,被认为是东南亚重要、有活力的市场之一,成为中企出海竞相布局的热门目的地。然而,在积极进军印尼市场的过程中,中国企业普遍面…

【贪心算法】No.1---贪心算法(1)

文章目录 前言一、贪心算法:二、贪心算法示例:1.1 柠檬⽔找零1.2 将数组和减半的最少操作次数1.3 最⼤数1.4 摆动序列1.5 最⻓递增⼦序列1.6 递增的三元⼦序列 前言 👧个人主页:小沈YO. 😚小编介绍:欢迎来到…

自动驾驶---“火热的”时空联合规划

1 背景 早期的不少规划算法都是横纵分离的(比如Apollo),先求解path之后,依赖path的结果再进行speed的求解。这种横纵解耦的规划方式具有以下特点: 相对较为简单,计算量通常较小,容易实现实时性…