昇思25天学习打卡营第16天 | Vision Transformer图像分类

昇思25天学习打卡营第16天 | Vision Transformer图像分类

文章目录

  • 昇思25天学习打卡营第16天 | Vision Transformer图像分类
    • Vision Transform(ViT)模型
      • Transformer
        • Attention模块
        • Encoder模块
      • ViT模型输入
    • 模型构建
      • Multi-Head Attention模块
      • Encoder模块
      • Patch Embedding模块
      • ViT网络
    • 总结
    • 打卡

Vision Transform(ViT)模型

ViT是NLP和CV领域的融合,可以在不依赖于卷积操作的情况下在图像分类任务上达到很好的效果。

ViT模型的主体结构是基于Transformer的Encoder部分。

Transformer

Transformer由很多Encoder和Decoder模块构成,包括多头注意力(Multi-Head Attention)层,Feed Forward层,Normalization层和残差连接(Residual Connection)。
encoder-decoder
多头注意力结构基于自注意力机制(Self-Attention),是多个Self-Attention的并行组成。

Attention模块

Attention的核心在于为输入向量的每个单词学习一个权重。

  1. 最初的输入向量首先经过Embedding层映射为Q(Query),K(Key),V(Value)三个向量。
  2. 通过将Q和所有K进行点乘初一维度平方根,得到向量间的相似度,通过softmax获取每词向量之间的关系权重。
  3. 利用关系权重对词向量的V加权求和,得到自注意力值。
    self-attention
    多头注意力机制只是对self-attention的并行化:
    multi-head-attention
Encoder模块

ViT中的Encoder相对于标准Transformer,主要在于将Normolization放在self-attention和Feed Forward之前,其他结构与标准Transformer相同。
vit-encoder

ViT模型输入

传统Transformer主要应用于自然语言处理的一维词向量,而图像时二维矩阵的堆叠。
在ViT中:

  1. 通过卷积将输入图像在每个channel上划分为 16 × 16 16\times 16 16×16个patch。如果输入 224 × 224 224\times224 224×224的图像,则每一个patch的大小为 14 × 14 14\times 14 14×14
  2. 将每一个patch拉伸为一个一维向量,得到近似词向量堆叠的效果。如将 14 × 14 14\times14 14×14展开为 196 196 196的向量。
    这一部分Patch Embedding用来替换Transformer中Word Embedding,用作网络中的图像输入。

模型构建

Multi-Head Attention模块

from mindspore import nn, opsclass Attention(nn.Cell):def __init__(self,dim: int,num_heads: int = 8,keep_prob: float = 1.0,attention_keep_prob: float = 1.0):super(Attention, self).__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = ms.Tensor(head_dim ** -0.5)self.qkv = nn.Dense(dim, dim * 3)self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)self.out = nn.Dense(dim, dim)self.out_drop = nn.Dropout(p=1.0-keep_prob)self.attn_matmul_v = ops.BatchMatMul()self.q_matmul_k = ops.BatchMatMul(transpose_b=True)self.softmax = nn.Softmax(axis=-1)def construct(self, x):"""Attention construct."""b, n, c = x.shapeqkv = self.qkv(x)qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))q, k, v = ops.unstack(qkv, axis=0)attn = self.q_matmul_k(q, k)attn = ops.mul(attn, self.scale)attn = self.softmax(attn)attn = self.attn_drop(attn)out = self.attn_matmul_v(attn, v)out = ops.transpose(out, (0, 2, 1, 3))out = ops.reshape(out, (b, n, c))out = self.out(out)out = self.out_drop(out)return out

Encoder模块

from typing import Optional, Dictclass FeedForward(nn.Cell):def __init__(self,in_features: int,hidden_features: Optional[int] = None,out_features: Optional[int] = None,activation: nn.Cell = nn.GELU,keep_prob: float = 1.0):super(FeedForward, self).__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.dense1 = nn.Dense(in_features, hidden_features)self.activation = activation()self.dense2 = nn.Dense(hidden_features, out_features)self.dropout = nn.Dropout(p=1.0-keep_prob)def construct(self, x):"""Feed Forward construct."""x = self.dense1(x)x = self.activation(x)x = self.dropout(x)x = self.dense2(x)x = self.dropout(x)return xclass ResidualCell(nn.Cell):def __init__(self, cell):super(ResidualCell, self).__init__()self.cell = celldef construct(self, x):"""ResidualCell construct."""return self.cell(x) + xclass TransformerEncoder(nn.Cell):def __init__(self,dim: int,num_layers: int,num_heads: int,mlp_dim: int,keep_prob: float = 1.,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: nn.Cell = nn.LayerNorm):super(TransformerEncoder, self).__init__()layers = []for _ in range(num_layers):normalization1 = norm((dim,))normalization2 = norm((dim,))attention = Attention(dim=dim,num_heads=num_heads,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob)feedforward = FeedForward(in_features=dim,hidden_features=mlp_dim,activation=activation,keep_prob=keep_prob)layers.append(nn.SequentialCell([ResidualCell(nn.SequentialCell([normalization1, attention])),ResidualCell(nn.SequentialCell([normalization2, feedforward]))]))self.layers = nn.SequentialCell(layers)def construct(self, x):"""Transformer construct."""return self.layers(x)

Patch Embedding模块

class PatchEmbedding(nn.Cell):MIN_NUM_PATCHES = 4def __init__(self,image_size: int = 224,patch_size: int = 16,embed_dim: int = 768,input_channels: int = 3):super(PatchEmbedding, self).__init__()self.image_size = image_sizeself.patch_size = patch_sizeself.num_patches = (image_size // patch_size) ** 2self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)def construct(self, x):"""Path Embedding construct."""x = self.conv(x)b, c, h, w = x.shapex = ops.reshape(x, (b, c, h * w))x = ops.transpose(x, (0, 2, 1))return x

ViT网络

from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameterdef init(init_type, shape, dtype, name, requires_grad):"""Init."""initial = initializer(init_type, shape, dtype).init_data()return Parameter(initial, name=name, requires_grad=requires_grad)class ViT(nn.Cell):def __init__(self,image_size: int = 224,input_channels: int = 3,patch_size: int = 16,embed_dim: int = 768,num_layers: int = 12,num_heads: int = 12,mlp_dim: int = 3072,keep_prob: float = 1.0,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: Optional[nn.Cell] = nn.LayerNorm,pool: str = 'cls') -> None:super(ViT, self).__init__()self.patch_embedding = PatchEmbedding(image_size=image_size,patch_size=patch_size,embed_dim=embed_dim,input_channels=input_channels)num_patches = self.patch_embedding.num_patchesself.cls_token = init(init_type=Normal(sigma=1.0),shape=(1, 1, embed_dim),dtype=ms.float32,name='cls',requires_grad=True)self.pos_embedding = init(init_type=Normal(sigma=1.0),shape=(1, num_patches + 1, embed_dim),dtype=ms.float32,name='pos_embedding',requires_grad=True)self.pool = poolself.pos_dropout = nn.Dropout(p=1.0-keep_prob)self.norm = norm((embed_dim,))self.transformer = TransformerEncoder(dim=embed_dim,num_layers=num_layers,num_heads=num_heads,mlp_dim=mlp_dim,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob,drop_path_keep_prob=drop_path_keep_prob,activation=activation,norm=norm)self.dropout = nn.Dropout(p=1.0-keep_prob)self.dense = nn.Dense(embed_dim, num_classes)def construct(self, x):"""ViT construct."""x = self.patch_embedding(x)cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))x = ops.concat((cls_tokens, x), axis=1)x += self.pos_embeddingx = self.pos_dropout(x)x = self.transformer(x)x = self.norm(x)x = x[:, 0]if self.training:x = self.dropout(x)x = self.dense(x)return x

总结

这一节对Transformer进行介绍,包括Attention机制、并行化的Attention以及Encoder模块。由于传统Transformer主要作用于一维的词向量,因此二维图像需要被转换为类似的一维词向量堆叠,在ViT中通过将Patch Embedding解决这一问题,并用来代替传统Transformer中的Word Embedding作为网络的输入。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/46145.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

工业三防平板助力工厂生产数据实时管理

在当今高度数字化和智能化的工业生产环境中,工业三防平板正逐渐成为工厂实现生产数据实时管理的得力助手。这种创新的技术设备不仅能够在恶劣的工业环境中稳定运行,还为工厂的生产流程优化、效率提升和质量控制带来了前所未有的机遇。 工业生产场景通常充…

机器学习——数据预处理和特征工程(sklearn)

目录 一、数据挖掘流程 1. 获取数据 2. 数据预处理 3. 特征工程 4. 建模,测试模型并预测出结果 5. 验证模型效果 二、sklearn中的相关包 1.sklearn.preprocessing 2.sklearn.Impute 3.sklearn.feature_selection 4.sklearn.decomposition 三、数据预处理…

【网络安全】PostMessage:分析JS实现XSS

未经许可,不得转载。 文章目录 前言示例正文 前言 PostMessage是一个用于在网页间安全地发送消息的浏览器 API。它允许不同的窗口(例如,来自同一域名下的不同页面或者不同域名下的跨域页面)进行通信,而无需通过服务器…

【Arduino IDE】安装及开发环境、ESP32库

一、Arduino IDE下载 二、Arduino IDE安装 三、ESP32库 四、Arduino-ESP32库配置 五、新建ESP32-S3N15R8工程文件 乐鑫官网 Arduino官方下载地址 Arduino官方社区 Arduino中文社区 一、Arduino IDE下载 ESP-IDF、MicroPython和Arduino是三种不同的开发框架,各自适…

定制开发AI智能名片商城微信小程序在私域流量池构建中的应用与策略

摘要 在数字经济蓬勃发展的今天,私域流量已成为企业竞争的新战场。定制开发AI智能名片商城微信小程序,作为私域流量池构建的创新工具,正以其独特的优势助力企业实现用户资源的深度挖掘与高效转化。本文深入探讨了定制开发AI智能名片商城微信…

AIoTedge智能物联网边缘计算平台:引领未来智能边缘技术

引言 随着物联网技术的飞速发展,我们正步入一个万物互联的时代。AIoTedge智能物联网边缘计算平台,以其创新的边云协同架构,为智能设备和系统提供了强大的数据处理和智能决策能力,开启了智能物联网的新篇章。 智能边缘计算平台的核…

LLaMA-Factory

文章目录 一、关于 LLaMA-Factory项目特色性能指标 二、如何使用1、安装 LLaMA Factory2、数据准备3、快速开始4、LLaMA Board 可视化微调5、构建 DockerCUDA 用户:昇腾 NPU 用户:不使用 Docker Compose 构建CUDA 用户:昇腾 NPU 用户&#xf…

【Java项目笔记】01项目介绍

一、技术框架 1.后端服务 Spring Boot为主体框架 Spring MVC为Web框架 MyBatis、MyBatis Plus为持久层框架,负责数据库的读写 阿里云短信服务 2.存储服务 MySql redis缓存数据 MinIO为对象存储,存储非结构化数据(图片、视频、音频&a…

推荐一款处理TCP数据的架构--EasyTcp4Net

EasyTcp4Net是一个基于c# Pipe,ReadonlySequence的高性能Tcp通信库,旨在提供稳定,高效,可靠的tcp通讯服务。 基础的消息通讯 重试机制 超时机制 SSL加密通信支持 KeepAlive 流量背压控制 粘包和断包处理 (支持固定头处理,固定长度处理,固定字符处理) 日志支持Pipe &…

Spring MVC 的常用注解

RequestMapping 和 RestController注解 上面两个注解,是Spring MCV最常用的注解。 RequestMapping , 他是用来注册接口的路由映射。 路由映射:当一个用户访问url时,将用户的请求对应到某个方法或类的过程叫做路由映射。 Reques…

定制QCustomPlot 带有ListView的QCustomPlot 全网唯一份

定制QCustomPlot 带有ListView的QCustomPlot 文章目录 定制QCustomPlot 带有ListView的QCustomPlot摘要需求描述实现关键字: Qt、 QCustomPlot、 魔改、 定制、 控件 摘要 先上效果,是你想要的,再看下面的分解,顺便点赞搜藏一下;不是直接右上角。 QCustomPlot是一款…

基于springboot+vue+uniapp的驾校预约平台小程序

开发语言:Java框架:springbootuniappJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包&#…

35.UART(通用异步收发传输器)-RS232(2)

(1)RS232接收模块visio框图: (2)接收模块Verilog代码编写: /* 常见波特率: 4800、9600、14400、115200 在系统时钟为50MHz时,对应计数为: (1/4800) * 10^9 /20 -1 10416 …

【作业】 贪心算法1

Tips:三题尚未完成。 #include <iostream> #include <algorithm> using namespace std; int a[110]; int main(){int n,r,sum0;cin>>n>>r;for(int i0;i<n;i){cin>>a[i];}sort(a0,an);for(int i0;i<n;i){if(i>r){a[i]a[i-r]a[i];}suma[…

大气热力学(8)——热力学图的应用之一(气象要素求解)

本篇文章源自我在 2021 年暑假自学大气物理相关知识时手写的笔记&#xff0c;现转化为电子版本以作存档。相较于手写笔记&#xff0c;电子版的部分内容有补充和修改。笔记内容大部分为公式的推导过程。 文章目录 8.1 复习斜 T-lnP 图上的几种线8.1.1 等温线和等压线8.1.2 干绝热…

连锁零售门店分析思路-人货场 数据分析

连锁零售门店分析思路 以下是一个连锁零售门店的分析思路&#xff1a; 一、市场与竞争分析 二、门店运营分析&#xff08;销售分析&#xff09; 三、销售与财务分析 四、客户分析 五、数字化与营销分析 最近帮一个大学生培训&#xff0c;就门店销售分析 &#xff0c;说到门店…

使用windows批量解压和布局ImageNet ISLVRC2012数据集

使用的系统是windows&#xff0c;找到的解压命令很多都linux系统中的&#xff0c;为了能在windows系统下使用&#xff0c;因此下载Git这个软件&#xff0c;在其中的Git Bash中使用以下命令&#xff0c;因为Git Bash集成了很多linux的命令&#xff0c;方便我们的使用。 ImageNe…

[iOS]类和对象的底层原探索

[iOS]类和对象的底层探索 文章目录 [iOS]类和对象的底层探索继承链&#xff08;类&#xff0c;父类&#xff0c;元类&#xff09;instance 实例对象class 类对象meta-class 元类对象 对对象、类、元类和分类的探索instance 实例对象class 类对象meta-class 元类对象分类(catego…

防火墙之带宽管理篇

核心思想 1.带宽限制&#xff1a;限制非关键业务流量占用带宽的比例 2.带宽保证&#xff1a;保证关键的业务流量传输不受影响。业务繁忙时&#xff0c;确保业务不受影响。 3.限制连接数&#xff1a;可以针对某些业务进行连接数的限制&#xff0c;首先可以降低该业务占用带宽…

基于UltraFace的人脸检测在地平线旭日X3派上的部署和测试(Python版本和C++版本)

电脑端的测试环境搭建 如果不想再搭建环境和测试代码bug上浪费更多的时间可以直接获取本人的测试虚拟机&#xff0c;所有的测试代码、虚拟环境和板端测试工程以全部打包到了虚拟机&#xff0c;需要的可以通过网盘获取&#xff1a; 代码和虚拟机百度网盘链接&#xff1a; 链接…