Vit Transformer

一 VitTransformer 介绍

vit : An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

        论文是基于Attention Is All You Need,由于图像数据和词数据数据格式不一样,经典的transformer不能处理图像数据,在视觉领域的应用有限。本文提出的方法可以将transformer直接应用图像分类任务,引入Patch Embedding,位置编码等方法,克服了Transformer在处理图像数据时的限制。整体流程如下。

从图中可以看出, Vision Transformer 主要有三个部分组成: 1 ) 第一部分是Linear Projection of Flattened Patches ,也就是 Emdedding 层,主要的工作就是将图像数据转换成transformer可以处理的数据格式。2)第二部分是Transformer Encoder部分,它是vit 最核心的组件(原始的NLP的transformer还有Decoder部分)。它主要是层归一化,多头注意力机制,MLP,Dropout/DropPath四个小block组成,用于学习图像数据。3) 第三部分就是MLP head ,用于分类。

二 PatchEmbedding & Positional Encoding

        首先,每个图像被分割成一系列不重叠的块(16x16或者 32x32),然后做一个线性的embedding ,由于这些块如果并行的输入到transformer中,不提供位置信息,模型不知道这些块的顺序。因此要加一个 positional encoding。 

        在实际的实现上,图像数据是[batch_size, C , H, W] 的格式,要将其变成[batch_size , token_len , dim],其中token_len 可以理解成图像patch token的数量。以[4,3,224,224]的图像为例子,首先我们模拟分割块,对于一个图像,我们要将其分割成 (H*w)/(patch_size*patch_size)个patches,即(224x224)// (16x16) = 196个 patches 。每个patch的大小是(3,16,16),然后我们将其flatten一个768( 3x16x16)dim的 token。这样数据格式就变成[4,196,768]。

代码分割图像块 : 

def split_patches(x, patch_size=16):batch_size, channels, height, width = x.shapex = x.reshape(batch_size, channels, height // patch_size, patch_size, width // patch_size, patch_size)x = x.permute(0, 2, 4, 1, 3, 5)x = x.reshape(batch_size, -1, channels * patch_size * patch_size)return x

当然这个过程可以通过卷积实现,官方代码其实就是用卷积来实现的。

class PatchEmbed(nn.Module) :def __init__(self,img_size=224,patch_size=16,in_channels=3,embed_dim=768,norm_layer=None):super().__init__()img_size = (img_size,img_size)patch_size = (patch_size,patch_size)self.img_size = img_sizeself.patch_size = patch_sizeself.grid_size = (img_size[0] // patch_size[0],img_size[1]//patch_size[1])self.num_patches = self.grid_size[0] * self.grid_size[1] # 卷积层self.proj = nn.Conv2d(in_channels=in_channels,out_channels=embed_dim,kernel_size=patch_size,stride=patch_size)# 归一化self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()def forward(self,x) :B,C,H,W = x.shapeassert H == self.img_size[0] and W == self.img_size[1], \f"Input image size ({H}x{W} does not match input image size ({self.img_size[0]}x{self.img_size[1]}"x = self.proj(x).flatten(2).transpose(1,2)x = self.norm(x)return x

 positional Embedding 

由于输入的图像数据patch序列没有能够表达patch之间相对位置关系,因此需要加入位置编码(Positional encoding)这个特征,为了得到不同位置的对应的编码,Transformer模型使用不同频率的正余弦函数

PE(Pos,2i) = sin(\frac{pos}{10000^{2i/d}})

PE(Pos,2i+1) = cos(\frac{pos}{10000^{2i/d}})

 其中 pos是表示token(flattened image patch)的位置,2i和2i+1表示位置编码向量中对应的维度,d是对应位置编码的总维度。

def add_positional_encoding(x, max_len):batch_size, patch_numbers, dim = x.shapeposition = torch.arange(max_len).reshape(-1, 1)div_term = torch.exp(torch.arange(0, dim, 2) * -(torch.log(torch.tensor(10000.0)) / dim))pe = torch.zeros((max_len, dim))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)x += pe[:patch_numbers]return x

三 Self-Attention 

        其中最核心的部分就是对于注意力部分了。在基于Transformer的机器翻译模型中,要建模源语言和目标语言任意两个单词的依赖关系,引入自注意力K(键)Q(查询)V(值)。这三个用来计算上下文单词所对应的权重得分,这些权重反映了在编码当前单词时,对于上下文不同部分所需要关注程度。

    同样在vision transformer中,对于一个个image patch token 来说,也需要建模任意 token之间的相互关注关系,当处理当前token时,哪些token与它有更高的关联度。

        上图是论文中的Scaled Dot-Product AttentionMulti-head Attention,我们首先定义三个矩阵 Q,K,V,这三个矩阵是由 输入X([4,196,768])分别经过三个权重矩阵$w_{q},w_{k},w_{v}$得到的。其中 Q 矩阵和K 矩阵,V矩阵是“同源”的,因为它们都是来自于同一个输入序列(图像patch token)的某种表示(线性变换的嵌入表示)。

根据Attenton分数的计算公式,Q(shape=[4,196,768])左乘一个K(shape=[4,196,768])矩阵的转置,得到一个相似度矩阵(shape=[4,196,196]),为了防止过大的相似度数值在后续Softmax计算过程中导致的梯度爆炸以及收敛效率差的问题,因此使用一个缩放因子\sqrt{d}缩放来稳定优化。放缩后的得分经过Softmax函数归一化为概率后,与其他位置的值向量相乘来聚合希望关注的上下文信息,并最小化不相关信息的干扰。

def self_attention(x, w_q, w_k, w_v):query = torch.matmul(x, w_q)key = torch.matmul(x, w_k)value = torch.matmul(x, w_v)scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(query.shape[-1], dtype=torch.float32))attention_scores = softmax(scores)output = torch.matmul(attention_scores, value)return attention_scores, output

过程如下

四 Multi-head self-attention (MSA)

        为了进一步提升自注意力机制的全局信息聚合能力,提出了Multi-head attention机制,具体来说,上下文的每个token 向量的表示x_{i}的经过多组的线性{W_{j}^{Q},W_{j}^{K},W_{j}^{V}}映射到不同的表示空间。计算出不同子空间得到的attention score得到{Z_{j}}_{j=1}^{N},再用一个线性变换w^{o} 用于综合不同子空间中的上下文表示形成最后的输出。

import torch
import torch.nn.functional as Fclass MultiHeadAttention(torch.nn.Module):def __init__(self, input_dim, num_heads, head_dim):super(MultiHeadAttention, self).__init__()self.num_heads = num_headsself.head_dim = head_dimassert input_dim % self.num_heads == 0self.projection_dim = input_dim // self.num_heads# 定义 权重矩阵 self.weight_q = torch.nn.Parameter(torch.randn(num_heads, input_dim, self.projection_dim))self.weight_k = torch.nn.Parameter(torch.randn(num_heads, input_dim, self.projection_dim))self.weight_v = torch.nn.Parameter(torch.randn(num_heads, input_dim, self.projection_dim))self.weight_combine = torch.nn.Parameter(torch.randn(num_heads * self.projection_dim, input_dim))def forward(self, x):batch_size, seq_length, _ = x.size()queries = torch.matmul(x, self.weight_q)keys = torch.matmul(x, self.weight_k)values = torch.matmul(x, self.weight_v)queries = queries.view(batch_size, seq_length, self.num_heads, self.projection_dim)keys = keys.view(batch_size, seq_length, self.num_heads, self.projection_dim)values = values.view(batch_size, seq_length, self.num_heads, self.projection_dim)queries = queries.transpose(1, 2)keys = keys.transpose(1, 2)values = values.transpose(1, 2)# 计算得分scores = torch.matmul(queries, keys.transpose(-2, -1)) / (self.projection_dim ** 0.5)attention_weights = F.softmax(scores, dim=-1)attention_output = torch.matmul(attention_weights, values)attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_length, -1)output = torch.matmul(attention_output, self.weight_combine)return outputinput_dim = 64
num_heads = 8
head_dim = input_dim // num_heads
seq_length = 10
batch_size = 4multihead_attention = MultiHeadAttention(input_dim, num_heads, head_dim)
x = torch.rand(batch_size, seq_length, input_dim)
output = multihead_attention(x)print("输入形状:", x.shape)
print("输出形状:", output.shape)


五 代码实现

import torch
import matplotlib.pyplot as plt
from PIL import Image
from torchvision.transforms import transforms
import numpy as npdef softmax(x):return torch.nn.functional.softmax(x, dim=-1)def split_patches(x, patch_size=16):batch_size, channels, height, width = x.shapex = x.reshape(batch_size, channels, height // patch_size, patch_size, width // patch_size, patch_size)x = x.permute(0, 2, 4, 1, 3, 5)x = x.reshape(batch_size, -1, channels * patch_size * patch_size)return xdef add_positional_encoding(x, max_len):batch_size, patch_numbers, dim = x.shapeposition = torch.arange(max_len).reshape(-1, 1)div_term = torch.exp(torch.arange(0, dim, 2) * -(torch.log(torch.tensor(10000.0)) / dim))pe = torch.zeros((max_len, dim))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)x += pe[:patch_numbers]return xdef plot_heatmap(scores,index,name ):plt.figure(figsize=(8, 6))plt.imshow(scores, cmap='hot', interpolation='nearest')plt.xlabel('Keys')plt.ylabel('Queries')plt.title(f'Attention Scores Heatmap {name}')plt.colorbar()plt.savefig(f"./attention_heatmap{index}.png")# plt.show()
def self_attention(x, w_q, w_k, w_v):query = torch.matmul(x, w_q)key = torch.matmul(x, w_k)value = torch.matmul(x, w_v)scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(query.shape[-1], dtype=torch.float32))attention_scores = softmax(scores)output = torch.matmul(attention_scores, value)return attention_scores, outputdef plot_heatmap_on_image(image, attention_scores, patch_size=16,index=0):# 对每个patch的注意力分数求平均attention_scores_mean = attention_scores.mean(dim=1)# 将注意力分数转换为与原始图像大小相匹配的热力图attention_map = attention_scores_mean.view(1, 1, int(224 / patch_size), int(224 / patch_size))attention_map = torch.nn.functional.interpolate(attention_map, size=(224, 224), mode='bilinear', align_corners=False)attention_map = attention_map.squeeze().cpu().detach().numpy()plt.figure(figsize=(6, 6))plt.imshow(image)plt.imshow(attention_map, cmap='jet', alpha=0.5)plt.axis('off')plt.savefig(f'attention_map{index}.png')# plt.show()if __name__ == '__main__':batch_size = 4channels = 3height = 224width = 224input_dim = 768output_dim = 64transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor()])image_paths = ["./images/11.jpg","./images/15.jpg","./images/16.jpg","./images/17.jpg"]images = torch.zeros((4, 3, 224, 224), dtype=torch.float32)for i, path in enumerate(image_paths):img = Image.open(path).convert('RGB')img_tensor = transform(img)images[i] = img_tensorpatch_embeddings = split_patches(images, patch_size=16)patch_embeddings_pe = add_positional_encoding(patch_embeddings, max_len=196)w_q = torch.normal(0, 0.01, size=(input_dim, output_dim))w_k = torch.normal(0, 0.01, size=(input_dim, output_dim))w_v = torch.normal(0, 0.01, size=(input_dim, output_dim))attention_scores, output = self_attention(patch_embeddings_pe, w_q, w_k, w_v)# plot_heatmap(attention_scores[0])# for index in range(4) :##     name = image_paths[index].split('/')[-1].split('.')[0]#     plot_heatmap(attention_scores[index],index,name)  # 选择第一张图像的注意力分数进行绘制# 将热力图叠加到原始图像上for index in range(4) :image_path = image_paths[index]img = Image.open(image_path).convert('RGB')img_tensor = transform(img)img_np = np.array(img)plot_heatmap_on_image(img_np, attention_scores[index],16,index=index)

参考

  • LLM(廿四):Transformer 的结构改进与替代方案 - 知乎

  • 【深度学习系列】五、Self Attention_self attention 加入位置信息-CSDN博客

  • NLP(五):Transformer及其attention机制 - 知乎

  • 有关vision transformer的一个综述 - 知乎

  • 为什么 Vision transformer 训练和推理很慢? - 知乎

  • 大规模语言模型:从理论到实践 -- 张奇、桂韬、郑锐、黄萱菁

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/772485.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

4、事件修饰符、过滤器、自定义指令、生命周期

一、事件修饰符 按键别名enter 回车 delete 删除键 esc取消键 space 空格键 <script> export default {name: "KeyUp",methods:{keyUp(e){ console.log(e) }},skip(){window.location.href "http:www.xx.com"} } </script> <template>…

鸿蒙应用开发-录音保存并播放音频

功能介绍&#xff1a; 录音并保存为m4a格式的音频&#xff0c;然后播放该音频&#xff0c;参考文档使用AVRecorder开发音频录制功能(ArkTS)&#xff0c;更详细接口信息请查看接口文档&#xff1a;ohos.multimedia.media (媒体服务)。 知识点&#xff1a; 熟悉使用AVRecorder…

super的使用细节

1、super的使用细节 2、super和this的比较

159.乐理基础-和声模板是什么?优缺点与运用要点

如果到这五线谱还没记住还不认识的话去看102.五线谱-高音谱号与103.五线谱-低音谱号这两个里&#xff0c;这里面有五线谱对应的音名&#xff0c;对比着看 如果一章没落下&#xff0c;看到这里&#xff0c;但是看不懂什么意思&#xff0c;那就强行下看&#xff0c;看着看着指不…

[leetcode]118.杨辉三角

前言&#xff1a;剑指offer刷题系列 问题&#xff1a; 给定一个非负整数 *numRows&#xff0c;*生成「杨辉三角」的前 numRows 行。 在「杨辉三角」中&#xff0c;每个数是它左上方和右上方的数的和。 示例&#xff1a; 输入: numRows 5 输出: [[1],[1,1],[1,2,1],[1,3,3,…

CKS之镜像漏洞扫描工具:Trivy

目录 Trivy介绍 Trivy安装 Trivy使用命令 容器镜像扫描 打印指定&#xff08;高危、严重&#xff09;漏洞信息 JSON格式输出 HTML格式输出 离线扫描命令 离线更新Trivy数据库 Harbor安装Trivy Trivy介绍 Trivy是一款用于扫描容器镜像、文件系统、Git仓库等的漏洞扫描…

Matlab|基于两阶段鲁棒优化的微网电源储能容量优化配置

目录 主要内容 1.1 目标函数 1.2 约束条件 1.3 不确定变量 部分代码 结果一览 下载链接 主要内容 程序主要复现的是《考虑寿命损耗的微网电池储能容量优化配置》&#xff0c;解决微网中电源/储能容量优化配置的问题&#xff0c;即风电、光伏、储能以及燃气轮机…

LeetCode - 执行子串操作后的字典序最小字符串

题目要求经过操作后的字符串的字典序要比之前小。 在做这道题的之后陷入了一个误区&#xff0c;就是看a的位置&#xff0c;a-1之后z&#xff0c;z的字典序比a大&#xff0c;所以要尽可能的避免a变成z&#xff0c;但是字典序的比较是从前往后比较的&#xff0c;纠结于a变成z&am…

NSCaching: Simple and Efficient NegativeSampling for Knowledge Graph Embedding

摘要 知识图嵌入是数据挖掘研究中的一个基本问题&#xff0c;在现实世界中有着广泛的应用。它的目的是将图中的实体和关系编码到低维向量空间中&#xff0c;以便后续算法使用。负抽样&#xff0c;即从训练数据中未观察到的负三元组中抽取负三元组&#xff0c;是KG嵌入的重要步…

第四百二十六回

文章目录 1. 概念介绍2. 实现方法2.1 原生方式2.1 插件方式 3. 示例代码4. 内容总结 我们在上一章回中介绍了"如何修改程序的桌面图标"相关的内容&#xff0c;本章回中将介绍如何处理ListView中的事件冲突.闲话休提&#xff0c;让我们一起Talk Flutter吧。 1. 概念介…

利用vite创建vue3项目

vue3 项目推荐使用vue官方推荐的vite手脚架创建&#xff0c;vue3项目&#xff0c;使用vue-cli 会存在一些问题 1.node的版本 目前的vue3需要至少需要node18及以上&#xff0c;可以安装nvm node包管理器可以快速切换node版本&#xff0c;因为node的版本的兼容性真是一言难尽。…

第十四届蓝桥杯C++A组(A/B/C/D/E/H)

文章目录 A.幸运数B.有奖问答C.平方差D.更小的数E.颜色平衡树H.异或和之和 A.幸运数 /*纯暴力*/ #include <bits/stdc.h>using namespace std;void solve() {int sum 0;for(int i 1; i < 100000000; i ){int n i;int a[11];int j 1;for(; n ! 0; j ){a[j] n % …

C++ 友元函数

目录 如果觉得有用的话&#xff0c;给小弟点个赞吧&#xff01;哈哈哈哈&#xff0c;谢谢嘞&#xff01; 概念&#xff1a; 如何理解&#xff1f; 概念&#xff1a; 友元&#xff1a;慎用&#xff08;突破封装&#xff09; 友元函数&#xff1a;在函数前加friend的函数称为…

网页代理ip怎么设置的

众所周知&#xff0c;现在网络安全和隐私保护是我们非常关注的问题。为了更好地保护自己的隐私&#xff0c;提高上网的安全性&#xff0c;使用代理IP成为了很多人的首选。 那么&#xff0c;网页代理IP是怎么设置的呢&#xff1f;下面&#xff0c;就让我来一一为大家介绍。 一、…

CMake学习笔记(二)从PROJECT_BINARY_DIR看外部编译和内部编译

目录 外部编译 内部编译 总结 外部编译 看如下例子&#xff1a;我在EXE_OUT_PATH中建立了文件夹build、文件夹src2 和 文件CMakeLists.txt 其中EXE_OUT_PATH/CMakeLists.txt的内容如下&#xff1a; PROJECT(out_path) ADD_SUBDIRECTORY(src2 bin2) MESSAGE(STATUS "m…

(一)whatsapp 语音通话基本流程

经过了一整年的开发测试&#xff0c;终于将whatsapp 语音通话完成&#xff0c;期间主要参考webrtc的源码来实现.下面简要说一下大致的步骤 XMPP 协商 发起或者接受语音通话第一步是发起XMPP 协商&#xff0c;这个协商过程非常重要。下面是协商一个包 <call toxxxs.whatsap…

【大模型基础】什么是KV Cache?

哪里存在KV Cache&#xff1f; KV cache发生在多个token生成的步骤中&#xff0c;并且只发生在decoder中&#xff08;例如&#xff0c;decoder-only模型&#xff0c;如 GPT&#xff0c;或在encoder-decoder模型&#xff0c;如T5的decoder部分&#xff09;&#xff0c;BERT这样…

Protocol Buffers设计要点

概述 一种开源跨平台的序列化结构化数据的协议。可用于存储数据或在网络上进行数据通信。它提供了用于描述数据结构的接口描述语言&#xff08;IDL&#xff09;&#xff0c;也提供了根据 IDL 产生代码的程序工具。Protocol Buffers的设计目标是简单和性能&#xff0c;所以与 XM…

(执行上下文作用域链)前端八股文修炼Day4

一 作用域作用域链 作用域&#xff08;Scope&#xff09;是指程序中定义变量的区域&#xff0c;作用域规定了在这个区域内变量的可访问性。在 JavaScript 中&#xff0c;作用域可以分为全局作用域和局部作用域。 全局作用域&#xff1a;在代码中任何地方都可以访问的作用域&am…

基于Springboot的狱内罪犯危险性评估系统的设计与实现(有报告)。Javaee项目,springboot项目。

演示视频&#xff1a; 基于Springboot的狱内罪犯危险性评估系统的设计与实现&#xff08;有报告&#xff09;。Javaee项目&#xff0c;springboot项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#…