【论文复现】Vision Transformer(ViT)

1. Transformer结构

在这里插入图片描述

1.1 编码器和解码器

翻译这个过程需要中间体。也就是说,编码,解码之间需要一个中介,英文先编码成一个意思,再解码成中文。

那么查字典这个过程就是编码和解码的体现。首先我们的大脑会把它编码,编码这个句子的意思,然后通过字典映射解码。但是这样的过程太过于繁琐,如果让机器做,超长文本就对应着超长的数据量,也不利于机器学习的上下文理解。那么就有了Attention注意力机制。

1.2 Attention:注意力机制

Attention机制的核心思想是,要想翻译一个句子并不需要完全编码。像我们人类一样,仅凭借几个词就可以猜出整句话的大概意义,即使我们不懂日语,也可以根据一些汉字推出来大概的意思,这是准确度低的情况;而“中译中”这种情况,准确度当然就更高一些。

Attention注意力机制:

Attention示意图本质上是加权平均。如果我非常注意某个地方,我想要多看,那就分配更高的权重。

计算权重是使用相似度计算,

Attention机制的优点

Attention的优点是能够实现并行计算和全局视野。

并行计算的可能性是因为它不像RNN一样,依赖时序数据。它只是加权计算,但并不需要像时序数据那样,依赖像是队列一样的进出顺序。

全局视野是因为在加权计算的时候,这个计算就是涉及了整体的,它一看就能看到全部。

1.3 Self Attention

对于Self Attention来说,它的输入是一个序列,序列的获得依靠的是vector。我们把一个词转换为序列模块,需要用到vector向量去指向。而vector的指向,是有空间性的,比如说两个意思很相近或者同义的词汇,它们在空间中的距离就会比较小。相反,意思差的多,距离当然就远。

这样也可以理解为vector是和词语的意义有关系的。

注意力分配的多少取决于公式:

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V

其中,Q代表Queries(输入的信息),K代表Keys(内容信息),V代表Values(信息本身,只表达了信息的特征)。

在这里插入图片描述

Q,K,V的获得

本质上,是input的线性变换。计算使用的是矩阵乘法,实现方法是nn.Linear。

用点积表示相似度的方法是因为cos角投影长度可以很直观地理解两个量的相似度。

在1.3的式子里除以 d k \sqrt{dk} dk 这一步看起来很多余,但是它是为了避免较大的数值。较大的数值会导致softmax之后值更极端,softmax的极端值会导致梯度消失。这一步相当于控制了数值范围,让它在可观测范围内。

为什么是 d k \sqrt{dk} dk ?假设q,k是均值为0,方差为1的标准正态分布的独立随机变量,那么它们的点积的均值和方差分别为0以及dk。

之后做逐元素相乘(enterwise)。

我们需要将单词意思转换为句中意思,这就涉及一词多义的问题,在逐元素相乘得到的sum之后的z就涉及这个问题。那么我们需要根据句子中其他词来推理当前词的意思。

就比如说Mine,它有两种意思,一种是“我的”,一种是“矿石”。这当然是截然不同的词性和词义。假设我们完全不知道这个多义词,但我们可以通过观察它们在句子中的位置和与上下文的联系来推理这是什么意思。

1.4 Multihead Attention

这一步的核心是复读机()

这一步就是有多个W_q,W_k,W_v,那么上述操作重复多次,将结果用concat串在一起。

这样的复读机机制就是给注意力提供多种可能性。

应用了multihead的Conditional DETR就发现不同的head会将注意力放到物体的不同边上。
在这里插入图片描述

1.5 输入端适配

直接把图片切分成patches,flatten操作拉平patches,然后过一个linear projection使patches维度变小,然后编号123456789…输入网络即可。

就是切蛋糕喂给encoder和decoder。

在这里插入图片描述

这块儿有个patch 0的原因,有一种说法是从NLP来的:为了保持整体结构,变换尽可能的少。而NLP需要一些token负责输出,需要“终止输出”的功能模块。另一种CV里的说法是整合信息,设置在1-9之外就保持了1-9本身无干扰。Patch 0本质上是dynamic pooling layer。

1.6 位置编码

图像切分重排后失去了位置信息,而transformer的内部运算与空间信息无关。这样一来,就需要把位置信息编码重新传进网络。ViT使用了一个可学习的vector来编码,维度和patch维度一样,所以编码vector和patch vector直接相加组成输入。本质是相加,而相加是concat的一种特例。

1.7 ViT结构的数据流

输入图像是256x256像素大小,然后切开,切成N(8x8=64)个小块,每一块则是256/8=32单位长(宽)度。也就是说,现在每一小块儿是32x32。把切开的每一小块都拉平,RGB值为3,每一块儿的维度就是3x32x32=3072维。但是3072维太高了,所以过一个linear projection把维度变成1024。但是此时每个小块儿的空间位置丢失了。所以需要加上position embedding这个可学习的向量,维度一样也是1024,让他们相加。position embedding放在patch0这里,一起进入transformer。

进入了transformer encoder之后,首先因为多了一个patch 0,Patches的表示向量里数量取N+1,即(b,65,1024)。在这个norm层里patches会被归一化,一直检验维度,保证维度是一样的。

最后到MLP Head手里,就只输入负责整合信息的patch 0,此时它表示为(b,1,1024)。这样就可以做分类任务了。
在这里插入图片描述

1.8 训练方法

Transformer 非常吃数据量,需要大量的样本,大规模使用Pre-Train。它先在大数据集(ImageNet)上预训练,然后到小数据集上做 Fine Tuning。

迁移过去之后,需要把原本的MLP Head换掉,换成对应类别数的FC层。处理不同尺寸输入的时候需要对Postional Encoding的结果进行插值。

插值方法:图片切好了之后,编号,但不同的input size和patch size会切出不同数量的patch,position embedding也会变。所以编号的方法需要缩放。

1.9 实验结果

Transformer的性能需要庞大数据量的保证,很吃资源。否则,无法充分的发挥出它的性能。它和ResNet的性能不相上下

Attention的距离可以等价为Conv的感受野大小。越深的层数,Attention跨越的距离越远。在最底层,也有head能覆盖到很远的距离。这说明它确实在捕捉信息,做信息整合。

模型注意力集中的地方,都和分类的语义高度相关。

2. 代码复现

VIT仓库链接:https://github.com/lucidrains/vit-pytorch

Usage:

import torch
from vit_pytorch import ViT # 抽象出了一个VIT类v = ViT(image_size = 256,# 图片像素大小patch_size = 32, # patch的大小num_classes = 1000, # 分类数量dim = 1024,# 维度depth = 6, # transformer的block数量heads = 16, # 线性变换后输出张量的最后维数,多头注意力层中的头数mlp_dim = 2048, # MLP前馈层维度dropout = 0.1, # 每个训练步骤中被关闭神经元的比例,可以调成0emb_dropout = 0.1 # 嵌入丢失率
)img = torch.randn(1, 3, 256, 256)preds = v(img) # (1, 1000)

2.1 切图重排

输入端适配涉及到切图和reshape。

以下是ViT部分.py代码:

import torch
from torch import nnfrom einops import rearrange, repeat
from einops.layers.torch import Rearrange# helpersdef pair(t):return t if isinstance(t, tuple) else (t, t)# classesclass FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout = 0.):super().__init__()self.net = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout))def forward(self, x):return self.net(x)class Attention(nn.Module):def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):super().__init__()inner_dim = dim_head *  headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.norm = nn.LayerNorm(dim)self.attend = nn.Softmax(dim = -1)self.dropout = nn.Dropout(dropout)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x):x = self.norm(x)qkv = self.to_qkv(x).chunk(3, dim = -1)q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scaleattn = self.attend(dots)attn = self.dropout(attn)out = torch.matmul(attn, v)out = rearrange(out, 'b h n d -> b n (h d)')return self.to_out(out)class Transformer(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):super().__init__()self.norm = nn.LayerNorm(dim)self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),FeedForward(dim, mlp_dim, dropout = dropout)]))def forward(self, x):for attn, ff in self.layers:x = attn(x) + xx = ff(x) + xreturn self.norm(x)class ViT(nn.Module):def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):super().__init__()image_height, image_width = pair(image_size)patch_height, patch_width = pair(patch_size)assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'num_patches = (image_height // patch_height) * (image_width // patch_width)patch_dim = channels * patch_height * patch_widthassert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'self.to_patch_embedding = nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),nn.LayerNorm(patch_dim),nn.Linear(patch_dim, dim),nn.LayerNorm(dim),)self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))self.cls_token = nn.Parameter(torch.randn(1, 1, dim))self.dropout = nn.Dropout(emb_dropout)self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)self.pool = poolself.to_latent = nn.Identity()self.mlp_head = nn.Linear(dim, num_classes)
# 解析代码段def forward(self, img):x = self.to_patch_embedding(img)b, n, _ = x.shapecls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)x = torch.cat((cls_tokens, x), dim=1)x += self.pos_embedding[:, :(n + 1)]x = self.dropout(x)x = self.transformer(x)x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]x = self.to_latent(x)return self.mlp_head(x)

transpose函数的作用:重排张量(tensor)或者数组的维度。有一个形状为(batch_size, channels, height, width)的四维张量,代表一批图像数据。那就可以将将channels维度移动到最前面,即形状变为(channels, batch_size, height, width)。这时,你就可以使用transpose操作来实现这一转换。(类似于矩阵行变换)

img = torch.randn(1,3,256,256)b=1
c=3
h=256 = h*p1,h = 8
w =256
self.to_patch.embedding = nn.Sequential(Rearrange(‘b c (h p1)(w p2)-> b(h w)(p1 p2 c),p1 = patch_height, p2 = patch_width),# 图片切分重排nn.Linear(patch_dim, dim) # Linear Projection of Flattened Patches)

2.2 构造Patch 0

这一步:

cls_tokens = repeat(self.cls_token, '() n d -> b n d',b = b)x = torch.cat((cls_tokens, x), dim = 1) # concat方法,维度为1,在n的维度上

2.3 positional embedding

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
# 位置编码:它是一个可学习的参数,初始化为随机值。
x += self.pos_embedding[:,:(n+1)]
# 将位置编码加到输入序列上。

2.4 代码示例

首先准备数据集。

from __future__ import print_functionimport glob
from itertools import chain
import os
import random
import zipfileimport matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from linformer import Linformer
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, DataSet
from torchvision import datasets, transforms
from tqdm.notebook import tqdmfrom vit_pytorch import ViT
batch_size = 64
epochs = 20
lr = 3e-5
gamma = 0.7
seed = 42
def seed_everything(seed):random.seed(seed)os.environ['PYTHONHASHSEED'] = str(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Trueseed_everything(seed)device = 'cuda'
os.makedirs('data',exist_ok = True)train_dir = 'data/train'
test_dir = 'data/test'with zipfile.ZipFile('data/train.zip') as train_zip:train_zip.extractall('data')
with zipfile.ZipFile('data/test.zip') as test_zip:test_zip.extractall('data')train_list = glob.glob(os.path.join(train_dir,'*.jpg')) # 查找匹配的jpg文件
test_list = glob.glob(os.path.join(test_dir,'*.jpg'))print(f'Train Data:{len(train_list)}')
print(f'Test Data:{len(test_list)}')labels = [path.split('/')[-1].split('\\')[-1].split[0] for path in train_list]
print(train_list[0]
print(labels[0]))
random_idx = np.random.randint(1,len(train_list),size = 9)
fig, axes = plt.subplots(3,3,figsize = (16,12))for idx,ax in enumerate(axes.ravel()):img = Imag.open(train_list[idx])ax.set_title(labels[idx])ax.imshow(img)
train_list, valid_list = train_test_split(train_list, test_size = 0.2, stratify = labels, random_state = seed)print(f'Train Data:{len(train_list)}')
print(f'Validation Data:{len(valid_list)}')
print(f'Test Data:{len(test_list)}')
train_tranforms = tranforms.Compose([transforms.Resize((224,224)),transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),]
)val_tranforms = tranforms.Compose([transforms.Resize((224,224)),transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),]
)test_tranforms = tranforms.Compose([transforms.Resize((224,224)),transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),]
)
class CatsDogsDataset(Dataset):def __init__(self, file_list, transform = None):self.file_list = file_listself.transform = transformdef __len__(self):self.filelength = len(self.file_list)return self.filelengthdef __getitem__(self,idx):img_path = self.file_list[idx]img = Image.open(img_path)img_transformed = self.transform(img)label = img_path.split('/')[-1].split("\\")[-1].split("、")[0]label = 1 if label == "dog" else 0return img_transformed, label
train_data = CatsDogsDataset(train_list, transform = train_transforms)
valid_data = CatsDogsDataset(valid_list, transform = val_transforms)
test_data = CatsDogsDataset(test_list, transform = test_transforms)
train_loader = DataLoader(dataset = train_data,batch_size = batch_size,shuffle = True)
valid_loader = DataLoader(dataset = valid_data,batch_size = batch_size,shuffle = True)
test_loader = DataLoader(dataset = test_data,batch_size = batch_size,shuffle = True)
print(len(train_data), len(train_loader))
print(len(valid_data), len(valid_loader))

模型建立:

model = ViT(image_size = 224,patch_size = 16,num_classes = 2,dim = 768,depth = 12,heads = 12,mlp_dim = 3072,dropout = 0.1,emb_dropout = 0.1
).to(device) # 将Transformer模型移动到指定设备上,比如GPUmodel.load_state_dict(torch.load('vit_base_patch16_224_r.pth'),strict = False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/50132.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

遍历dom元素下面的子元素的方法,vue中原始标签的ref得到是该元素的dom及下面包含的子dom,与组件ref是引用不同

研究到这个的目的来源是 想用div 遍历方式 替代之前的table tr td 那种框选功能,觉得div灵活,可以随便在外面套层,td与tr之间就不能加div加了布局就乱,然后使用之前的原理( const cellList tableIdR.value.querySelec…

【反转链表 II】python刷题记录

印象中,这是遍历r2了,还好没放弃。 # Definition for singly-linked list. # class ListNode: # def __init__(self, val0, nextNone): # self.val val # self.next next class Solution:def reverseBetween(self, head: Optional…

了解Selenium中的WebElement

Selenium中到处都使用WebElement来执行各种操作。什么是WebElement?这篇文章将详细讨论WebElement。 Selenium中的WebElement是一个表示网站HTML元素的Java接口。HTML元素包含一个开始标记和一个结束标记,内容位于这两个标记之间。 HTML元素的重命名 …

SCADA系统易用性的重要性

对于中小企业而言,SCADA系统的易用性至关重要,因为它直接影响到系统的实施效率、员工的接受程度和培训成本。一个易用的SCADA系统可以减少员工对新技术的学习曲线,加快系统的部署速度,并降低长期的维护成本。此外,易用…

Parameter index out of range (2 > number of parameters, which is 1【已解决】

文章目录 1、SysLogMapper.xml添加注释导致的2、解决方法3、总结 1、SysLogMapper.xml添加注释导致的 <!--定义一个查询方法&#xff0c;用于获取日志列表--><!--方法ID为getLogList&#xff0c;返回类型com.main.server.api.model.SysLogModel,参数类型为com.main.se…

Unity UGUI 之 坐标转换

本文仅作学习笔记与交流&#xff0c;不作任何商业用途 本文包括但不限于unity官方手册&#xff0c;唐老狮&#xff0c;麦扣教程知识&#xff0c;引用会标记&#xff0c;如有不足还请斧正 本文在发布时间选用unity 2022.3.8稳定版本&#xff0c;请注意分别 前置知识&#xff1a;…

牛客JS题(三)文件扩展名

注释很详细&#xff0c;直接上代码 涉及知识点&#xff1a; 正则表达式可选链操作符 题干&#xff1a; 我的答案 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /></head><body><script>/*** 可能…

快速上手,spring boot3整合task实现定时任务

在已经上线的项目中&#xff0c;定时任务是必不可少的。基于spring boot自动装配的原理&#xff0c;我们要集成task定时任务还是非常简单的。只需要简单的两步就可以实现。 1、创建一个spring boot项目&#xff0c;并在项目的启动类&#xff08;也不一定非要是启动类&#xff…

LabVIEW 实现用户授权与管理多项测试项目

在使用 LabVIEW 开发测试软件时&#xff0c;用户授权和项目管理是一个重要的功能。为了确保系统安全性、灵活性和可扩展性&#xff0c;可以设计一个用户管理系统&#xff0c;允许管理员增加或减少用户的测试项目权限。以下是一个详细的实现方案&#xff0c;包括用户授权管理、项…

buu做题(7)

[BJDCTF2020]Mark loves cat 开始的界面没啥东西, 看了下源码好像也没啥东西 用dirsearch扫描一下 有git 泄露 用工具githack下载源码 <?phpinclude flag.php;$yds "dog"; $is "cat"; $handsome yds;foreach($_POST as $x > $y){$$x $y; }f…

GNSS相关资料

常识 GNSS(二)&#xff0c;自动驾驶定位团队的“保护伞”&#xff1a;https://owwjm7oycuv.feishu.cn/docx/BAfsdC34zoN6htx4uUycbvknnub#CiWXdRZfWoqrzmxRptJcH7MYnug GNSS伪距差分和RTK&#xff1a;https://zhuanlan.zhihu.com/p/680687517 关于GNSS技术介绍&#xff08;一&…

江科大/江协科技 STM32学习笔记P6

文章目录 LED闪烁&LE流水&蜂鸣器一、操作STM32的GPIO步骤二、RCC库函数什么是AHB与APB&#xff1f; 三、GPIO库函数GPIO初始化选择IO接口工作方式 四、四种方法实现LED闪灯 LED闪烁&LE流水&蜂鸣器 一、操作STM32的GPIO步骤 1、使用RCC开启GPIO的时钟 2、使用…

CV Method:YOLOv10 vs YOLOv8

文章目录 前言一、介绍二、YOLOv8 and YOLOv10 Comparison1.模型结构YOLOv8&#xff1a;YOLOv10&#xff1a; 2. 推理和时延3. 检测表现4. 参数利用5. 关键比较 总结 前言 YOLOv10已经开源一段时间了&#xff0c;经过我实际使用测试&#xff0c;也确实性能更好一些&#xff0c…

静态IP地址在网络安全中的角色解析与实测分析

在这个网络边界日益模糊的时代&#xff0c;每一次点击、每一次数据传输都有着安全问题。作为网络安全体系中的基石&#xff0c;静态IP地址的角色显得尤为重要而复杂。今天&#xff0c;我们的测评团队将带您深入剖析静态IP地址在网络安全中的多重角色&#xff0c;并通过两家代理…

JavaScript(16)——定时器-间歇函数

开启定时器 setInterval(函数,间隔时间) 作用&#xff1a;每隔一段时间调用这个函数&#xff0c;时间单位是毫秒 例如&#xff1a;每一秒打印一个hello setInterval(function () { document.write(hello ) }, 1000) 注&#xff1a;如果是具名函数的话不能加小括号&#xf…

【图像标签转换】XML转为TXT图像数据集标签

引言 该脚本用于将包含对象标注的 XML 文件转换为 YOLO&#xff08;You Only Look Once&#xff09;对象检测格式的 TXT 文件。脚本读取 XML 文件&#xff0c;提取对象信息&#xff0c;规范化边界框坐标&#xff0c;并将数据写入相应的 TXT 文件。此外&#xff0c;它还生成一个…

做视频混剪都是去哪里找高清素材的?分享10个高清视频素材库

提升视频混剪质感的10个高清素材库推荐 在这个视觉体验至上的时代&#xff0c;视频的视觉质量对吸引观众至关重要。如果你正在寻找高清素材以提升视频混剪作品的层次&#xff0c;那么你来对地方了。今天&#xff0c;我将为你揭秘10个视频混剪达人常用的高清素材库&#xff0c;…

学习笔记-系统框图传递函数公式推导

目录 *待了解 现代控制理论和自动控制理论区别 自动控制系统的组成 信号流图 1、系统框图 1.1、信号线、分支点、相加点 1.2、系统各环节间的连接 1.3、 相加点和分支点的等效移动&#xff08;比较点、引出点&#xff09; 2、反馈连接公式推导 2.1、前向通路传递函数…

Windows:批处理脚本学习

目录 一、第一个批处理文件 1. &&和 | | 2. | 和 & 二、变量 1.传参变量%name 2.初始化变量set命令 3.变量的使用 4.局部变量与全局变量 5.使用环境变量 6.扩充变量语法 三、注释REM和 &#xff1a;&#xff1a; 四&#xff1a;函数 1.定义函数 2.…

js 习题 1

文章目录 前言T1T2T3T4T5T6T7T8T9结语 前言 『最孤独的人最亲切&#xff0c;受过伤的人总是笑的最灿烂。』—— 「素媛」 T1 let buf""; process.stdin.on("readable",function(){let chunkprocess.stdin.read();if(chunk){bufchunk.toString();} });pr…