|从零搭建网络| VisionTransformer网络详解及搭建

🌜|从零搭建网络| VisionTransformer系列网络详解及搭建🌛

文章目录

  • 🌜|从零搭建网络| VisionTransformer系列网络详解及搭建🌛
    • 🌜 前言 🌛
    • 🌜 VIT模型详解 🌛
    • 🌜 VIT模型架构 🌛
      • 🌜 Patch 🌛
      • 🌜 Encoder Block 🌛
      • 🌜 MLP Head 🌛
    • 🌜 VIT模型复现 🌛
      • 🌜 二维Patch Embedding的实现 🌛
      • 🌜 二维Multi_head Attention(多头自注意力机制)的实现 🌛
      • 🌜 二维Encoder Block的实现 🌛
      • 🌜 二维VisionTransformer的实现 🌛
      • 🌜 一维VisionTransformer的实现 🌛
    • 🌜 总结 🌛

🌜 前言 🌛

   最近学习的时候遇到了一点小小的瓶颈,导致停更了好长时间,最近突然想到能不能从大模型方面有个总体的改进,就想到最近比较火的VisionTransformer(VIT)模型,实验了一段时间之后确实有了很多想法,决定写一篇博客浅浅记录一下。
   本篇博客主要介绍一维Transformer(1D_VIT)模型二维Transformer(2D_VIT)模型网络的模型简介以及复现的相关细节。

🌜 VIT模型详解 🌛

    VisionTransformer(VIT)模型最早在2020年发表的一篇论文上的文章:https://arxiv.org/abs/2010.11929,从论文名称An Image Worth 16x16 Words:Transformers For Image Recognition At Scale可以看出VIT模型不同于传统的Transformer模型主要运用于自然语言处理领域,他本身就是为了图像分类任务而设计,同时保留了Transformer模型中较为重要的Positional EncodingSelf-Attention以及Patch等操作,其独特的自注意力机制、全局的感受野以及并行处理能力使得VIT模型得以打破传统卷积神经网络(CNN)在视觉任务中的主导地位。
在这里插入图片描述

🌜 VIT模型架构 🌛

    VIT模型整体采用了传统Transformer的结构,同时在某些地方做了更改。
在这里插入图片描述
   由于VIT模型整体用于图像的分类任务,所以在Embedding之前整体对于图片按照分辨率分成不同的Patch,并且在 MLP HeadClass Token中也做了相关改进,后续将分别从PatchAttentionEncoderMLP 四个方面详细介绍VIT模型的架构。

🌜 Patch 🌛

在这里插入图片描述
    上图是一个vit模型的简要工作图。其次介绍一下Patch Embedding的过程,假如输入数据为三通道的彩色RGB图像,并且每张图像像素为224×224。首先将每张图像分成大小相同的token,假设分成大小为16×16的token,则一共可以分成14×14个大小相同的token。分成每个信息互不相交的token之后,会将其进行一个线性的恒等映射

这种线性的恒等映射一般为一个线性层,从而将数据映射到更高维的空间,这个过程可以看做一种嵌入(embedding)。该线性层的作用是将展平后的图像块转换为固定维度的特征向量,是的每个图像块在Transformer中作为一个token处理。

并且在这个过程中恒等映射可以起到保持原始特征的作用。

   因此在进行embedding之后,数据的形状就会变为[num_token,dim_token],在上述例子中就会变成[196,embed_dim],这里的embed_dim是我们自己所设置的token嵌入维度。另外就是在patch embed之后,vit模型会在token中嵌入一个class token用于后续分类层的分类任务,而class token的形状是有patchtoken的大小所决定。如上例子所示,当数据经过embedding变为[196,embed_dim]之后,会拼接一个大小为[1,embed_dim]的class token,最后就是注意这里的是方式是cat拼接在一起,而不是单纯的相加,在进行拼接后,数据会变为[197,embed_dim]。
   进行完class token的拼接后会需要加上position token,这里position token的相加和正常transformerposition token的相加类似,求得每个token位置的余弦相似度后生成进行相加,但是和class token的拼接不同的地方是,这里的相加只会改变token大小,并不会对token的维度产生任何影响,所以继续回到刚才的例子,生成[197,embed_dim]的position token,相加后数据形状还是[197,embed_dim]。

🌜 Encoder Block 🌛

    Encoder block是vit模型的核心组件之一,其中单个编码器的结构和传统的Transformer模型的编码器非常相似,主要由多头注意力机制(Multi-Head Self-Attention)残差连接与层归一化(Residual Connection and Layer Normalization)前馈神经网络(Feed-Forward Neural Network,FFN) 所组成。下图为Encoder block的结构图。在这里插入图片描述
   将token首先进行归一化处理后使用多头注意力机制来在全局范围内捕捉不同token之间的依赖关系,而不仅限于局部感受野,这对于理解图像中的复杂结构和长距离依赖关系很重要,同时通过并行处理多个注意力头,可以很好地增强模型的表示能力和学习能力;而后使用残差连接将输入值与多头注意力的输出值相加一方面缓解深度网络中梯度消失或爆炸的问题,确保梯度在反向传播中更稳定地流动,并且某种程度上加速模型收敛速度;最后在经过层归一化、全连接层以及残差连接,在全连接层中引入了非线性激活函数,使得模型能够学习到更多的非线性特征,从而进一步提升其表征能力。vit 模型正是通过将这种Encoder block重复堆叠L次来完成tokenEncoder

🌜 MLP Head 🌛

    上述模块就是VIT模型提取特征最为重要的几个模块,而MLP Head多数情况下都只是用来做一个分类处理,下面为详细的结构图。(图片来自B站UP主@霹雳吧啦Wz ,感谢大佬)
在这里插入图片描述
   可以从中看出主要就是由线性层、激活函数以及Drop out层所组成。

🌜 VIT模型复现 🌛

   由于我的课题是关于一维信号分类,所以我在复现模型的时候会多复现一个可以处理一维信号的一维模型,所以本节VIT模型的实现包括VIT_1D(一维VIT)VIT_2D(二维VIT)的实现。并且由于VIT模型中一维和二维的实现差别不太大,所以后面在介绍的时候着重介绍二维VIT模型的实现,一维VIT模型会在本节末尾给出。其中VIT模型的实现主要包括Patch Embedding层的实现、Multi_head Attention(多头自注意力机制)的实现、Encoder block的实现以及最后MLP Head的实现。

🌜 二维Patch Embedding的实现 🌛

    本节代码实现部分默认输入数据为3通道224×224大小的RGB图像,并且patch_size(token大小)为16×16,embed_dim(嵌入后的数据维度)为768。即以下是Patch Embedding的初始化部分。

class PatchEmbed(torch.nn.Module):def __init__(self,img_size = 224,patch_size = 16,in_channel = 3,embed_dim = 768,norm_layer = None):'''初始化:param img_size: 输入数据大小:param patch_size: 分成的token大小:param in_channel: 输入数据通道数:param embed_dim: embed后的维度大小:param norm_layer: 是否使用归一化处理'''super().__init__()img_size = (img_size,img_size)patch_size = (patch_size,patch_size)self.image_size = img_sizeself.patch_size = patch_sizeself.grid_size = (img_size[0] // patch_size[0],img_size[1] // patch_size[1])self.num_patches = self.grid_size[0] * self.grid_size[1]self.proj = torch.nn.Conv2d(in_channel,embed_dim,patch_size,patch_size[0])self.norm = norm_layer(embed_dim) if norm_layer else torch.nn.Identity()

   由于VIT模型进行embedding的重要手段为使用卷积层进行维度映射,所以这里实例化一个卷积核大小为patch_size×patch_size的卷积层,并且设置输出维度为embed_dim,并且最后在初始化归一化层时,令如果有输入归一化层就是用输入的归一化层,否则不适用归一化措施。
   下面是前向传播的代码:

    def forward(self,x):# B,C,H,W = x.shapex = self.proj(x).flatten(2).transpose(1,2)#[B,C,HW] [B,HW,C] [B,token,dim]x = self.norm(x)return x

   假设输入数据x的初始形状为[B,C,H,W]。从代码中可以看到,输入数据首先是经过初始化中定义的卷积层。此时数据形状为[B,embed_dim,H_patch,W_patch];而后使用flatten(2)将数据从切片为2的地方开始进行展平操作,即对数据后两个维度展平,展平后的数据形状为[B,embed_dim,H_patch×W_patch];最后使用transpose(1,2)将数据切片的第1个维度和第2个维度进行位置的互换。最后数据形状变为:[B,H_patch×W_patch,embed_dim],并且此时第1个维度为num_token,第二个维度为token_dimtranspose函数的主要作用就是改变数据维度位置,下面是一段实例代码,更清晰展示transpose函数的作用。

import torchx = torch.randn(1,3,2)
print(x.shape)#torch.size([1,3,2])
x = x.transpose(1,2)
print(x.shape)#torch.size([1,2,3])

   Patch Embedding完整代码为:

class PatchEmbed(torch.nn.Module):def __init__(self,img_size = 224,patch_size = 16,in_channel = 3,embed_dim = 768,norm_layer = None):'''初始化:param img_size: 输入数据大小:param patch_size: 分成的token大小:param in_channel: 输入数据通道数:param embed_dim: embed后的维度大小:param norm_layer: 是否使用归一化处理'''super().__init__()img_size = (img_size,img_size)patch_size = (patch_size,patch_size)self.image_size = img_sizeself.patch_size = patch_sizeself.grid_size = (img_size[0] // patch_size[0],img_size[1] // patch_size[1])self.num_patches = self.grid_size[0] * self.grid_size[1]self.proj = torch.nn.Conv2d(in_channel,embed_dim,patch_size,patch_size[0])self.norm = norm_layer(embed_dim) if norm_layer else torch.nn.Identity()def forward(self,x):# B,C,H,W = x.shapex = self.proj(x).flatten(2).transpose(1,2)#[B,C,HW] [B,HW,C] [B,token,dim]x = self.norm(x)return x

🌜 二维Multi_head Attention(多头自注意力机制)的实现 🌛

   首先附上自注意力机制的实现公式:
在这里插入图片描述
   多头自注意力机制的实现就不进行过多赘述了,就是分别寻找qkv向量使其分别进行信息交互,与正常多头自注意力机制不一样的地方是这里实现的时候加入了Dropout层,并且在输入参数的地方可以自主输入缩放因子的数值。
   多头自注意力机制实现代码:

class Multihead_Attention(torch.nn.Module):def __init__(self,dim, #输入token的dimensionnum_heads = 8,#head数量qkv_bias = False,#生成QKV时是否使用偏置qk_scale = None,#自定义缩放因子attn_drop_ratio = 0.,proj_drop_ratio = 0.):'''多头自注意力机制:param dim: 输入token的维度:param num_heads: 注意力头的数量:param qkv_bias: 生成三个向量时是否使用偏置:param qk_scale: 是否自定义缩放因子:param attn_drop_ratio: 注意力机制层Dropout的概率:param proj_drop_ratio: 映射层Dropout的概率'''super().__init__()self.num_heads = num_headshead_dim = dim // self.num_heads#每一个head的dimensionself.scale = qk_scale or head_dim ** -0.5self.qkv = torch.nn.Linear(dim,dim*3,bias=qkv_bias)self.attn_drop = torch.nn.Dropout(attn_drop_ratio)self.proj = torch.nn.Linear(dim,dim)self.proj_drop = torch.nn.Dropout(proj_drop_ratio)def forward(self,x):B,N,C = x.shapeqkv = self.qkv(x).reshape(B,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)q,k,v = qkv[0],qkv[1],qkv[2]attn = (q @ k.transpose(-2,-1)) * self.scaleattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)x = (attn @ v).transpose(1,2).reshape(B,N,C)x = self.proj(x)x = self.proj_drop(x)return x

这里生成三个向量的时候,并没有选择一个一个生成,而是一次性生成三个然后在分别拿出来进行信息交互处理。

🌜 二维Encoder Block的实现 🌛

在这里插入图片描述
   上图为Encoder block的结构,直接根据上述结构进行搭建即可。下面为搭建代码:

class Block(torch.nn.Module):def __init__(self,dim,num_heads,mlp_ratio = 4.,qkv_bias = False,qk_scale = None,drop_ratio = 0,attn_drop_ratio = 0.,drop_path_ratio = 0.,act_layer = torch.nn.GELU,norm_layer = torch.nn.LayerNorm):'''Encoder block:param dim: token 的输入维度:param num_heads: 注意力头的数量:param mlp_ratio: mlp隐藏层层倍数:param qkv_bias: qkv是否使用偏置:param qk_scale: 是否自定义缩放因子:param drop_ratio: 映射层dropout概率:param attn_drop_ratio: 注意力机制dropout概率:param drop_path_ratio: 是否使用Droppath:param act_layer: 是否自定义激活函数:param norm_layer: 是否自定义归一化层'''super().__init__()self.norm1 = norm_layer(dim)self.attn = Multihead_Attention(dim,num_heads,qkv_bias,qk_scale,attn_drop_ratio,drop_ratio)self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else torch.nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = MLP(in_features=dim,hidden_features=mlp_hidden_dim,act_layer=act_layer,drop_ratio=drop_ratio)def forward(self,x):x += self.drop_path(self.attn(self.norm1(x)))x += self.drop_path(self.mlp(self.norm2(x)))return x

🌜 二维VisionTransformer的实现 🌛

在这里插入图片描述
   上图为VisionTransformer的总体结构,在初始化的过程中要记得把class tokenposition token也一并进行初始化,下面为初始化部分代码:

class VisionTransformer(torch.nn.Module):def __init__(self,img_size = 224, #输入图片大小patch_size = 16, #每个token大小in_channel = 3, #输入图片通道num_classes = 1000, #类别embed_dim = 768, #token维度depth = 12, #encoder重复次数num_heads = 12, #多头注意力机制mlp_ratio = 4.0,#mlp隐藏层倍数qkv_bias = True, #查询QKV时是否使用偏置qk_scale = None, #自定义缩放因子representation_size = None, #是否使用representationdistilled = False, #是否知识蒸馏drop_ratio = 0, #dropout比例attn_drop_ratio = 0, #attention中dropout比例drop_path_ratio = 0, #encoder中dropout比例embed_layer = PatchEmbed, #patchembednorm_layer = None, #归一化act_layer = None #激活函数):super().__init__()self.num_classes = num_classesself.num_features = self.embed_dim = embed_dimself.num_tokens = 2 if distilled else 1 #num_token默认为1norm_layer = norm_layer or partial(torch.nn.LayerNorm,eps = 1e-6)#默认为layernormact_layer = act_layer or torch.nn.GELU #激活函数默认为geluself.patch_embed = embed_layer(img_size,patch_size,in_channel,embed_dim,norm_layer)num_patches = self.patch_embed.num_patchesself.cls_token = torch.nn.Parameter(torch.zeros(1,1,embed_dim)) #可训练参数(batch,1,embed_dim)self.dist_token = torch.nn.Parameter(torch.zeros(1,1,embed_dim)) if distilled else None #知识蒸馏self.pos_embed = torch.nn.Parameter(torch.zeros(1,num_patches+self.num_tokens,embed_dim))#(batch,197,768)self.pos_drop = torch.nn.Dropout(p=drop_ratio)#position后的Dropoutdpr = [x.item() for x in torch.linspace(0,drop_path_ratio,depth)]#创建depth个递增的dprself.blocks = torch.nn.Sequential(*[Block(embed_dim,num_heads,mlp_ratio,qkv_bias,qk_scale,drop_ratio,attn_drop_ratio,dpr[i],norm_layer=norm_layer,act_layer=act_layer)for i in range(depth)])self.norm = norm_layer(embed_dim)#representation layerif representation_size and not distilled:#是否使用representationself.has_logits = Trueself.num_features = representation_sizeself.pre_logits = torch.nn.Sequential(OrderedDict([('fc',torch.nn.Linear(embed_dim,representation_size)),('act',torch.nn.Tanh())]))else:self.has_logits = Falseself.pre_logits = torch.nn.Identity()#classifier headsself.head = torch.nn.Linear(self.num_features,num_classes) if num_classes > 0 else torch.nn.Identity()#linearself.head_dist = Noneif distilled:self.head_dist = torch.nn.Linear(self.embed_dim,self.num_classes) if num_classes > 0 else torch.nn.Identity()#weight inittorch.nn.init.trunc_normal_(self.pos_embed,std=0.02)if self.dist_token is not None:torch.nn.init.trunc_normal_(self.dist_token,std=0.02)torch.nn.init.trunc_normal_(self.cls_token,std = 0.02)self.apply(_init_vit_weights)

   在传入的参数中有一个representation_size指的是是否使用representation。如果使用的话在最后的MLP Head中会加入一个卷积层和一个Tanh激活函数,在源码中预训练阶段使用了representation,而在迁移学习之后没有使用,后续使用的话我们可以根据自己的 需求来看是否使用。还有一个就是distilled指的是是否使用知识蒸馏,没有这方面需求的话可以直接将其设置为False
   初始化结束后,会先进行一个特征提取的前向传播,下面是实现代码:

    def forward_features(self,x):x = self.patch_embed(x)cls_token = self.cls_token.expand(x.shape[0],-1,-1)if self.dist_token is None:x = torch.cat((cls_token,x),dim=1)else:x = torch.cat((cls_token,self.dist_token.expand(x.shape[0],-1,-1),x),dim=1)x = self.pos_drop(x + self.pos_embed)x = self.blocks(x)x = self.norm(x)if self.dist_token is None:return self.pre_logits(x[:,0])else:return x[:,0],x[:,1]

   这里比较需要注意class token是使用cat拼接上去的,而position token是直接进行相加。
最后是整体的前向传播代码:

    def forward(self,x):x = self.forward_features(x)if self.head_dist is not None:x,x_dist = self.head(x[0]),self.head_dist(x[1])if self.training and not torch.jit.is_scripting():return x,x_distelse:return (x+x_dist) / 2else:x = self.head(x)return x

   不需要使用知识蒸馏的情况下,可以直接将参数distilled设置为False或者是将代码中相关内容直接删除。
   二维VisionTransformer完整实现代码:

import torch
from collections import OrderedDict
from functools import partial#二维网络
def drop_path(x,drop_prob:float = 0.,training:bool = False):if drop_prob == 0. or not training:return xkeep_prob = 1 - drop_probshape = (x.shape[0],) + (1,) * (x.ndim - 1)random_tensor = keep_prob + torch.rand(shape,dtype=x.dtype,device=x.device)random_tensor.floor_()output = x.div(keep_prob) * random_tensorreturn outputclass DropPath(torch.nn.Module):def __init__(self,drop_prob = None):super().__init__()self.drop_prob = drop_probdef forward(self,x):return drop_path(x,self.drop_prob,self.training)class PatchEmbed(torch.nn.Module):def __init__(self,img_size = 224,patch_size = 16,in_channel = 3,embed_dim = 768,norm_layer = None):'''初始化:param img_size: 输入数据大小:param patch_size: 分成的token大小:param in_channel: 输入数据通道数:param embed_dim: embed后的维度大小:param norm_layer: 是否使用归一化处理'''super().__init__()img_size = (img_size,img_size)patch_size = (patch_size,patch_size)self.image_size = img_sizeself.patch_size = patch_sizeself.grid_size = (img_size[0] // patch_size[0],img_size[1] // patch_size[1])self.num_patches = self.grid_size[0] * self.grid_size[1]self.proj = torch.nn.Conv2d(in_channel,embed_dim,patch_size,patch_size[0])self.norm = norm_layer(embed_dim) if norm_layer else torch.nn.Identity()def forward(self,x):# B,C,H,W = x.shapex = self.proj(x).flatten(2).transpose(1,2)#[B,C,HW] [B,HW,C] [B,token,dim]x = self.norm(x)return xclass Multihead_Attention(torch.nn.Module):def __init__(self,dim, #输入token的dimensionnum_heads = 8,#head数量qkv_bias = False,#生成QKV时是否使用偏置qk_scale = None,#自定义缩放因子attn_drop_ratio = 0.,proj_drop_ratio = 0.):'''多头自注意力机制:param dim: 输入token的维度:param num_heads: 注意力头的数量:param qkv_bias: 生成三个向量时是否使用偏置:param qk_scale: 是否自定义缩放因子:param attn_drop_ratio: 注意力机制层Dropout的概率:param proj_drop_ratio: 映射层Dropout的概率'''super().__init__()self.num_heads = num_headshead_dim = dim // self.num_heads#每一个head的dimensionself.scale = qk_scale or head_dim ** -0.5self.qkv = torch.nn.Linear(dim,dim*3,bias=qkv_bias)self.attn_drop = torch.nn.Dropout(attn_drop_ratio)self.proj = torch.nn.Linear(dim,dim)self.proj_drop = torch.nn.Dropout(proj_drop_ratio)def forward(self,x):B,N,C = x.shapeqkv = self.qkv(x).reshape(B,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)q,k,v = qkv[0],qkv[1],qkv[2]attn = (q @ k.transpose(-2,-1)) * self.scaleattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)x = (attn @ v).transpose(1,2).reshape(B,N,C)x = self.proj(x)x = self.proj_drop(x)return xclass MLP(torch.nn.Module):def __init__(self,in_features,hidden_features = None,#一般为in_features的四倍out_features = None,act_layer = torch.nn.GELU,drop_ratio = 0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = torch.nn.Linear(in_features,hidden_features)self.act = act_layer()self.fc2 = torch.nn.Linear(hidden_features,out_features)self.drop = torch.nn.Dropout(drop_ratio)def forward(self,x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return xclass Block(torch.nn.Module):def __init__(self,dim,num_heads,mlp_ratio = 4.,qkv_bias = False,qk_scale = None,drop_ratio = 0,attn_drop_ratio = 0.,drop_path_ratio = 0.,act_layer = torch.nn.GELU,norm_layer = torch.nn.LayerNorm):'''Encoder block:param dim: token 的输入维度:param num_heads: 注意力头的数量:param mlp_ratio: mlp隐藏层层倍数:param qkv_bias: qkv是否使用偏置:param qk_scale: 是否自定义缩放因子:param drop_ratio: 映射层dropout概率:param attn_drop_ratio: 注意力机制dropout概率:param drop_path_ratio: 是否使用Droppath:param act_layer: 是否自定义激活函数:param norm_layer: 是否自定义归一化层'''super().__init__()self.norm1 = norm_layer(dim)self.attn = Multihead_Attention(dim,num_heads,qkv_bias,qk_scale,attn_drop_ratio,drop_ratio)self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else torch.nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = MLP(in_features=dim,hidden_features=mlp_hidden_dim,act_layer=act_layer,drop_ratio=drop_ratio)def forward(self,x):x += self.drop_path(self.attn(self.norm1(x)))x += self.drop_path(self.mlp(self.norm2(x)))return xclass VisionTransformer(torch.nn.Module):def __init__(self,img_size = 224, #输入图片大小patch_size = 16, #每个token大小in_channel = 3, #输入图片通道num_classes = 1000, #类别embed_dim = 768, #token维度depth = 12, #encoder重复次数num_heads = 12, #多头注意力机制mlp_ratio = 4.0,#mlp隐藏层倍数qkv_bias = True, #查询QKV时是否使用偏置qk_scale = None, #自定义缩放因子representation_size = None, #是否使用representationdistilled = False, #是否知识蒸馏drop_ratio = 0, #dropout比例attn_drop_ratio = 0, #attention中dropout比例drop_path_ratio = 0, #encoder中dropout比例embed_layer = PatchEmbed, #patchembednorm_layer = None, #归一化act_layer = None #激活函数):super().__init__()self.num_classes = num_classesself.num_features = self.embed_dim = embed_dimself.num_tokens = 2 if distilled else 1 #num_token默认为1norm_layer = norm_layer or partial(torch.nn.LayerNorm,eps = 1e-6)#默认为layernormact_layer = act_layer or torch.nn.GELU #激活函数默认为geluself.patch_embed = embed_layer(img_size,patch_size,in_channel,embed_dim,norm_layer)num_patches = self.patch_embed.num_patchesself.cls_token = torch.nn.Parameter(torch.zeros(1,1,embed_dim)) #可训练参数(batch,1,embed_dim)self.dist_token = torch.nn.Parameter(torch.zeros(1,1,embed_dim)) if distilled else None #知识蒸馏self.pos_embed = torch.nn.Parameter(torch.zeros(1,num_patches+self.num_tokens,embed_dim))#(batch,197,768)self.pos_drop = torch.nn.Dropout(p=drop_ratio)#position后的Dropoutdpr = [x.item() for x in torch.linspace(0,drop_path_ratio,depth)]#创建depth个递增的dprself.blocks = torch.nn.Sequential(*[Block(embed_dim,num_heads,mlp_ratio,qkv_bias,qk_scale,drop_ratio,attn_drop_ratio,dpr[i],norm_layer=norm_layer,act_layer=act_layer)for i in range(depth)])self.norm = norm_layer(embed_dim)#representation layerif representation_size and not distilled:#是否使用representationself.has_logits = Trueself.num_features = representation_sizeself.pre_logits = torch.nn.Sequential(OrderedDict([('fc',torch.nn.Linear(embed_dim,representation_size)),('act',torch.nn.Tanh())]))else:self.has_logits = Falseself.pre_logits = torch.nn.Identity()#classifier headsself.head = torch.nn.Linear(self.num_features,num_classes) if num_classes > 0 else torch.nn.Identity()#linearself.head_dist = Noneif distilled:self.head_dist = torch.nn.Linear(self.embed_dim,self.num_classes) if num_classes > 0 else torch.nn.Identity()#weight inittorch.nn.init.trunc_normal_(self.pos_embed,std=0.02)if self.dist_token is not None:torch.nn.init.trunc_normal_(self.dist_token,std=0.02)torch.nn.init.trunc_normal_(self.cls_token,std = 0.02)self.apply(_init_vit_weights)def forward_features(self,x):x = self.patch_embed(x)cls_token = self.cls_token.expand(x.shape[0],-1,-1)if self.dist_token is None:x = torch.cat((cls_token,x),dim=1)else:x = torch.cat((cls_token,self.dist_token.expand(x.shape[0],-1,-1),x),dim=1)x = self.pos_drop(x + self.pos_embed)x = self.blocks(x)x = self.norm(x)if self.dist_token is None:return self.pre_logits(x[:,0])else:return x[:,0],x[:,1]def forward(self,x):x = self.forward_features(x)if self.head_dist is not None:x,x_dist = self.head(x[0]),self.head_dist(x[1])if self.training and not torch.jit.is_scripting():return x,x_distelse:return (x+x_dist) / 2else:x = self.head(x)return xdef _init_vit_weights(m):if isinstance(m,torch.nn.Linear):torch.nn.init.trunc_normal_(m.weight,std=.01)if m.bias is not None:torch.nn.init.zeros_(m.bias)elif isinstance(m,torch.nn.Conv2d):torch.nn.init.kaiming_normal(m.weight,mode = 'fan_out')if m.bias is not None:torch.nn.init.zeros_(m.bias)elif isinstance(m,torch.nn.LayerNorm):torch.nn.init.zeros_(m.bias)torch.nn.init.ones_(m.weight)def vit_base_patch16_224(num_classes: int = 1000):"""ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA  密码: eu9f"""model = VisionTransformer(img_size=224,patch_size=16,embed_dim=768,depth=12,num_heads=12,representation_size=None,num_classes=num_classes)return modelif __name__ == '__main__':x = torch.randn(1,3,224,224)model = vit_base_patch16_224(1000)y = model(x)print(y.shape)print(model)

   代码最后实例化了一个vit_base_patch16_224模型,其中需要预训练的权重可以根据连接下载自取。

🌜 一维VisionTransformer的实现 🌛

   一维模型的实现主要是将代码中二维网络以及相关展平操作进行更改,下列是实现代码:

import torch
import os
from collections import OrderedDict
from functools import partial#一维网络
def drop_path(x,drop_prob:float = 0.,training:bool = False):if drop_prob == 0. or not training:return xkeep_prob = 1 - drop_probshape = (x.shape[0],) + (1,) * (x.ndim - 1)random_tensor = keep_prob + torch.rand(shape,dtype=x.dtype,device=x.device)random_tensor.floor_()output = x.div(keep_prob) * random_tensorreturn outputclass DropPath(torch.nn.Module):def __init__(self,drop_prob = None):super().__init__()self.drop_prob = drop_probdef forward(self,x):return drop_path(x,self.drop_prob,self.training)class PatchEmbed(torch.nn.Module):def __init__(self,input_size,patch_size,in_channel,embed_dim,norm_layer = None):super().__init__()self.input_size = input_sizeself.patch_size = patch_sizeself.grid_size = input_size // patch_sizeself.proj = torch.nn.Conv1d(in_channel,embed_dim,patch_size,patch_size)self.norm = norm_layer(embed_dim) if norm_layer else torch.nn.Identity()def forward(self,x):x = self.proj(x)x = x.transpose(1,2)x = self.norm(x)return xclass Multihead_Attention(torch.nn.Module):def __init__(self,dim, #输入token的dimensionnum_heads = 8,#head数量qkv_bias = False,#生成QKV时是否使用偏置qk_scale = None,#自定义缩放因子attn_drop_ratio = 0,proj_drop_ratio = 0):super().__init__()self.num_heads = num_headshead_dim = dim // self.num_heads#每一个head的dimensionself.scale = qk_scale or head_dim ** -0.5self.qkv = torch.nn.Linear(dim,dim*3,bias=qkv_bias)self.attn_drop = torch.nn.Dropout(attn_drop_ratio)self.proj = torch.nn.Linear(dim,dim)self.proj_drop = torch.nn.Dropout(proj_drop_ratio)def forward(self,x):B,N,C = x.shape# print(f'input shape:{x.shape}')qkv = self.qkv(x)# print(f'qkv shape:{qkv.shape}')assert C % self.num_heads == 0, "Embedding dimension must be divisible by number of heads"qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)# print(f'reshape shape:{qkv.shape}')qkv = qkv.permute(2, 0, 3, 1, 4)q,k,v = qkv[0],qkv[1],qkv[2]attn = (q @ k.transpose(-2,-1)) * self.scaleattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)x = (attn @ v).transpose(1,2).reshape(B,N,C)x = self.proj(x)x = self.proj_drop(x)return xclass MLP(torch.nn.Module):def __init__(self,in_features,hidden_features = None,#一般为in_features的四倍out_features = None,act_layer = torch.nn.GELU,drop_ratio = 0):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = torch.nn.Linear(in_features,hidden_features)self.act = act_layer()self.fc2 = torch.nn.Linear(hidden_features,out_features)self.drop = torch.nn.Dropout(drop_ratio)def forward(self,x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return xclass Block(torch.nn.Module):def __init__(self,dim,num_heads,mlp_ratio = 4,qkv_bias = False,qk_scale = None,drop_ratio = 0,attn_drop_ratio = 0,drop_path_ratio = 0,act_layer = torch.nn.GELU,norm_layer = torch.nn.LayerNorm):super().__init__()self.norm1 = norm_layer(dim)self.attn = Multihead_Attention(dim,num_heads,qkv_bias,qk_scale,attn_drop_ratio,drop_ratio)self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0 else torch.nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = MLP(in_features=dim,hidden_features=mlp_hidden_dim,act_layer=act_layer,drop_ratio=drop_ratio)def forward(self,x):x = x + self.drop_path(self.attn(self.norm1(x)))x = x + self.drop_path(self.mlp(self.norm2(x)))return xclass VisionTransformer(torch.nn.Module):def __init__(self,img_size = 224, #输入数据大小patch_size = 16, #每个token大小in_channel = 3, #输入数据通道num_classes = 1000, #类别embed_dim = 768, #token维度depth = 12, #encoder重复次数num_heads = 12, #多头注意力机制mlp_ratio = 4.0,#mlp隐藏层倍数qkv_bias = True, #查询QKV时是否使用偏置qk_scale = None, #自定义缩放因子representation_size = None, #是否使用representationdistilled = False, #是否知识蒸馏drop_ratio = 0., #dropout比例attn_drop_ratio = 0., #attention中dropout比例drop_path_ratio = 0., #encoder中dropout比例embed_layer = PatchEmbed, #patchembednorm_layer = None, #归一化act_layer = None #激活函数):super().__init__()self.num_classes = num_classesself.num_features = self.embed_dim = embed_dimself.num_tokens = 2 if distilled else 1 #num_token默认为1norm_layer = norm_layer or partial(torch.nn.LayerNorm,eps = 1e-6)#默认为layernormact_layer = act_layer or torch.nn.GELU #激活函数默认为geluself.patch_embed = embed_layer(img_size,patch_size,in_channel,embed_dim,norm_layer)num_patches = self.patch_embed.grid_sizeself.cls_token = torch.nn.Parameter(torch.zeros(1,1,embed_dim)) #可训练参数(batch,1,embed_dim)self.dist_token = torch.nn.Parameter(torch.zeros(1,1,embed_dim)) if distilled else None #知识蒸馏self.pos_embed = torch.nn.Parameter(torch.zeros(1,num_patches+self.num_tokens,embed_dim))#(batch,197,768)self.pos_drop = torch.nn.Dropout(p=drop_ratio)#position后的Dropoutdpr = [x.item() for x in torch.linspace(0,drop_path_ratio,depth)]#创建depth个递增的dprself.blocks = torch.nn.Sequential(*[Block(embed_dim,num_heads,mlp_ratio,qkv_bias,qk_scale,drop_ratio,attn_drop_ratio,dpr[i],norm_layer=norm_layer,act_layer=act_layer)for i in range(depth)])self.norm = norm_layer(embed_dim)#representation layerif representation_size and not distilled:#是否使用representationself.has_logits = Trueself.num_features = representation_sizeself.pre_logits = torch.nn.Sequential(OrderedDict([('fc',torch.nn.Linear(embed_dim,representation_size)),('act',torch.nn.Tanh())]))else:self.has_logits = Falseself.pre_logits = torch.nn.Identity()#classifier headsself.head = torch.nn.Linear(self.num_features,num_classes) if num_classes > 0 else torch.nn.Identity()#linearself.head_dist = Noneif distilled:self.head_dist = torch.nn.Linear(self.embed_dim,self.num_classes) if num_classes > 0 else torch.nn.Identity()#weight inittorch.nn.init.trunc_normal_(self.pos_embed,std=0.02)if self.dist_token is not None:torch.nn.init.trunc_normal_(self.dist_token,std=0.02)torch.nn.init.trunc_normal_(self.cls_token,std = 0.02)self.apply(_init_vit_weights)def forward_features(self,x):x = self.patch_embed(x)cls_token = self.cls_token.expand(x.shape[0],-1,-1)if self.dist_token is None:x = torch.cat((cls_token,x),dim=1)else:x = torch.cat((cls_token,self.dist_token.expand(x.shape[0],-1,-1),x),dim=1)x = self.pos_drop(x + self.pos_embed)x = self.blocks(x)x = self.norm(x)if self.dist_token is None:return self.pre_logits(x[:,0])else:return x[:,0],x[:,1]def forward(self,x):x = self.forward_features(x)if self.head_dist is not None:x,x_dist = self.head(x[0]),self.head_dist(x[1])if self.training and not torch.jit.is_scripting():return x,x_distelse:return (x+x_dist) / 2else:x = self.head(x)return xdef _init_vit_weights(m):if isinstance(m,torch.nn.Linear):torch.nn.init.trunc_normal_(m.weight,std=.01)if m.bias is not None:torch.nn.init.zeros_(m.bias)elif isinstance(m,torch.nn.Conv2d):torch.nn.init.kaiming_normal(m.weight,mode = 'fan_out')if m.bias is not None:torch.nn.init.zeros_(m.bias)elif isinstance(m,torch.nn.LayerNorm):torch.nn.init.zeros_(m.bias)torch.nn.init.ones_(m.weight)def vit_base_patch16_224(num_classes: int = 1000):"""ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA  密码: eu9f"""model = VisionTransformer(img_size = 13, #输入数据大小patch_size = 1, #每个token大小in_channel = 1024, #输入数据通道num_classes = 125, #类别embed_dim = 768, #token维度depth = 12, #encoder重复次数num_heads = 12, #多头注意力机制mlp_ratio = 4.0,#mlp隐藏层倍数qkv_bias = True, #查询QKV时是否使用偏置qk_scale = None, #自定义缩放因子representation_size = None, #是否使用representationdistilled = False, #是否知识蒸馏drop_ratio = 0., #dropout比例attn_drop_ratio = 0., #attention中dropout比例drop_path_ratio = 0., #encoder中dropout比例embed_layer = PatchEmbed, #patchembednorm_layer = None, #归一化act_layer = None)return modelif __name__ == '__main__':model = vit_base_patch16_224()x = torch.randn(400,1024,13)y = model(x)print(y.shape)

🌜 总结 🌛

   感觉这个模型应该是我复现过的所有模型中最大的一个了,而且最后感觉使用VIT去训练一维信号不如CNN。。。。因为一方面这个方面本身就是为了图片分类而创造的Transformer模型,另一方面他性能好很大一部分原因是因为使用了Image-Net上预训练得出的权重,而我们如果使用他来训练一维模型的话那些预训练的参数指定是用不了的。。。。
   见仁见智吧这个问题,有写的不好的地方我们可以一起探讨。
   最近好像还有个模型是使用ResNet进行特征提取,然后后面接VIT模型进行后续训练,感觉那个应该对于一维信号或许好使,找个时间可以实现以下试试看。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/40238.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mybatis、mybatis-plus插件开发,实现数据脱敏功能

首先说一下mybatis中四大组件的作用,下面开发的插件拦截器会使用 四大组件Executor、StatementHandler、ParameterHandler、ResultSetHandler Executor: Executor 是 MyBatis 中的执行器,负责 SQL 语句的执行工作。它通过调度 StatementHan…

python基础语法 004-3流程控制- while

1 while while 主要用的场景没有 for 循环多。 while循环&#xff1a;主要运行场景 我不知道什么时候结束。。。不知道运行多少次 1.1 基本用法 # while 4 > 3: #一直执行 # print("hell0")while 4 < 3: #不会打印&#xff0c;什么都没有print("…

IT之旅启航:高考后IT专业预习全攻略

✨作者主页&#xff1a; Mr.Zwq✔️个人简介&#xff1a;一个正在努力学技术的Python领域创作者&#xff0c;擅长爬虫&#xff0c;逆向&#xff0c;全栈方向&#xff0c;专注基础和实战分享&#xff0c;欢迎咨询&#xff01; 您的点赞、关注、收藏、评论&#xff0c;是对我最大…

Java知识点大纲

文章目录 第一阶段&#xff1a;JavaSE1、面向对象编程(基础)1)面向过程和面向对象区别2)类和对象的概述3)类的属性和方法4)创建对象内存分析5)构造方法(Construtor)及其重载6)对象类型的参数传递7)this关键字详解8)static关键字详解9)局部代码块、构造代码块和静态代码块10)pac…

2-24 基于图像处理的细胞计数方法

基于图像处理的细胞计数方法。经过初次二值化、中值滤波后二值化、优化后二值化图像、填充背景色的二进制图像、开运算后的图像一系列运算后&#xff0c;进行标签设置&#xff0c;最终得到细胞总数。程序已调通&#xff0c;可直接运行。 2-24 细胞计数方法 中值滤波后二值化 - …

【C++】 解决 C++ 语言报错:Invalid Cast

文章目录 引言 无效类型转换&#xff08;Invalid Cast&#xff09;是 C 编程中常见且严重的错误之一。当程序试图进行不合法或不安全的类型转换时&#xff0c;就会发生无效类型转换错误。这种错误不仅会导致程序崩溃&#xff0c;还可能引发不可预测的行为。本文将深入探讨无效…

图像增强方法汇总OpenCV+python实现【第一部分:常用图像增强方法】

图像增强方法汇总OpenCVpython实现【第一部分】 前言常用的图像增强方法1. 旋转&#xff08;Rotation&#xff09;&#xff1a;2. 平移&#xff08;Translation&#xff09;&#xff1a;3. 缩放&#xff08;Scaling&#xff09;&#xff1a;4. 剪切变换&#xff08;Shear Trans…

UserWarning: IPython History requires SQLite, your history will not be saved

UserWarning: IPython History requires SQLite, your history will not be saved 很久未打开pycharm&#xff0c;控制台出现爆红 解决方法&#xff1a; 重启pycharm&#xff0c;就好啦&#xff01;&#xff01;&#xff01;我猜测可能是上次pycharm没有关闭就电脑关机&…

《企业实战分享 · 内存溢出分析》

&#x1f4e2; 大家好&#xff0c;我是 【战神刘玉栋】&#xff0c;有10多年的研发经验&#xff0c;致力于前后端技术栈的知识沉淀和传播。 &#x1f497; &#x1f33b; 近期刚转战 CSDN&#xff0c;会严格把控文章质量&#xff0c;绝不滥竽充数&#xff0c;如需交流&#xff…

用PyQt5打造炫酷界面:深入解析pyqt5-custom-widgets

在PyQt5中&#xff0c;使用自定义小部件可以为应用程序增添更多实用性和时尚感。pyqt5-custom-widgets是一个开源项目&#xff0c;提供了一系列有用且时尚的自定义小部件&#xff0c;如开关按钮、动画按钮等。本文将详细介绍pyqt5-custom-widgets的安装和使用方法。 安装 可以…

权限维持Linux---监控功能Strace后门命令自定义Alias后门

免责声明:本文仅做技术交流与学习... 目录 监控功能Strace后门 1、记录 sshd 明文 监控 筛选查看 2、记录sshd私钥 命令自定义Alias后门 1、简单粗鲁实现反弹&#xff1a; 靶机替换命令 攻击机监听上线 2.升级(让命令正常) 将反弹命令进行base64编码 替换alias命令 …

【Linux】--help,man page , info page

我们知道Linux有很多的命令&#xff0c;那LInux要不要背命令&#xff1f; 答案是背最常用的那些就行了 那有的时候我们想查询一些命令的详细用法该怎么办呢&#xff1f; 这里我给出3种方法 1.--help --help的使用方法很简单啊 要查询的命令 --help 我们看个例子 这里我只…

java版企业工程管理系统源码:全方位的项目管理解决方案

工程管理系统是一款专注于建设工程项目全生命周期管理的软件。它覆盖了项目从策划、设计、施工到竣工的每一个阶段&#xff0c;提供全方位的管理功能。系统采用模块化设计&#xff0c;包括系统管理、系统设置、项目管理、合同管理、预警管理、竣工管理、质量管理、统计报表和工…

6月30日功能测试Day10

3.4.4拼团购测试点 功能位置&#xff1a;营销-----拼团购 后台优惠促销列表管理可以添加拼团&#xff0c;查看拼团活动&#xff0c;启动活动&#xff0c;编辑活动&#xff0c;删除活动。 可以查看拼团活动中已下单的订单以状态 需求分析 功能和添加拼团 商品拼团活动页 3…

python使用pywebview集成vue3和element-plus开发桌面系统框架

随着web技术越来越成熟&#xff0c;就连QQ的windows客户端都用web技术来开发&#xff0c;所以在未来&#xff0c;web技术来开发windows桌面软件也会越来越多&#xff0c;所以在此发展驱动之下&#xff0c;将最近流程的python与web技术相结合&#xff0c;使用vue3和element-plus…

图像增强 目标检测 仿射变换 图像处理 扭曲图像

1.背景 在目标检测中&#xff0c;需要进行图像增强。这里的代码模拟了旋转、扭曲图像的功能&#xff0c;并且在扭曲的时候&#xff0c;能够同时把标注的结果也进行扭曲。 这里忽略了读取xml的过程&#xff0c;假设图像IMG存在对应的标注框&#xff0c;且坐标为左上、右下两个…

[C++初阶]vector的初步理解

一、标准库中的vector类 1.vector的介绍 1. vector是表示可变大小数组的序列容器 &#xff0c; 和数组一样&#xff0c;vector可采用的连续存储空间来存储元素。也就是意味着可以采用下标对vector的元素进行访问&#xff0c;和数组一样高效。但是又不像数组&#xff0c;它的大…

Java学习高级一

修饰符 static 类变量的应用场景 成员方法的分类 成员变量的执行原理 成员方法的执行原理 Java之 main 方法 类方法的常见应用场景 代码块 设计模式 单例设计模式 饿汉式单例设计模式 懒汉式单例设计模式 继承 权限修饰符

小红书 达芬奇:生活问答 AI 机器人

小红书去年 9 月开始内测的生活问答 AI 机器人&#xff1a;达芬奇&#xff0c;现在可以在小红书 APP 上用了 得益于小红书平台的特性&#xff0c;该助手擅长吃、住、宠、喝、学等等各类生活知识&#xff0c;目前还在搞活动&#xff0c;写评测笔记最高得 666 元

为什么不能在foreach中删除元素

文章目录 快速失败机制&#xff08;fail-fast&#xff09;for-each删除元素为什么报错原因分析逻辑分析 如何正确的删除元素remove 后 breakfor 循环使用 Iterator 总结 快速失败机制&#xff08;fail-fast&#xff09; In systems design, a fail-fast system is one which i…