Transformer之Swin-Transformer结构解读

写在最前面之如何只用nn.Linear实现nn.Conv2d的功能

很多人说,Swin-Transformer就是另一种Convolution,但是解释得真就是一坨shit,这里我郑重解释一下,这是为什么?
首先,Convolution是什么?

Convolution是一种矩形区域内参数共享的Linear
在这里插入图片描述

这么说可能不好理解,那么我们上代码

import torch
import torch.nn as nn
import torch.nn.functional as Fclass Conv2D(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride):"""为了简单且便于理解,我们设定图片的Size是Kernel_size的整数倍,且Kernel_size等于Stride"""super(LinearConv2d, self).__init__()self.in_channels = in_channelsself.out_channels = out_channelsself.kernel_size = kernel_sizeself.stride = stride# 计算权重矩阵的维度weight_size = in_channels * kernel_size * kernel_sizeself.linear = nn.Linear(weight_size, out_channels, bias=False)def forward(self, x):# 计算输出特征图的尺寸B, C, H, W = x.size()output_height = H // self.strideoutput_width = W // self.stride# 展开输入特征,沿着kernel_size的窗口展开x_flatten = x.view(B, H // self.kernel_size, self.kernel_size, W // self.kernel_size, self.kernel_size, C)x_flatten = x_flatten.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.kernel_size, self.kernel_size, C)# 应用线性变换output_flatten = self.linear(x_flatten)# 重塑输出形状output = output_flatten.view(B, self.out_channels, output_height, output_width)return output# 使用nn.Linear实现nn.Conv2d(256, 256, k=7, s=7)
conv2d_manual = Conv2D(256, 256, 7, 7)# 创建一个随机初始化的输入张量,确保尺寸是7的整数倍
input_tensor = torch.randn(1, 256, 56, 56)  # 假设输入图像大小为56x56,56是7的倍数# 应用卷积操作
output = conv2d_manual(input_tensor)
# 输出形状应为[1, 256, 8, 8]
print(output.shape)  

上述代码通过了使用输入数据的维度变换,实现了利用nn.Linear来进行nn.Conv2d的过程,当然,nn.Conv1d甚至nn.Conv3d等也是同样操作。这里我们先记住,后面我们详细解释

Swin-Transformer为什么这么叫

首先,需要理解为什么叫Swin!
作者依然使用了Vision Transformer的主题架构,核心区别是对数据处理的区别!
在Vision Transformer中,数据根据spatial维度进行拉伸,并成为[Batch, HW, C]的样子,如图所示,具体参考Transformer之Vision Transformer结构解读
在这里插入图片描述而在Swin-Transformer中,额外增加了一步,就是把维度为 [ B a t c h , H × W , C ] [Batch, H\times W, C] [Batch,H×W,C]的patch_embedding,进行二次分割,变成 [ B a t c h × n u m _ w i n d o w 2 , w i n d o w _ s i z e , w i n d o w _ s i z e , C ] [Batch \times num\_window^2, window\_size, window\_size, C] [Batch×num_window2,window_size,window_size,C],如图所示,

  • 第一张图片就是经过patch_embed的patch_embedding
  • 第二张图片就是经过window_partrition分割后的图片
  • 第三张图片就是处理成 [ B a t c h × n u m _ w i n d o w 2 , w i n d o w _ s i z e , w i n d o w _ s i z e , C ] [Batch \times num\_window^2, window\_size, window\_size, C] [Batch×num_window2,window_size,window_size,C]的图片
    在这里插入图片描述这里还有一个操作,就是在第偶数个Attention-Block中,把输入的patch_embedding进行torch.roll操作,这个操作就是循环位移
    在这里插入图片描述
    这时候就可以解释为什么说Swin-Transformer就是另一种形式的CNN
    从上面的图片中可以看到如下过程:
  • 一张图片,经过nn.Conv2d(k=patch_size, stride=patch_size),将其分割成 N 2 N^2 N2个patch_embedding
  • patch_embedding经过维度重整,从 [ B , H × W , C ] [B, H\times W, C] [B,H×W,C]变成 [ B a t c h × n u m _ w i n d o w 2 , w i n d o w _ s i z e , w i n d o w _ s i z e , C ] [Batch \times num\_window^2, window\_size, window\_size, C] [Batch×num_window2,window_size,window_size,C],然后送入nn.Linear()。这里的维度重整加上nn.Linear(),等于nn.Conv2d,可以通过写在最前面的"如何只用nn.Linear()实现nn.Conv2d的功能"看出
  • 上一步可以总结为:经过nn.Conv2d的patch_embedding继续经过若干nn.Conv2d

Swin-Transformer的位置编码

绝对位置编码

详情参考Transformer之位置编码的通俗理解
在patch_embedding过程中,依然将Token和PE相加,如上图二所示。
但是既然有了相对位置编码,为什么还要加上绝对位置编码呢?

  • 数学解释如下:

Q E + P E × K E + P E T = X E + P E × W q × [ X E + P E × W k ] T = X E + P E × W q × W k T × X E + P E T = ( X q + P E q ) × W q × W k T × ( X k + P E k ) T = X q × W q ⏞ Q u e r y × W k T × X k T ⏞ K e y ⏟ 第一项 + P E q × W q ⏞ a × W k T × X k T ⏞ K e y ⏟ 第二项 + X q × W q ⏞ Q u e r y × W k T × P E k T ⏞ b ⏟ 第三项 + P E q × W q ⏞ a × W k T × P E k T ⏞ b ⏟ 第四项 \begin{array}{ccl} Q_{E+PE} \times K_{E+PE}^T &= & X_{E + PE} \times W_q \times \Big[X_{E + PE} \times W_k \Big]^T \\ && \\ &= & X_{E + PE} \times W_q \times W_k^T \times X^T_{E + PE} \\ && \\ & = &(X_q+PE_q) \times W_q \times W_k^T \times (X_k+PE_k)^T \\ &&\\ &= &\underbrace{\overbrace{X_q \times W_q}^{Query} \times \overbrace{W_k^T \times X_k^T}^{Key}}_{第一项}+ \underbrace{ \overbrace{PE_q \times W_q}^{a} \times \overbrace{W_k^T \times X_k^T}^{Key}}_{第二项} + \underbrace{\overbrace{X_q \times W_q}^{Query} \times \overbrace{W_k^T \times PE^T_k}^{b}}_{第三项} + \underbrace{\overbrace{PE_q \times W_q}^{a} \times \overbrace{W_k^T \times PE^T_k}^{b}}_{第四项} \end{array} QE+PE×KE+PET====XE+PE×Wq×[XE+PE×Wk]TXE+PE×Wq×WkT×XE+PET(Xq+PEq)×Wq×WkT×(Xk+PEk)T第一项 Xq×Wq Query×WkT×XkT Key+第二项 PEq×Wq a×WkT×XkT Key+第三项 Xq×Wq Query×WkT×PEkT b+第四项 PEq×Wq a×WkT×PEkT b
绝对位置编码只能消去第三项和第四项中的d项,依然需要第二项中的a项,才能具有完整的偏置

  • 直觉解释如下
    如果只有相对位置编码,也就是相当于只有相对位置偏置,这个过程和只有绝对位置偏置的意义是相同的,所以只有同时具有相对位置编码和绝对位置编码,才能避免两者是等效的

相对位置编码

详情参考Transformer之位置编码的通俗理解
相对位置编码,实际上是Attention机制的偏置的位置编码:
A t t = s o f t m a x ( Q × K T D i m + r e l a t i v e _ p o s i t i o n _ b i a s ) × V Att = softmax\Big( \frac{Q \times K^T}{\sqrt{Dim}} + relative\_position\_bias\Big) \times V Att=softmax(Dim Q×KT+relative_position_bias)×V
在这里插入图片描述
这里受到CSDN图片尺寸的限制,只能发这种清晰度的,点击这里下载无损svg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/49233.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

什么是离线语音识别芯片?与在线语音识别的区别

离线语音识别芯片是一种不需要联网和其他外部设备支持,‌上电即可使用的语音识别系统。‌它的应用场合相对单一,‌主要适用于智能家电、‌语音遥控器、‌智能玩具等,‌以及车载声控和一部分智能家居。‌离线语音识别芯片的特点包括小词汇量、…

Python文件写入读取,文件复制以及一维,二维,多维数据存储

基础解释 在 Python 中,文件操作的模式除了 w (只写)、 a (追加写)、 r (只读)外,还有以下几种常见模式:- r :可读可写。该文件必须已存在,写操…

分类损失函数 (一) torch.nn.CrossEntropyLoss()

1、交叉熵 是一种用于衡量两个概率分布之间的距离或相似性的度量方法。机器学习中,交叉熵常用于损失函数,用于评估模型的预测结果和实际标签的差异。公式: y:真是标签的概率分布,y:模型预测的概率分布 …

数据库中的内、外、左、右连接

常用的数据库连表形式: 内连接 :inner join 外连接 :outer join 左外连接 :left outer join 左连接 :left join 右外连接 right outer join 右连接: right join 全连接 full join 、union 一、内连接 内…

企业私有云的部署都有哪些方式?

如今常见的企业私有云的部署方式有自建私有云、托管私有云、虚拟私有云、混合云、容器化私有云、本地数据中心部署等。如今,企业私有云的部署呈多样化趋势,以用来满足各个企业的具体需求。以下是RAK部落小编为大家汇总的企业私有云常见的部署方式&#x…

LeetCode 58.最后一个单词的长度 C++

LeetCode 58.最后一个单词的长度 C 思路🤔: 先解决当最后字符为空格的情况,如果最后字符为空格下标就往后移动,直到不为空格才停止,然后用rfind查询空格找到的就是最后一个单词的起始位置,最后相减就是单词…

C++ 正则库与HTTP请求

正则表达式的概念和语法 用于描述和匹配字符串的工具,通过特定的语法规则,灵活的定义复杂字符串匹配条件 常用语法总结 基本字符匹配 a:匹配字符aabc:匹配字符串abc 元字符(特殊含义的字符) .:匹…

1Panel面板配置java运行环境及网站的详细操作教程

本篇文章主要讲解,通过1Panel面板实现java运行环境,部署网站并加载的详细教程。 日期:2024年7月21日 作者:任聪聪 独立博客:https://rccblogs.com/501.html 一、实际效果 二、详细操作 步骤一、给我的项目进行打包&am…

在jsPsych中使用Vue

jspsych 介绍 jsPsych是一个非常好用的心理学实验插件,可以用来构建心理学实验。具体的就不多介绍了,大家可以去看官网:https://www.jspsych.org/latest/ 但是大家在使用时就会发现,这个插件只能使用js绘制界面,或者…

STM32自己从零开始实操10:PCB全过程

一、PCB总体分布 分布主要参考有: 方便供电布线。方便布信号线。方便接口。人体工学。 以下只能让大家看到各个模块大致分布在板子的哪一块,只能说每个人画都有自己的理由,我的理由如下。 还有很多没有表达出来的东西,我也不知…

PingCAP 王琦智:下一代 RAG,tidb.ai 使用知识图谱增强 RAG 能力

导读 随着 ChatGPT 的流行,LLMs(大语言模型)再次进入人们的视野。然而,在处理特定领域查询时,大模型生成的内容往往存在信息滞后和准确性不足的问题。如何让 RAG 和向量搜索技术在实际应用中更好地满足企业需求&#…

昇思25天学习打卡营第14天|计算机视觉

昇思25天学习打卡营第14天 文章目录 昇思25天学习打卡营第14天FCN图像语义分割语义分割模型简介网络特点数据处理数据预处理数据加载训练集可视化 网络构建网络流程 训练准备导入VGG-16部分预训练权重损失函数自定义评价指标 Metrics 模型训练模型评估模型推理总结引用 打卡记录…

FPGA开发在verilog中关于阻塞和非阻塞赋值的区别

一、概念 阻塞赋值:阻塞赋值的赋值号用“”表示,对应的是串行执行。 对应的电路结构往往与触发沿没有关系,只与输入电平的变化有关系。阻塞赋值的操作可以认为是只有一个步骤的操作,即计算赋值号右边的语句并更新赋值号左边的语句…

Transformer-Bert---散装知识点---mlm,nsp

本文记录的是笔者在了解了transformer结构后嗑bert中记录的一些散装知识点,有时间就会整理收录,希望最后能把transformer一个系列都完整的更新进去。 1.自监督学习 bert与原始的transformer不同,bert是使用大量无标签的数据进行预训…

规范:前后端接口规范

1、前言 随着互联网的高速发展,前端页面的展示、交互体验越来越灵活、炫丽,响应体验也要求越来越高,后端服务的高并发、高可用、高性能、高扩展等特性的要求也愈加苛刻,从而导致前后端研发各自专注于自己擅长的领域深耕细作。 然…

volatile,最轻量的同步机制

目录 一、volatile 二、如何使用? 三、volatile关键字能代替synchronized关键字吗? 四、总结: 还是老样子,先来看一段代码: 我们先由我们自己的常规思路分析一下代码:子线程中,一直循环&…

NoSQL之Redis非关系型数据库

目录 一、数据库类型 1)关系型数据库 2)非关系型数据库 二、Redis远程字典服务器 1)redis介绍 2)redis的优点 3)Redis 为什么那么快? 4)Redis使用场景 三、Redis安装部署 1&#xff0…

kail-linux如何使用NAT连接修改静态IP

1、Contos修改静态IP vi /etc/sysconfig/network-scripts/ifcfg-ens33, 标记红色处可能序号会变动 参考linux配置网络不通解决方案_kylinv10sp2 网关不通-CSDN博客https://tanrt06.blog.csdn.net/article/details/132430485?spm1001.2014.3001.5502 Kail时候NAT连…

从 NextJS SSRF 漏洞看 Host 头滥用所带来的危害

前言 本篇博文主要内容是通过代码审计以及场景复现一个 NextJS 的安全漏洞(CVE-2024-34351)来讲述滥用 Host 头的危害。 严正声明:本博文所讨论的技术仅用于研究学习,旨在增强读者的信息安全意识,提高信息安全防护技能…

浅谈断言之XML Schema断言

浅谈断言之XML Schema断言 “XML Schema断言”是一种专门用于验证基于XML的响应是否遵循特定XML Schema定义的标准和结构的断言类型。下面我们将详细探讨XML Schema断言的各个方面。 XML Schema断言简介 XML Schema断言(XML Schema Assertion)允许用户…