计算机视觉之Vision Transformer图像分类

Vision Transformer(ViT)简介

自注意结构模型的发展,特别是Transformer模型的出现,极大推动了自然语言处理模型的发展。Transformers的计算效率和可扩展性使其能够训练具有超过100B参数的规模空前的模型。ViT是自然语言处理和计算机视觉的结合,能够在图像分类任务上取得良好效果,而不依赖卷积操作。

Vision Transformer(ViT)简介

近些年,随着基于自注意(Self-Attention)结构的模型的发展,特别是Transformer模型的提出,极大地促进了自然语言处理模型的发展。由于Transformers的计算效率和可扩展性,它已经能够训练具有超过100B参数的空前规模的模型。

ViT则是自然语言处理和计算机视觉两个领域的融合结晶。在不依赖卷积操作的情况下,依然可以在图像分类任务上达到很好的效果。

模型结构

ViT模型的主体结构是基于Transformer模型的Encoder部分(部分结构顺序有调整,如:Normalization的位置与标准Transformer不同),其结构图[1]如下:

vit-architecture

模型特点

ViT模型是一种用于图像分类的模型,将原图像划分为多个图像块,然后将这些图像块转换为一维向量,加上类别向量和位置向量作为模型输入。模型主体采用基于Transformer的Encoder结构,但调整了Normalization的位置,其中最主要的结构是Multi-head Attention。模型在Blocks堆叠后接全连接层,使用类别向量的输出进行分类,通常将全连接层称为Head,Transformer Encoder部分称为backbone。

Transformer基本原理

Transformer模型源于2017年的一篇文章[2]。在这篇文章中提出的基于Attention机制的编码器-解码器型结构在自然语言处理领域获得了巨大的成功。模型结构如下图所示:

transformer-architecture

模型训练

模型训练前需要设定损失函数、优化器、回调函数等,以及建议根据项目需要调整epoch_size。训练ViT模型需要很长时间,可以通过输出的信息查看训练的进度和指标。

from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train# define super parameter
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()# construct model
network = ViT()# load ckpt
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)# define learning rate
lr = nn.cosine_decay_lr(min_lr=float(0),max_lr=0.00005,total_step=epoch_size * step_size,step_per_epoch=step_size,decay_epoch=10)# define optimizer
network_opt = nn.Adam(network.trainable_params(), lr, momentum)# define loss function
class CrossEntropySmooth(LossBase):"""CrossEntropy."""def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):super(CrossEntropySmooth, self).__init__()self.onehot = ops.OneHot()self.sparse = sparseself.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)def construct(self, logit, label):if self.sparse:label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)loss = self.ce(logit, label)return lossnetwork_loss = CrossEntropySmooth(sparse=True,reduction="mean",smooth_factor=0.1,num_classes=num_classes)# set checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)# initialize model
# "Ascend + mixed precision" can improve performance
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
else:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")# train model
model.train(epoch_size,dataset_train,callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],dataset_sink_mode=False,)

总结

本案例演示了如何在ImageNet数据集上训练、验证和推断ViT模型。通过讲解ViT模型的关键结构和原理,帮助用户理解Multi-Head Attention、TransformerEncoder和pos_embedding等关键概念。建议用户基于源码深入学习,以更详细地理解ViT模型的原理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/45983.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

STM32+HC-05蓝牙模块学习与使用(内附资料)

引言 随着物联网技术的快速发展,短距离无线通信技术变得日益重要。蓝牙作为一种低功耗、低成本的无线通信技术,在嵌入式系统中得到了广泛应用。本文将详细介绍如何使用STM32微控制器与HC-05蓝牙模块进行通信,实现数据的无线传输。 硬件准备…

prompt第一讲-prompt科普

文章目录 大语言模型输入要求中英翻译助手直接抛出问题描述问题描述(详细)问题描述案例问题描述案例上下文问题为什么要加入上下文 prompt总结prompt心得 大语言模型输入要求 大语言模型本质上就是一个NLP语言模型,语言模型其实就是接受一堆…

ubuntu服务器安装labelimg报错记录

文章目录 报错提示查看报错原因安装报错 报错提示 按照步骤安装完labelimg后,在终端输入labelImg后,报错: (labelimg) rootinteractive59753:~# labelImg ………………Got keys from plugin meta data ("xcb") QFactoryLoader::Q…

日常学习--20240713

1、字节流转字符流时,除了使用字节流实例作为参数,还需要什么参数? 还需要使用字符编码作为参数,保证即使在不同平台上也是使用相同的字符编码(否则会使用平台默认的编码,不同平台默认编码可能不一样&…

hutool处理excel时候空指针小记

如图所示&#xff0c;右侧的会识别不到 参考解决方案&#xff1a; /***Description: 填补空缺位置为null/空串*Param: hutool读取的list*return: 无*Author: y*date: 2024/7/13*/public static void formatHutoolExcelArr(List<List<Object>> list) {if (CollUtil…

企业网络实验dhcp-snooping、ip source check,防非法dhcp服务器、自动获取ip(虚拟机充当DHCP服务器)、禁手动修改IP

文章目录 需求相关配置互通性配置配置vmware虚拟机&#xff08;dhcp&#xff09;分配IP服务配置dhcp relay&#xff08;dhcp中继&#xff09;配置dhcp-snooping&#xff08;防非法dhcp服务器&#xff09;配置ip source check&#xff08;禁手动修改IP&#xff09; DHCP中继&…

Android ListView

ListView ListView是以列表的形式展示具体内容的控件&#xff0c;ListView能够根据数据的长度自适应显示&#xff0c;如手机通讯录、短消息列表等都可以使用ListView实现。如图1所示是两个ListView&#xff0c;上半部分是数组形式的ListView&#xff0c;下半部分是简单列表Lis…

《Linux系统编程篇》认识在linux上的文件 ——基础篇

前言 Linux系统编程的文件操作如同掌握了一把魔法钥匙&#xff0c;打开了无尽可能性的大门。在这个世界中&#xff0c;你需要了解文件描述符、文件权限、文件路径等基础知识&#xff0c;就像探险家需要了解地图和指南针一样。而了解这些基础知识&#xff0c;就像学会了魔法咒语…

【C++】指针学习 知识点总结+代码记录

一.示例代码知识点总结 1. 基本指针操作 指针声明和初始化&#xff1a;int* ptr_a a; 表示声明了一个指向整型的指针&#xff0c;并初始化为指向数组a的首地址。引用和指针的区别&#xff1a;int& i2 i; 声明了一个整型引用&#xff0c;绑定到变量i上&#xff0c;而int…

k3s配置docker容器/dev/shm

在使用K3s和Docker容器时&#xff0c;如果你发现容器的 /dev/shm 默认大小是64MB&#xff0c;并且需要扩大它的大小&#xff0c;可以通过以下几种方法实现。 方法1&#xff1a;使用 Docker 的 --shm-size 选项 如果你直接使用 Docker 运行容器&#xff0c;可以通过 --shm-siz…

jenkins系列-07.轻易级jpom安装

jpom是一个容器化服务管理工具&#xff1a;在线构建&#xff0c;自动部署&#xff0c;日常运维, 比jenkins轻量多了。 本篇介绍mac m1安装jpom: #下载&#xff1a;https://jpom.top/pages/all-downloads/ 解压&#xff1a;/Users/jelex/Documents/work/jpom-2.10.40 启动前修…

git 分支介绍

在Git版本控制系统中&#xff0c;分支&#xff08;Branch&#xff09;是一种非常强大的功能&#xff0c;它允许开发者在不影响主代码库&#xff08;如master分支&#xff09;的情况下进行开发或修复工作。你提到的五种分支类型是在Gitflow工作流&#xff08;Gitflow Workflow&a…

css基础(1)

CSS CCS Syntax CSS 规则由选择器和声明块组成。 CSS选择器 CSS选择器用于查找想要设置样式的HTML元素 一般选择器分为五类 Simple selectors (select elements based on name, id, class) 简单选择器&#xff08;根据名称、id、类选择元素&#xff09; //页面上的所有 …

Git配置笔记

文章目录 Git配置一、Git配置文件1.1 配置文件位置1.2 参考 二、换行符相关2.1 背景2.2 相关配置2.3 推荐配置2.4 参考资料 Git配置 一、Git配置文件 1.1 配置文件位置 Git 自带一个 git config 的工具来帮助设置控制 Git 外观和行为的配置变量。 这些变量存储在三个不同的位…

Web 性能入门指南-1.5 创建 Web 性能优化文化的最佳实践

最成功的网站都有什么共同点&#xff1f;那就是他们都有很强的网站性能和可用性文化。以下是一些经过验证的有效技巧和最佳实践&#xff0c;可帮助您建立健康、快乐、值得庆祝的性能文化。 创建强大的性能优化文化意味着在你的公司或团队中创建一个如下所示的反馈循环&#xff…

MySQL入门学习-深入索引.匹配顺序

在 MySQL 中&#xff0c;索引的匹配顺序是指在查询执行时&#xff0c;数据库系统根据查询条件中涉及的列和索引的结构&#xff0c;决定如何使用索引来提高查询效率的方式。 以下是关于深入索引和匹配顺序的一些详细信息&#xff1a; 一、索引的类型&#xff1a; - B-Tree 索引…

centos7|Linux操作系统|编译最新的OpenSSL-3.3,制作rpm安装包

一、 为什么需要编译rpm包 通常&#xff0c;我们需要安装某个软件&#xff0c;尤其是在centos7这样的操作系统&#xff0c;一般是通过yum包管理器来安装软件&#xff0c;yum的作用是管理rpm包的依赖关系&#xff0c;自动的处理rpm包的安装顺序&#xff0c;安装依赖等的相关问…

交换机和路由器的工作流程

1、交换机工作流程&#xff1a; 将接口中的电流识别为二进制&#xff0c;并转换成数据帧&#xff0c;交换机会记录学习该数据帧的源MAC地址&#xff0c;并将其端口关联起来记录在MAC地址表中。然后查看MAC地址表来查找目标MAC地址&#xff0c;会有一下一些情况&#xff1a; MA…

通过Bugly上报的日志查找崩溃闪退原因

第一步&#xff0c;解析堆栈信息 在bugly上收集到的信息是这样的 0x000000010542e46c 0x0000000104db4000 6792300 OS应用发生崩溃时&#xff0c;系统会生成一份崩溃日志&#xff0c;这份日志中包含了崩溃时的堆栈信息&#xff0c;但这些堆栈信息并非直接指向源代码&#x…

【漏洞复现】某赛通 电子文档安全管理系统 多个接口存在远程命令执行漏洞

免责声明&#xff1a; 本文内容旨在提供有关特定漏洞或安全漏洞的信息&#xff0c;以帮助用户更好地了解可能存在的风险。公布此类信息的目的在于促进网络安全意识和技术进步&#xff0c;并非出于任何恶意目的。阅读者应该明白&#xff0c;在利用本文提到的漏洞信息或进行相关测…