【Vit】Vision Transformer 入门与理解

在学习VIT之前,建议先把 Transformer 搞明白了:【transformer】入门与理解

做了那些改进?

在这里插入图片描述

看图就比较明白了,VIT只用了Encoder的部分,把每一个图片裁剪成若干子图,然后把一个子图flatten一下,当成nlp中的一个token处理。
值得注意的是,在首个 token中嵌入了一个 class_token,维度为(1,embed_dim=768),这个class_token在预测的时候比较有意思,见下图:

在这里插入图片描述
注意上图中有些细节遗漏,全流程应该是:先把输入进行 patch_embedding 变成 visual tokens,然后和 class_token 合并,最后 position_embedding。

另外需要注意的是,class_token 是一个可学习的参数,并不是每次输入时都需要输入的类别数值。

self.class_token = nn.Parameter(torch.ones(1, 1, embed_dim) * 0.98)  #(1,1,768)

代码

其实有了 Transformer 的基础后,直接看代码就知道VIT是怎么做的了。

import copy
import torch
import torch.nn as nn# 所有基于nn.Module结构的模版,可以删掉
class Identity(nn.Module):def __init__(self):super().__init__()def forward(self, x):return xclass Mlp(nn.Module):def __init__(self, embed_dim, mlp_ratio, dropout=0.):super().__init__()self.fc1 = nn.Linear(embed_dim, int(embed_dim * mlp_ratio)) # 中间层扩增self.fc2 = nn.Linear(int(embed_dim * mlp_ratio), embed_dim)self.act = nn.GELU()self.dropout = nn.Dropout(dropout)def forward(self, x):# TODOx = self.fc1(x)x = self.act(x)x = self.dropout(x)x = self.fc2(x)x = self.dropout(x)return xclass PatchEmbedding(nn.Module):def __init__(self, image_size=224, patch_size=16, in_channels=3, embed_dim=768, dropout=0.):super().__init__()n_patches = (image_size // patch_size) * (image_size // patch_size)  # 196 个 patchself.patch_embedding = nn.Conv2d(in_channels=in_channels,  # embedding 操作后变成 torch.Size([10, 768, 14, 14])out_channels=embed_dim,kernel_size=patch_size,stride=patch_size)self.dropout = nn.Dropout(dropout)# TODO: add class tokenself.class_token = nn.Parameter(torch.ones(1, 1, embed_dim) * 0.98)  #(1,1,768)# TODO: add position embeddingself.position_embedding = nn.Parameter(torch.ones(1, n_patches+1, embed_dim) * 0.98)  #(1,196+1,768)def forward(self, x): # 先把 x patch_embedding,然后和 class_token 合并,最后 position_embedding# [n, c, h, w]cls_tokens = self.class_token.expand([x.shape[0], -1, -1]) #(10,1,768) 根据batch扩增 class_tokenx = self.patch_embedding(x) # [n, embed_dim, h', w']x = x.flatten(2) # torch.Size([10, 768, 196])x = x.permute([0, 2, 1]) # torch.Size([10, 196, 768])x = torch.concat([cls_tokens, x], axis=1) # (10,196+1,768)x = x + self.position_embeddingreturn x # torch.Size([10, 197, 768])class Attention(nn.Module):"""multi-head self attention"""def __init__(self, embed_dim, num_heads, qkv_bias=True, dropout=0., attention_dropout=0.):super().__init__()self.num_heads = num_headsself.head_dim = int(embed_dim / num_heads) # 768/4=192self.all_head_dim = self.head_dim * num_headsself.scales = self.head_dim ** -0.5self.qkv = nn.Linear(embed_dim,self.all_head_dim * 3) # [768, 768*3]self.proj = nn.Linear(embed_dim, embed_dim)self.dropout = nn.Dropout(dropout)self.attention_dropout = nn.Dropout(attention_dropout)self.softmax = nn.Softmax()def transpose_multihead(self, x):# x: [N, num_patches 197, all_head_dim 768] -> [N, n_heads, num_patches, head_dim]new_shape = [x.shape[:-1][0], x.shape[:-1][1], self.num_heads, self.head_dim] # [10, 197, 4, 192]x = x.reshape(new_shape) x = x.permute([0, 2, 1, 3]) # [10, 4, 197, 192]return xdef forward(self, x): # Attention 前后输入输出维度不变,都是 [10, 197, 768]B, N, _ = x.shape   # torch.Size([10, 197, 768])qkv = self.qkv(x).chunk(3, axis=-1) # 含有三个元素的列表,每一个元素大小 [10, 197, 768]q, k, v = map(self.transpose_multihead, qkv) # [10, 4, 197, 192]attn = torch.matmul(q, k.transpose(2,3)) # [10, 4, 197, 197]attn = attn * self.scalesattn = self.softmax(attn)attn = self.attention_dropout(attn)out = torch.matmul(attn, v) # [10, 4, 197, 192]out = out.permute([0, 2, 1, 3]) # [10, 197, 4, 192]out = out.reshape([B, N, -1]) # [10, 197, 768]out = self.proj(out) # [10, 197, 768]out = self.dropout(out)return outclass EncoderModule(nn.Module):def __init__(self, embed_dim=768, num_heads=4, qkv_bias=True, mlp_ratio=4.0, dropout=0., attention_dropout=0.):super().__init__()self.attn_norm = nn.LayerNorm(embed_dim)self.attn = Attention(embed_dim, num_heads)self.mlp_norm = nn.LayerNorm(embed_dim)self.mlp = Mlp(embed_dim, mlp_ratio)def forward(self, x):h = x # residualx = self.attn_norm(x)x = self.attn(x)x = x + hh = x # residualx = self.mlp_norm(x)x = self.mlp(x)x = x + hreturn xclass Encoder(nn.Module):def __init__(self, embed_dim, depth):super().__init__()Module_list = []for i in range(depth):encoder_Module = EncoderModule()Module_list.append(encoder_Module)self.Modules = nn.ModuleList(Module_list)self.norm = nn.LayerNorm(embed_dim)def forward(self, x):for Module in self.Modules:x = Module(x)x = self.norm(x)return xclass VisualTransformer(nn.Module):def __init__(self,image_size=224,patch_size=16,in_channels=3,num_classes=1000,embed_dim=768,depth=3,num_heads=8,):super().__init__()self.patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)self.encoder = Encoder(embed_dim, depth)self.classifier = nn.Linear(embed_dim, num_classes)def forward(self, x):# x:[N, C, H, W]x = self.patch_embedding(x) # torch.Size([10, 197, 768])x = self.encoder(x) # torch.Size([10, 197, 768])x = self.classifier(x[:, 0]) # 注意这里的处理很奇妙哦,参考 x = torch.concat([cls_tokens, x], axis=1) # (10,196+1,768)return xvit = VisualTransformer()
print(vit)input_data = torch.randn([10,3,224,224]) # 每批次输入10张图片
print(vit(input_data).shape) # torch.Size([10, 1000])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/807838.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【MATLAB源码-第12期】基于matlab的4FSK(4CPFSK)的误码率BER理论值与实际值仿真。

1、算法描述 4FSK在频移键控(FSK)编码的基础上有所扩展。FSK是一种调制技术,它通过在不同频率上切换来表示不同的数字或符号。而4FSK则是FSK的一种变种,表示使用了4个不同的频率来传输信息。 在4FSK中,每个数字或符号…

真随机数和伪随机数

真随机数和伪随机数 我先是看的TI的DL_TRNG_sendCommand(TRNG, DL_TRNG_CMD_NORM_FUNC);函数,能生成真随机数。要在microchip的八位机上移植同样的功能,但是那个库函数是伪随机数,我就看了两者的区别。区别就是,真随机数会出现随机…

基于Java的图书借阅网站, java+springboot+vue开发的图书借阅管理系统 - 毕业设计 - 课程设计

基于Java的图书借阅网站, javaspringbootvue开发的图书借阅管理系统 - 毕业设计 - 课程设计 文章目录 基于Java的图书借阅网站, javaspringbootvue开发的图书借阅管理系统 - 毕业设计 - 课程设计一、功能介绍二、代码结构三、部署运行1、后端运行步骤2、…

PaddleDetection 项目使用说明

PaddleDetection 项目使用说明 PaddleDetection 项目使用说明数据集处理相关模块环境搭建 PaddleDetection 项目使用说明 https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.7/configs/ppyoloe/README_cn.md 自己项目: https://download.csdn.net/d…

功能测试、自动化测试、性能测试的区别

🍅 视频学习:文末有免费的配套视频可观看 🍅 关注公众号:互联网杂货铺,回复1 ,免费获取软件测试全套资料,资料在手,薪资嘎嘎涨 按测试执行的类型来分:功能测试、自动化测…

Cheat Engine ceserver 连接手机记录

按照流程 电脑端的 cheat engine 7.5不显示任何内容 换一个 cheatengine 就好了 真神奇 链接:https://pan.baidu.com/s/14nMRHPEJ7enygI2nQf86YQ?pwdkxma 提取码:kxma

C++ stl容器vector的底层模拟实现

目录 前言: 1.成员变量,容量与大小 2.构造函数 无参构造: 带参的使用值进行构造: 使用迭代器区间进行构造: 3.交换 4.拷贝构造 5.赋值重载 6.迭代器 7.扩容 reserve: resize: 8.…

通过pre标签进行json格式化展示,并实现搜索高亮和通过鼠标进行逐个定位的功能

功能说明 实现一个对json进行格式化的功能添加搜索框,回车进行关键词搜索,并对关键词高亮显示搜索到的多个关键词,回车逐一匹配监听json框,如果发生了编辑,需要在退出时提示,在得到用户确认的情况下再退出…

30天精通Linux系统编程-----第一天:底层文件I/O (建议收藏)

目录 1.什么是底层文件I/O 2.底层文件I/O常用库函数 2.1 write函数 2.2 read函数 2.3 open函数 2.4 close函数 2.5 lseek函数 2.6 ioctl函数 2.7 fcntl()函数 2.8 pread()函数 2.9 pwrite()函数 1.什么是底层文件I/O 底层I/O指的是与硬件设备之间的直接输入输出操作…

Pytest精通指南(04)前后置和测试用例执行优先级

文章目录 Pytest 固件核心概念Pytest 固件原理Pytest 固件分类方法级函数级类级模块级夹具优先级测试用例执行优先级固件不仅如此后续大有文章 Pytest 固件核心概念 在 pytest 测试框架中,固件是一个核心概念; 它是一种特殊的函数,用于在测试…

蓝桥杯物联网竞赛_STM32L071KBU6_全部工程及国赛省赛真题及代码

包含stm32L071kbu6全部实验工程、源码、原理图、官方提供参考代码及国、省赛真题及代码 链接:https://pan.baidu.com/s/1pXnsMHE0t4RLCeluFhFpAg?pwdq497 提取码:q497

若依ruoyi 动态多数据源配置(多种不同类型的数据库mysql,oracle,sqlite3,sqlserver等等)

我使用的是若依mybaits-plus&#xff0c;具体根据自己的情况做更改 增加其他数据库的配置 &#xff0c;我这里是sqlite3与sqlserver <dependency><groupId>org.xerial</groupId><artifactId>sqlite-jdbc</artifactId><version>3.36.0.3&l…

箭头函数和普通函数的区别

箭头函数和普通函数在JavaScript中有几个关键的区别。以下是它们之间的一些主要差异&#xff1a; 1. 语法差异 普通函数可以使用function关键字进行定义&#xff1a; function regularFunction(arg1, arg2) { return arg1 arg2; } 箭头函数使用>符号进行定义&#xff0…

【Python】报错ModuleNotFoundError: No module named fileName解决办法

1.前言 当我们导入一个模块时&#xff1a; import xxx &#xff0c;默认情况下python解释器会搜索当前目录、已安装的内置模块和第三方模块。 搜索路径存放在sys模块的path中。【即默认搜索路径可以通过sys.path打印查看】 2.sys.path.append() sys.path是一个列表 list ,它里…

JVM常用参数一

jvm启动参数 JVM&#xff08;Java虚拟机&#xff09;的启动参数是在启动JVM时可以设置的一些命令行参数。这些参数用于指定JVM的运行环境、内存分配、垃圾回收器以及其他选项。以下是一些常见的JVM启动参数&#xff1a; -Xms&#xff1a;设置JVM的初始堆大小。 -Xmx&#xff1…

证书生成和获取阿里云备案获取密钥流程

1.在java文件夹下 输入 cmd 打开命令行窗口 2. keytool -genkey -alias 证书名 -keyalg RSA -keysize 2048 -validity 36500 -keystore 证书名.keystore 输入这一行&#xff0c;把证书名三个字 改成 项目的名称&#xff08;例如&#xff1a;D23102802&#xff09; 3. 密码默认填…

天工 AI 爆赞的数据分析能力

分享一个 AI 应用。 天工 AI 天工AI - 首页 (tiangong.cn) 可以上传数据&#xff0c;给出数据分析命令&#xff0c;并能出图。 数据分析师岌岌可危。 又知道其他好用的数据分析应用么&#xff0c;可以告诉我下。

vscode + wsl1 搭建远程C/C++开发环境

记录第一次搭建环境过程。 如何选择开发环境 搭建C/C开发环境有很多种方式&#xff0c;如 MinGW vscode&#xff08;MinGW 是GCC的Windows版本&#xff0c;本地编译环境&#xff09;SSH隧道连接 vscode&#xff08;远程Linux主机&#xff09;wsl vscode&#xff08;远程Li…

Axios网络请求

Axios网络请求主要用于前后端请求&#xff0c;前后端分离时前端需要通过url请求后端的接口&#xff0c;并且处理后端传过来的数据。 Axios官网教程 安装 npm install axios在main.js导入 import axios from axios;//声明一个http变量&#xff01;&#xff01;&#xff01…

linux安装maven和git

https://maven.apache.org/download.cgi 下载apache-maven-3.9.6-bin.tar.gz 创建文件夹 mkdir maven chmod 777 maven 解压 tar zxvf apache-maven-3.9.6-bin.tar.gz vim /etc/profile #文件添加以下内容 maven environment export M2_HOME/data/maven/apache-maven-3.9.6 …