论文阅读 - Video Swin Transformer

文章目录

    • 1 概述
    • 2 模型介绍
      • 2.1 整体架构
        • 2.1.1 backbone
        • 2.1.2 head
      • 2.2 模块详述
        • 2.2.1 Patch Partition
        • 2.2.2 3D Patch Merging
        • 2.2.3 W-MSA
        • 2.2.4 SW-MSA
        • 2.2.5 Relative Position Bias
    • 3 模型效果
    • 参考资料

1 概述

Vision Transformer是transformer应用到图像领域的一个里程碑,它将CNN完全剔除,只使用了transformer来完成网络的搭建,并且在图像分类任务中取得了state-of-art的效果。

Swin Transformer则更进一步,引入了一些inductive biases,将CNN的结构和transformer结合在了一起,使得transformer在图像全领域都取得了state of art的效果。Swin Transformer中也有用到CNN,但是并不是把CNN当做CNN来用的,只是用CNN的模块来写代码比较方便。所以,也可以认为是完全没有使用CNN。

网上关于Swin Transformer的解读多的不得了,这里来说说Swin Transformer在视频领域的应用,也就是Video Swin Transformer。如果非常熟悉Swin Transformer的话,那这篇文章就非常容易读懂了,只是多了一个时间的维度,做attention和构建window的时候略有区别。本文的参考资料也大多是Swin Transformer的。

这篇文章会从视频的角度来解读Swin Transformer。

2 模型介绍

2.1 整体架构

2.1.1 backbone

Video Swin Transformer的backbone的整体架构和Swin Transformer大同小异,多了一个时间维度TTT,在做Patch Partition的时候会有个时间维度的patch size。
tiny video swin tranformer架构

图2-1 tiny video swin tranformer架构

以图2-1为例,输入为一个尺寸为T×H×W×3T \times H \times W \times 3T×H×W×3的视频,通常还会有个batch size,这里省略掉了。TTT一般设置为32,表示从视频的所有帧中采样得到32帧,采样的方法可以自行选择,不同任务可能会有不同的采样方法,一般为等间隔采样。这里其实也就天然限制了模型不能处理和训练数据时长相差太多的视频。通常视频分类任务的视频会在10s左右,太长的视频也很难分到一个类别里。

输入经过Patch Partition之后会变成一个T2×H4×W4×96\frac{T}{2} \times \frac{H}{4} \times \frac{W}{4} \times 962T×4H×4W×96的向量。这是因为patch size在这里为(2,4,4)(2,4,4)(2,4,4),分别是时间,高度和宽度三个维度的尺寸,其中969696是因为2×4×4×3=962 \times 4 \times 4 \times 3 = 962×4×4×3=96,也就是一个patch内的所有像素点的rgb三个通道的值。Patch Partition会在2.2中详述。

Patch Partiion之后会紧跟一个Linear Embedding,这两个模块在代码中是写在一起的,可以参见PatchEmbed3D,就是直接用一个3D的卷积,用这个卷积充当全连接。如果embedding的dim为96,那么经过embedding之后的尺寸还是2×4×4×3=962 \times 4 \times 4 \times 3 = 962×4×4×3=96

之后分别会经过多个video swin transformer block和patch merging。video swin transformer是利用attention同一个window内的特征进行特征融合的模块;patch merging则是用来改变特征的shape,可以当作CNN模型当中的pooling,不过规则不同,而且patch merging还会改变特征的dim,也就是CCC改变。整个过程模仿了CNN模块中的下采样过程,这也是为了让模型可以针对不同尺度生成特征。浅层可以看到小物体,深层则着重关注大物体。

video swin transformer block的结构如下图2-2所示。

video swin transformer block结构

图2-2 video swin transformer block结构

图2-2的左和右是两个不同的blocks,需要连在一起搭配使用。在图2-1中的video swin tranformer block下方有×2\times 2×2或是×6\times 6×6这样的符号,表示有几个blocks,这必定是个偶数,比如×2\times 2×2就表示图2-2这样1组blocks,×6\times 6×6就表示图2-2这样3组blocks相连。

不难看出,有两种blocks,每个block都是先过一个LN(LayerNorm),再过一个MSA(multi-head self-attention),再过一个LN,最后过一个MLP(multilayer perceptron),其中有两处使用了残差模块。残差块主要是为了缓解梯度弥散。

两种blocks的区别在于前者的MSA是window MSA,后者是shifted-window MSA。前者是为了window内的信息交流(局部),后者是为了window间的信息交流(全局)。这个会在2.2中进行详述。

2.1.2 head

backbone的作用是提取视频的特征,真正来做分类的还是接在backbone后面的head,这个部分就很简单了,就是一层全连接,代码中使用的是I3DHead。顺便还带了AdaptiveAvgPool3d,这是用来将输入变成适合全连接的shape的。这部分就不说了,没啥说的。

2.2 模块详述

2.2.1 Patch Partition

下图2-3是一段视频中的8帧,每帧都被分成了8×8=648 \times 8=648×8=64个网格,假设每个网格的像素为4×44 \times 44×4,那么当patch size为(1,4,4)(1, 4, 4)(1,4,4)时,每个小网格就是一个patch;当patch size为(2,4,4)(2,4,4)(2,4,4)时,每相邻两帧的同一个位置的网格组成一个patch。这里和vision tranformer中的划分方式相同,只不过多了时间的概念。
patch partition

图2-3 Patch Partition示意图

2.2.2 3D Patch Merging

3D Patch Merging这一块直接看代码会比较好理解,它和swin transformer中的2D patch merging一模一样,3D Patch Merging虽然多了时间维度,但是并没有对时间维度做merging的操作,也就是输出的时间维度不变。

x0 = x[:, :, 0::2, 0::2, :]  # B T H/2 W/2 C
x1 = x[:, :, 1::2, 0::2, :]  # B T H/2 W/2 C
x2 = x[:, :, 0::2, 1::2, :]  # B T H/2 W/2 C
x3 = x[:, :, 1::2, 1::2, :]  # B T H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1)  # B T H/2 W/2 4*C

看代码再结合图就更好理解了。图中每个颜色都是一个patch。
patch merging

图2-4 Patch Merging示意图

2.2.3 W-MSA

MSA(multihead self attention)的原理这里就不说了,不懂的可以参见搞懂Transformer,这里主要来说一说这个window。W-MSA(window based MSA)相比于MSA多了一个window的概念,相比于vision transformer引入window的目的是减小计算复杂度,使得复杂度和输入图片的尺寸成线性关系。这里不推导复杂度的计算,有兴趣的可以看Swin Transformer论文精读,这里有很详细的推导,3D和2D的复杂度计算方法是一致的。
w-msa

图2-5 W-MSA示意图

窗口的划分方式如图2-5所示,每个窗口的大小由window size决定。图2-5的window size为(4,4,4)(4,4,4)(4,4,4)就表示在时间,高度和宽度的window尺寸都是4个patch,划分后的结果如图2-5右半所示。之后的attention每个window单独做,window之间不互相干扰。

2.2.4 SW-MSA

由于W-MSA的attention是局部的,作者就提出了SW-MSA(shifted window based MSA)。

SW-MSA

图2-6 SW-MSA示意图

SW-MSA如图2-6所示,图中shift size为(2,2,2)(2,2,2)(2,2,2),一般shift size都是window size的一半,也就是(P2,M2,M2)(\frac{P}{2}, \frac{M}{2}, \frac{M}{2})(2P,2M,2M)。shift了之后,window会往前,往右,往下分别移动对应的size,目的是让patch可以和不同window的patch做特征的融合,这样多过几层之后,也就相当于做了全局的特征融合。

不过这里有一个问题,shift了之后,window的数量从原来的2×2×2=82 \times 2 \times 2=82×2×2=8变成了3×3×3=273 \times 3 \times 3=273×3×3=27。这带来的弊端就是计算时窗口不统一会比较麻烦。为了解决这个问题,作者引入了mask,并将窗口的位置进行了移动,使得进行shift和不进行shift的MSA计算方式相同,只不过mask不同。

shift window示意图

图2-7 shift window示意图

我用PPT画了一下shift的过程,画图能力有限,能看懂就好。我们的目的是把图2-6中最右侧的27个windows变成和图2-6中间那样的8个window。我给每个window都标了序号,标序号的方式是从前往后,从上往下,从左往右。shift window的方法就是把左上角的移到右下角,把前面的移到后面。这样一来,比如[27,25,21,19,9,7,3,1][27, 25, 21, 19, 9, 7, 3, 1][27,25,21,19,9,7,3,1]就组成了1个window,[18,16,12,10][18, 16, 12, 10][18,16,12,10]就组成了1个window,依此类推,一共有8个windows。平移的方式可以和上述的不同,只要保证可以把27个windows变成和8个windows的计算方式一样即可。

这样在每个window做self-attention的时候,需要加一层mask,可以说是引入了inductive bias。因为在组合而成的window内,各个小window我们不希望他们交换信息,因为这不是图像原有的位置,比如17和11经过shift之后,会在同一个window内做attention,但是11是从上面移下来的,只是为了计算的统一,并不是物理意义上的同一个window。有了mask就不一样了,mask的目的是告诉17号窗口内的每一个patch,只和17号窗口内的patches做attention,不和11号窗口内的做attention,依此类推其他

mask的生成方法可以参见源码,这里不细讲,主要思路是就像图2-7这样,给每个patch一个window的编号,编号相同的patch之间mask为0,否则为-100。

def compute_mask(D, H, W, window_size, shift_size, device):img_mask = torch.zeros((1, D, H, W, 1), device=device)  # 1 Dp Hp Wp 1cnt = 0for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0],None):for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1],None):for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2],None):img_mask[:, d, h, w, :] = cntcnt += 1mask_windows = window_partition(img_mask, window_size)  # nW, ws[0]*ws[1]*ws[2], 1mask_windows = mask_windows.squeeze(-1)  # nW, ws[0]*ws[1]*ws[2]attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))return attn_mask

如果window的大小为图2-6中的(P,M,M)(P, M, M)(P,M,M)的话,attention mask就是一个(P×M×M,P×M×M)(P \times M \times M,P \times M \times M)(P×M×MP×M×M)的矩阵,这是一个对称矩阵,第iii行第jjj列就表示window中的第iii个patch和第jjj个patch的window编号是否是相同的,相同则为0,不同则为-100。对角线上的元素必为0。

有人认为浅层的网络需要SW-MSA,深层的就不需要了,因为浅层已经讲全局的信息都交流了,深层不需要进一步交流了。这种说法的确有一定的道理,但也要看网络的深度和shift的尺寸。

2.2.5 Relative Position Bias

在上述的所有内容中,都没有涉及到位置的概念,也就是模型并不知道每个patch在图片中和其他patches的位置关系是怎么样的,最有也就是知道某几个patch是在同一个window内的,但window内的具体位置也是不知道的,因此就有了Relative Position Bias。它是加在attention的部分的,下式(2−1)(2-1)(21)中的BBB就是Relative Position Bias。

Attention(Q,K,V)=Softmax(QKT/d+B)V(2-1)Attention(Q,K,V) = Softmax(QK^T/\sqrt{d} + B)V \tag{2-1} Attention(Q,K,V)=Softmax(QKT/d+B)V(2-1)

很多swin tranformer的文章都会将这个BBB是如何得到的,但是却没有讲为什么要这样生成BBB。其实只要知道了设计这个BBB的目的,就可以不用管是如何生成的了,甚至自己设计一种生成的方法都行。

BBB是为了表示一个windows内,每个patch的相对位置,给每个相对位置一个特殊的embedding值。其实也正是因为这个BBB的存在,SW-MSA才必须要有mask,因为SW-MSA内的patches可能来自于多个windows,相对位置不能按照这个方法给,如果BBB可以表示全图的相对位置,那就不用这个mask了。

这个B和mask的shape是一致的,也是(P×M×M,P×M×M)(P \times M \times M,P \times M \times M)(P×M×MP×M×M)的矩阵,第iii行第jjj列就表示window中的第jjj个patch相对于第iii个patch的位置。

下图2-8是我画的一个示意图,即使是一个(2,2,2)(2, 2, 2)(2,2,2)的window,我也感到工作量太大,矩阵没填满,画了几个示意了一下。如果window size为(P,M,M)(P, M, M)(P,M,M)的话,那么相对位置状态就会有(2P−1)×(2M−1)×(2M−1)(2P-1) \times (2M-1) \times (2M-1)(2P1)×(2M1)×(2M1)种状态,我把(2,2,2)(2, 2, 2)(2,2,2)的window的27种相对位置状态全都在图2-8上写出来了。

Relative Position Bias

图2-8 Relative Position Bias示意图

有了状态之后,就只需要在BBB这个矩阵中将相对位置的状态对号入座即可。这就是很多其他博客写的相对位置坐标相减,然后加个偏置,再乘个系数的地方。理解了为什么要这么做,看那些操作也就不会觉得奇怪了。

但最终使用的不是状态,而是状态对应的embedding值,这就需要有一个table来根据状态查找embedding,这个embedding是模型训练出来的。

3 模型效果

作者在三个数据集上进行了测试,分别是kinetics-400,kinetics-600和something-something v2。每个数据集上都有着state-of-art的表现。

表3-1 kinetics-400模型对比指标

kinetics-400

表3-2 kinetics-600模型对比指标

kinetics-600

表3-3 something-something v2模型对比指标

something-something v2

参考资料

[1] Video Swin Transformer
[2] Swin-Transformer网络结构详解
[3] Swin Transformer论文精读
[4] Swin Transformer从零详细解读
[5] https://github.com/SwinTransformer/Video-Swin-Transformer

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/470524.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

rocketmq queue_RocketMQ 实战(三) - 消息的有序性

■ RocketMQ有序消息的使用1 为什么需要消息的有序性比如用户张三终于挣了一百存在在银行卡里存取款,对应两个异步的短信消息,肯定要保证先存后取吧,不然都没钱怎么发了取钱的消息呢! M1 - 存钱 M2 - 取钱而mq默认发消息到不同q显然是行不通的,会乱序 需要发往同一个q,先进先出…

三十、PHP框架Laravel学习笔记——模型的预加载

一.预加载 预加载,就是解决关联查询中产生的 N1 次查询带来的资源消耗我们要获取所有书籍的作者(或拥有者),普通查询方案如下: //获取所有书籍列表 $books Book::all(); //遍历每一本书 foreach ($books as $book) { //每一本…

论文阅读:Spatial Transformer Networks

文章目录1 概述2 模型说明2.1 Localisation Network2.2 Parameterised Sampling Grid3 模型效果参考资料1 概述 CNN的机理使得CNN在处理图像时可以做到transition invariant,却没法做到scaling invariant和rotation invariant。即使是现在火热的transformer搭建的图…

dataframe 排序_疯狂Spark之DataFrame创建方式详解一(九)

创建DataFrame的几种方式1、读取json格式的文件创建DataFrame注意:1. json文件中的json数据不能嵌套json格式数据。2. DataFrame是一个一个Row类型的RDD,df.rdd()/df.javaRdd()。3. 可以两种方式读取json格式的文件。4. df.show()默认显示前20行数据。5.…

【原】npm 常用命令详解

今年上半年在学习gulp的使用,对npm的掌握是必不可少的,经常到npm官网查询文档让我感到不爽,还不如整理了一些常用的命令到自己博客上,于是根据自己的理解简单翻译过来,终于有点输出,想学习npm这块的朋友不可…

论文阅读 - CRNN

文章目录1 概述2 模型介绍2.1 输入2.2 Feature extraction2.3 Sequence modeling2.4 Transcription2.4.1 训练部分2.4.2 预测部分3 模型效果参考资料1 概述 CRNN(Convolutional Recurrent Neural Network)是2015年华科的白翔老师团队提出的,直至今日,仍…

python easygui_Python里的easygui库

想要用python开发一些简单的图形界面,于是接触了easygui库,由于这是新手教程,我会把它写的尽量简单,希望大家都能看懂。1.msgboxmsgbox( )有一个标题,内容和一个ok键(是可以更改的)。举个例子:import easyg…

recv发送失败 缓冲区太小_从 GFS 失败的架构设计来看一致性的重要性

作者简介 陈东明,饿了么北京技术中心架构组负责人,负责饿了么的产品线架构设计以及饿了么基础架 构研发工作。曾任百度架构师,负责百度即时通讯产品的架构设计。具有丰富的大规模系统构 建和基础架构的研发经验,善于复杂业务需求下…

好用的记事本_分类记事本软件哪个好用?大家推荐一个苹果手机用的分类记事本便签呗...

随着“互联网”的发展,现在都开始在软件上记事备忘了。那么,都有哪些好用的记事本软件可以选择使用呢?大家在选择记事本软件的时候,都有哪些标准呢?不知道大家的标准是什么,小编有一个不能妥协的标准&#…

bootstrap table 分页_Java入门007~springboot+freemarker+bootstrap快速实现分页功能

本节是建立在上节的基础上,上一节给大家讲了管理后台表格如何展示数据,但是当我们的数据比较多的时候我们就需要做分页处理了。这一节给大家讲解如何实现表格数据的分页显示。准备工作1,项目要引入freemarker和bootstrap,如果不知…

1+X web中级 Laravel学习笔记——查询构造器简介及新增、更新、删除、查询数据

一、新增数据 插入多条数据: 二、更新数据 更新某条数据: 自增某字段的值: 自减某字段的值: 自增的同时改变其他字段的值: 三、删除数据 四、查询 查面构造器查面数据 有以下几种方法 get(&…

【HTML5】Canvas画布

什么是 Canvas? HTML5 的 canvas 元素使用 JavaScript 在网页上绘制图像。 画布是一个矩形区域,您可以控制其每一像素。 canvas 拥有多种绘制路径、矩形、圆形、字符以及添加图像的方法。 * 添加 canvas 元素。规定元素的 id、宽度和高度: &l…

SynthText流程解读 - 不看代码不知道的那些事

文章目录1 概述2 流程解读2.1 生成文字mask2.2 plane2xyz的bug2.3 文字上色2.4 图像融合参考资料1 概述 SynthText是OCR领域生成数据集非常经典,且至今看来无人超越的方法。整体可以分为三个大的步骤,分别是生成文字的mask,这里用到了图像的…

python if name main 的作用_Python中if __name__ == '__main__':的作用和原理

if __name__ __main__:的作用 一个python文件通常有两种使用方法,第一是作为脚本直接执行,第二是 import 到其他的 python 脚本中被调用(模块重用)执行。因此 if __name__ main: 的作用就是控制这两种情况执行代码的过程&#x…

1+X web中级 Laravel学习笔记——Eloquent ORM查询、更新、删除、新增

Eloquent ORM简介 larave1所自带的Eloquent oRM是一个非常优美简洁的ActiveRecord实现,用来实现数据库的操作他的每个数据的表都有对应的模型(model)用于数据表的交互模型的建立 一、Eloquent ORM的查询 二、Eloquent ORM新增 通过模型新增…

使用复合设计模式扩展持久化的CURD,Select能力

大家可能会经常遇到接口需要经常增加新的方法和实现,可是我们原则上是不建议平凡的增加修改删除接口方法,熟不知这样使用接口是不是正确的接口用法,比如我见到很多的项目分层都是IDAL,DAL,IBLL,BLL&#xf…