1. Transformer结构
1.1 编码器和解码器
翻译这个过程需要中间体。也就是说,编码,解码之间需要一个中介,英文先编码成一个意思,再解码成中文。
那么查字典这个过程就是编码和解码的体现。首先我们的大脑会把它编码,编码这个句子的意思,然后通过字典映射解码。但是这样的过程太过于繁琐,如果让机器做,超长文本就对应着超长的数据量,也不利于机器学习的上下文理解。那么就有了Attention注意力机制。
1.2 Attention:注意力机制
Attention机制的核心思想是,要想翻译一个句子并不需要完全编码。像我们人类一样,仅凭借几个词就可以猜出整句话的大概意义,即使我们不懂日语,也可以根据一些汉字推出来大概的意思,这是准确度低的情况;而“中译中”这种情况,准确度当然就更高一些。
Attention注意力机制:
Attention示意图本质上是加权平均。如果我非常注意某个地方,我想要多看,那就分配更高的权重。
计算权重是使用相似度计算,
Attention机制的优点
Attention的优点是能够实现并行计算和全局视野。
并行计算的可能性是因为它不像RNN一样,依赖时序数据。它只是加权计算,但并不需要像时序数据那样,依赖像是队列一样的进出顺序。
全局视野是因为在加权计算的时候,这个计算就是涉及了整体的,它一看就能看到全部。
1.3 Self Attention
对于Self Attention来说,它的输入是一个序列,序列的获得依靠的是vector。我们把一个词转换为序列模块,需要用到vector向量去指向。而vector的指向,是有空间性的,比如说两个意思很相近或者同义的词汇,它们在空间中的距离就会比较小。相反,意思差的多,距离当然就远。
这样也可以理解为vector是和词语的意义有关系的。
注意力分配的多少取决于公式:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dkQKT)V
其中,Q代表Queries(输入的信息),K代表Keys(内容信息),V代表Values(信息本身,只表达了信息的特征)。
Q,K,V的获得
本质上,是input的线性变换。计算使用的是矩阵乘法,实现方法是nn.Linear。
用点积表示相似度的方法是因为cos角投影长度可以很直观地理解两个量的相似度。
在1.3的式子里除以 d k \sqrt{dk} dk这一步看起来很多余,但是它是为了避免较大的数值。较大的数值会导致softmax之后值更极端,softmax的极端值会导致梯度消失。这一步相当于控制了数值范围,让它在可观测范围内。
为什么是 d k \sqrt{dk} dk?假设q,k是均值为0,方差为1的标准正态分布的独立随机变量,那么它们的点积的均值和方差分别为0以及dk。
之后做逐元素相乘(enterwise)。
我们需要将单词意思转换为句中意思,这就涉及一词多义的问题,在逐元素相乘得到的sum之后的z就涉及这个问题。那么我们需要根据句子中其他词来推理当前词的意思。
就比如说Mine,它有两种意思,一种是“我的”,一种是“矿石”。这当然是截然不同的词性和词义。假设我们完全不知道这个多义词,但我们可以通过观察它们在句子中的位置和与上下文的联系来推理这是什么意思。
1.4 Multihead Attention
这一步的核心是复读机()
这一步就是有多个W_q,W_k,W_v,那么上述操作重复多次,将结果用concat串在一起。
这样的复读机机制就是给注意力提供多种可能性。
应用了multihead的Conditional DETR就发现不同的head会将注意力放到物体的不同边上。
1.5 输入端适配
直接把图片切分成patches,flatten操作拉平patches,然后过一个linear projection使patches维度变小,然后编号123456789…输入网络即可。
就是切蛋糕喂给encoder和decoder。
这块儿有个patch 0的原因,有一种说法是从NLP来的:为了保持整体结构,变换尽可能的少。而NLP需要一些token负责输出,需要“终止输出”的功能模块。另一种CV里的说法是整合信息,设置在1-9之外就保持了1-9本身无干扰。Patch 0本质上是dynamic pooling layer。
1.6 位置编码
图像切分重排后失去了位置信息,而transformer的内部运算与空间信息无关。这样一来,就需要把位置信息编码重新传进网络。ViT使用了一个可学习的vector来编码,维度和patch维度一样,所以编码vector和patch vector直接相加组成输入。本质是相加,而相加是concat的一种特例。
1.7 ViT结构的数据流
输入图像是256x256像素大小,然后切开,切成N(8x8=64)个小块,每一块则是256/8=32单位长(宽)度。也就是说,现在每一小块儿是32x32。把切开的每一小块都拉平,RGB值为3,每一块儿的维度就是3x32x32=3072维。但是3072维太高了,所以过一个linear projection把维度变成1024。但是此时每个小块儿的空间位置丢失了。所以需要加上position embedding这个可学习的向量,维度一样也是1024,让他们相加。position embedding放在patch0这里,一起进入transformer。
进入了transformer encoder之后,首先因为多了一个patch 0,Patches的表示向量里数量取N+1,即(b,65,1024)。在这个norm层里patches会被归一化,一直检验维度,保证维度是一样的。
最后到MLP Head手里,就只输入负责整合信息的patch 0,此时它表示为(b,1,1024)。这样就可以做分类任务了。
1.8 训练方法
Transformer 非常吃数据量,需要大量的样本,大规模使用Pre-Train。它先在大数据集(ImageNet)上预训练,然后到小数据集上做 Fine Tuning。
迁移过去之后,需要把原本的MLP Head换掉,换成对应类别数的FC层。处理不同尺寸输入的时候需要对Postional Encoding的结果进行插值。
插值方法:图片切好了之后,编号,但不同的input size和patch size会切出不同数量的patch,position embedding也会变。所以编号的方法需要缩放。
1.9 实验结果
Transformer的性能需要庞大数据量的保证,很吃资源。否则,无法充分的发挥出它的性能。它和ResNet的性能不相上下。
Attention的距离可以等价为Conv的感受野大小。越深的层数,Attention跨越的距离越远。在最底层,也有head能覆盖到很远的距离。这说明它确实在捕捉信息,做信息整合。
模型注意力集中的地方,都和分类的语义高度相关。
2. 代码复现
VIT仓库链接:https://github.com/lucidrains/vit-pytorch
Usage:
import torch
from vit_pytorch import ViT # 抽象出了一个VIT类v = ViT(image_size = 256,# 图片像素大小patch_size = 32, # patch的大小num_classes = 1000, # 分类数量dim = 1024,# 维度depth = 6, # transformer的block数量heads = 16, # 线性变换后输出张量的最后维数,多头注意力层中的头数mlp_dim = 2048, # MLP前馈层维度dropout = 0.1, # 每个训练步骤中被关闭神经元的比例,可以调成0emb_dropout = 0.1 # 嵌入丢失率
)img = torch.randn(1, 3, 256, 256)preds = v(img) # (1, 1000)
2.1 切图重排
输入端适配涉及到切图和reshape。
以下是ViT部分.py代码:
import torch
from torch import nnfrom einops import rearrange, repeat
from einops.layers.torch import Rearrange# helpersdef pair(t):return t if isinstance(t, tuple) else (t, t)# classesclass FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout = 0.):super().__init__()self.net = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout))def forward(self, x):return self.net(x)class Attention(nn.Module):def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):super().__init__()inner_dim = dim_head * headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.norm = nn.LayerNorm(dim)self.attend = nn.Softmax(dim = -1)self.dropout = nn.Dropout(dropout)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x):x = self.norm(x)qkv = self.to_qkv(x).chunk(3, dim = -1)q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scaleattn = self.attend(dots)attn = self.dropout(attn)out = torch.matmul(attn, v)out = rearrange(out, 'b h n d -> b n (h d)')return self.to_out(out)class Transformer(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):super().__init__()self.norm = nn.LayerNorm(dim)self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),FeedForward(dim, mlp_dim, dropout = dropout)]))def forward(self, x):for attn, ff in self.layers:x = attn(x) + xx = ff(x) + xreturn self.norm(x)class ViT(nn.Module):def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):super().__init__()image_height, image_width = pair(image_size)patch_height, patch_width = pair(patch_size)assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'num_patches = (image_height // patch_height) * (image_width // patch_width)patch_dim = channels * patch_height * patch_widthassert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'self.to_patch_embedding = nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),nn.LayerNorm(patch_dim),nn.Linear(patch_dim, dim),nn.LayerNorm(dim),)self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))self.cls_token = nn.Parameter(torch.randn(1, 1, dim))self.dropout = nn.Dropout(emb_dropout)self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)self.pool = poolself.to_latent = nn.Identity()self.mlp_head = nn.Linear(dim, num_classes)
# 解析代码段def forward(self, img):x = self.to_patch_embedding(img)b, n, _ = x.shapecls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)x = torch.cat((cls_tokens, x), dim=1)x += self.pos_embedding[:, :(n + 1)]x = self.dropout(x)x = self.transformer(x)x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]x = self.to_latent(x)return self.mlp_head(x)
transpose函数的作用:重排张量(tensor)或者数组的维度。有一个形状为(batch_size, channels, height, width)的四维张量,代表一批图像数据。那就可以将将channels维度移动到最前面,即形状变为(channels, batch_size, height, width)。这时,你就可以使用transpose操作来实现这一转换。(类似于矩阵行变换)
img = torch.randn(1,3,256,256)b=1
c=3
h=256 = h*p1,h = 8
w =256
self.to_patch.embedding = nn.Sequential(Rearrange(‘b c (h p1)(w p2)-> b(h w)(p1 p2 c)’,p1 = patch_height, p2 = patch_width),# 图片切分重排nn.Linear(patch_dim, dim) # Linear Projection of Flattened Patches)
2.2 构造Patch 0
这一步:
cls_tokens = repeat(self.cls_token, '() n d -> b n d',b = b)x = torch.cat((cls_tokens, x), dim = 1) # concat方法,维度为1,在n的维度上
2.3 positional embedding
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
# 位置编码:它是一个可学习的参数,初始化为随机值。
x += self.pos_embedding[:,:(n+1)]
# 将位置编码加到输入序列上。
2.4 代码示例
首先准备数据集。
from __future__ import print_functionimport glob
from itertools import chain
import os
import random
import zipfileimport matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from linformer import Linformer
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, DataSet
from torchvision import datasets, transforms
from tqdm.notebook import tqdmfrom vit_pytorch import ViT
batch_size = 64
epochs = 20
lr = 3e-5
gamma = 0.7
seed = 42
def seed_everything(seed):random.seed(seed)os.environ['PYTHONHASHSEED'] = str(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Trueseed_everything(seed)device = 'cuda'
os.makedirs('data',exist_ok = True)train_dir = 'data/train'
test_dir = 'data/test'with zipfile.ZipFile('data/train.zip') as train_zip:train_zip.extractall('data')
with zipfile.ZipFile('data/test.zip') as test_zip:test_zip.extractall('data')train_list = glob.glob(os.path.join(train_dir,'*.jpg')) # 查找匹配的jpg文件
test_list = glob.glob(os.path.join(test_dir,'*.jpg'))print(f'Train Data:{len(train_list)}')
print(f'Test Data:{len(test_list)}')labels = [path.split('/')[-1].split('\\')[-1].split[0] for path in train_list]
print(train_list[0]
print(labels[0]))
random_idx = np.random.randint(1,len(train_list),size = 9)
fig, axes = plt.subplots(3,3,figsize = (16,12))for idx,ax in enumerate(axes.ravel()):img = Imag.open(train_list[idx])ax.set_title(labels[idx])ax.imshow(img)
train_list, valid_list = train_test_split(train_list, test_size = 0.2, stratify = labels, random_state = seed)print(f'Train Data:{len(train_list)}')
print(f'Validation Data:{len(valid_list)}')
print(f'Test Data:{len(test_list)}')
train_tranforms = tranforms.Compose([transforms.Resize((224,224)),transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),]
)val_tranforms = tranforms.Compose([transforms.Resize((224,224)),transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),]
)test_tranforms = tranforms.Compose([transforms.Resize((224,224)),transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),]
)
class CatsDogsDataset(Dataset):def __init__(self, file_list, transform = None):self.file_list = file_listself.transform = transformdef __len__(self):self.filelength = len(self.file_list)return self.filelengthdef __getitem__(self,idx):img_path = self.file_list[idx]img = Image.open(img_path)img_transformed = self.transform(img)label = img_path.split('/')[-1].split("\\")[-1].split("、")[0]label = 1 if label == "dog" else 0return img_transformed, label
train_data = CatsDogsDataset(train_list, transform = train_transforms)
valid_data = CatsDogsDataset(valid_list, transform = val_transforms)
test_data = CatsDogsDataset(test_list, transform = test_transforms)
train_loader = DataLoader(dataset = train_data,batch_size = batch_size,shuffle = True)
valid_loader = DataLoader(dataset = valid_data,batch_size = batch_size,shuffle = True)
test_loader = DataLoader(dataset = test_data,batch_size = batch_size,shuffle = True)
print(len(train_data), len(train_loader))
print(len(valid_data), len(valid_loader))
模型建立:
model = ViT(image_size = 224,patch_size = 16,num_classes = 2,dim = 768,depth = 12,heads = 12,mlp_dim = 3072,dropout = 0.1,emb_dropout = 0.1
).to(device) # 将Transformer模型移动到指定设备上,比如GPUmodel.load_state_dict(torch.load('vit_base_patch16_224_r.pth'),strict = False)