【Transformer系列(4)】基于vision transformer(ViT)实现猫狗二分类项目实战

文章目录

  • 一、vision transformer(ViT)结构解释
  • 二、Patch Embedding部分
    • 2.1 图像Patch化
    • 2.2 cls token
    • 2.3 位置编码(positional embedding)
  • 三、Transformer Encoder部分
    • (1) Multi-head Self-Attention
    • (2) encoder block
  • 四、head部分
  • 五、vision transformer(ViT)完整代码
  • 六、基于vision transformer(ViT)实现猫狗二分类项目实战


一、vision transformer(ViT)结构解释

vision transformerViT)结构大致流程如下图

+------------+       +--------------+
|   Input    | ----> |    Patch     |
+------------+       +--------------+|v+-------+|  Embed  |+-------+|v+-------------------+|   Transformer     |+-------------------+|v+-------+|  Pool |+-------+|v+-------+|  MLP  |+-------+|v+-------+|  Class|+-------+|vOutput

Vision TransformrerViT)是一种基于自注意力机制的图像分类模型,它试图将图像分类任务转化为自然语言处理中的序列建模问题。与传统的卷积神经网络不同,ViT使用Transformer作为它的基本结构。

ViT的整体结构可以分为两个部分:Patch EmbeddingTransformer Encoder

Patch Embedding阶段,输入的图像首先被划分为多个小的固定尺寸的图像块,称为patch。每个patch经过一个线性投影层和一个位置编码层得到相应的向量表示。这些向量表示被展平为序列并通过一个可训练的嵌入层得到输入序列。

Transformer Encoder阶段,输入序列通过多个堆叠的Transformer Encoder层进行处理。每个Transformer Encoder层由多个注意力机制和多层感知机组成。注意力机制用于捕捉全局和局部的上下文信息,通过计算输入序列中不同位置的相互关系来获取注意力权重。多层感知机则用于在每个位置上对向量进行非线性转换。

ViT的最后,经过多个Transformer Encoder层处理后的序列经过一个全局平均池化层得到固定长度的表示,再通过一个线性分类层进行分类预测。

总的来说,ViT的结构利用自注意力机制,将输入的图像转化为序列,并通过多个Transformer Encoder层对序列进行处理,最后通过全局平均池化和线性分类层得到分类结果。这种结构在图像分类任务上取得了不错的性能,并且能够处理较大尺寸的图像。

二、Patch Embedding部分

Patch Embedding部分主要由(1)图像Patch化(2)cls token(3)positional embeding构成

2.1 图像Patch化

在这里插入图片描述
代码实现

# 序列组合位置编码
class PatchEmbed(nn.Module):def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_features=768, norm_layer=None, flatten=True):super().__init__()# 196 = (224 // 16) * (224 // 16) Patch化self.num_patches    = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)# Trueself.flatten        = flatten# 注意: kernel_size = stride 才能实现patch之间不相交self.proj = nn.Conv2d(in_chans, num_features, kernel_size=patch_size, stride=patch_size)self.norm = norm_layer(num_features) if norm_layer else nn.Identity()def forward(self, x):# Step 1. Patch using Conv2d with 'kernel_size = patch_size'# [1,3,224,224] -> [1,768,14,14]x = self.proj(x)if self.flatten:# x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC# Step 2. H*W -> N 宽高维度平铺,形成序列# BCHW -> BCN   N = H*W# [1,768,14,14] -> [1,768,196]x = x.flatten(2)# BCN -> BNC 交换1维和2维,即CN transpose一次只能对两个维度进行操作# [1,768,196] -> [1,196,768]x = x.transpose(1, 2)x = self.norm(x)return x

在这里插入图片描述
在这里插入图片描述
PatchEmbed之后,维度由[1,3,224,224]变为了[1,196,768]

2.2 cls token

ViT模型中,每个图块都经过一系列的Transformer编码器层,这些编码器层处理图块之间的局部关系。而cls token则是在第一个编码器层的输入中插入的一个特殊令牌。它作为整个图像的表示引入了全局信息。

cls token的计算方式与其他图块的计算方式相同,它经过自注意力机制和前馈神经网络进行特征转换。然后,将经过编码器层处理后的cls token的输出连接到分类器中,用于图像分类任务的最终预测。

cls token的作用是捕捉整个图像的全局特征。因为Transformer模型是一种自注意力模型,并没有显式的全局信息概念,cls token的引入可以将整个图像的特征聚合成一个向量,使得模型具备对整个图像的全局理解能力。这样,模型就可以利用cls token的特征进行分类任务的预测。
代码实现

# [1,1,768]
cls_token = self.cls_token.expand(batch_size, -1, -1)
# H*W+1
# [1,196,768] -> [1,197,768]
x = torch.cat((cls_token, x), dim=1)

在这里插入图片描述

cls token之后,维度由[1,196,768]变为了[1,197,768]

2.3 位置编码(positional embedding)

由于ViT是基于自注意力机制(self-attention mechanism)构建的,它无法直接处理序列中项目的顺序信息。
位置编码通常是通过将位置信息转换为向量形式,然后将其添加到输入图像的嵌入表示中来实现的。这样,每个嵌入向量就会包含图像中的位置信息。位置编码的加入可以帮助模型在处理图像时更好地理解不同位置之间的关系,它使得模型能够关注图像中的全局和局部结构,从而更好地对图像进行建模。

总结起来,位置编码在ViT中的作用是引入图像中不同位置的位置关系,以帮助模型理解和处理图像中的全局和局部结构。
代码实现

# [1,196+1,768] -> [1,196,768]
img_token_pe = self.pos_embed[:, 1:, :]# old_feature_shape: [1,196,768] -> [1,14,14,768] -> [1,768,14,14]
img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)
# new_feature_shape: [1,768,14,14] -> [1,768,14,14]
img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
# [1,768,14,14] -> [1,14,14,768] -> [1,196,768]
img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)
# [1,1,768] cat [1,196,768] -> [1,197,768]
pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)# Step 4. residual connection + droppath
# [1,197,768] + [1,197,768] -> [1,197,768]
x = self.pos_drop(x + pos_embed)

位置编码(positional embedding)之后,维度还是[1,197,768],没有改变。

三、Transformer Encoder部分

Encoder部分主要由下面流程构成
(1)Multi-head Self-Attention:在Encoder block中,每个patch embedding都会与其他所有patch embeddings进行注意力计算。这种注意力计算将每个patch embedding与其他patch embeddings进行交互,从而使每个patch能够“看到”其他patch的信息。这种注意力计算可以通过独立的多头注意力机制实现,其中每个注意力头都可以学习不同的关注模式。
(2)Layer Normalization:在注意力计算之后,对每个patch embedding进行层归一化操作,以减少信息波动。
(3)Feed-Forward Network:在层归一化之后,通过一个全连接前馈网络,对每个patch的特征进行非线性转换。这个前馈网络可以是多层感知机(MLP),可以通过两个线性变换和一个激活函数来实现。
(4)Residual Connection:Encoder block中的每个操作都有一个残差连接,将输入与输出相加,以保留输入的信息。
(5)Layer Normalization:在前馈网络之后,再次对每个patch embedding进行层归一化操作。
(6)Dropout:为了防止过拟合,可以在Encoder block中应用dropout操作,以随机丢弃一部分特征。

(1) Multi-head Self-Attention

原理参考:【Transformer系列(2)】Multi-head self-attention 多头自注意力
代码实现

# multi-head self-attention
class Attention(nn.Module):def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):super().__init__()# multi-head,这里有点类似分组卷积self.num_heads  = num_heads# 尺度self.scale      = (dim // num_heads) ** -0.5# qkv通过Linear生成self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop  = nn.Dropout(attn_drop)self.proj       = nn.Linear(dim, dim)self.proj_drop  = nn.Dropout(proj_drop)def forward(self, x):# Step 1. get qkv# N=W*HB, N, C     = x.shape# [B,N,3,num_heads,C//num_heads] -> [3,B,num_heads,N,//num_heads]qkv         = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v     = qkv[0], qkv[1], qkv[2]# Step 2. get attention# q*k的转置,再除以根号维度attn = (q @ k.transpose(-2, -1)) * self.scale# softmax就是attentionattn = attn.softmax(dim=-1)# dropout,随机失活attn = self.attn_drop(attn)# Step 3. use attention on v# 注意力乘vx = (attn @ v).transpose(1, 2).reshape(B, N, C)# Linearx = self.proj(x)# dropout,随机失活x = self.proj_drop(x)return x

(2) encoder block

代码实现

# 多个block组成
self.blocks = nn.Sequential(*[Block(dim=num_features,num_heads=num_heads,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,drop=drop_rate,attn_drop=attn_drop_rate,drop_path=dpr[i],norm_layer=norm_layer,act_layer=act_layer) for i in range(depth)]
)

Transformer Encoder之后,维度还是[1,197,768],没有改变。

四、head部分

head部分,(1)会对encoder输出的[1,197,768]向量进行归一化,(2)然后再取出cls token,(3)再将cls token送入Linear层。
(1)归一化

# [1,197,768]
x = self.norm(x)

(2)取出cls token

#  get cls_token 768类似channel
# [1,197,768] -> [1,768]
x= x[:, 0]

维度变化:1,197,768] -> [1,768]
(3)送入Linar层

self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
# [1,768] -> [1,2] 2分类问题
x = self.head(x)

维度变化:[1,768] -> [1,2] 2分类问题

五、vision transformer(ViT)完整代码

class VisionTransformer(nn.Module):def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU):super().__init__()# [224, 224, 3] -> [196, 768]# 196 = (224 // 16) * (224 // 16) Patch化后,再平铺# 768 = 16 * 16 * 3   input_channel = 3 HW分别缩放16倍 output_channel拓宽16*16倍self.patch_embed = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans,num_features=num_features)# kernel_size = stride# 196 = (224 // 16) * (224 // 16)num_patches = (224 // patch_size) * (224 // patch_size)self.num_features = num_features# new feature shape: [14,14]self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]# old feature shape: [14,14]self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]# --------------------------------------------------------------------------------------------------------------------##   classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。##   在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。#   此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。#   在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。# --------------------------------------------------------------------------------------------------------------------##   196, 768 -> 197, 768# [1,1,768]self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))# --------------------------------------------------------------------------------------------------------------------##   为网络提取到的特征添加上位置信息。#   以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768#   196 = (224 // 16) * (224 //16) 768 = 16 * 16 * 3#   此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。# --------------------------------------------------------------------------------------------------------------------##   197, 768 -> 197, 768# [1,196,768] -> [1,196+1,768]self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))self.pos_drop = nn.Dropout(p=drop_rate)# -----------------------------------------------##   197, 768 -> 197, 768  12次# -----------------------------------------------## 0~drop_path_rate的等差数列,12位dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]# 多个block组成self.blocks = nn.Sequential(*[Block(dim=num_features,num_heads=num_heads,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,drop=drop_rate,attn_drop=attn_drop_rate,drop_path=dpr[i],norm_layer=norm_layer,act_layer=act_layer) for i in range(depth)])self.norm = norm_layer(num_features)self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()def forward_features(self, x):# Step 1. 序列化:先Patch,再HW平铺# [B,C,H,W] -> BNC N=H*Wx = self.patch_embed(x)# Step 2. 增加cls token# -1:当前维度不拓展# 1batch_size = x.shape[0]# [1,1,768]cls_token = self.cls_token.expand(batch_size, -1, -1)# H*W+1# [1,196,768] -> [1,197,768]x = torch.cat((cls_token, x), dim=1)# Step 3. 位置编码# [1,196+1,768] -> [1,1,768]cls_token_pe = self.pos_embed[:, 0:1, :]# [1,196+1,768] -> [1,196,768]img_token_pe = self.pos_embed[:, 1:, :]# old_feature_shape: [1,196,768] -> [1,14,14,768] -> [1,768,14,14]img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)# new_feature_shape: [1,768,14,14] -> [1,768,14,14]img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)# [1,768,14,14] -> [1,14,14,768] -> [1,196,768]img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)# [1,1,768] cat [1,196,768] -> [1,197,768]pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)# Step 4. residual connection + droppath# [1,197,768] + [1,197,768] -> [1,197,768]x = self.pos_drop(x + pos_embed)# Step 5. multi-head self_attention# [1, 197, 768]x = self.blocks(x)# Step 6. layers_norm# [1,197,768]x = self.norm(x)# Step 7. get cls_token 768类似channel# [1,197,768] -> [1,768]x= x[:, 0]return xdef forward(self, x):# Step 1~6.# [1,3,224,224] -> [1,768]x = self.forward_features(x)# Step 7. Linear# [1,768] -> [1,2] 2分类问题x = self.head(x)return x

六、基于vision transformer(ViT)实现猫狗二分类项目实战

项目链接:https://download.csdn.net/download/m0_51579041/89255878
数据集链接:https://download.csdn.net/download/m0_51579041/89255922

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/6022.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uni-app(优医咨询)项目实战 - 第2天

学习目标: 掌握WXML获取节点信息的用法 知道如何修改 uni-ui 扩展组件的样式 掌握 uniForm 表单验证的使用方法 能够在 uni-app 中使用自定义字体图标 一、uni-app 基础知识 uni-app 是组合了 Vue 和微信小程序的相关技术知识,要求大家同时俱备 Vue 和原生小程序的开发基础。…

Python中的else魔法:不止是if

写在前面 提到else,肯定会对应一个if。虽然在许多编程语言中这都是正确的,但 Python 却不是。Python 的else语句有着更广泛的用途。从循环语句后的else到try-except块后的else…,本文将探讨else语句鲜为人知的功能。 1. if-else else 可以与 if 一起使用,这也是最常用的…

程序包的实例和删除

目录 程序包的实例 我们创建一个程序包,内容包含上一章所创建的存储过程和函数 程序包的删除 Oracle从入门到总裁:​​​​​​https://blog.csdn.net/weixin_67859959/article/details/135209645 程序包的实例 下面就通过具体范例来演示程序包的使用。 我们…

pyqt 按钮常用格式Qss设置

pyqt 按钮常用格式Qss设置 QSS介绍按钮常用的QSS设置效果代码 QSS介绍 Qt Style Sheets (QSS) 是 Qt 框架中用于定制应用程序界面样式的一种语言。它类似于网页开发中的 CSS(Cascading Style Sheets),但专门为 Qt 应用程序设计。使用 QSS&am…

高可用系列三:事务

都成功或者都失败是事务目标,实际中往往会采用最终一致、最大努力一致和不一致时人工介入策略。 评估事务,通常会根据业务特点,考虑对于事务相关业务之间所需的时效性、依赖联系因素,框定事务可用方案,并结合事务实现…

【论文阅读笔记】Frequency Perception Network for Camouflaged Object Detection

1.论文介绍 Frequency Perception Network for Camouflaged Object Detection 基于频率感知网络的视频目标检测 2023年 ACM MM Paper Code 2.摘要 隐蔽目标检测(COD)的目的是准确地检测隐藏在周围环境中的目标。然而,现有的COD方法主要定位…

信息系统项目管理师0083:项目管理的重要性(6项目管理概论—6.2项目基本要素—6.2.2项目管理的重要性)

点击查看专栏目录 文章目录 6.2.2项目管理的重要性 6.2.2项目管理的重要性 项目管理就是将知识、技能、工具与技术应用于项目活动,以满足项目的要求。通过合理地应用并整合特定的项目管理过程,项目管理使组织能够有效并高效地开展项目。 有效的项目管理能…

Rust个人学习之Rust操作Mysql数据库

Rust 使用 mysql 的 crate 进行 mysql 的连接操作,特进行记录。 写在前面 如果想使用 mysql 需要在 CargoToml 文件中增加 mysql 的引用 [dependencies] chrono "0.4" mysql "*"连接数据库 数据库信息如下: 字段数据数据库地…

可靠的智能组网系统有哪些?

天联是一种可靠的智能组网解决方案,在现今复杂网络环境下具备明显的优势。本文将介绍天联组网以及其所带来的诸多优势。 天联组网的优势 天联组网具有以下优势,使其成为一种可靠的智能组网方案: 无网络限制:天联组网能够解决复杂…

Tire 字典树、前缀树

字典树(又称单词查找树或Trie树)是一种树形结构,它是哈希树的变种,通常用于统计、排序和保存大量的字符串(但不仅限于字符串)。字典树在搜索引擎系统中常用于文本词频统计。它的主要优点在于能够利用字符串…

每日一题(力扣213):打家劫舍2--dp+分治

与打家劫舍1不同的是它最后一个和第一个会相邻,事实上,从结果思考,最后只会有三种:1 第一家不被抢 最后一家被抢 2 第一家被抢 最后一家不被抢 3 第一和最后一家都不被抢 。那么,根据打家劫舍1中的算法 我们能算出在i…

excel办公系列-图表元素及其作用

Excel图表元素及其作用 Excel图表由各种元素组成,每个元素都有其特定的作用,可以帮助我们更清晰地传达数据信息。下面将介绍Excel图表中常见的一些元素及其作用,并附上相关截图。 原始数据 月份 网站访问量 (万次) 销售额 (万…

FIFO Generate IP核使用——Data Counts页详解

在Vivado IDE中,当看到一个用于设置数据计数选项的选项卡时,需要注意的是,尽管某些选项值可能因为当前的配置而显示为灰色(即不可选或已禁用),但IDE中显示的有效范围值实际上是你可以选择的真实值。即使某些…

opencv t函数

在OpenCV中&#xff0c;t函数通常用于转置矩阵&#xff08;Transpose&#xff09;。这意味着矩阵的行和列互换位置。 在C中&#xff0c;使用OpenCV库进行矩阵转置的代码如下所示&#xff1a; #include <opencv2/opencv.hpp> #include <iostream>int main() {// 创…

《十二》Qt各种对话框之FileDialog文件对话框及QMessageBox 消息对话框

QFileDialog 对话框 选择打开一个文件 若要打开一个文件&#xff0c;可调用静态函数 QFileDialog::getOpenFileName()&#xff0c;“打开一个文件”按钮的响应代码如下&#xff1a; void Dialog::on_btnOpen_clicked() { //选择单个文件QString curPathQDir::currentPath()…

基于React实现B站评论区

今天继续来学习一下React&#xff0c;使用React实现B站评论区&#xff0c;如下图&#xff1a; 在使用React开发类似B站评论区的功能时&#xff0c;我们需要考虑以下几个关键点来构建一个基本的评论系统&#xff1a; 1. 设计组件结构 首先&#xff0c;设计组件结构是关键。至少…

Unity Animation--动画剪辑

Unity Animation--动画剪辑 动画剪辑 动画剪辑是Unity动画系统的核心元素之一。Unity支持从外部来源导入动画&#xff0c;并提供创建动画剪辑的能力使用“动画”窗口在编辑器中从头开始。 外部来源的动画 从外部来源导入的动画剪辑可能包括&#xff1a; 人形动画 运动捕捉…

Python中关于子类约束的开发规范

Python中关于子类约束的开发规范 我们知道&#xff0c;在java和C#中有一种接口的类型&#xff0c;用来约束实现该接口的类&#xff0c;必须要定义接口中指定的方法 而在python中&#xff0c;我们可以基于父类子类异常来仿照着实现这个功能 class Base:def func():raise NotI…

css---浮动知识点精炼汇总

前言 欢迎来到我的博客 个人主页:北岭敲键盘的荒漠猫-CSDN博客 浮动简单理解与介绍 这是我们普通的页面标签效果。 每个标签从上到下依次排列。 浮动顾名思义就是让这个标签飞翔起来。 他飞起来后&#xff0c;后面的标签来到他的位置上。 而浮动的标签就会显示在标签的上面。…

设计模式之MVC模式

在编程江湖闯荡多年&#xff0c;我手中打磨过的设计模式多如繁星&#xff0c;但论及经典与实用&#xff0c; MVC&#xff08;Model-View-Controller&#xff09;模式 绝对是个中翘楚&#xff01;它不仅是Web应用的骨架&#xff0c;更是软件架构的智慧结晶。今天&#xff0c;咱们…