计算机视觉之Vision Transformer图像分类

Vision Transformer(ViT)简介

自注意结构模型的发展,特别是Transformer模型的出现,极大推动了自然语言处理模型的发展。Transformers的计算效率和可扩展性使其能够训练具有超过100B参数的规模空前的模型。ViT是自然语言处理和计算机视觉的结合,能够在图像分类任务上取得良好效果,而不依赖卷积操作。

Vision Transformer(ViT)简介

近些年,随着基于自注意(Self-Attention)结构的模型的发展,特别是Transformer模型的提出,极大地促进了自然语言处理模型的发展。由于Transformers的计算效率和可扩展性,它已经能够训练具有超过100B参数的空前规模的模型。

ViT则是自然语言处理和计算机视觉两个领域的融合结晶。在不依赖卷积操作的情况下,依然可以在图像分类任务上达到很好的效果。

模型结构

ViT模型的主体结构是基于Transformer模型的Encoder部分(部分结构顺序有调整,如:Normalization的位置与标准Transformer不同),其结构图[1]如下:

vit-architecture

模型特点

ViT模型是一种用于图像分类的模型,将原图像划分为多个图像块,然后将这些图像块转换为一维向量,加上类别向量和位置向量作为模型输入。模型主体采用基于Transformer的Encoder结构,但调整了Normalization的位置,其中最主要的结构是Multi-head Attention。模型在Blocks堆叠后接全连接层,使用类别向量的输出进行分类,通常将全连接层称为Head,Transformer Encoder部分称为backbone。

Transformer基本原理

Transformer模型源于2017年的一篇文章[2]。在这篇文章中提出的基于Attention机制的编码器-解码器型结构在自然语言处理领域获得了巨大的成功。模型结构如下图所示:

transformer-architecture

模型训练

模型训练前需要设定损失函数、优化器、回调函数等,以及建议根据项目需要调整epoch_size。训练ViT模型需要很长时间,可以通过输出的信息查看训练的进度和指标。

from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train# define super parameter
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()# construct model
network = ViT()# load ckpt
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)# define learning rate
lr = nn.cosine_decay_lr(min_lr=float(0),max_lr=0.00005,total_step=epoch_size * step_size,step_per_epoch=step_size,decay_epoch=10)# define optimizer
network_opt = nn.Adam(network.trainable_params(), lr, momentum)# define loss function
class CrossEntropySmooth(LossBase):"""CrossEntropy."""def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):super(CrossEntropySmooth, self).__init__()self.onehot = ops.OneHot()self.sparse = sparseself.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)def construct(self, logit, label):if self.sparse:label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)loss = self.ce(logit, label)return lossnetwork_loss = CrossEntropySmooth(sparse=True,reduction="mean",smooth_factor=0.1,num_classes=num_classes)# set checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)# initialize model
# "Ascend + mixed precision" can improve performance
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
else:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")# train model
model.train(epoch_size,dataset_train,callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],dataset_sink_mode=False,)

总结

本案例演示了如何在ImageNet数据集上训练、验证和推断ViT模型。通过讲解ViT模型的关键结构和原理,帮助用户理解Multi-Head Attention、TransformerEncoder和pos_embedding等关键概念。建议用户基于源码深入学习,以更详细地理解ViT模型的原理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/45983.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

prompt第一讲-prompt科普

文章目录 大语言模型输入要求中英翻译助手直接抛出问题描述问题描述(详细)问题描述案例问题描述案例上下文问题为什么要加入上下文 prompt总结prompt心得 大语言模型输入要求 大语言模型本质上就是一个NLP语言模型,语言模型其实就是接受一堆…

ubuntu服务器安装labelimg报错记录

文章目录 报错提示查看报错原因安装报错 报错提示 按照步骤安装完labelimg后,在终端输入labelImg后,报错: (labelimg) rootinteractive59753:~# labelImg ………………Got keys from plugin meta data ("xcb") QFactoryLoader::Q…

hutool处理excel时候空指针小记

如图所示&#xff0c;右侧的会识别不到 参考解决方案&#xff1a; /***Description: 填补空缺位置为null/空串*Param: hutool读取的list*return: 无*Author: y*date: 2024/7/13*/public static void formatHutoolExcelArr(List<List<Object>> list) {if (CollUtil…

企业网络实验dhcp-snooping、ip source check,防非法dhcp服务器、自动获取ip(虚拟机充当DHCP服务器)、禁手动修改IP

文章目录 需求相关配置互通性配置配置vmware虚拟机&#xff08;dhcp&#xff09;分配IP服务配置dhcp relay&#xff08;dhcp中继&#xff09;配置dhcp-snooping&#xff08;防非法dhcp服务器&#xff09;配置ip source check&#xff08;禁手动修改IP&#xff09; DHCP中继&…

Android ListView

ListView ListView是以列表的形式展示具体内容的控件&#xff0c;ListView能够根据数据的长度自适应显示&#xff0c;如手机通讯录、短消息列表等都可以使用ListView实现。如图1所示是两个ListView&#xff0c;上半部分是数组形式的ListView&#xff0c;下半部分是简单列表Lis…

《Linux系统编程篇》认识在linux上的文件 ——基础篇

前言 Linux系统编程的文件操作如同掌握了一把魔法钥匙&#xff0c;打开了无尽可能性的大门。在这个世界中&#xff0c;你需要了解文件描述符、文件权限、文件路径等基础知识&#xff0c;就像探险家需要了解地图和指南针一样。而了解这些基础知识&#xff0c;就像学会了魔法咒语…

jenkins系列-07.轻易级jpom安装

jpom是一个容器化服务管理工具&#xff1a;在线构建&#xff0c;自动部署&#xff0c;日常运维, 比jenkins轻量多了。 本篇介绍mac m1安装jpom: #下载&#xff1a;https://jpom.top/pages/all-downloads/ 解压&#xff1a;/Users/jelex/Documents/work/jpom-2.10.40 启动前修…

css基础(1)

CSS CCS Syntax CSS 规则由选择器和声明块组成。 CSS选择器 CSS选择器用于查找想要设置样式的HTML元素 一般选择器分为五类 Simple selectors (select elements based on name, id, class) 简单选择器&#xff08;根据名称、id、类选择元素&#xff09; //页面上的所有 …

Web 性能入门指南-1.5 创建 Web 性能优化文化的最佳实践

最成功的网站都有什么共同点&#xff1f;那就是他们都有很强的网站性能和可用性文化。以下是一些经过验证的有效技巧和最佳实践&#xff0c;可帮助您建立健康、快乐、值得庆祝的性能文化。 创建强大的性能优化文化意味着在你的公司或团队中创建一个如下所示的反馈循环&#xff…

centos7|Linux操作系统|编译最新的OpenSSL-3.3,制作rpm安装包

一、 为什么需要编译rpm包 通常&#xff0c;我们需要安装某个软件&#xff0c;尤其是在centos7这样的操作系统&#xff0c;一般是通过yum包管理器来安装软件&#xff0c;yum的作用是管理rpm包的依赖关系&#xff0c;自动的处理rpm包的安装顺序&#xff0c;安装依赖等的相关问…

交换机和路由器的工作流程

1、交换机工作流程&#xff1a; 将接口中的电流识别为二进制&#xff0c;并转换成数据帧&#xff0c;交换机会记录学习该数据帧的源MAC地址&#xff0c;并将其端口关联起来记录在MAC地址表中。然后查看MAC地址表来查找目标MAC地址&#xff0c;会有一下一些情况&#xff1a; MA…

通过Bugly上报的日志查找崩溃闪退原因

第一步&#xff0c;解析堆栈信息 在bugly上收集到的信息是这样的 0x000000010542e46c 0x0000000104db4000 6792300 OS应用发生崩溃时&#xff0c;系统会生成一份崩溃日志&#xff0c;这份日志中包含了崩溃时的堆栈信息&#xff0c;但这些堆栈信息并非直接指向源代码&#x…

【漏洞复现】某赛通 电子文档安全管理系统 多个接口存在远程命令执行漏洞

免责声明&#xff1a; 本文内容旨在提供有关特定漏洞或安全漏洞的信息&#xff0c;以帮助用户更好地了解可能存在的风险。公布此类信息的目的在于促进网络安全意识和技术进步&#xff0c;并非出于任何恶意目的。阅读者应该明白&#xff0c;在利用本文提到的漏洞信息或进行相关测…

【RAG 实践】LlamaIndex 快速实现一个基于 OpenAI 的 RAG

这是 LlamaIndex 官方 Starter Tutorial 中 demo&#xff0c;用很少的代码来使用 OpenAI 快速实现出一个 RAG。 Ref: Starter Tutorial | LlamaIndex 代码&#xff1a;llamindex-rag-demo | Kaggle 1&#xff09;设置 OpenAI Token 这里使用国内的 OpenAI 中转 API token&…

【Python】数据分析-Matplotlib绘图

数据分析 Jupyter Notebook Jupyter Notebook: 一款用于编程、文档、笔记和展示的软件。 启动命令&#xff1a; jupyter notebookMatplotlib 设置中文格式&#xff1a;plt.rcParams[font.sans-serif] [KaiTi] # 查看本地所有字体 import matplotlib.font_manager a sorted…

802.11ax RU - 传输的最小单元

子载波 无线信号是加载在某个固定频率上进行传输的&#xff0c;这个频率被称为载波。802.11标准中&#xff0c;对传输频率有更新的划分&#xff0c;而这些划分的频率被称为子载波。Wi-Fi 6中&#xff0c;以20Mhz信道为例&#xff0c;20Mhz信道被划分成256个子载波&#xff0c;…

QML 鼠标和键盘事件

学习目标&#xff1a;Qml 鼠标和键盘事件 学习内容 1、QML 鼠标事件处理QML 直接提供 MouseArea 来捕获鼠标事件&#xff0c;该操作必须配合Rectangle 获取指定区域内的鼠标事件, 2、QML 键盘事件处理&#xff0c;并且获取对OML直接通过键盘事件 Keys 监控键盘任意按键应的消…

基于3D感知的端到端具身操作论文导读

DexIL&#xff1a;面向双臂灵巧手柔性操作的端到端具身执行模型 模型架构 输入&#xff1a;   观测Ot&#xff1a; RGB点云&#xff0c;使用PointNet进行编码;   状态St&#xff1a; 双臂末端7x2Dof位姿16x2灵巧手关节位置&#xff0c;只进行归一化&#xff0c;无编码&am…

Linux Win 10 Windows上安装Ollama部署大模型qwen2 7b/15配置启动 LangChain-ChatChat 0.2.10进行对话

Win 10 Window安装Ollama部署qwen2 7b LangChain-ChatChat 环境说明 Win 10 Python 3.11.9 LangChain-ChatChat 0.20 Ollama 0.2.10 Qwen2 1.5b/7b Windows 安装Ollama 下载并安装Windows版Ollama https://ollama.com/download#/ 下载大模型qwen2:1.5b或者qwen2:7b 在命令…

从实时监控到风险智能预警:EasyCVR视频AI智能监控技术在工业制造中的应用

随着科技的不断进步和工业制造领域的持续发展&#xff0c;传统的生产管理方式正逐渐转型&#xff0c;迈向更加智能、高效和安全的新阶段。在这个变革过程中&#xff0c;视频智能监控技术凭借其独特的优势&#xff0c;成为工业制造领域的管理新引擎&#xff0c;推动着从“制造”…