【深度学习】深刻理解Swin Transformer

Swin Transformer 是一种基于 Transformer 的视觉模型,由 Microsoft 研究团队提出,旨在解决传统 Transformer 模型在计算机视觉任务中的高计算复杂度问题。其全称是 Shifted Window Transformer,通过引入分层架构和滑动窗口机制,Swin Transformer 在性能和效率之间取得了平衡,广泛应用于图像分类、目标检测、分割等视觉任务,称为新一代的backbone,可直接套用在各项下游任务中。在Swin Transformer中,提供大、中、小不同版本模型,可以进行自由选择合适的使用。

论文原文:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

1.  介绍

        Transformer 最初在自然语言处理(NLP)领域大获成功,但直接将 Transformer 应用于计算机视觉任务存在很多挑战。传统Transformer中,拿到了图像数据,将图片进行划分成一个个patch,尽可能patch细一些。但是图像中像素点太多了,如果需要更多的特征,就必须构建很长的序列。而越长的序列算起注意力肯定越慢,自注意力机制的计算复杂度是O(n^2),当处理高分辨率图像时,这种复杂度会快速增长,这就导致了效率问题。

        而且图像中许多视觉信息依赖于局部关系,而标准 Transformer 处理的是全局关系,可能无法有效捕获局部特征。Swin Transformer便采用窗口和分层的形式来替代长序列的方法,CNN中经常提到感受野,在Transformer中对应的就是分层。也就是说,我们可以将当前这件事做L次(Lx),每次都会两两进行合并,向量数越来越小(400个token-200个token-100个token),窗口的大小也会增大。分层操作也就是,第一层的时候token很多,第二层合并token,第三层合并token,就像我们的卷积和池化的操作。而在传统的Transformer中,第一层怎么做,第二层第三层也会采用同样的尺寸进行,都是一样的操作。

2. Swin Transformer 整体架构

2.1. Patch Embedding

        在 Swin Transformer 中,Patch Embedding 负责将输入图像分割成多个小块(patches),并将这些小块的像素值嵌入到一个高维空间中,形成适合 Transformer 处理的特征表示。在传统的卷积神经网络(CNN)中,卷积操作可以用来提取局部特征。在 Swin Transformer 中,为了将输入图像转化为适合 Transformer 模型处理的 patch 序列,首先对输入图像进行分块。假设输入图像的大小为 224x224x3,其通过一个卷积操作实现。卷积操作可以将每个局部区域的像素值映射为一个更高维的特征向量。假设输入图像大小为 224x224x3,应用一个卷积层,参数为 Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4)),这表示卷积核的大小是 4x4,步长是 4,输入的通道数是 3(RGB图像),输出的通道数是 96。卷积后,图像的空间维度会变小,输出的特征图的尺寸会变为 56x56(通过计算:(224 - 4) / 4 + 1 = 56)。所以,卷积后的输出大小是 56x56x96,这表示每个空间位置(56x56)都有一个96维的特征向量。

        在 Swin Transformer 中,通常将图像通过卷积操作分割成不重叠的小块(patches)。每个小块对应一个特征向量。例如,56x56x96 的输出可以视为有 3136 个 patch,每个 patch 是一个 96 维的向量。这些特征向量将作为 Transformer 模型的输入序列。根据不同的卷积参数(如 kernel_size 和 stride),你可以控制生成的 patch 的数量和每个 patch 的维度。例如,如果使用更小的卷积核和步长,可以得到更细粒度的 patch,反之则可以得到较大的 patch。

  • kernel_size 决定了每个 patch 的空间大小。
  • stride 决定了每个 patch 之间的间隔,即步长。

2.2. window_partition

        在 Swin Transformer 中,图像的特征表示不仅仅是通过 Patch Embedding 来获得,还通过 窗口划分(Window Partition) 来进一步细化和处理,通过窗口内的局部注意力机制来增强计算效率并捕捉局部特征。

        假设输入的图像经过卷积处理后得到了大小为 56x56x96 的特征图,将这个特征图划分为多个小窗口(window),每个窗口包含一部分局部信息,其中窗口大小为7x7,特征图大小为56x56。为了将特征图划分成大小为 7x7 的窗口,我们首先计算在空间维度(高和宽)上可以分成多少个窗口,水平和垂直方向上,每个 7x7 窗口可以覆盖 56 / 7 = 8 个窗口(总共 8x8 = 64 个窗口),窗口内部的特征图由 96 个通道组成。因此,在划分后,特征图的维度将变为 (64, 7, 7, 96),其中:

  • 64 表示窗口的数量(即 8x8 = 64 个窗口)。
  • 7x7 是每个窗口的空间维度。
  • 96 是每个窗口内的特征通道数。

        在 Swin Transformer 中,Token 通常指的是图像中的局部特征,每个 Token 是图像的一个小区域。在 Window Partition 过程中,我们将整个图像的 Token 重新组织成窗口(Window)。之前每个 Token 对应一个图像位置,现在每个 Token 对应一个窗口的内部特征。所以,原来每个 Token(如卷积后的每个空间位置)代表了图像的一部分信息,现在我们通过窗口划分来捕捉更大范围的局部信息。这种划分有助于模型专注于图像的局部结构,同时减少计算量,因为每个窗口只在局部范围内进行注意力计算。

2.3. W-MSA(Windwow multi-head self attention)

        在 Swin Transformer 中,W-MSA (Window Multi-Head Self Attention) 是关键的注意力机制,它通过在每个窗口内部独立地计算自注意力(Self-Attention)来减少计算复杂度,并捕捉局部特征。

        通过 Window Partition 将特征图划分为 64 个窗口,每个窗口的尺寸为 7x7,并且每个位置的特征通道数为 96,因此每个窗口的形状为 (7, 7, 96),这些窗口将作为 W-MSA 的输入。在 Multi-Head Self-Attention 中,首先需要将输入特征矩阵(窗口内的特征)通过三个不同的矩阵进行线性变换,得到 查询(Q)键(K)值(V),这三个矩阵用于计算注意力得分。对于每个头(Head),计算过程是独立的。假设有 3 个头,那么每个头的输入特征维度为 96 / 3 = 32,因为 96 维的输入被平均分成了 3 个头,每个头负责 32 维的特征。在 W-MSA 中,针对每个窗口独立计算自注意力得分,计算方法如下:

  • 对每个窗口中的 49 个像素点(即每个位置的特征向量)进行查询Q、键K、值V的计算。

  • 自注意力得分(Attention Score) 是通过计算查询与键的点积(或者其他相似度度量)得到的,这可以表示为:

    \text{Attention Score} = \frac{Q \cdot K^T}{\sqrt{d_k}}

    其中,d_k 是每个头的维度(在这里是 32),Q 和 K 的乘积衡量了每个位置之间的相似性。

  • Softmax:通过 Softmax 操作将得分归一化,使其成为概率分布,得到每个位置与其他位置的相关性。

  • 加权值(Weighted Sum):使用得分对值V进行加权求和,得到每个位置的最终输出表示。

        每个头的自注意力计算都会产生一个形状为 (64, 3, 49, 49) 的结果,其中,64 表示窗口的数量,3 表示头的数量,49 是每个窗口中位置的数量(7x7),49 代表每个位置对其他位置的注意力得分(自注意力矩阵)。因此,每个头会计算出每个窗口内所有位置之间的自注意力得分,输出的形状为 (64, 3, 49, 49)

2.4. window_reverse

   Window Reverse 操作的目的是将计算得到的 (64, 49, 96) 特征图恢复回原始的空间维度 (56, 56, 96)。为此,我们需要将每个窗口的 49 个位置(7x7)重新排列到原始的图像空间中。步骤:

  • Reshape 操作: 每个窗口的特征图形状是 (49, 96),我们将其转换成 (7, 7, 96) 的形状,表示每个窗口中的每个像素点都有一个 96 维的特征向量。

  • 按窗口拼接: 将所有 64 个窗口按照它们在特征图中的位置重新排列成 56x56 的大特征图。原始的输入特征图大小是 56x56,这意味着 64 个窗口将按照 8x8 的网格排列,并恢复到一个 (56, 56, 96) 的特征图。

        在 Window Reverse 操作后,恢复得到的特征图形状是 (56, 56, 96),这与卷积后的特征图的形状一致。56x56 是恢复后的空间维度,代表每个像素点在特征图中的位置;96 是每个像素点的特征维度,表示每个位置的特征信息。

2.5. SW-MASA

为什么要滑动窗口(Shifted Window)?

        原始的 Window MSA 将图像划分为固定的窗口(例如 7x7),并在每个窗口内计算自注意力。这样做的一个问题是每个窗口内部的信息相对封闭,没有与相邻窗口之间的信息交流。因此,模型容易局限于各自的小区域,无法充分捕捉不同窗口之间的关联。

        通过引入 滑动窗口Shifted Window)机制,窗口在原来位置的基础上向四个方向移动一部分,重叠区域与原窗口有交集。这样,原本相互独立的窗口就可以共享信息,增强了模型的表达能力和全局感知。

位移操作(Shift Operation)

        位移操作的细节如下:

  • 初始的窗口被划分为 4x4 的块(例如 7x7 窗口),每个块进行独立的自注意力计算。
  • 在进行位移时,原来 4x4 的窗口将被平移,变成新的大小为 9x9 的窗口,窗口重叠区域包含了不同窗口之间的信息。
  • 通过平移,模型能获取到更广泛的信息,使得窗口之间能够通过共享信息来融合彼此的特征,避免局部化。

        Shifted Window MSA 会导致计算量的增加,特别是在窗口滑动后,窗口数量从 4x4 变为 9x9,计算量几乎翻倍。为了控制计算量的增长,可以通过 mask 操作 来减少不必要的计算。在位移后,窗口之间会重叠。为了避免重复计算,我们可以使用 mask 来屏蔽掉不需要计算的部分。在计算注意力时,对于每个位置的 QK 的匹配,使用 softmax 时,设置不需要计算的位置的值为负无穷,这样对应位置的注意力值将接近零,不会对结果产生影响。

        在进行 SW-MSA 后,输出的特征图的形状仍然是 56x56x96,与输入特征图的大小一致。通过 shifted windowmask 操作,模型不仅保留了原始的窗口内的自注意力计算,还增强了窗口之间的信息交换和融合。即使窗口被移动了,经过计算后的特征也需要回到其原本的位置,也就是还原平移,保持图像的完整性。

2.6. PatchMerging

        PatchMergingSwin Transformer 中的一种下采样操作,但是不同于池化,这个相当于间接的(对H和W维度进行间隔采样后拼接在一起,得到H/2,W/2,C*4),目的是将输入特征图的空间维度(即高和宽)逐渐减小,同时增加通道数,从而在保持计算效率的同时获得更高层次的特征表示。它是下采样的过程,但与常规的池化操作不同,PatchMerging 通过将相邻的 patch 拼接在一起,并对拼接后的特征进行线性变换,从而实现下采样。具体来说,在 Swin Transformer 中,随着网络层数的加深,输入的特征图会逐渐减小其空间尺寸(即 H 和 W 维度),而同时增加其通道数(即 C 维度),以便模型可以捕捉到更为复杂的高层次信息。

        假设输入的特征图形状为 H x W x CPatchMerging 通过以下步骤来实现下采样和通道数扩展:

  • 分割和拼接(Splitting and Concatenation)

    • 输入的特征图会按照一定的步长(通常是 2)进行分割,即对每个 2x2 的 patch 进行合并。
    • 这样原本的 H x W 的空间尺寸会缩小一半,变成 H/2 x W/2
    • 然后,将每个 2x2 的 patch 内部的特征进行拼接,得到新的特征维度。假设原始通道数为 C,拼接后的通道数为 4C
  • 卷积操作

    • 对拼接后的特征进行 卷积,以进一步增强特征表达。卷积操作用于转换特征空间,虽然通道数增加了,但通过卷积,特征能够更加丰富。

2.7. 分层计算

        在 Swin Transformer 中,模型的每一层都会进行下采样操作,同时逐步增加通道数。每次 PatchMerging 后的特征图会作为输入进入下一层的 Attention 计算。通过这种方式,Swin Transformer 能够逐渐提取到越来越复杂的特征,同时保持计算效率。每一层的 PatchMerging 操作实际上是将输入的特征图通过 线性变换(通常是卷积)合并成更高维度的特征图,从而为后续的注意力计算提供更丰富的表示。

        从图中可以得到,通道数在每层中并不是从C变成4C而是2C,这是因为中间又加了一层卷积操作。

3. 实验结果

        在ImageNet22K数据集上,准确率能达到惊人的86.4%。另外在检测,分割等任务上表现也很优异。

参考资料:

【深度学习】详解 Swin Transformer (SwinT)-CSDN博客

深度学习之Swin Transformer学习篇(详细 - 附代码)_swintransformer训练-CSDN博客

图解Swin Transformer - 知乎

【论文精读】Swin Transformer - 知乎

ICCV2021最佳论文:Swin Transformer论文解读+源码复现,迪哥带你从零解读霸榜各大CV任务的Swin Transformer模型!_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/63152.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mysql客户端命令

目录 结束符 ; \g \G 中断输入 ctrl c 查看命令列表 help ? (\?) connect (\r) status (\s) delimiter (\d) exit (\q) quit (\q) tee (\T) ​编辑 notee (\t) prompt (\R) source (\.) system (\!) ​编辑 use (\u) help contents 结束符 ; \g \G 当我…

Tomcat原理(4)——尝试手动Servlet的实现

目录 一、什么是Servlet 1.servlet的定义 2.servlet的结构 二、实现servlet的流程图 三、具体实现代码 1、server 2.实体类request&response 3.HttpServlet抽象类 4.再定义三个servlet进行测试 Tomcat原理(3)——静&动态资源以及运行项…

D3 基础1

D3 D3.js (Data-Driven Documents) 是一个基于 JavaScript 的库&#xff0c;用于生成动态、交互式数据可视化。它通过操作文档对象模型 (DOM) 来生成数据驱动的图形。官方网站是 https://d3js.org/ <!DOCTYPE html> <html lang"en"><head><me…

基线检查:Windows安全基线.【手动 || 自动】

基线定义 基线通常指配置和管理系统的详细描述&#xff0c;或者说是最低的安全要求&#xff0c;它包括服务和应用程序设置、操作系统组件的配置、权限和权利分配、管理规则等。 基线检查内容 主要包括账号配置安全、口令配置安全、授权配置、日志配置、IP通信配置等方面内容&…

Python -- Linux中的Matplotlib图中无法显示中文 (中文为方框)

目的 用matplotlib生成的图中文无法正常显示 方法 主要原因: 没找到字体 进入windows系统的C:\Windows\Fonts目录, 复制自己想要的字体 粘贴到Linux服务器中对应python文件所处的文件夹内 设置字体: 设置好字体文件的路径在需要对字体设置的地方设置字体 效果 中文正常显…

快速理解类的加载过程

当程序主动使用某个类时&#xff0c;如果该类还未加载到内存中&#xff0c;则系统会通过如下三个步骤来对该类进行初始化&#xff1a; 1.加载&#xff1a;将class文件字节码内容加载到内存中&#xff0c;并将这些静态数据转换成方法区的运行时数据结构&#xff0c;然后生成一个…

宝塔-docker拉取宝塔镜像,并运行宝塔镜像

宝塔-拉取宝塔镜像&#xff0c;并运行镜像 第1步&#xff1a;查询 docker search btpanel/baota此docker镜像由堡塔安全官方发布&#xff0c;镜像版本为宝塔面板9.2.0正式版和9.0.0_lts 稳定版&#xff0c;镜像会随着宝塔面板更新。 目前支持x86_64和arm架构可供下载使用 版本…

穷举vs暴搜vs深搜vs回溯vs剪枝专题一>子集

题目&#xff1a; 两个方法本质就是决策树的画法不同 方法一解析&#xff1a; 代码&#xff1a; class Solution {private List<List<Integer>> ret;//返回结果private List<Integer> path;//记录路径&#xff0c;注意返回现场public List<List<Int…

leecode双指针部分题目

leecode双指针部分题目 1. 验证回文串2. 判断子序列3. 两数之和 II - 输入有序数组4. 盛最多水的容器5. 三数之和 1. 验证回文串 如果在将所有大写字符转换为小写字符、并移除所有非字母数字字符之后&#xff0c;短语正着读和反着读都一样。则可以认为该短语是一个 回文串 。 …

TCP协议简单分析和握手挥手过程

TCP介绍 TCP是可靠的传输层协议&#xff0c;建立连接之前会经历3次握手的阶段。 确认机制&#xff1a;接受方 收到数据之后会向 发送方 回复ACK重传机制&#xff1a;发送方 在一定时间内没有收到 接收方的ACK就会重新发送 握手目的&#xff1a;与端口建立连接 TCP的三次握手 …

opencv所有常见函数

一、opencv图像操作 二、opencv图像的数值运算 三、opencv图像的放射变换 四、opencv空间域图像滤波 五、图像灰度化与直方图 六、形态学图像处理 七、阈值处理与边缘检测 八、轮廓和模式匹配

【Excel】单元格分列

目录 分列&#xff08;新手友好&#xff09; 1. 选中需要分列的单元格后&#xff0c;选择 【数据】选项卡下的【分列】功能。 2. 按照分列向导提示选择适合的分列方式。 3. 分好就是这个样子 智能分列&#xff08;进阶&#xff09; 高级分列 Tips&#xff1a; 新手推荐基…

【STM32练习】基于STM32的PM2.5环境监测系统

一.项目背景 最近为了完成老师交付的任务&#xff0c;遂重制了一下小项目用STM32做一个小型的环境监测系统。 项目整体示意框图如下&#xff1a; 二.器件选择 单片机&#xff08;STM32F103&#xff09;数字温湿度模块&#xff08;DHT11&#xff09;液晶显示模块&#xff08;0.8…

ReactPress最佳实践—搭建导航网站实战

Github项目地址&#xff1a;https://github.com/fecommunity/easy-blog 欢迎Star。 近期&#xff0c;阮一峰在科技爱好者周刊第 325 期中推荐了一款开源工具——ReactPress&#xff0c;ReactPress一个基于 Next.js 的博客和 CMS 系统&#xff0c;可查看 demo站点。&#xff08;…

2024,大模型杀进“决赛圈”

Henry Chesbrough在著作《通过技术创新盈利势在必行》中&#xff0c;曾提出过一个创新的“漏斗模型”。开放式创新一开始鼓励百花齐放&#xff0c;但最终只有10%的技术能够通过这个漏斗&#xff0c;成功抵达目标市场target market&#xff0c;进入到商业化与产业化的下一个阶段…

STM8单片机学习笔记·GPIO的片上外设寄存器

目录 前言 IC基本定义 三极管基础知识 单片机引脚电路作用 STM8GPIO工作模式 GPIO外设寄存器 寄存器含义用法 CR1&#xff1a;Control Register 1 CR2&#xff1a;Control Register 2 ODR&#xff1a;Output Data Register IDR&#xff1a;Input Data Register 赋值…

【CSS in Depth 2 精译_081】 13.1:CSS 渐变效果(下)——CSS 径向渐变(13.1.3)+ CSS 锥形渐变(13.1.4)

当前内容所在位置&#xff08;可进入专栏查看其他译好的章节内容&#xff09; 第四部分 视觉增强技术 ✔️【第 13 章 渐变、阴影与混合模式】 ✔️ 13.1 渐变 ✔️ 13.1.1 使用多个颜色节点&#xff08;上&#xff09;13.1.2 颜色插值方法&#xff08;中&#xff09;13.1.3 径…

ubuntu 用 ss-tproxy的最终网络结构

1、包含了AD广告域名筛选 2、Ss-tproxy 国内国外地址分类 3、chinadns-ng解析 4、透明网关 更多细节看之前博客 ubuntu 用ss-TPROXY实现透明代理&#xff0c;基于TPROXY的透明TCP/UDP代理,在 Linux 2.6.28 后进入官方内核。ubuntu 用 ss-tproxy的内置 DNS 前挂上 AdGuardHome…

BUUCTF Pwn [HarekazeCTF2019]baby_rop2 题解

下载 得到两个文件 checksec 64位 拖入IDA64 查看main函数 看到给了个libc说明这题是ret2libc题 这里的打印函数是printf 所以利用printf函数的plt输出真实地址got 但printf的got好像不行 所以换成了read的got 因为这是64位程序 所以用寄存器传参&#xff1b;又因为printf得…

语音识别失败 chrome下获取浏览器录音功能,因为安全性问题,需要在localhost或127.0.0.1或https下才能获取权限

环境&#xff1a; Win10专业版 谷歌浏览器 版本 131.0.6778.140&#xff08;正式版本&#xff09; &#xff08;64 位&#xff09; 问题描述&#xff1a; 局域网web语音识别出现识别失败 chrome控制台出现下获取浏览器录音功能&#xff0c;因为安全性问题&#xff0c;需要在…