【AI】计算机视觉VIT文章(Transformer)源码解析

论文:Dosovitskiy A, Beyer L, Kolesnikov A, et al. An image is worth 16x16 words: Transformers for image recognition at scale[J]. arXiv preprint arXiv:2010.11929, 2020

源码的Pytorch版:https://github.com/lucidrains/vit-pytorch

0.前言

Transformer提出后在NLP领域中取得了极好的效果,其全Attention的结构,不仅增强了特征提取能力,还保持了并行计算的特点,可以又快又好的完成NLP领域内几乎所有任务,极大地推动自然语言处理的发展。
VIT这篇文章就是将Transformer模型应用在了CV领域,它将图像处理成Transformer模型可以应用的形式,沿用NLP领域中Transformer的方法,直接验证了其精度可以和ResNet不相上下,展示了在计算机视觉中使用纯Transformer结构的可能,为Transformer在CV领域的应用打开了大门。

直接读文章通常比较抽象,英文的原文更能劝退一大部分人,但对于程序员来说,代码是通行于世界的语言,理解起来就比较简单,结合源码理解论文中的结构,就比较事半功倍。

在这里插入图片描述

上图是VIT文章中的结构,我们看图提问题,从数据的流向来看:图像怎么切分重排的?Linear Projection of Flattened Patches对图像作了什么,怎么让图像变成Transformer能够输入的格式?
Position Embedding是怎么做图像位置编码的,为什么会多出来一个0的位置编码?Transformer Encoder中的各个结构分别代表什么,是怎么实现的?输出的类别是什么?

1.图像是如何适配Transformer输入的?

Transformer的输入是一个一维的向量,而我们的图像是二维的,需要把图像拉伸成一维的,最简单的方式就是沿着x轴展开,将所有的行拼接在一起,也就是Flatten的操作,但是这样处理会导致向量维度比较大,而且同一张图片也只能生成一个Embedding,不能适配Multi-head Embedding(这种解释有点牵强,有点拿结果去解释原因)。

1.1 图像切分重排

如结构图所示,VIT采取了图像切分重排的方式,将一个完整的图像按照行列的方向切分成小块,然后再进行后续处理。切分重排是怎么实现的,我们看代码:

self.to_patch_embedding = nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), #图片切分重排nn.Linear(patch_dim, dim), # Linear Projection of Flattened Patches)

这里的Rearrange使用到了einops的库,相关的介绍可以查看文档
这里主要是对b c (h p1) (w p2) -> b (h w) (p1 p2 c)表达式的理解,b是batchsize,c是chanel,h和w是图像的高度和宽度,表达式前边的(h p1)表示将输入的h重新拆分成一个h*p1的向量,同理(w p2)是将输入的w拆分成w*p2的向量,这里的p1和p2是模型定义的patch_height和patch_width,可以理解为切分后小图的高和宽;
表达式右边表示输出向量的维度,(p1 p2 c)表示这3个数相乘,表示的是切分后一个小图的一维向量大小,(h w)则表示总共有多少个切分后的小图所生成的向量。这里的h和w的值跟输入的h和w值不同,表示的是原有h和w除以patch_height或patch_width的值,也就是在高和宽上各能切分出几个小图。
通过这样的处理,就将输入的c个channel的h*w的二维向量,转换成小图展开后的一维向量。

1.2 构建patch0

图像在输入Transformer之前,还连接了一个patch0,这一步的操作可以认为是延续了nlp的操作,在VIT这边的操作,可以认为是将分类类别拼接上去。

cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)

1.3 Positional Embedding

为了在模型中引入位置信息,VIT引入了位置编码的形式,也就是图中的0、1、2、3…

x += self.pos_embedding[:, :(n + 1)]

然后将position_embedding与图像的Embedding相加。

2.Transformer Encoder中的各个结构分别代表什么,是怎么实现的?

结构中的Norm可以认为是归一化处理,MLP是多个全连接层,理解起来都比较简单。其实最主要的结构就是Multi-head Attention。
在这里插入图片描述

2.1 Attention 定义

class Attention(nn.Module):def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):super().__init__()inner_dim = dim_head *  headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.attend = nn.Softmax(dim = -1)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x):b, n, _, h = *x.shape, self.headsqkv = self.to_qkv(x).chunk(3, dim = -1)#得到qkvq, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scaleattn = self.attend(dots)out = einsum('b h i j, b h j d -> b h i d', attn, v)out = rearrange(out, 'b h n d -> b n (h d)')return self.to_out(out)

2.2 Attention 的理解

VIT不同于之前RNN的地方就是引入了Attention机制,其本质可以认为是相似度计算,是计算每一个输入值与其他输入值的相似度,然后带入到计算中,如图所示,q是输入的查询值,k是关键词,v是计算值,计算每一个q与其他k的相似度,然后再带入计算中去。
在这里插入图片描述

其对应的计算公式为:
在这里插入图片描述
相似度计算是用的点积的形式,除以根号下dk是为了抑制极端值,保证softmax之后数值不至于丧失梯度。对应的代码如下:

qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

然后过一个softmax

attn = self.attend(dots) # self.attend = nn.Softmax(dim = -1)
attn = self.dropout(attn)

然后跟value值相乘

out = torch.matmul(attn, v)

2.3 Multihead Attention

多头注意力机制可以认为是定义多个Attention(每个attention关注的重点不同)分别来对数据进行处理,这里从源码中的循环结构可以体现出来:

class Transformer(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):super().__init__()self.norm = nn.LayerNorm(dim)self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),FeedForward(dim, mlp_dim, dropout = dropout)]))def forward(self, x):for attn, ff in self.layers:x = attn(x) + xx = ff(x) + xreturn self.norm(x)

3.使用示例

这里可以参考官方的readme文档

$ pip install vit-pytorch

使用示例

import torch
from vit_pytorch import ViTv = ViT(image_size = 256,patch_size = 32,num_classes = 1000,dim = 1024,depth = 6,heads = 16,mlp_dim = 2048,dropout = 0.1,emb_dropout = 0.1
)img = torch.randn(1, 3, 256, 256)preds = v(img) # (1, 1000)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/587492.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

三角函数两角和差公式推导

一.几何推理 1.两角和公式 做一斜边为1的直角△ABC,任意旋转非 k Π , k N kΠ,kN kΠ,kN,补充如图,令 ∠ A B C ∠ α , ∠ C B F ∠ β ∠ABC∠α,∠CBF∠β ∠ABC∠α,∠CBF∠β ∴ ∠ D B F ∠ D B A ∠ α ∠ β 90 , ∠ D A …

Linux基础知识学习3

vim编辑器 其分为四种模式 1.普通(命令)模式 2.编辑模式 3.底栏模式 4.可视化模式 vim编辑器被称为编辑器之神,而Emacs更是神之编辑器 普通模式: 1.光标移动 ^ 移动到行首 w 跳到下一个单词的开头…

C#中使用as关键字将对象转换为指定类型

目录 一、定义 二、示例 三、生成 使用as关键字可以将对象转换为指定类型,与is关键字不同,is关键字用于检查对象是否与给定类型兼容,如果兼容则返回true,如果不兼容则返回false。而as关键字会直接进行类型转换,如果…

小白备战蓝桥杯:Java集合与数据结构

目录 什么是集合&#xff1f; 集合的分类 <> : 泛型 浅谈泛型 代码示例 细说泛型 泛型类 泛型方法 泛型接口 泛型通配符 Collection接口 集合的通用遍历方式 1、迭代器遍历 2、增强for循环 3、forEach方法 4、代码示例 List接口 方法 List集合的遍历方…

【哈希数组】697. 数组的度

697. 数组的度 解题思路 首先创建一个IndexMap 键表示元素 值表示一个列表List list存储该元素在数组的所有索引之后再次创建一个map1 针对上面的List 键表示列表的长度 值表示索引的差值遍历indexmap 将所有的list的长度 和 索引的差值存储遍历map1 找到最大的key 那么这个Ke…

基于Python的B站排行榜大数据分析与可视化系统

温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长 QQ 名片 :) 1. 项目简介 本文介绍了一项基于Python的B站排行榜大数据分析与可视化系统的研究。通过网络爬虫技术&#xff0c;系统能够自动分析B站网址&#xff0c;提取大量相关文本信息并存储在系统中。通过对这些信息进行…

LoongArch指令集-特权指令系统——摘抄自胡伟武体系结构和龙芯架构32位精简版参考手册

例外与中断 1 中断 1.1 中断类型 龙芯架构 32 位精简版下的中断采用线中断的形式。每个处理器核内部可记录 12 个线中断&#xff0c;分别是&#xff1a;1 个核间中断&#xff08;IPI&#xff09;&#xff0c;1 个定时器中断&#xff08;TI&#xff09;&#xff0c;8 个硬中断…

CSAPP: LinkBomb 重定位和链接题解(一)

前言 我看了一下&#xff0c;网上关于 LinkBomb 的题解不是很多&#xff0c;LinkBomb 不是 CSAPP 目前大纲的内容&#xff0c;大多数都是写的 LinkLab。如果你做的作业内容是要求每关输出学号&#xff0c;那么你就是跟我一样的 LinkBomb 的实验&#xff08;需要注意的是&#…

emacs:Searching for program: No such file or directory,sml;

首先&#xff0c;编辑一个现有的或新的 SML 文件&#xff08;如果没有其他方便的方法&#xff0c;可尝试C-x C-f test.smlC-x C-f test.sml 创建一个新文件&#xff09;。你会看到 Emacs 窗口底部的模式显示从 "基本"&#xff08;或其他任何模式&#xff09;变成了 S…

OSG 关于MVPW变换

目录 1、模型 Model 2、观察矩阵 ViewMatrix 4、窗口矩阵变化 5、总结 在osg中观察矩阵接口设置如下: 其中eye是相机的世界坐标位置,center是相机观察的位置,up是相机向上向量。 在计算机的三维世界中&#xff0c;相机如同我们的眼睛&#xff0c;捕捉眼前的每一副画面&#xff…

20231231_小米音箱接入chatgpt

参考资料&#xff1a; GitHub - yihong0618/xiaogpt: Play ChatGPT and other LLM with Xiaomi AI Speaker 小爱音箱ChatGPT的折腾记录&#xff1a;win平台部署并运行成功_哔哩哔哩_bilibili GitHub - chatanywhere/GPT_API_free: Free ChatGPT API Key&#xff0c;免费Chat…

UG装配-接触对齐

UG装配约束命令在如下位置 首选接触&#xff1a;含接触和对齐&#xff0c;自动判断两种类型 接触&#xff1a;约束对象使其曲面法向在相反方向&#xff0c;并共面或共线 对齐&#xff1a;约束对象使其曲面法向在同一方向&#xff0c;并共面或共线 自动判断中心/轴&#xff1…

Mysql实时数据同步工具Alibaba Canal 使用

目录 Mysql实时数据同步工具Alibaba Canal 使用Canal是什么&#xff1f;工作原理重要版本更新说明 环境准备安装Canalwindow Java : Canal Client 集成依赖编码 工作流程开启原生MQRocketMQ 安装部署 canal配置说明1.1 canal.properties常用配置介绍&#xff1a;2.common参数定…

分库分表之Mycat应用学习一

1 为什么要分库分表 1.1 数据库性能瓶颈的出现 对于应用来说&#xff0c;如果数据库性能出现问题&#xff0c;要么是无法获取连接&#xff0c;是因为在高并发的情况下连接数不够了。要么是操作数据变慢&#xff0c;数据库处理数据的效率除了问题。要么是存储出现问题&#xf…

C#中使用is关键字检查对象是否与给定类型兼容

目录 一、定义 二、示例 三、生成 在程序的开发过程中经常会使用类型转换&#xff0c;如果类型转换不成功则会出现异常&#xff0c;从抛出异常到捕获并处理异常&#xff0c;无形中增加了系统的开销&#xff0c;而且太过频繁地处理异常还会严重地影响系统的稳定性。is关键字可…

双指针刷题(三)

所有算法文章链接&#xff08;最底部&#xff09; http://t.csdnimg.cn/IbllR 1.有效三角形个数 力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 1.分析题意 给一个非负的数组&#xff0c;判断这个数组能组成多少个三角形。 2.解题思路 补充知识…

前端vue uni-app使用Vue和ECharts构建交互式树形结构图

题目&#xff1a;使用Vue和ECharts构建交互式树形结构图 摘要&#xff1a;本文介绍了如何使用Vue.js和ECharts构建一个交互式的树形结构图。通过整合ECharts的强大可视化功能&#xff0c;我们创建了一个可拖拽移动、点击展开和收缩的树形结构图&#xff0c;并实现了无限添加子…

【ARMv8M Cortex-M33 系列 2.1 -- Cortex-M33 使用 .hex /.srec 文件介绍】

请阅读【嵌入式开发学习必备专栏 之Cortex-M33 专栏】 文章目录 HEX 文件介绍英特尔十六进制文件格式记录类型hex 示例Cortex-M 系列hex 文件的使用 hex 文件和srec 文件生成Motorola S-Record (srec) 格式 HEX 文件介绍 .hex 文件通常用于微控制器编程&#xff0c;包括 ARM C…

蜕变,我的2023

作者&#xff1a;苍何&#xff0c;前大厂高级 Java 工程师&#xff0c;阿里云专家博主&#xff0c;CSDN 2023 年 实力新星&#xff0c;土木转码&#xff0c;现任部门技术 leader&#xff0c;专注于互联网技术分享&#xff0c;职场经验分享。 &#x1f525;热门文章推荐&#xf…

react-router-dom5升级到6

前言 升级前版本为5.1.2 下载与运行 下载 npm install react-router-dom6运行 运行发现报错: 将node_modules删除&#xff0c;重新执行npm i即可 运行发现如下报错 这是因为之前有引用react-router-dom.min&#xff0c;v6中取消了该文件&#xff0c;所以未找到文件导致报错。…