音视频开发之旅(90)-Vision Transformer论文解读与源码分析

目录

1.背景和问题

2.Vision Transformer(VIT)模型结构

3.Patch Embedding

4.实现效果

5.代码解析

6.资料

一、背景和问题

上一篇我们学习了Transformer的原理,主要介绍了在NLP领域上的应用,那么在CV(图像视频)领域该如何使用?

最直观的想法就是把每一个像素像NLP中一个文字一样处理,理论上可行,但是这样做有什么不足吗?

Transformer的自注意力机制的计算复杂度是O(n^2),其中n是序列长度,一张720*1280的图片就需要921600个token,这将导致巨大的计算开销,使得模型的训练和推理非常缓慢。图像不同像素之间存在很多冗余信息(编码时会进行帧内压缩),是否可以采用类似编码压缩技术中的宏块方案呐(把图像分割为固定大小的16x16、8x8、4x4的的块)。

二、VIT模型结构

VIT的思路和视频编码的宏块思想类似,把图像分割为固定大小pathchs,然后通过线性变换得到patch embedding,将图像的patch embeddings送入transformer的Encoder进行特征提取,在根据不同任务添加不同的Head。ViT模型原理如下图所示:

图片

模型由三个模块组成:

  • Linear Projection of Flattened Patches(该网络的前处理,把图像分割为patch,然后进行Embedding)

  • Transformer Encoder(该网络的backbone,用于特征提取)

  • MLP Head(该网络的head,用于分类任务)

主要的公式如下:

图片

图片

可以看到VIT只用到了Transfomer的Encoder作为backbone进行特征提取,TransfomerEncoderLayer也是使用Multi-head Attention,不同的是LayerNormalation放在了Multi-head Attention的前面。和Transfromer的结构主要区别在于Embedding的过程,如果对于注意力机制还不太清楚,建议复习下上一篇。

三、Patch Embedding

图片

关键点包括:

  1. 图像被分割成固定大小的patches。

  2. 每个patch通过线性投影映射到嵌入空间。

  3. 添加一个特殊的分类token。

  4. 加入位置编码以保留空间信息。

将2D图像转换为一个1D序列,使得标准Transformer架构可以直接处理图像数据,允许ViT像处理文本序列一样处理图像,充分利用了Transformer的自注意力机制来捕捉图像中的全局依赖关系。

下面我们用一个示例来说明PatchEmbedding的过程。

输入一张:256x256的rgb图像,然后把它分割为64个32x32的patchs,对patchs进行线性投影得到序列长度为64,dim为1024的Embedding,然后加上用于分类的可训练的classToken(随机初始化),最后在加上相同形状的PosEmbedding 作为TransformEncodeer的输入。

图片

图片来自:详解 Vision Transformer

图片

不同于Transfromer的PositionEmbedding(采用sin和cos固定编码),VIT中的PositionEmbedding采用了符合正态分布随机初始化,可训练的方案(bert也采用了类似方式)

论文中对学习到的positional embedding进行了可视化,发现相近的patchs的positional embedding比较相似,而且同行或同列的positional embedding也相近:

图片

需要注意的是:如果改变图像的输入大小,ViT不会改变patchs的大小,patchs的数量会发生变化,之前学习的pos_embed就维度对不上了,通常ViT采用插值的方式来解决这个问题,但效果不好,另外一篇论文给出了说明和解决措施 https://arxiv.org/pdf/2102.10882,有兴趣可以进一步研究下。

四、实验效果

ViT的训练策略:先在大数据集上做预训练,然后在小数据集上做迁移使用。

图片

如果在小数据集ImageNet上做预训练时,VIT的模型架构效果普遍低于ResNet搭建的BiT网络;当在中等数据集ImageNet-21k上做预训练时,VIT的模型架构基本位于BiT最好和最差的之间;而当在大数据集JFT-300M上做预训练时,VIT的模型架构最好的效果已经超过了BiT。

结论:VIT模型需要在大数据集上进行预训练,在大数据集上预训练的效果会比卷积神经网络的上限高

例如下图先在有3亿张图像的JFT大数据集上预训练,然后在ImageNet上进行微调,准确率达到88.55%

图片

ViT 还可根据 Attention Map 来可视化,得知模型具体关注图像的哪个部分,

图片

五、代码解析

源码地址:https://github.com/lucidrains/vit-pytorch

图片

图片来自:Vision Transformer详解

3.1、调用

import torchfrom vit_pytorch import ViT
def test():    #VIT的具体实现在vit.py中    v = ViT(        #原始图像尺寸        image_size = 256,        #切割的每个图像块的尺寸        patch_size = 32,        #类别数量        num_classes = 1000,        #Transformer隐变量维度大小        dim = 1024,        #Transformer Encoder层的个数        depth = 6,        #Multi-Head Attention 头的个数        heads = 16,        #mlp层 hid层的维度        mlp_dim = 2048,        dropout = 0.1,        emb_dropout = 0.1    )
    img = torch.randn(1, 3, 256, 256)
    preds = v(img)

3.2、Attention和FFN的实现

# helpers#确保t为元组def pair(t):    return t if isinstance(t, tuple) else (t, t)
# classes#前馈网络class FeedForward(nn.Module):    def __init__(self, dim, hidden_dim, dropout = 0.):        super().__init__()        self.net = nn.Sequential(            nn.LayerNorm(dim),            nn.Linear(dim, hidden_dim),            nn.GELU(),            nn.Dropout(dropout),            nn.Linear(hidden_dim, dim),            nn.Dropout(dropout)        )
    def forward(self, x):        return self.net(x)
#VIT中的self-Attention实现,这里也是多头注意力机制class Attention(nn.Module):    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):        super().__init__()        inner_dim = dim_head *  heads #多头的个数heads:16 * 每个头的维度:64 =1024        project_out = not (heads == 1 and dim_head == dim)
        self.heads = heads        self.scale = dim_head ** -0.5 # dim_head =64, scale=1/8
        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)        self.dropout = nn.Dropout(dropout)        #to_qkv线性变化,将输入映射到一个三维空间,以便在多头注意力机制中生成QKV 输入特征维度为dim (1024),输出维度为inner_dim*3 (1024*3)        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) #dim:1024,inner_dim:1024
        self.to_out = nn.Sequential(            nn.Linear(inner_dim, dim),            nn.Dropout(dropout)        ) if project_out else nn.Identity()
    def forward(self, x):        x = self.norm(x)        #将输入数据x映射到三维空间,x.shape为[1,65,1024],to_qkv经过线性变换后输出维度为[1,65,1024*3]; chunk(3,-1)将最后一个维度分割为3个子张量,生成qkv元组        qkv = self.to_qkv(x).chunk(3, dim = -1)        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) #进行形状转换,生成[batchsize,heads,squcelen,dim] 值为[1,16,65,64]        #经典的attention计算, 把q和K的转置相乘除以缩放系数,得到相似性系数        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale        #沿最后一维度进行softmax归一化        attn = self.attend(dots)        attn = self.dropout(attn)        #attn[1, 16, 65, 65]点乘V [1, 16, 65, 64]输出[1, 16, 65, 64]        out = torch.matmul(attn, v)        out = rearrange(out, 'b h n d -> b n (h d)') #对多头进行concate,得到[1, 65, 1024]        return self.to_out(out)

3.3、Transfromer Encoder层的实现

#VIT中Transfromer的实现,用到了Transformer的Encoder层. 和原始的Transfromer稍微有些差异,主要是layernormalization的位置class Transformer(nn.Module):    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):#dim:1024,depth:6;heads:16;dim_head:64;mlp_dim:2048;dropout:0.1        super().__init__()        self.norm = nn.LayerNorm(dim)        self.layers = nn.ModuleList([])        for _ in range(depth):            self.layers.append(nn.ModuleList([                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),                FeedForward(dim, mlp_dim, dropout = dropout)            ]))
    def forward(self, x):        for attn, ff in self.layers:            x = attn(x) + x #Attention进行残差            x = ff(x) + x #MLP进行残差
        return self.norm(x)

3.4、ViT的实现

#入口Module,这里的posEmbedding没有使用固定编码,而是像bert一样可训练的. 把image切分成多个patch,展平进行to_patch_embedding处理class ViT(nn.Module):    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):        super().__init__()        image_height, image_width = pair(image_size)        patch_height, patch_width = pair(patch_size)
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'        # num_patches =(256//32)*(256//32)=64;  patch_dim:3*32*32=3072; dim=1024        num_patches = (image_height // patch_height) * (image_width // patch_width)        patch_dim = channels * patch_height * patch_width        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'        #使用einops的Rearrange优雅地处理张量维度        self.to_patch_embedding = nn.Sequential(            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),#这里(h p1) (w p2)就相当于h与w变为原来的1/p1,1/p2            nn.LayerNorm(patch_dim),            nn.Linear(patch_dim, dim),#patch_dim3072,dim 1024 线性变换            nn.LayerNorm(dim),        )
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) # 创建一个形状为 (1, 65, 1024) 的随机张量,VIT中PE和Transformer中positionEmbedding的定义不同,这里是一个可以训练的模块        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))#创建一个随机的张量(1,1,1024)的cls_token        self.dropout = nn.Dropout(emb_dropout)
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
        self.pool = pool        self.to_latent = nn.Identity()
        self.mlp_head = nn.Linear(dim, num_classes)
    def forward(self, img):        x = self.to_patch_embedding(img)        b, n, _ = x.shape
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)        x = torch.cat((cls_tokens, x), dim=1)        x += self.pos_embedding[:, :(n + 1)]        x = self.dropout(x)        #输入和输出的形状都是 torch.Size([1, 65, 1024])        x = self.transformer(x)         #这里的pool为cls分类,所以沿dim=1,取第1个数据        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]        #这里的to_latent目前就是一个恒等变换层nn.Identity(),即输入和输出每个任何变化,可以去掉,这里起到占位的作用        x = self.to_latent(x)        return self.mlp_head(x)

六、资料

1.论文VIT:https://arxiv.org/pdf/2010.11929

2.源码:https://github.com/lucidrains/vit-pytorch

3.timm/models/vision_transformer.py: https://github.com/huggingface/pytorch-image-4.models/blob/main/timm/models/vision_transformer.py

5.ViT论文逐段精读【论文精读】https://www.bilibili.com/video/BV15P4y137jb

6.Vision Transformer(vit)网络详解 https://www.bilibili.com/video/BV1Jh411Y7WQ

7.李宏毅-Transformer 

https://www.bilibili.com/video/av56239558

8.详解VisionTransformer

 https://blog.csdn.net/qq_39478403/article/details/118704747

9.Vision Transformer详解  https://blog.csdn.net/qq_37541097/article/details/118242600

10.ViT代码超详细解读 https://blog.csdn.net/weixin_43334693/article/details/131836233

11.ViT PyTorch代码全解析(附图解)

https://blog.csdn.net/weixin_44966641/article/details/118733341

12.Vision Transformer(VIT)代码分析 https://blog.csdn.net/qq_38683460/article/details/127346916

13.ViT:视觉Transformer backbone网络ViT论文与代码详解 https://mp.weixin.qq.com/s/Nok5UQ2nzex94GXyrltiBg

14.可视化VIT中的注意力 https://mp.weixin.qq.com/s/O-56hxVa6Fgiz2YpjXTodQ

15."未来"的经典之作 ViT:transformer is all you need! https://www.cvmart.net/community/detail/4461

16.搞懂 Vision Transformer 原理和代码 https://mp.weixin.qq.com/s/ozUHHGMqIC0-FRWoNGhVYQ

17.3W字长文带你轻松入门视觉transformer https://zhuanlan.zhihu.com/p/308301901

18.Vision Transformer, LLM, Diffusion Model 超详细解读 (原理分析+代码解读) https://zhuanlan.zhihu.com/p/348593638

19.einops.repeat, rearrange, reduce优雅地处理张量维度 https://blog.csdn.net/qq_37297763/article/details/120348764

感谢你的阅读

接下来我们继续学习输出AI相关内容,欢迎关注公众号“音视频开发之旅”,一起学习成长。

欢迎交流

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/52936.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

算法复盘——LeetCode hot100:哈希

文章目录 哈希表哈希表的基本概念哈希表的使用1. 插入操作2. 查找操作3. 删除操作 哈希表的优点和缺点1.两数之和复盘 242.有效的字母异位词复盘 49.字母异位词分组复盘 128. 最长连续序列复盘HashSet 哈希表 先来搞清楚什么是哈希表吧~ 概念不清楚方法不清楚怎么做题捏 哈希表…

使用mysql保存密码

登录MySQL 这行命令告诉MySQL客户端程序用户root准备登录,-p表示告诉 MySQL 客户端程序提示输入密码。 mysql -u root -p创建数据库 create database wifi; use wifi;create table password(user_password CHAR(8),primary key(user_password));源码 代码编译 …

QT实战项目之音乐播放器

项目效果演示 myMusicShow 项目概述 在本QT音乐播放器实战项目中,开发环境使用的是QT Creator5.14版本。该项目实现了音乐播放器的基本功能,例如开始播放、停止播放、下一首播放、上一首播放、调节音量、调节倍速、设置音乐播放模式等。同时还具备搜索功…

Centos 下载和 VM 虚拟机安装

1. Centos 下载 阿里云下载地址 centos-7.9.2009-isos-x86_64安装包下载_开源镜像站-阿里云 2. VM 中创建 Centos 虚拟机 2.1 先打开 VM 虚拟机,点击首页的创建新的虚拟机 2.2 选择自定义,然后点击下一步。 2.3 这里默认就好,继续选择下一…

gitlab SSH的使用

一、 安装git bash https://git-scm.com/download/win 下载windows 版本,默认安装即可。 二、使用命令 打开本地git bash,使用如下命令生成ssh公钥和私钥对 ssh-keygen -t rsa -C ‘xxxxxx.com’ 然后一路回车 (-C 参数是你的邮箱地址) 若是想输入密码可以输入…

算法-最长连续序列

leetcode的题目链接 这道题的思路主要是要求在O(n)的时间复杂度下,所以你暴力解决肯定不行,暴力至少两层for循环,所以要在O(n)的时间复杂度下,你可以使用HashSet来存储数组,对于每个数字&#…

黑马JavaWeb开发笔记07——Ajax、Axios请求、前后端分离开发介绍、Yapi详细配置步骤

文章目录 前言一、Ajax1. 概述2. 作用3. 同步异步4. 原生Ajax请求(了解即可)5. Axios(重点)5.1 基本使用5.2 Axios别名(简化书写) 二、前后端分离开发1. 介绍1.1 前后台混合开发1.2 前后台分离开发方式&…

Docker续6:容器网络

1.bridge-utils 一个用于Linux系统的网络桥接工具集。它提供了一些命令行工具,帮助用户创建、管理和配置网络桥接。网络桥接是一种将多个网络接口连接在一起,以使它们能够作为单个网络段进行通信的技术。 bridge-utils 常用的命令包括: b…

【 OpenHarmony 系统应用源码魔改 】-- Launcher 之「桌面布局定制」

前言 阅读本篇文章之前,有几个需要说明一下: 调试设备:平板,如果你是开发者手机,一样可以加 Log 调试,源码仍然是手机和平板一起分析;文章中的 Log 信息所显示的数值可能跟你的设备不一样&…

单片机编程魔法师-并行多任务程序

程序架构 程序代码 小结 数码分离,本质上就是将数据和代码逻辑进行分离,跟第一章使用数据驱动程序一样的道理。 不过这里不同之处在于。这里使用通过任务线程,但是却有2个任务在运行,两个任务都通过先初始化任务数据参数&#x…

SQLite的安装和使用

一、官网链接下载安装包 点击跳转 步骤:点击安装这个红框的dll以及红框下面的tools (如果有navicat可以免上面这个安装步骤,安装上面这个是为了能在命令行敲SQL而已) 二、SQLite的特点 嵌入的(无服务器的&#x…

hello树先生——AVL树

AVL树 一.什么是AVL树二.AVL树的结构1.AVL树的节点结构2.插入函数3.旋转调整 三.平衡测试 一.什么是AVL树 二叉搜索树虽可以缩短查找的效率,但如果数据有序或接近有序二叉搜索树将退化为单支树,查找元素相当于在顺序表中搜索元素,效率低下。…

python学习——爬虫之session请求处理cookie

import requestssessionrequests.session() url"https://passport.17k.com/ck/user/login" data{"loginName": "19139186287","password":"2001022600hzk"} ressession.post(url,datadata) print(res.text)# session通过会话…

Windows系统中批量管理Windows服务器远程桌面工具——RDCMan

一、背景 在公司没有部署对应的堡垒机系统之前,做运维测试工作的人员,需要管理大量的服务器,每天需要对服务器进行必要的巡检、系统更新发布等内容,特别是有很多Windows服务器的时候,如果我们使用Windows自带的“远程桌面连接”只能一台台连接,比较繁琐。并且不能知道那台…

【手撕数据结构】二叉树的性质

目录 叶子节点和边的性质概念小试牛刀 叶子节点和边的性质 概念 可以看到度为0的节点如F没有边,度为1的节点如C有一条边,而度为2的节点如B有两条边。那么设度为2的节点为a个,度为1的节点为b个。二叉树边 2ab另⼀⽅⾯,由于共有 a…

AcWing 897. 最长公共子序列

动态规划就是多见识应用题就完事儿了&#xff0c;也没有什么好说的。 讲解参考&#xff1a; 【E05 线性DP 最长公共子序列】 #include<iostream> #include<algorithm> #define N 1010 using namespace std; char a[N],b[N]; int n,m; int f[N][N]; int main(){…

Loki Unable to fetch labels from Loki (no org id)

应该是多租户相关导致的 参考文档: 参考文档cMulti-tenancy | Grafana Loki documentationDescribes how Loki implements multi-tenancy to isolate tenant data and queries.https://grafana.com/docs/loki/latest/operations/multi-tenancy/ https://github.com/grafana…

主流AI绘画工具-StableDiffusion本地部署方法(mac电脑版本)

Stable Diffusion是一款强大的AI生成图像模型&#xff0c;它可以基于文本描述生成高质量的图像。对于想要在本地运行此模型的用户来说&#xff0c;使用Mac电脑部署Stable Diffusion是一个非常吸引人的选择&#xff0c;特别是对于M1或M2芯片的用户。本文将详细介绍如何在Mac上本…

高效能低延迟:EasyCVR平台WebRTC支持H.265在远程监控中的优势

TSINGSEE青犀视频EasyCVR视频汇聚平台在WebRTC方面确实支持H.265编码&#xff0c;尽管标准的WebRTC API在大多数浏览器中默认并不支持H.265&#xff08;也称为HEVC&#xff0c;高效视频编码&#xff09;编码。EasyCVR平台通过一系列创新的技术手段&#xff0c;实现了在WebRTC协…

食家巷中秋美食,味蕾上的团圆盛宴

月到中秋分外明&#xff0c;在这个充满温情与思念的节日里&#xff0c;美食成为了人们传递情感、共享团圆的重要载体。而食家巷&#xff0c;以其独特的中秋美食&#xff0c;为这个佳节增添了一抹别样的风味。 走进食家巷&#xff0c;仿佛踏入了一个美食的宝藏之地。这里的传统…