【MindSpore学习打卡】应用实践-计算机视觉-深入解析 Vision Transformer(ViT):从原理到实践

在近年来的深度学习领域,Transformer模型凭借其在自然语言处理(NLP)中的卓越表现,迅速成为研究热点。尤其是基于自注意力(Self-Attention)机制的模型,更是推动了NLP的飞速发展。然而,随着研究的深入,Transformer模型不仅在NLP领域大放异彩,还被引入到计算机视觉领域,形成了Vision Transformer(ViT)。ViT模型在不依赖传统卷积神经网络(CNN)的情况下,依然能够在图像分类任务中取得优异的效果。本文将深入解析ViT模型的结构、特点,并通过代码示例展示如何使用MindSpore框架实现ViT模型的训练、验证和推理。

ViT模型结构

ViT模型的主体结构基于Transformer模型的编码器(Encoder)部分,其整体结构如下图所示:

vit-architecture

模型特点

为什么要使用Patch Embedding?

在传统的Transformer模型中,输入通常是一维的词向量序列,而图像数据是二维的像素矩阵。为了将图像数据转换为Transformer可以处理的形式,我们需要将图像划分为多个小块(patch),并将每个patch转换为一维向量。这一过程称为Patch Embedding。通过这种方式,我们可以将图像数据转换为类似于词向量的形式,从而利用Transformer模型处理图像数据。
为什么要使用位置编码(Position Embedding)?

由于Transformer模型在处理输入序列时不考虑顺序信息,因此在图像数据中,patch之间的空间关系可能会丢失。为了解决这个问题,我们引入了位置编码(Position Embedding),它为每个patch增加了位置信息,使得模型能够识别不同patch之间的空间关系。这对于保留图像的空间结构信息非常重要。

  1. Patch Embedding:输入图像被划分为多个patch(图像块),然后将每个二维patch转换为一维向量,并加上类别向量和位置向量作为模型输入。
  2. Transformer Encoder:模型主体的Block结构基于Transformer的Encoder部分,主要结构是多头注意力(Multi-Head Attention)和前馈神经网络(Feed Forward)。
  3. 分类头(Head):在Transformer Encoder堆叠后接一个全连接层,用于分类。

环境准备与数据读取

开始实验之前,请确保本地已经安装了Python环境和MindSpore。

首先下载本案例的数据集,该数据集是从ImageNet中筛选出来的子集。数据集路径结构如下:

.dataset/├── ILSVRC2012_devkit_t12.tar.gz├── train/├── infer/└── val/
from download import download
import os
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms# 下载数据集
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"
path = download(dataset_url, path, kind="zip", replace=True)data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)trans_train = [transforms.RandomCropDecodeResize(size=224, scale=(0.08, 1.0), ratio=(0.75, 1.333)),transforms.RandomHorizontalFlip(prob=0.5),transforms.Normalize(mean=mean, std=std),transforms.HWC2CHW()
]dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)

Transformer基本原理

Transformer模型源于2017年的一篇文章,其主要结构为多个编码器和解码器模块。编码器和解码器由多头注意力(Multi-Head Attention)、前馈神经网络(Feed Forward)、归一化层(Normalization)和残差连接(Residual Connection)组成。

Self-Attention机制

Self-Attention机制是Transformer的核心,其主要步骤如下:

  1. 输入向量映射:将输入向量映射成Query(Q)、Key(K)、Value(V)三个向量。
  2. 计算注意力权重:通过点乘计算Query和Key的相似性,并通过Softmax函数归一化。
  3. 加权求和:使用注意力权重对Value进行加权求和,得到最终的Attention输出。

以下是Self-Attention的代码实现:

from mindspore import nn, opsclass Attention(nn.Cell):def __init__(self, dim: int, num_heads: int = 8, keep_prob: float = 1.0, attention_keep_prob: float = 1.0):super(Attention, self).__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = ms.Tensor(head_dim ** -0.5)self.qkv = nn.Dense(dim, dim * 3)self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)self.out = nn.Dense(dim, dim)self.out_drop = nn.Dropout(p=1.0-keep_prob)self.attn_matmul_v = ops.BatchMatMul()self.q_matmul_k = ops.BatchMatMul(transpose_b=True)self.softmax = nn.Softmax(axis=-1)def construct(self, x):b, n, c = x.shapeqkv = self.qkv(x)qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))q, k, v = ops.unstack(qkv, axis=0)attn = self.q_matmul_k(q, k)attn = ops.mul(attn, self.scale)attn = self.softmax(attn)attn = self.attn_drop(attn)out = self.attn_matmul_v(attn, v)out = ops.transpose(out, (0, 2, 1, 3))out = ops.reshape(out, (b, n, c))out = self.out(out)out = self.out_drop(out)return out

Transformer Encoder

为什么要使用残差连接(Residual Connection)和归一化层(Normalization Layer)?

在深层神经网络中,随着层数的增加,梯度消失和梯度爆炸的问题变得越来越严重。残差连接通过在每一层加上输入的跳跃连接,可以有效缓解这些问题,确保信息能够顺利传递。此外,归一化层(如LayerNorm)可以加速模型的训练,并提高模型的稳定性和泛化能力。这些技术的结合,使得Transformer模型能够在更深的层次上进行有效的训练。

Transformer Encoder由多层Self-Attention和前馈神经网络(Feed Forward)组成,通过残差连接和归一化层增强模型的训练效果和泛化能力。

class FeedForward(nn.Cell):def __init__(self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, activation: nn.Cell = nn.GELU, keep_prob: float = 1.0):super(FeedForward, self).__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.dense1 = nn.Dense(in_features, hidden_features)self.activation = activation()self.dense2 = nn.Dense(hidden_features, out_features)self.dropout = nn.Dropout(p=1.0-keep_prob)def construct(self, x):x = self.dense1(x)x = self.activation(x)x = self.dropout(x)x = self.dense2(x)x = self.dropout(x)return xclass ResidualCell(nn.Cell):def __init__(self, cell):super(ResidualCell, self).__init__()self.cell = celldef construct(self, x):return self.cell(x) + xclass TransformerEncoder(nn.Cell):def __init__(self, dim: int, num_layers: int, num_heads: int, mlp_dim: int, keep_prob: float = 1., attention_keep_prob: float = 1.0, drop_path_keep_prob: float = 1.0, activation: nn.Cell = nn.GELU, norm: nn.Cell = nn.LayerNorm):super(TransformerEncoder, self).__init__()layers = []for _ in range(num_layers):normalization1 = norm((dim,))normalization2 = norm((dim,))attention = Attention(dim=dim, num_heads=num_heads, keep_prob=keep_prob, attention_keep_prob=attention_keep_prob)feedforward = FeedForward(in_features=dim, hidden_features=mlp_dim, activation=activation, keep_prob=keep_prob)layers.append(nn.SequentialCell([ResidualCell(nn.SequentialCell([normalization1, attention])), ResidualCell(nn.SequentialCell([normalization2, feedforward]))]))self.layers = nn.SequentialCell(layers)def construct(self, x):return self.layers(x)

ViT模型的输入

ViT模型通过将输入图像划分为多个patch,将每个patch转换为一维向量,并加上类别向量和位置向量作为模型输入。以下是Patch Embedding的代码实现:

class PatchEmbedding(nn.Cell):MIN_NUM_PATCHES = 4def __init__(self, image_size: int = 224, patch_size: int = 16, embed_dim: int = 768, input_channels: int = 3):super(PatchEmbedding, self).__init__()self.image_size = image_sizeself.patch_size = patch_sizeself.num_patches = (image_size // patch_size) ** 2self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)def construct(self, x):x = self.conv(x)b, c, h, w = x.shapex = ops.reshape(x, (b, c, h * w))x = ops.transpose(x, (0, 2, 1))return x

整体构建ViT

以下代码构建了一个完整的ViT模型:

from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameterdef init(init_type, shape, dtype, name, requires_grad):initial = initializer(init_type, shape, dtype).init_data()return Parameter(initial, name=name, requires_grad=requires_grad)class ViT(nn.Cell):def __init__(self, image_size: int = 224, input_channels: int = 3, patch_size: int = 16, embed_dim: int = 768, num_layers: int = 12, num_heads: int = 12, mlp_dim: int = 3072, keep_prob: float = 1.0, attention_keep_prob: float = 1.0, drop_path_keep_prob: float = 1.0, activation: nn.Cell = nn.GELU, norm: Optional[nn.Cell] = nn.LayerNorm, pool: str = 'cls') -> None:super(ViT, self).__init__()self.patch_embedding = PatchEmbedding(image_size=image_size, patch_size=patch_size, embed_dim=embed_dim, input_channels=input_channels)num_patches = self.patch_embedding.num_patchesself.cls_token = init(init_type=Normal(sigma=1.0), shape=(1, 1, embed_dim), dtype=ms.float32, name='cls', requires_grad=True)self.pos_embedding = init(init_type=Normal(sigma=1.0), shape=(1, num_patches + 1, embed_dim), dtype=ms.float32, name='pos_embedding', requires_grad=True)self.pool = poolself.pos_dropout = nn.Dropout(p=1.0-keep_prob)self.norm = norm((embed_dim,))self.transformer = TransformerEncoder(dim=embed_dim, num_layers=num_layers, num_heads=num_heads, mlp_dim=mlp_dim, keep_prob=keep_prob, attention_keep_prob=attention_keep_prob, drop_path_keep_prob=drop_path_keep_prob, activation=activation, norm=norm)self.dropout = nn.Dropout(p=1.0-keep_prob)self.dense = nn.Dense(embed_dim, num_classes)def construct(self, x):x = self.patch_embedding(x)cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))x = ops.concat((cls_tokens, x), axis=1)x += self.pos_embeddingx = self.pos_dropout(x)x = self.transformer(x)x = self.norm(x)x = x[:, 0]if self.training:x = self.dropout(x)x = self.dense(x)return x

模型训练与推理

模型训练

模型训练前,需要设定损失函数、优化器和回调函数。以下是训练ViT模型的代码:

from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train# 定义超参数
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()# 构建模型
network = ViT()# 加载预训练模型参数
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"
vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)# 定义学习率
lr = nn.cosine_decay_lr(min_lr=float(0), max_lr=0.00005, total_step=epoch_size * step_size, step_per_epoch=step_size, decay_epoch=10)# 定义优化器
network_opt = nn.Adam(network.trainable_params(), lr, momentum)# 定义损失函数
class CrossEntropySmooth(LossBase):def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):super(CrossEntropySmooth, self).__init__()self.onehot = ops.OneHot()self.sparse = sparseself.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)def construct(self, logit, label):if self.sparse:label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)loss = self.ce(logit, label)return lossnetwork_loss = CrossEntropySmooth(sparse=True, reduction="mean", smooth_factor=0.1, num_classes=num_classes)# 设置检查点
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)# 初始化模型
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
else:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")# 训练模型
model.train(epoch_size, dataset_train, callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)], dataset_sink_mode=False)

在这里插入图片描述

模型验证

模型验证过程主要应用了ImageFolderDataset,CrossEntropySmooth和Model等接口。以下是验证ViT模型的代码:

dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)trans_val = [transforms.Decode(),transforms.Resize(224 + 32),transforms.CenterCrop(224),transforms.Normalize(mean=mean, std=std),transforms.HWC2CHW()
]dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)# 构建模型
network = ViT()# 加载预训练模型参数
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)network_loss = CrossEntropySmooth(sparse=True, reduction="mean", smooth_factor=0.1, num_classes=num_classes)# 定义评价指标
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(), 'Top_5_Accuracy': train.Top5CategoricalAccuracy()}if ascend_target:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2")
else:model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0")# 验证模型
result = model.eval(dataset_val)
print(result)

模型推理

在进行模型推理之前,首先要定义一个对推理图片进行数据预处理的方法。以下是推理ViT模型的代码:

dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)trans_infer = [transforms.Decode(),transforms.Resize([224, 224]),transforms.Normalize(mean=mean, std=std),transforms.HWC2CHW()
]dataset_infer = dataset_infer.map(operations=trans_infer, input_columns=["image"], num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)# 读取推理数据
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):image = image["image"]image = ms.Tensor(image)prob = model.predict(image)label = np.argmax(prob.asnumpy(), axis=1)mapping = index2label()output = {int(label): mapping[int(label)]}print(output)show_result(img="./dataset/infer/n01440764/ILSVRC2012_test_00000279.JPEG", result=output, out_file="./dataset/infer/ILSVRC2012_test_00000279.JPEG")

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/40316.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

pandas,dataframe使用笔记

目录 新建一个dataframe不带列名带列名 dataframe添加一行内容查看dataframe某列的数据类型新建dataframe时设置了列名,则数据类型为object dataframe的保存保存为csv文件保存为excel文件 dataframe属于pandas 新建一个dataframe 不带列名 df pd.DataFrame() 带…

GuLi商城-商品服务-API-品牌管理-效果优化与快速显示开关

<template><div class"mod-config"><el-form :inline"true" :model"dataForm" keyup.enter.native"getDataList()"><el-form-item><el-input v-model"dataForm.key" placeholder"参数名&qu…

java集合(1)

目录 一.集合概述 二. 集合体系概述 1. Collection接口 1.1 List接口 1.2 Set接口 2. Map接口 三. ArrayList 1.ArrayList常用方法 2.ArrayList遍历 2.1 for循环 2.2 增强for循环 2.3 迭代器遍历 一.集合概述 我们经常需要存储一些数据类型相同的元素,之前我们学过…

C++ 仿QT信号槽二

// 实现原理 // 每个signal映射到bitset位&#xff0c;全集 // 每个slot做为signal的bitset子集 // signal全集触发&#xff0c;标志位有效 // flip将触发事件队列前置 // slot检测智能指针全集触发的标志位&#xff0c;主动运行子集绑定的函数 // 下一帧对bitset全集进行触发清…

【C++】 解决 C++ 语言报错:Segmentation Fault

文章目录 引言 段错误&#xff08;Segmentation Fault&#xff09;是 C 编程中常见且令人头疼的错误之一。段错误通常发生在程序试图访问未被允许的内存区域时&#xff0c;导致程序崩溃。本文将深入探讨段错误的产生原因、检测方法及其预防和解决方案&#xff0c;帮助开发者在…

4.2 投影

一、投影和投影矩阵 我们以下面两个问题开始&#xff0c;问题一是为了展示投影是很容易视觉化的&#xff0c;问题二是关于 “投影矩阵”&#xff08;projection matrices&#xff09;—— 对称矩阵且 P 2 P P^2P P2P。 b \boldsymbol b b 的投影是 P b P\boldsymbol b Pb。…

2024年7月5日 (周五) 叶子游戏新闻

老板键工具来唤去: 它可以为常用程序自定义快捷键&#xff0c;实现一键唤起、一键隐藏的 Windows 工具&#xff0c;并且支持窗口动态绑定快捷键&#xff08;无需设置自动实现&#xff09;。 卸载工具 HiBitUninstaller: Windows上的软件卸载工具 《乐高地平线大冒险》为何不登陆…

江汉大学刘春萌同学整理的wifi模块 上传mqtt实验步骤

一.固件烧录 1.打开安信可官网 2.点击wifi模组系列的ESP8266 3.点击各类固件后选择固件号1471下载 4.打开烧录工具将下载的二进制文件导入并将后面的起始地址写为0x00000,下面勾选40mhz QIO 8Mbit点击start下载即可 二.本地部署mqtt服务器(windows) 1.下载mosquitto后有一个m…

Java并发编程知识整理笔记

目录 ​1. 什么是线程和进程&#xff1f; 线程与进程有什么区别&#xff1f; 那什么是上下文切换&#xff1f; 进程间怎么通信&#xff1f; 什么是用户线程和守护线程&#xff1f; 2. 并行和并发的区别&#xff1f; 3. 创建线程的几种方式&#xff1f; Runnable接口和C…

Qt实现流动的管道效果代码示例

在现代图形用户界面&#xff08;GUI&#xff09;应用程序中&#xff0c;动态效果可以显著增强用户体验。本文将介绍如何使用Qt框架实现一个流动的管道效果。我们将通过自定义QWidget来绘制管道&#xff0c;并使用定时器来实现流动效果。 1. 准备工作 首先&#xff0c;确保你已…

HMI 的 UI 风格创造奇迹

HMI 的 UI 风格创造奇迹

Laravel5+mycat 报错 “Packets out of order”

背景 近期对负责项目&#xff0c;配置了一套 主从复制的 MySQL 集群 使用了中间件 mycat 但测试发现&#xff0c;替换了原来的数据连接后&#xff0c;会出现 Packets out of order 的报错 同时注意到&#xff0c;有的框架代码中竟然也会失效&#xff0c;比如 controller 类中&…

Linux:进程间通信(一.初识进程间通信、匿名管道与命名管道、共享内存)

上次结束了基础IO&#xff1a;Linux&#xff1a;基础IO&#xff08;三.软硬链接、动态库和静态库、动精态库的制作和加载&#xff09; 文章目录 1.认识进程间通信2.管道2.1匿名管道2.2pipe()函数 —创建匿名管道2.3匿名管道的四种情况2.4管道的特征 3.基于管道的进程池设计4.命…

Vue3学习笔记(n.0)

vue指令之v-for 首先创建自定义组件&#xff08;practice5.vue&#xff09;&#xff1a; <!--* Author: RealRoad1083425287qq.com* Date: 2024-07-05 21:28:45* LastEditors: Mei* LastEditTime: 2024-07-05 21:35:40* FilePath: \Fighting\new_project_0705\my-vue-app\…

重载一元运算符

自增运算符 #include<iostream> using namespace std; class CGirl { public:string name;int ranking;CGirl() { name "zhongge"; ranking 5; }void show() const{ cout << "name : "<<name << " , ranking : " <…

cmake编译源码教程(一)

1、介绍 本次博客介绍使用cmake编译平面点云分割的源代码,其对室内点云以及TLS点云中平面结构进行分割,分割效果如下: 2、编译过程 2.1 源代码下载 首先,下载源代码,如下所示,在该文件夹下新建一个build文件夹,用于后续生成sln工程。 同时,由于该库依赖open…

自动化设备上位机设计 二

目录 一 设计原型 二 后台代码 一 设计原型 二 后台代码 namespace 自动化上位机设计 {public partial class Form1 : Form{public Form1(){InitializeComponent();timer1.Enabled true;timer1.Tick Timer1_Tick;}private void Timer1_Tick(object? sender, EventArgs e)…

您的私人办公室!-----ONLYOFFICE8.1版本的桌面编辑器测评

随时随地创建并编辑文档&#xff0c;还可就其进行协作 ONLYOFFICE 文档是一款强大的在线编辑器&#xff0c;为您使用的平台提供文本文档、电子表格、演示文稿、表单和 PDF 编辑工具。 网页地址链接&#xff1a; https://www.onlyoffice.com/zh/office-suite.aspxhttps://www…

AJAX-day1:

注&#xff1a;文件布局&#xff1a; 一、AJAX的概念&#xff1a; AJAX是浏览器与服务器进行数据通信的技术 >把数据变活 二、AJAX的使用&#xff1a; 使用axios库&#xff0c;与服务器进行数据通信 基于XMLHttpRequest封装&#xff0c;代码简单 Vue,React项目使用 学习…

哪个品牌的加密软件稳定方便使用?

一、什么是企业加密软件&#xff1f; 企业加密软件是一种用于保护企业内部数据安全的工具。在数字化时代&#xff0c;随着数据量的爆炸式增长&#xff0c;信息安全和隐私保护变得愈发重要。企业加密软件作为保障数据安全的关键工具&#xff0c;受到越来越多用户的青睐。 企业…