Vision Transformer(ViT)PyTorch代码全解析(附图解)

Vision Transformer(ViT)PyTorch代码全解析

最近CV领域的Vision Transformer将在NLP领域的Transormer结果借鉴过来,屠杀了各大CV榜单。本文将根据最原始的Vision Transformer论文,及其PyTorch实现,将整个ViT的代码做一个全面的解析。

对原Transformer还不熟悉的读者可以看一下Attention is All You Need原文,中文讲解推荐李宏毅老师的视频 YouTube,BiliBili 个人觉得讲的很明白。

话不多说,直接开始。

下图是ViT的整体框架图,我们在解析代码时会参照此图:
在这里插入图片描述

以下是文中给出的符号公式,也是我们解析的重要参照:
z=[xclass;xp1E,xp2E,…;xpNE]+Epos,E∈R(P2⋅C)×D,Epos∈R(N+1)×D(1)\mathbf{z}=[\mathbf{x}_{class};\mathbf{x}^1_p\mathbf{E},\mathbf{x}^2_p\mathbf{E},\dots;\mathbf{x}^N_p\mathbf{E}]+\mathbf{E}_{pos},\ \ \ \mathbf{E}\in\mathbb{R}^{(P^2\cdot C)\times D},\mathbf{E}_{pos}\in \mathbb{R}^{(N+1)\times D} \ \ \ \ \ \ \ \ \ \ \ \ \ (1) z=[xclass;xp1E,xp2E,;xpNE]+Epos,   ER(P2C)×D,EposR(N+1)×D             (1)

zℓ′=MSA(LN(zℓ−1))+zℓ−1(2)\mathbf{z'_\ell}=MSA(LN(\mathbf{z}_{\ell-1}))+\mathbf{z}_{\ell-1}\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (2) z=MSA(LN(z1))+z1                                      (2)

zℓ=MLP(LN(z′ℓ))+z′ℓ(3)\mathbf{z}_{\ell}=MLP(LN(\mathbf{z'}_{\ell}))+\mathbf{z'}_{\ell}\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (3) z=MLP(LN(z))+z                                 (3)

y=LN(zL0)(4)\mathbf{y}=LN(\mathbf{z}_L^0)\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (4) y=LN(zL0)                                  (4)

导入需要的包

import torch
from torch import nn, einsum
import torch.nn.functional as Ffrom einops import rearrange, repeat
from einops.layers.torch import Rearrange

都是搭建网络时常用的PyTorch包,其中在卷积神经网络的搭建中并不常用的einops和einsum,还不熟悉的读者可以参考博客:einops和einsum:直接操作张量的利器。

pair函数

def pair(t):return t if isinstance(t, tuple) else (t, t)

作用是:判断t是否是元组,如果是,直接返回t;如果不是,则将t复制为元组(t, t)再返回。
用来处理当给出的图像尺寸或块尺寸是int类型(如224)时,直接返回为同值元组(如(224, 224))。

PreNorm


class PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.norm = nn.LayerNorm(dim)self.fn = fndef forward(self, x, **kwargs):return self.fn(self.norm(x), **kwargs)

PreNorn对应框图中最下面的黄色的Norm层。其参数dim是维度,而fn则是预先要进行的处理函数,是以下的Attention、FeedForward之一,分别对应公式(2)(3)。
zℓ′=MSA(LN(zℓ−1))+zℓ−1(2)\mathbf{z'_\ell}=MSA(LN(\mathbf{z}_{\ell-1}))+\mathbf{z}_{\ell-1}\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (2) z=MSA(LN(z1))+z1                                      (2)

zℓ=MLP(LN(z′ℓ))+z′ℓ(3)\mathbf{z}_{\ell}=MLP(LN(\mathbf{z'}_{\ell}))+\mathbf{z'}_{\ell}\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (3) z=MLP(LN(z))+z                                 (3)

FeedForward


class FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout=0.):super().__init__()self.net = nn.Sequential(nn.Linear(dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim), nn.Dropout(dropout))def forward(self, x):return self.net(x)

FeedForward层由线性层,配合激活函数GELU和Dropout实现,对应框图中蓝色的MLP。参数dim和hidden_dim分别是输入输出的维度和中间层的维度,dropour则是dropout操作的概率参数p。

Attention

   
class Attention(nn.Module):              def __init__(self, dim, heads=8, dim_head=64, dropout=0.):super().__init__()inner_dim = dim_head * headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.attend = nn.Softmax(dim=-1)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout),) if project_out else nn.Identity()def forward(self, x):b, n, _, h = *x.shape, self.headsqkv = self.to_qkv(x).chunk(3, dim=-1)           # (b, n(65), dim*3) ---> 3 * (b, n, dim)q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)          # q, k, v   (b, h, n, dim_head(64))dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scaleattn = self.attend(dots)out = einsum('b h i j, b h j d -> b h i d', attn, v)out = rearrange(out, 'b h n d -> b n (h d)')return self.to_out(out)

Attention,Transformer中的核心部件,对应框图中的绿色的Multi-Head Attention。参数heads是多头自注意力的头的数目,dim_head是每个头的维度。

本层的对应公式就是经典的Tansformer的计算公式:
Attention(Q,K,V)=softmax(QKTdk)VAttention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dkQKT)V

Transformer


class Transformer(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):super().__init__()self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))]))def forward(self, x):for attn, ff in self.layers:x = attn(x) + xx = ff(x) + xreturn x

定义好几个层之后,我们就可以构建整个Transformer Block了,即对应框图中的整个右半部分Transformer Encoder。有了前面的铺垫,整个Block的实现看起来非常简洁。

参数depth是每个Transformer Block重复的次数,其他参数与上面各个层的介绍相同。

笔者也在图中也做了标注与代码的各部分对应。
在这里插入图片描述

ViT

class ViT(nn.Module):def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3, dim_head=64, dropout=0., emb_dropout=0.):super().__init__()image_height, image_width = pair(image_size)patch_height, patch_width = pair(patch_size)assert  image_height % patch_height ==0 and image_width % patch_width == 0num_patches = (image_height // patch_height) * (image_width // patch_width)patch_dim = channels * patch_height * patch_widthassert pool in {'cls', 'mean'}self.to_patch_embedding = nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),nn.Linear(patch_dim, dim))self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))self.cls_token = nn.Parameter(torch.randn(1, 1, dim))					# nn.Parameter()定义可学习参数self.dropout = nn.Dropout(emb_dropout)self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)self.pool = poolself.to_latent = nn.Identity()self.mlp_head = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, num_classes))def forward(self, img):x = self.to_patch_embedding(img)        # b c (h p1) (w p2) -> b (h w) (p1 p2 c) -> b (h w) dimb, n, _ = x.shape           # b表示batchSize, n表示每个块的空间分辨率, _表示一个块内有多少个值cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)  # self.cls_token: (1, 1, dim) -> cls_tokens: (batchSize, 1, dim)  x = torch.cat((cls_tokens, x), dim=1)               # 将cls_token拼接到patch token中去       (b, 65, dim)x += self.pos_embedding[:, :(n+1)]                  # 加位置嵌入(直接加)      (b, 65, dim)x = self.dropout(x)x = self.transformer(x)                                                 # (b, 65, dim)x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]                   # (b, dim)x = self.to_latent(x)                                                   # Identity (b, dim)print(x.shape)return self.mlp_head(x)                                                 #  (b, num_classes)

笔者在forward()函数代码中的注释说明了各中间state的尺寸形状,可供参考比对。

在 x 送入transformer之前,都是对应公式(1)的预处理操作:
z=[xclass;xp1E,xp2E,…;xpNE]+Epos,E∈R(P2⋅C)×D,Epos∈R(N+1)×D(1)\mathbf{z}=[\mathbf{x}_{class};\mathbf{x}^1_p\mathbf{E},\mathbf{x}^2_p\mathbf{E},\dots;\mathbf{x}^N_p\mathbf{E}]+\mathbf{E}_{pos},\ \ \ \mathbf{E}\in\mathbb{R}^{(P^2\cdot C)\times D},\mathbf{E}_{pos}\in \mathbb{R}^{(N+1)\times D} \ \ \ \ \ \ \ \ \ \ \ \ \ (1) z=[xclass;xp1E,xp2E,;xpNE]+Epos,   ER(P2C)×D,EposR(N+1)×D             (1)
positional embedding和class token由nn.Parameter()定义,该函数会将送到其中的Tensor注册到Parameters列表,随模型一起训练更新,对nn.Parameter()不熟悉的同学可参考博客:PyTorch中的torch.nn.Parameter() 详解。

在这里插入图片描述

我们知道,transformer模型最后送到mlp中做预测的只有cls_token的输出结果(如上图红框所示),而其他的图像块的输出全都不要了,是由这一步实现:

x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]                   # (b, dim)

可以看到,如果指定池化方式为'mean'的话,则会对全部token做平均池化,然后全部进行送到mlp中,但是我们可以看到,默认的self.pool='cls',也就是说默认不会进行平均池化,而是按照ViT的设计只使用cls_token,即x[:, 0]只取第一个token(cls_token)。

最后经过mlp_head,得到各类的预测值。

笔者也简单做了一张图展示整个过程中的信号流,可以结合代码中注释的维度的变化来看:

在这里插入图片描述

图中各符号含义:H,W,CH,W,CH,W,C 分别是某一张输入图像的长、宽、通道数,h,wh,wh,w 是图块的长、宽,如此这张图中块的个数就是 Hh×Ww\frac{H}{h}\times \frac{W}{w}hH×wW ,用 NpN_pNp 表示,DDD 是维度数dim,NcN_cNc 是类的个数。

至此,ViT模型的定义就全部完成了,在训练脚本中实例化一个ViT模型来进行训练即可,以下脚本可验证ViT模型正常运作。

model_vit = ViT(image_size = 256,patch_size = 32,num_classes = 1000,dim = 1024,depth = 6,heads = 16,mlp_dim = 2048,dropout = 0.1,emb_dropout = 0.1)img = torch.randn(16, 3, 256, 256)preds = model_vit(img) print(preds.shape)  # (16, 1000)

有疑惑或异议欢迎留言讨论。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532357.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux下的ELF文件、链接、加载与库(含大量图文解析及例程)

Linux下的ELF文件、链接、加载与库 链接是将将各种代码和数据片段收集并组合为一个单一文件的过程,这个文件可以被加载到内存并执行。链接可以执行与编译时,也就是在源代码被翻译成机器代码时;也可以执行于加载时,也就是被加载器加…

java 按钮 监听_Button的四种监听方式

Button按钮设置点击的四种监听方式注:加粗放大的都是改变的代码1.使用匿名内部类的形式进行设置使用匿名内部类的形式,直接将需要设置的onClickListener接口对象初始化,内部的onClick方法会在按钮被点击的时候执行第一个活动的java代码&#…

linux查看java虚拟机内存_深入理解java虚拟机(linux与jvm内存关系)

本文转载自美团技术团队发表的同名文章https://tech.meituan.com/linux-jvm-memory.html一, linux与进程内存模型要理解jvm最重要的一点是要知道jvm只是linux的一个进程,把jvm的视野放大,就能很好的理解JVM细分的一些概念下图给出了硬件系统进程三个层面内存之间的关系.从硬件上…

java function void_Java8中你可能不知道的一些地方之函数式接口实战

什么时候可以使用 Lambda?通常 Lambda 表达式是用在函数式接口上使用的。从 Java8 开始引入了函数式接口,其说明比较简单:函数式接口(Functional Interface)就是一个有且仅有一个抽象方法,但是可以有多个非抽象方法的接口。 java8…

java jvm内存地址_JVM--Java内存区域

Java虚拟机在执行Java程序的过程中会把它所管理的内存划分为若干个不同的数据区域,如图:1.程序计数器可以看作是当前线程所执行的字节码的行号指示器,通俗的讲就是用来指示执行哪条指令的。为了线程切换后能恢复到正确的执行位置Java多线程是…

java情人节_情人节写给女朋友Java Swing代码程序

马上又要到情人节了,再不解风情的人也得向女友表示表示。作为一个程序员,示爱的时候自然也要用我们自己的方式。这里给大家上传一段我在今年情人节的时候写给女朋友的一段简单的Java Swing代码,主要定义了一个对话框,让女友选择是…

java web filter链_filter过滤链:Filter链是如何构建的?

在一个Web应用程序中可以注册多个Filter程序,每个Filter程序都可以针对某一个URL进行拦截。如果多个Filter程序都对同一个URL进行拦截,那么这些Filter就会组成一个Filter链(也叫过滤器链)。Filter链用FilterChain对象来表示,FilterChain对象中…

java final static_Java基础之final、static关键字

一、前言关于这两个关键字,应该是在开发工作中比较常见的,使用频率上来说也比较高。接口中、常量、静态方法等等。但是,使用频繁却不代表一定是能够清晰明白的了解,能说出个子丑演卯来。下面,对这两个关键字的常见用法…

java语言错误的是解释运行的_Java基础知识测试__A卷_答案

考试宣言:同学们, 考试考多少分不是我们的目的! 排在班级多少的名次也不是我们的初衷!我的考试的目的是要通过考试中的题目,检查大家在这段时间的学习中,是否已经把需要掌握的知识掌握住了,如果哪道题目你不会做,又或者做错了, 那么不用怕, 考完试后, 导师讲解的时候你要注意听…

java 持续集成工具_Jenkins-Jenkins(持续集成工具)下载 v2.249.2官方版--pc6下载站

Jenkins是一款基于java开发的持续集成工具,是一款开源软件,主要用于监控持续重复的工作,为开发者提供一个开发易用的软件平台,使软件的持续集成变成可能。。相关软件软件大小版本说明下载地址Jenkins是一款基于java开发的持续集成…

java中线程调度遵循的原则_深入理解Java多线程核心知识:跳槽面试必备

多线程相对于其他 Java 知识点来讲,有一定的学习门槛,并且了解起来比较费劲。在平时工作中如若使用不当会出现数据错乱、执行效率低(还不如单线程去运行)或者死锁程序挂掉等等问题,所以掌握了解多线程至关重要。本文从基础概念开始到最后的并…

java类构造方法成员方法练习_面向对象方法论总结 练习(一)

原标题:面向对象方法论总结 & 练习(一)学习目标1.面向对象与面向过程2.类与对象的概念3.类的定义,对象的创建和使用4.封装5.构造方法6.方法的重载内容1.面向对象与面向过程为什么会出现面向对象反分析方法?因为现实世界太复杂多变&#x…

mysql 统计查询不充电_MySql查询语句介绍,单表查询,来充电吧

mysql在网站开发中,越来越多人使用了,方便部署,方便使用。我们要掌握mysql,首先要学习查询语句。查询单个表的数据,和多个表的联合查询。下面以一些例子来先简单介绍下单表查询。操作方法01首先看下我们例子用到的数据表&#xff…

MySQL线上优化_线上MySQL千万级大表,如何优化?

前段时间应急群有客服反馈,会员管理功能无法按到店时间、到店次数、消费金额进行排序。经过排查发现是 SQL 执行效率低,并且索引效率低下。图片来自 Pexels应急问题商户反馈会员管理功能无法按到店时间、到店次数、消费金额进行排序,一直转圈…

php创建表设置编码,教您在Zend Framework里如何设置数据库编码以及怎样给数据表设定前缀!...

当我们在开发项目时..大家都会遇到一个问题就是:数据库的编码问题.当然我们不用Zend Framework做为项目开发的框架时..我们可以很快,很容易搞定这个小问题..但是当我们要使用Zend Framewok开发项目时..我们可能一时会不知道如何解决这个小问题..比如我就是这样的人..在开发这个…

python 怎么将数组转为列表_怎么将视频转为GIF动态图 表情包怎么制作

说到GIF,大家应该都不陌生了吧!尤其是在聊天中使用较多,似乎一言不合就开启了斗图模式,但是我们平时使用的GIF一般都是软件中自带的,其实自己制作也是很方便的,而且会发现很有趣,不但可以直接录…

proteus里面没有stm32怎么办_嵌入式单片机之stm32串口你懂了多少!!

stm32作为现在嵌入式物联网单片机行业中经常要用多的技术,相信大家都有所接触,今天这篇就给大家详细的分析下有关于stm32的出口,还不是很清楚的朋友要注意看看了哦,在最后还会为大家分享有些关于stm32的视频资料便于学习参考。点击…

tomcat不能解析php,tomcat不支持php怎么办

tomcat不支持php的解决办法:首先将“PHP/Java Bridge”下的相关文件复制到tomcat的lib目录下;然后修改tomcat安装目录下conf文件夹里的“web.xml”文件;最后重启tomcat即可。java开发者都知道,tomcat是用来部署java web项目的。这…

c++ dicom图像切割_【高训智造】原创专业课堂第225期--定位滑座的线切割加工

原标题:【高训智造】原创专业课堂第225期--定位滑座的线切割加工欢迎来到【高训智造】原创专业课堂第225期,本期由郭沃沛老师给大家带来线切割小课堂。定位滑座的线切割加工郭沃沛1零件图如图1所示为定位滑座零件图,其材料为45钢,…

c iostream.源码_通达信指标公式源码精准买卖主图指标公式免费分享

V0:EMA(C,5),COLOR00FF66;V1:EMA(C,10),COLOR00FF66;V2:EMA(C,15),LINETHICK2,COLORFFFFFF;V3:EMA(C,30);V4:EMA(C,60),COLOR3366FF;年线:EMA(C,90),COLORBLUE;M1:1000*V1/V4<1015 AND 1000*V1/V4>975;M2:1000*V2/V4<1020 AND 1000*V2/V4>980;M3:1000*V3/V4<101…