【Vit】Vision Transformer 入门与理解

在学习VIT之前,建议先把 Transformer 搞明白了:【transformer】入门与理解

做了那些改进?

在这里插入图片描述

看图就比较明白了,VIT只用了Encoder的部分,把每一个图片裁剪成若干子图,然后把一个子图flatten一下,当成nlp中的一个token处理。
值得注意的是,在首个 token中嵌入了一个 class_token,维度为(1,embed_dim=768),这个class_token在预测的时候比较有意思,见下图:

在这里插入图片描述
注意上图中有些细节遗漏,全流程应该是:先把输入进行 patch_embedding 变成 visual tokens,然后和 class_token 合并,最后 position_embedding。

另外需要注意的是,class_token 是一个可学习的参数,并不是每次输入时都需要输入的类别数值。

self.class_token = nn.Parameter(torch.ones(1, 1, embed_dim) * 0.98)  #(1,1,768)

代码

其实有了 Transformer 的基础后,直接看代码就知道VIT是怎么做的了。

import copy
import torch
import torch.nn as nn# 所有基于nn.Module结构的模版,可以删掉
class Identity(nn.Module):def __init__(self):super().__init__()def forward(self, x):return xclass Mlp(nn.Module):def __init__(self, embed_dim, mlp_ratio, dropout=0.):super().__init__()self.fc1 = nn.Linear(embed_dim, int(embed_dim * mlp_ratio)) # 中间层扩增self.fc2 = nn.Linear(int(embed_dim * mlp_ratio), embed_dim)self.act = nn.GELU()self.dropout = nn.Dropout(dropout)def forward(self, x):# TODOx = self.fc1(x)x = self.act(x)x = self.dropout(x)x = self.fc2(x)x = self.dropout(x)return xclass PatchEmbedding(nn.Module):def __init__(self, image_size=224, patch_size=16, in_channels=3, embed_dim=768, dropout=0.):super().__init__()n_patches = (image_size // patch_size) * (image_size // patch_size)  # 196 个 patchself.patch_embedding = nn.Conv2d(in_channels=in_channels,  # embedding 操作后变成 torch.Size([10, 768, 14, 14])out_channels=embed_dim,kernel_size=patch_size,stride=patch_size)self.dropout = nn.Dropout(dropout)# TODO: add class tokenself.class_token = nn.Parameter(torch.ones(1, 1, embed_dim) * 0.98)  #(1,1,768)# TODO: add position embeddingself.position_embedding = nn.Parameter(torch.ones(1, n_patches+1, embed_dim) * 0.98)  #(1,196+1,768)def forward(self, x): # 先把 x patch_embedding,然后和 class_token 合并,最后 position_embedding# [n, c, h, w]cls_tokens = self.class_token.expand([x.shape[0], -1, -1]) #(10,1,768) 根据batch扩增 class_tokenx = self.patch_embedding(x) # [n, embed_dim, h', w']x = x.flatten(2) # torch.Size([10, 768, 196])x = x.permute([0, 2, 1]) # torch.Size([10, 196, 768])x = torch.concat([cls_tokens, x], axis=1) # (10,196+1,768)x = x + self.position_embeddingreturn x # torch.Size([10, 197, 768])class Attention(nn.Module):"""multi-head self attention"""def __init__(self, embed_dim, num_heads, qkv_bias=True, dropout=0., attention_dropout=0.):super().__init__()self.num_heads = num_headsself.head_dim = int(embed_dim / num_heads) # 768/4=192self.all_head_dim = self.head_dim * num_headsself.scales = self.head_dim ** -0.5self.qkv = nn.Linear(embed_dim,self.all_head_dim * 3) # [768, 768*3]self.proj = nn.Linear(embed_dim, embed_dim)self.dropout = nn.Dropout(dropout)self.attention_dropout = nn.Dropout(attention_dropout)self.softmax = nn.Softmax()def transpose_multihead(self, x):# x: [N, num_patches 197, all_head_dim 768] -> [N, n_heads, num_patches, head_dim]new_shape = [x.shape[:-1][0], x.shape[:-1][1], self.num_heads, self.head_dim] # [10, 197, 4, 192]x = x.reshape(new_shape) x = x.permute([0, 2, 1, 3]) # [10, 4, 197, 192]return xdef forward(self, x): # Attention 前后输入输出维度不变,都是 [10, 197, 768]B, N, _ = x.shape   # torch.Size([10, 197, 768])qkv = self.qkv(x).chunk(3, axis=-1) # 含有三个元素的列表,每一个元素大小 [10, 197, 768]q, k, v = map(self.transpose_multihead, qkv) # [10, 4, 197, 192]attn = torch.matmul(q, k.transpose(2,3)) # [10, 4, 197, 197]attn = attn * self.scalesattn = self.softmax(attn)attn = self.attention_dropout(attn)out = torch.matmul(attn, v) # [10, 4, 197, 192]out = out.permute([0, 2, 1, 3]) # [10, 197, 4, 192]out = out.reshape([B, N, -1]) # [10, 197, 768]out = self.proj(out) # [10, 197, 768]out = self.dropout(out)return outclass EncoderModule(nn.Module):def __init__(self, embed_dim=768, num_heads=4, qkv_bias=True, mlp_ratio=4.0, dropout=0., attention_dropout=0.):super().__init__()self.attn_norm = nn.LayerNorm(embed_dim)self.attn = Attention(embed_dim, num_heads)self.mlp_norm = nn.LayerNorm(embed_dim)self.mlp = Mlp(embed_dim, mlp_ratio)def forward(self, x):h = x # residualx = self.attn_norm(x)x = self.attn(x)x = x + hh = x # residualx = self.mlp_norm(x)x = self.mlp(x)x = x + hreturn xclass Encoder(nn.Module):def __init__(self, embed_dim, depth):super().__init__()Module_list = []for i in range(depth):encoder_Module = EncoderModule()Module_list.append(encoder_Module)self.Modules = nn.ModuleList(Module_list)self.norm = nn.LayerNorm(embed_dim)def forward(self, x):for Module in self.Modules:x = Module(x)x = self.norm(x)return xclass VisualTransformer(nn.Module):def __init__(self,image_size=224,patch_size=16,in_channels=3,num_classes=1000,embed_dim=768,depth=3,num_heads=8,):super().__init__()self.patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)self.encoder = Encoder(embed_dim, depth)self.classifier = nn.Linear(embed_dim, num_classes)def forward(self, x):# x:[N, C, H, W]x = self.patch_embedding(x) # torch.Size([10, 197, 768])x = self.encoder(x) # torch.Size([10, 197, 768])x = self.classifier(x[:, 0]) # 注意这里的处理很奇妙哦,参考 x = torch.concat([cls_tokens, x], axis=1) # (10,196+1,768)return xvit = VisualTransformer()
print(vit)input_data = torch.randn([10,3,224,224]) # 每批次输入10张图片
print(vit(input_data).shape) # torch.Size([10, 1000])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/807838.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【MATLAB源码-第12期】基于matlab的4FSK(4CPFSK)的误码率BER理论值与实际值仿真。

1、算法描述 4FSK在频移键控(FSK)编码的基础上有所扩展。FSK是一种调制技术,它通过在不同频率上切换来表示不同的数字或符号。而4FSK则是FSK的一种变种,表示使用了4个不同的频率来传输信息。 在4FSK中,每个数字或符号…

基于Java的图书借阅网站, java+springboot+vue开发的图书借阅管理系统 - 毕业设计 - 课程设计

基于Java的图书借阅网站, javaspringbootvue开发的图书借阅管理系统 - 毕业设计 - 课程设计 文章目录 基于Java的图书借阅网站, javaspringbootvue开发的图书借阅管理系统 - 毕业设计 - 课程设计一、功能介绍二、代码结构三、部署运行1、后端运行步骤2、…

PaddleDetection 项目使用说明

PaddleDetection 项目使用说明 PaddleDetection 项目使用说明数据集处理相关模块环境搭建 PaddleDetection 项目使用说明 https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.7/configs/ppyoloe/README_cn.md 自己项目: https://download.csdn.net/d…

Cheat Engine ceserver 连接手机记录

按照流程 电脑端的 cheat engine 7.5不显示任何内容 换一个 cheatengine 就好了 真神奇 链接:https://pan.baidu.com/s/14nMRHPEJ7enygI2nQf86YQ?pwdkxma 提取码:kxma

C++ stl容器vector的底层模拟实现

目录 前言: 1.成员变量,容量与大小 2.构造函数 无参构造: 带参的使用值进行构造: 使用迭代器区间进行构造: 3.交换 4.拷贝构造 5.赋值重载 6.迭代器 7.扩容 reserve: resize: 8.…

通过pre标签进行json格式化展示,并实现搜索高亮和通过鼠标进行逐个定位的功能

功能说明 实现一个对json进行格式化的功能添加搜索框,回车进行关键词搜索,并对关键词高亮显示搜索到的多个关键词,回车逐一匹配监听json框,如果发生了编辑,需要在退出时提示,在得到用户确认的情况下再退出…

30天精通Linux系统编程-----第一天:底层文件I/O (建议收藏)

目录 1.什么是底层文件I/O 2.底层文件I/O常用库函数 2.1 write函数 2.2 read函数 2.3 open函数 2.4 close函数 2.5 lseek函数 2.6 ioctl函数 2.7 fcntl()函数 2.8 pread()函数 2.9 pwrite()函数 1.什么是底层文件I/O 底层I/O指的是与硬件设备之间的直接输入输出操作…

Pytest精通指南(04)前后置和测试用例执行优先级

文章目录 Pytest 固件核心概念Pytest 固件原理Pytest 固件分类方法级函数级类级模块级夹具优先级测试用例执行优先级固件不仅如此后续大有文章 Pytest 固件核心概念 在 pytest 测试框架中,固件是一个核心概念; 它是一种特殊的函数,用于在测试…

蓝桥杯物联网竞赛_STM32L071KBU6_全部工程及国赛省赛真题及代码

包含stm32L071kbu6全部实验工程、源码、原理图、官方提供参考代码及国、省赛真题及代码 链接:https://pan.baidu.com/s/1pXnsMHE0t4RLCeluFhFpAg?pwdq497 提取码:q497

【Python】报错ModuleNotFoundError: No module named fileName解决办法

1.前言 当我们导入一个模块时: import xxx ,默认情况下python解释器会搜索当前目录、已安装的内置模块和第三方模块。 搜索路径存放在sys模块的path中。【即默认搜索路径可以通过sys.path打印查看】 2.sys.path.append() sys.path是一个列表 list ,它里…

JVM常用参数一

jvm启动参数 JVM(Java虚拟机)的启动参数是在启动JVM时可以设置的一些命令行参数。这些参数用于指定JVM的运行环境、内存分配、垃圾回收器以及其他选项。以下是一些常见的JVM启动参数: -Xms:设置JVM的初始堆大小。 -Xmx&#xff1…

证书生成和获取阿里云备案获取密钥流程

1.在java文件夹下 输入 cmd 打开命令行窗口 2. keytool -genkey -alias 证书名 -keyalg RSA -keysize 2048 -validity 36500 -keystore 证书名.keystore 输入这一行,把证书名三个字 改成 项目的名称(例如:D23102802) 3. 密码默认填…

天工 AI 爆赞的数据分析能力

分享一个 AI 应用。 天工 AI 天工AI - 首页 (tiangong.cn) 可以上传数据,给出数据分析命令,并能出图。 数据分析师岌岌可危。 又知道其他好用的数据分析应用么,可以告诉我下。

vscode + wsl1 搭建远程C/C++开发环境

记录第一次搭建环境过程。 如何选择开发环境 搭建C/C开发环境有很多种方式,如 MinGW vscode(MinGW 是GCC的Windows版本,本地编译环境)SSH隧道连接 vscode(远程Linux主机)wsl vscode(远程Li…

Axios网络请求

Axios网络请求主要用于前后端请求,前后端分离时前端需要通过url请求后端的接口,并且处理后端传过来的数据。 Axios官网教程 安装 npm install axios在main.js导入 import axios from axios;//声明一个http变量!!&#xff01…

初步了解Zookeeper

目录 1. Zookeeper定义 2. Zookeeper工作机制 3. Zookeeper特点 4. Zookeeper数据结构 5. Zookeeper应用场景 5.1 统一命名服务 5.2 统一配置管理 5.3 统一集群管理 5.4 服务器动态上下线 5.5 软负载均衡 6. Zookeeper 选举机制 6.1 第一次启动选举机制 6.2 非第一…

分类预测 | Matlab实现KPCA-IDBO-LSSVM基于核主成分分析和改进蜣螂优化算法优化最小二乘支持向量机分类预测

分类预测 | Matlab实现KPCA-IDBO-LSSVM基于核主成分分析和改进蜣螂优化算法优化最小二乘支持向量机分类预测 目录 分类预测 | Matlab实现KPCA-IDBO-LSSVM基于核主成分分析和改进蜣螂优化算法优化最小二乘支持向量机分类预测分类效果基本描述程序设计参考资料 分类效果 基本描述…

【Python数据分析】让工作自动化起来,无所不能的Python

这里写目录标题 前言一、Python是办公自动化的重要工具二、Python是提升职场竞争力的利器三、Python是企业数字化的重要平台四、Python是AI发展的重要通道之一编辑推荐内容简介作者简介前言读者对象如何阅读本书目录 前言 随着我国企业数字化和信息化的深入,企业对…

大屏可视化展示平台解决方案(word原件获取)

1.系统概述 1.1.需求分析 1.2.重难点分析 1.3.重难点解决措施 2.系统架构设计 2.1.系统架构图 2.2.关键技术 2.3.接口及要求 3.系统功能设计 3.1.功能清单列表 3.2.数据源管理 3.3.数据集管理 3.4.视图管理 3.5.仪表盘管理 3.6.移动端设计 3.1.系统权限设计 3.2.数据查询过程设…

【论文阅读】MCTformer: 弱监督语义分割的多类令牌转换器

【论文阅读】MCTformer: 弱监督语义分割的多类令牌转换器 文章目录 【论文阅读】MCTformer: 弱监督语义分割的多类令牌转换器一、介绍二、联系工作三、方法四、实验结果 Multi-class Token Transformer for Weakly Supervised Semantic Segmentation 本文提出了一种新的基于变换…