在学习VIT之前,建议先把 Transformer 搞明白了:【transformer】入门与理解
做了那些改进?
看图就比较明白了,VIT只用了Encoder的部分,把每一个图片裁剪成若干子图,然后把一个子图flatten一下,当成nlp中的一个token处理。
值得注意的是,在首个 token中嵌入了一个 class_token,维度为(1,embed_dim=768),这个class_token在预测的时候比较有意思,见下图:
注意上图中有些细节遗漏,全流程应该是:先把输入进行 patch_embedding 变成 visual tokens,然后和 class_token 合并,最后 position_embedding。
另外需要注意的是,class_token 是一个可学习的参数,并不是每次输入时都需要输入的类别数值。
self.class_token = nn.Parameter(torch.ones(1, 1, embed_dim) * 0.98) #(1,1,768)
代码
其实有了 Transformer 的基础后,直接看代码就知道VIT是怎么做的了。
import copy
import torch
import torch.nn as nn# 所有基于nn.Module结构的模版,可以删掉
class Identity(nn.Module):def __init__(self):super().__init__()def forward(self, x):return xclass Mlp(nn.Module):def __init__(self, embed_dim, mlp_ratio, dropout=0.):super().__init__()self.fc1 = nn.Linear(embed_dim, int(embed_dim * mlp_ratio)) # 中间层扩增self.fc2 = nn.Linear(int(embed_dim * mlp_ratio), embed_dim)self.act = nn.GELU()self.dropout = nn.Dropout(dropout)def forward(self, x):# TODOx = self.fc1(x)x = self.act(x)x = self.dropout(x)x = self.fc2(x)x = self.dropout(x)return xclass PatchEmbedding(nn.Module):def __init__(self, image_size=224, patch_size=16, in_channels=3, embed_dim=768, dropout=0.):super().__init__()n_patches = (image_size // patch_size) * (image_size // patch_size) # 196 个 patchself.patch_embedding = nn.Conv2d(in_channels=in_channels, # embedding 操作后变成 torch.Size([10, 768, 14, 14])out_channels=embed_dim,kernel_size=patch_size,stride=patch_size)self.dropout = nn.Dropout(dropout)# TODO: add class tokenself.class_token = nn.Parameter(torch.ones(1, 1, embed_dim) * 0.98) #(1,1,768)# TODO: add position embeddingself.position_embedding = nn.Parameter(torch.ones(1, n_patches+1, embed_dim) * 0.98) #(1,196+1,768)def forward(self, x): # 先把 x patch_embedding,然后和 class_token 合并,最后 position_embedding# [n, c, h, w]cls_tokens = self.class_token.expand([x.shape[0], -1, -1]) #(10,1,768) 根据batch扩增 class_tokenx = self.patch_embedding(x) # [n, embed_dim, h', w']x = x.flatten(2) # torch.Size([10, 768, 196])x = x.permute([0, 2, 1]) # torch.Size([10, 196, 768])x = torch.concat([cls_tokens, x], axis=1) # (10,196+1,768)x = x + self.position_embeddingreturn x # torch.Size([10, 197, 768])class Attention(nn.Module):"""multi-head self attention"""def __init__(self, embed_dim, num_heads, qkv_bias=True, dropout=0., attention_dropout=0.):super().__init__()self.num_heads = num_headsself.head_dim = int(embed_dim / num_heads) # 768/4=192self.all_head_dim = self.head_dim * num_headsself.scales = self.head_dim ** -0.5self.qkv = nn.Linear(embed_dim,self.all_head_dim * 3) # [768, 768*3]self.proj = nn.Linear(embed_dim, embed_dim)self.dropout = nn.Dropout(dropout)self.attention_dropout = nn.Dropout(attention_dropout)self.softmax = nn.Softmax()def transpose_multihead(self, x):# x: [N, num_patches 197, all_head_dim 768] -> [N, n_heads, num_patches, head_dim]new_shape = [x.shape[:-1][0], x.shape[:-1][1], self.num_heads, self.head_dim] # [10, 197, 4, 192]x = x.reshape(new_shape) x = x.permute([0, 2, 1, 3]) # [10, 4, 197, 192]return xdef forward(self, x): # Attention 前后输入输出维度不变,都是 [10, 197, 768]B, N, _ = x.shape # torch.Size([10, 197, 768])qkv = self.qkv(x).chunk(3, axis=-1) # 含有三个元素的列表,每一个元素大小 [10, 197, 768]q, k, v = map(self.transpose_multihead, qkv) # [10, 4, 197, 192]attn = torch.matmul(q, k.transpose(2,3)) # [10, 4, 197, 197]attn = attn * self.scalesattn = self.softmax(attn)attn = self.attention_dropout(attn)out = torch.matmul(attn, v) # [10, 4, 197, 192]out = out.permute([0, 2, 1, 3]) # [10, 197, 4, 192]out = out.reshape([B, N, -1]) # [10, 197, 768]out = self.proj(out) # [10, 197, 768]out = self.dropout(out)return outclass EncoderModule(nn.Module):def __init__(self, embed_dim=768, num_heads=4, qkv_bias=True, mlp_ratio=4.0, dropout=0., attention_dropout=0.):super().__init__()self.attn_norm = nn.LayerNorm(embed_dim)self.attn = Attention(embed_dim, num_heads)self.mlp_norm = nn.LayerNorm(embed_dim)self.mlp = Mlp(embed_dim, mlp_ratio)def forward(self, x):h = x # residualx = self.attn_norm(x)x = self.attn(x)x = x + hh = x # residualx = self.mlp_norm(x)x = self.mlp(x)x = x + hreturn xclass Encoder(nn.Module):def __init__(self, embed_dim, depth):super().__init__()Module_list = []for i in range(depth):encoder_Module = EncoderModule()Module_list.append(encoder_Module)self.Modules = nn.ModuleList(Module_list)self.norm = nn.LayerNorm(embed_dim)def forward(self, x):for Module in self.Modules:x = Module(x)x = self.norm(x)return xclass VisualTransformer(nn.Module):def __init__(self,image_size=224,patch_size=16,in_channels=3,num_classes=1000,embed_dim=768,depth=3,num_heads=8,):super().__init__()self.patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)self.encoder = Encoder(embed_dim, depth)self.classifier = nn.Linear(embed_dim, num_classes)def forward(self, x):# x:[N, C, H, W]x = self.patch_embedding(x) # torch.Size([10, 197, 768])x = self.encoder(x) # torch.Size([10, 197, 768])x = self.classifier(x[:, 0]) # 注意这里的处理很奇妙哦,参考 x = torch.concat([cls_tokens, x], axis=1) # (10,196+1,768)return xvit = VisualTransformer()
print(vit)input_data = torch.randn([10,3,224,224]) # 每批次输入10张图片
print(vit(input_data).shape) # torch.Size([10, 1000])