VIT用于图像分类 学习笔记(附代码)

论文地址:https://arxiv.org/abs/2010.11929

代码地址:https://github.com/bubbliiiing/classification-pytorch

1.是什么?

Vision Transformer(VIT)是一种基于Transformer架构的图像分类模型。它将图像分割成一系列的图像块,并将每个图像块作为输入序列传递给Transformer模型。VIT通过自注意力机制来捕捉图像中的全局上下文信息,并使用多层感知机(MLP)来进行特征提取和分类。

VIT的核心思想是将图像转换为序列数据,这使得模型能够利用Transformer的强大表达能力来处理图像。通过将图像分割成图像块,并将它们展平为序列,VIT能够在不依赖传统卷积神经网络的情况下实现图像分类任务。

2.为什么?

从2020年,transformer开始在CV领域大放异彩:图像分类(ViT, DeiT),目标检测(DETR,Deformable DETR),语义分割(SETR,MedT),图像生成(GANsformer)等。而从深度学习暴发以来,CNN一直是CV领域的主流模型,而且取得了很好的效果,相比之下transformer却独霸NLP领域,transformer在CV领域的探索正是研究界想把transformer在NLP领域的成功借鉴到CV领域。对于图像问题,卷积具有天然的先天优势(inductive bias):平移等价性(translation equivariance)和局部性(locality)。而transformer虽然不并具备这些优势,但是transformer的核心self-attention的优势不像卷积那样有固定且有限的感受野,self-attention操作可以获得long-range信息(相比之下CNN要通过不断堆积Conv layers来获取更大的感受野),但训练的难度就比CNN要稍大一些。

ViT(vision transformer)是Google在2020年提出的直接将transformer应用在图像分类的模型,后面很多的工作都是基于ViT进行改进的。这篇论文也是受到其启发,尝试将Transformer应用到CV领域通过这篇文章的实验,给出的最佳模型在ImageNet1K上能够达到88.55%的准确率(先在Google自家的JFT数据集上进行了预训练),说明Transformer在CV领域确实是有效的,而且效果还挺惊人。

3.怎么样?

3.1网络结构

与寻常的分类网络类似,整个Vision Transformer可以分为两部分,一部分是特征提取部分,另一部分是分类部分。

在特征提取部分,VIT所做的工作是特征提取。特征提取部分在图片中的对应区域是Patch+Position Embedding和Transformer Encoder。Patch+Position Embedding的作用主要是对输入进来的图片进行分块处理,每隔一定的区域大小划分图片块。然后将划分后的图片块组合成序列。在获得序列信息后,传入Transformer Encoder进行特征提取,这是Transformer特有的Multi-head Self-attention结构,通过自注意力机制,关注每个图片块的重要程度。

在分类部分,VIT所做的工作是利用提取到的特征进行分类。在进行特征提取的时候,我们会在图片序列中添加上Cls Token,该Token会作为一个单位的序列信息一起进行特征提取,提取的过程中,该Cls Token会与其它的特征进行特征交互,融合其它图片序列的特征。最终,我们利用Multi-head Self-attention结构提取特征后的Cls Token进行全连接分类。

3.2特征提取部分介绍

3.2.1Patch

Patch的作用主要是对输入进来的图片进行分块处理,每隔一定的区域大小划分图片块。然后将划分后的图片块组合成序列

该部分首先对输入进来的图片进行分块处理,处理方式其实很简单,使用的是现成的卷积。也就是说,不是把图片分割,是做了一次简单的卷积,可以理解为初步特征提取,或者说是映射。

由于卷积使用的是滑动窗口的思想,我们只需要设定特定的步长,就可以输入进来的图片进行分块处理了。在VIT中,我们常设置这个卷积的卷积核大小为16x16,步长也为16x16,此时卷积就会每隔16个像素点进行一次特征提取,由于卷积核大小为16x16,两个图片区域的特征提取过程就不会有重叠。当我们输入的图片是224, 224, 3的时候,我们可以获得一个14, 14, 768的特征层。

在代码实现中,直接通过一个卷积层来实现。 以ViT-B/16为例,直接使用一个卷积核大小为16x16,步距为16,卷积核个数为768的卷积来实现。通过卷积[224, 224, 3] -> [14, 14, 768],然后把H以及W两个维度展平即可[14, 14, 768] -> [196, 768],此时正好变成了一个二维矩阵,正是Transformer想要的。
 

3.2.2Position Embedding

Position Embedding的作用主要是对组合序列加上[class]token以及Position Embedding

在原论文中,作者说参考BERT,在刚刚得到的一堆tokens中插入一个专门用于分类的[class]token,这个[class]token是一个可训练的参数,数据格式和其他token一样都是一个向量,以ViT-B/16为例,就是一个长度为768的向量,与之前从图片中生成的tokens拼接在一起,Cat([1, 768], [196, 768]) -> [197, 768]。然后关于Position Embedding就是之前Transformer中讲到的Positional Encoding,这里的Position Embedding采用的是一个可训练的参数(1D Pos. Emb.),是直接叠加在tokens上的(add),所以shape要一样。以ViT-B/16为例,刚刚拼接[class]token后shape是[197, 768],那么这里的Position Embedding的shape也是[197, 768]。

对于Position Embedding作者也有做一系列对比试验,在源码中默认使用的是1D Pos. Emb.,对比不使用Position Embedding准确率提升了大概3个点,和2D Pos. Emb.比起来没太大差别。

3.2.3Transformer Encoder

Transformer Encoder其实就是重复堆叠Encoder Block L次,下图是太阳花的小绿豆绘制的Encoder Block,主要由以下几部分组成:

  1. Layer Norm,这种Normalization方法主要是针对NLP领域提出的,这里是对每个token进行Norm处理,之前也有讲过Layer Norm不懂的可以参考链接
  2. Multi-Head Attention,看懂Self-attention结构,其实看懂下面这个动图就可以了,动图中存在一个序列的三个单位输入,每一个序列单位的输入都可以通过三个处理(比如全连接)获得Query、Key、Value,Query是查询向量、Key是键向量、Value值向量。
  1. Dropout/DropPath,在原论文的代码中是直接使用的Dropout层,在但rwightman实现的代码中使用的是DropPath(stochastic depth),可能后者会更好一点。
  2. MLP Block,如图右侧所示,就是全连接+GELU激活函数+Dropout组成也非常简单,需要注意的是第一个全连接层会把输入节点个数翻4倍[197, 768] -> [197, 3072],第二个全连接层会还原回原节点个数[197, 3072] -> [197, 768]
     

3.3 分类部分

上面通过Transformer Encoder后输出的shape和输入的shape是保持不变的,以ViT-B/16为例,输入的是[197, 768]输出的还是[197, 768]。注意,在Transformer Encoder后其实还有一个Layer Norm没有画出来,后面有我自己画的ViT的模型可以看到详细结构。这里我们只是需要分类的信息,所以我们只需要提取出[class]token生成的对应结果就行,即[197, 768]中抽取出[class]token对应的[1, 768]。接着我们通过MLP Head得到我们最终的分类结果。MLP Head原论文中说在训练ImageNet21K时是由Linear+tanh激活函数+Linear组成。但是迁移到ImageNet1K上或者你自己的数据上时,只用一个Linear即可。

 3.4别人画的网络结构图

 

3.5代码实现

Patch+Position Embedding

class PatchEmbed(nn.Module):def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_features=768, norm_layer=None, flatten=True):super().__init__()self.num_patches    = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)self.flatten        = flattenself.proj = nn.Conv2d(in_chans, num_features, kernel_size=patch_size, stride=patch_size)self.norm = norm_layer(num_features) if norm_layer else nn.Identity()def forward(self, x):x = self.proj(x)if self.flatten:x = x.flatten(2).transpose(1, 2)  # BCHW -> BNCx = self.norm(x)return xclass VisionTransformer(nn.Module):def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU):super().__init__()#-----------------------------------------------##   224, 224, 3 -> 196, 768#-----------------------------------------------#self.patch_embed    = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)num_patches         = (224 // patch_size) * (224 // patch_size)self.num_features   = num_featuresself.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]#--------------------------------------------------------------------------------------------------------------------##   classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。##   在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。#   此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。#   在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。#--------------------------------------------------------------------------------------------------------------------##   196, 768 -> 197, 768self.cls_token      = nn.Parameter(torch.zeros(1, 1, num_features))#--------------------------------------------------------------------------------------------------------------------##   为网络提取到的特征添加上位置信息。#   以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768#   此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。#--------------------------------------------------------------------------------------------------------------------##   197, 768 -> 197, 768self.pos_embed      = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))def forward_features(self, x):x = self.patch_embed(x)cls_token = self.cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_token, x), dim=1)cls_token_pe = self.pos_embed[:, 0:1, :]img_token_pe = self.pos_embed[:, 1: , :]img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)x = self.pos_drop(x + pos_embed)

TransformerBlock 

class Mlp(nn.Module):""" MLP as used in Vision Transformer, MLP-Mixer and related networks"""def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):super().__init__()out_features    = out_features or in_featureshidden_features = hidden_features or in_featuresdrop_probs      = (drop, drop)self.fc1    = nn.Linear(in_features, hidden_features)self.act    = act_layer()self.drop1  = nn.Dropout(drop_probs[0])self.fc2    = nn.Linear(hidden_features, out_features)self.drop2  = nn.Dropout(drop_probs[1])def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop1(x)x = self.fc2(x)x = self.drop2(x)return xclass Block(nn.Module):def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm):super().__init__()self.norm1      = norm_layer(dim)self.attn       = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)self.norm2      = norm_layer(dim)self.mlp        = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)self.drop_path  = DropPath(drop_path) if drop_path > 0. else nn.Identity()def forward(self, x):x = x + self.drop_path(self.attn(self.norm1(x)))x = x + self.drop_path(self.mlp(self.norm2(x)))return x

VIT

整个VIT模型由一个Patch+Position Embedding加上多个TransformerBlock组成。典型的TransforerBlock的数量为12个。 

class VisionTransformer(nn.Module):def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU):super().__init__()#-----------------------------------------------##   224, 224, 3 -> 196, 768#-----------------------------------------------#self.patch_embed    = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)num_patches         = (224 // patch_size) * (224 // patch_size)self.num_features   = num_featuresself.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]#--------------------------------------------------------------------------------------------------------------------##   classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。##   在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。#   此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。#   在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。#--------------------------------------------------------------------------------------------------------------------##   196, 768 -> 197, 768self.cls_token      = nn.Parameter(torch.zeros(1, 1, num_features))#--------------------------------------------------------------------------------------------------------------------##   为网络提取到的特征添加上位置信息。#   以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768#   此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。#--------------------------------------------------------------------------------------------------------------------##   197, 768 -> 197, 768self.pos_embed      = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))self.pos_drop       = nn.Dropout(p=drop_rate)#-----------------------------------------------##   197, 768 -> 197, 768  12次#-----------------------------------------------#dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]self.blocks = nn.Sequential(*[Block(dim         = num_features, num_heads   = num_heads, mlp_ratio   = mlp_ratio, qkv_bias    = qkv_bias, drop        = drop_rate,attn_drop   = attn_drop_rate, drop_path   = dpr[i], norm_layer  = norm_layer, act_layer   = act_layer)for i in range(depth)])self.norm = norm_layer(num_features)self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()def forward_features(self, x):x = self.patch_embed(x)cls_token = self.cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_token, x), dim=1)cls_token_pe = self.pos_embed[:, 0:1, :]img_token_pe = self.pos_embed[:, 1: , :]img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)x = self.pos_drop(x + pos_embed)x = self.blocks(x)x = self.norm(x)return x[:, 0]def forward(self, x):x = self.forward_features(x)x = self.head(x)return xdef freeze_backbone(self):backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]for module in backbone:try:for param in module.parameters():param.requires_grad = Falseexcept:module.requires_grad = Falsedef Unfreeze_backbone(self):backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]for module in backbone:try:for param in module.parameters():param.requires_grad = Trueexcept:module.requires_grad = True

 Vision Transforme的构建代码

import math
from collections import OrderedDict
from functools import partialimport numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F#--------------------------------------#
#   Gelu激活函数的实现
#   利用近似的数学公式
#--------------------------------------#
class GELU(nn.Module):def __init__(self):super(GELU, self).__init__()def forward(self, x):return 0.5 * x * (1 + F.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x,3))))def drop_path(x, drop_prob: float = 0., training: bool = False):if drop_prob == 0. or not training:return xkeep_prob       = 1 - drop_probshape           = (x.shape[0],) + (1,) * (x.ndim - 1)random_tensor   = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)random_tensor.floor_() output          = x.div(keep_prob) * random_tensorreturn outputclass DropPath(nn.Module):def __init__(self, drop_prob=None):super(DropPath, self).__init__()self.drop_prob = drop_probdef forward(self, x):return drop_path(x, self.drop_prob, self.training)class PatchEmbed(nn.Module):def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_features=768, norm_layer=None, flatten=True):super().__init__()self.num_patches    = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)self.flatten        = flattenself.proj = nn.Conv2d(in_chans, num_features, kernel_size=patch_size, stride=patch_size)self.norm = norm_layer(num_features) if norm_layer else nn.Identity()def forward(self, x):x = self.proj(x)if self.flatten:x = x.flatten(2).transpose(1, 2)  # BCHW -> BNCx = self.norm(x)return x#--------------------------------------------------------------------------------------------------------------------#
#   Attention机制
#   将输入的特征qkv特征进行划分,首先生成query, key, value。query是查询向量、key是键向量、v是值向量。
#   然后利用 查询向量query 叉乘 转置后的键向量key,这一步可以通俗的理解为,利用查询向量去查询序列的特征,获得序列每个部分的重要程度score。
#   然后利用 score 叉乘 value,这一步可以通俗的理解为,将序列每个部分的重要程度重新施加到序列的值上去。
#--------------------------------------------------------------------------------------------------------------------#
class Attention(nn.Module):def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):super().__init__()self.num_heads  = num_headsself.scale      = (dim // num_heads) ** -0.5self.qkv        = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop  = nn.Dropout(attn_drop)self.proj       = nn.Linear(dim, dim)self.proj_drop  = nn.Dropout(proj_drop)def forward(self, x):B, N, C     = x.shapeqkv         = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v     = qkv[0], qkv[1], qkv[2]attn = (q @ k.transpose(-2, -1)) * self.scaleattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B, N, C)x = self.proj(x)x = self.proj_drop(x)return xclass Mlp(nn.Module):""" MLP as used in Vision Transformer, MLP-Mixer and related networks"""def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):super().__init__()out_features    = out_features or in_featureshidden_features = hidden_features or in_featuresdrop_probs      = (drop, drop)self.fc1    = nn.Linear(in_features, hidden_features)self.act    = act_layer()self.drop1  = nn.Dropout(drop_probs[0])self.fc2    = nn.Linear(hidden_features, out_features)self.drop2  = nn.Dropout(drop_probs[1])def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop1(x)x = self.fc2(x)x = self.drop2(x)return xclass Block(nn.Module):def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm):super().__init__()self.norm1      = norm_layer(dim)self.attn       = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)self.norm2      = norm_layer(dim)self.mlp        = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)self.drop_path  = DropPath(drop_path) if drop_path > 0. else nn.Identity()def forward(self, x):x = x + self.drop_path(self.attn(self.norm1(x)))x = x + self.drop_path(self.mlp(self.norm2(x)))return xclass VisionTransformer(nn.Module):def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU):super().__init__()#-----------------------------------------------##   224, 224, 3 -> 196, 768#-----------------------------------------------#self.patch_embed    = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)num_patches         = (224 // patch_size) * (224 // patch_size)self.num_features   = num_featuresself.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]#--------------------------------------------------------------------------------------------------------------------##   classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。##   在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。#   此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。#   在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。#--------------------------------------------------------------------------------------------------------------------##   196, 768 -> 197, 768self.cls_token      = nn.Parameter(torch.zeros(1, 1, num_features))#--------------------------------------------------------------------------------------------------------------------##   为网络提取到的特征添加上位置信息。#   以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768#   此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。#--------------------------------------------------------------------------------------------------------------------##   197, 768 -> 197, 768self.pos_embed      = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))self.pos_drop       = nn.Dropout(p=drop_rate)#-----------------------------------------------##   197, 768 -> 197, 768  12次#-----------------------------------------------#dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]self.blocks = nn.Sequential(*[Block(dim         = num_features, num_heads   = num_heads, mlp_ratio   = mlp_ratio, qkv_bias    = qkv_bias, drop        = drop_rate,attn_drop   = attn_drop_rate, drop_path   = dpr[i], norm_layer  = norm_layer, act_layer   = act_layer)for i in range(depth)])self.norm = norm_layer(num_features)self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()def forward_features(self, x):x = self.patch_embed(x)cls_token = self.cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_token, x), dim=1)cls_token_pe = self.pos_embed[:, 0:1, :]img_token_pe = self.pos_embed[:, 1: , :]img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)x = self.pos_drop(x + pos_embed)x = self.blocks(x)x = self.norm(x)return x[:, 0]def forward(self, x):x = self.forward_features(x)x = self.head(x)return xdef freeze_backbone(self):backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]for module in backbone:try:for param in module.parameters():param.requires_grad = Falseexcept:module.requires_grad = Falsedef Unfreeze_backbone(self):backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]for module in backbone:try:for param in module.parameters():param.requires_grad = Trueexcept:module.requires_grad = Truedef vit(input_shape=[224, 224], pretrained=False, num_classes=1000):model = VisionTransformer(input_shape)if pretrained:model.load_state_dict(torch.load("model_data/vit-patch_16.pth"))if num_classes!=1000:model.head = nn.Linear(model.num_features, num_classes)return model

参考:Vision Transformer详解

神经网络学习小记录67——Pytorch版 Vision Transformer(VIT)模型的复现详解

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/599978.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python-实现高并发的常见方式

高并发能帮支持快速处理大量执行任务,提高代码的执行效率,以下是在日常开发中常见的高并发方式 多线程(Threading) Python 的 threading 模块可以非常容易地创建和管理线程。线程共享内存空间,这意味着它们可以更高效的…

JS tostring()和join()方法

在JavaScript中,toString()和join()都是用于处理数组的方法。它们的功能和用法如下: 1.toString()方法: toString()方法将数组转换为一个由每个元素字符串形式拼接而成的字符串。该方法不会改变原始数组,而是返回一个新的字符串。…

如何实现安卓端与苹果端互通

在移动应用开发中,如何实现安卓端和苹果端的互通是一个重要的问题。二者缺少一个都会有损失,那如何实现安卓端跟苹果端互通,下面简单的介绍几点方法来帮助你再不同的平台上实现数据交互和功能互通。 基于Web技术 使用Web技术是一种常见并且…

构建可伸缩和高性能系统的设计原则和最佳实践

在当今快节奏的软件开发环境中,构建可伸缩和高性能的系统对于满足用户需求至关重要。采用设计原则和最佳实践是确保系统具备良好性能和可扩展性的关键。本文将介绍一些构建可伸缩和高性能系统的设计原则和最佳实践。 1. 分布式架构 采用分布式系统架构&#xff0c…

数字孪生与大数据和分析技术的结合

数字孪生与大数据和分析技术的结合可以为系统提供更深入的见解、支持实时决策,并优化模型的性能。以下是数字孪生在大数据和分析技术中的一些应用,希望对大家有所帮助。北京木奇移动技术有限公司,专业的软件外包开发公司,欢迎交流…

2024阿里云服务器配置推荐方案

阿里云服务器配置怎么选择合适?CPU内存、公网带宽和ECS实例规格怎么选择合适?阿里云服务器网aliyunfuwuqi.com建议根据实际使用场景选择,例如企业网站后台、自建数据库、企业OA、ERP等办公系统、线下IDC直接映射、高性能计算和大游戏并发&…

美年大健康黄伟:从选型到迁移,一个月升级核心数据库

核心生产系统的数据库,从接到替换需求到完成分布式升级,需要多久?一个月,这是美年大健康的回答。一个月集中调配各种资源,美年大健康完成了应用程序基本零改造的平滑迁移,新数据库在成本更低的前提下&#…

迪拜公司注册优势 迪拜公司注册条件 迪拜公司注册流程

迪拜作为阿 拉伯联合酋长国(United Arab Emirates,简称UAE)的一个城市,拥有独特的优势和吸引力。以下是迪拜公司注册的优势、条件和流程: 迪拜公司注册优势 1、无外汇管制:在迪拜注册的公司可以自 由转移资…

2023-RunwayML-Gen-2 AI视频生成功能发展历程

RunwayML是一个人工智能工具,它为设计师、艺术家和创意人士提供了一种简单的方式来探索和应用机器学习技术。 RunwayML官方网页地址:Runway - Advancing creativity with artificial intelligence. RunwayML专区RunwayML-喜好儿aigcRunwayML 是一种先进…

P1192 台阶问题————C++

目录 台阶问题题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1 提示 解题思路Code运行结果 台阶问题 题目描述 有 N N N 级台阶,你一开始在底部,每次可以向上迈 1 ∼ K 1\sim K 1∼K 级台阶,问到达第 N N N 级台阶有多少种不同方…

itextpdf中文不显示问题

原因1.没有指定中文字体 方法一&#xff1a;使用itext-asian <dependency><groupId>com.itextpdf</groupId><artifactId>itext-asian</artifactId><version>5.2.0</version> </dependency> BaseFont baseFont BaseFont.crea…

电商平台低价品牌要如何处理

低价会影响品牌渠道的长期发展&#xff0c;同时还会衍生很多问题&#xff0c;如为了追求低价而导致的店铺窜货、商品假货等&#xff0c;对于渠道来说&#xff0c;都是要及时解决的问题&#xff0c;否则渠道乱了&#xff0c;最终腐蚀的是品牌价值&#xff0c;同时还会影响经销商…

【LeetCode-剑指offer】--1.两数相除

1.两数相除 方法&#xff1a;使用减法实现除法 用“被减数”能减去几次“减数”来衡量最后的结果&#xff0c;这时候我们想到求x的幂次的快速解法&#xff0c;将x成倍成倍的求幂&#xff0c;这里将减数成倍成倍的增大&#xff0c;次数对应也是成倍成倍的增大&#xff0c;例如&…

力扣labuladong——一刷day86

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、力扣496. 下一个更大元素 I二、力扣739. 每日温度 前言 单调栈实际上就是栈&#xff0c;只是利用了一些巧妙的逻辑&#xff0c;使得每次新元素入栈后&#…

虚幻UE 增强输入-第三人称模板增强输入分析与扩展

本篇是增强输入模块&#xff0c;作为UE5.0新增加的模块。 其展现出来的功能异常地强大&#xff01; 让我们先来学习学习一下第三人称模板里面的增强输入吧&#xff01; 文章目录 前言一、增强输入四大概念二、使用步骤1、打开增强输入模块2、添加IA输入动作2、添加IMC输入映射内…

【亚马逊云科技】自家的AI助手 - Amazon Q

写在前面&#xff1a;博主是一只经过实战开发历练后投身培训事业的“小山猪”&#xff0c;昵称取自动画片《狮子王》中的“彭彭”&#xff0c;总是以乐观、积极的心态对待周边的事物。本人的技术路线从Java全栈工程师一路奔向大数据开发、数据挖掘领域&#xff0c;如今终有小成…

C++面对对象编程

面对对象编程入门 1.类与对象2.公有和私有概念3.类的成员函数4.类的实例化5.构造函数6.析构函数7.常成员函数8.静态属性和静态方法总结 1.类与对象 在python中&#xff0c;我们提到过类这个概念。所谓类&#xff0c;就是一个包含着元素和函数的数据类型&#xff0c;在C中&…

C语言预备知识_hello world_数据类型_变量(入门到入神)

为什么要学习 C语言 学习 C语言是非常接近底层的一种编程语言C语言是学习其它编程语言第基础&#xff0c;基础不牢&#xff0c;地动山摇考研会用到 C语言 C语言预备知识 CPU 内存条 硬盘 显卡 主板 显示器之间的关系 当你在电脑上观看一部存储在硬盘上的电影时&#xff0c;各…

如何在 ChatGPT 上使用 Wolfram 插件回答数学问题

这里写自定义目录标题 写在最前面Wolfram是什么&#xff1f;ChatGPT 如何与 Wolfram 相结合&#xff0c;为什么有效&#xff1f;如何在 ChatGPT 上安装 Wolfram 插件&#xff1f; 写在最前面 参考&#xff1a;https://clickthis.blog/zh-CN/how-to-answer-math-questions-usin…

大一C语言查缺补漏 12.28

在C语言中&#xff0c;5%&#xff08;-3&#xff09;答案是什么 在C语言中&#xff0c;5 % -3的结果是2。因为在C语言中&#xff0c;取余运算&#xff08;%&#xff09;的结果的符号与被除数相同。所以&#xff0c;5 % -3的计算结果为2。 在C语言种引用数组元素时&#xff0c;其…