Transformer之Swin-Transformer结构解读

写在最前面之如何只用nn.Linear实现nn.Conv2d的功能

很多人说,Swin-Transformer就是另一种Convolution,但是解释得真就是一坨shit,这里我郑重解释一下,这是为什么?
首先,Convolution是什么?

Convolution是一种矩形区域内参数共享的Linear
在这里插入图片描述

这么说可能不好理解,那么我们上代码

import torch
import torch.nn as nn
import torch.nn.functional as Fclass Conv2D(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride):"""为了简单且便于理解,我们设定图片的Size是Kernel_size的整数倍,且Kernel_size等于Stride"""super(LinearConv2d, self).__init__()self.in_channels = in_channelsself.out_channels = out_channelsself.kernel_size = kernel_sizeself.stride = stride# 计算权重矩阵的维度weight_size = in_channels * kernel_size * kernel_sizeself.linear = nn.Linear(weight_size, out_channels, bias=False)def forward(self, x):# 计算输出特征图的尺寸B, C, H, W = x.size()output_height = H // self.strideoutput_width = W // self.stride# 展开输入特征,沿着kernel_size的窗口展开x_flatten = x.view(B, H // self.kernel_size, self.kernel_size, W // self.kernel_size, self.kernel_size, C)x_flatten = x_flatten.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.kernel_size, self.kernel_size, C)# 应用线性变换output_flatten = self.linear(x_flatten)# 重塑输出形状output = output_flatten.view(B, self.out_channels, output_height, output_width)return output# 使用nn.Linear实现nn.Conv2d(256, 256, k=7, s=7)
conv2d_manual = Conv2D(256, 256, 7, 7)# 创建一个随机初始化的输入张量,确保尺寸是7的整数倍
input_tensor = torch.randn(1, 256, 56, 56)  # 假设输入图像大小为56x56,56是7的倍数# 应用卷积操作
output = conv2d_manual(input_tensor)
# 输出形状应为[1, 256, 8, 8]
print(output.shape)  

上述代码通过了使用输入数据的维度变换,实现了利用nn.Linear来进行nn.Conv2d的过程,当然,nn.Conv1d甚至nn.Conv3d等也是同样操作。这里我们先记住,后面我们详细解释

Swin-Transformer为什么这么叫

首先,需要理解为什么叫Swin!
作者依然使用了Vision Transformer的主题架构,核心区别是对数据处理的区别!
在Vision Transformer中,数据根据spatial维度进行拉伸,并成为[Batch, HW, C]的样子,如图所示,具体参考Transformer之Vision Transformer结构解读
在这里插入图片描述而在Swin-Transformer中,额外增加了一步,就是把维度为 [ B a t c h , H × W , C ] [Batch, H\times W, C] [Batch,H×W,C]的patch_embedding,进行二次分割,变成 [ B a t c h × n u m _ w i n d o w 2 , w i n d o w _ s i z e , w i n d o w _ s i z e , C ] [Batch \times num\_window^2, window\_size, window\_size, C] [Batch×num_window2,window_size,window_size,C],如图所示,

  • 第一张图片就是经过patch_embed的patch_embedding
  • 第二张图片就是经过window_partrition分割后的图片
  • 第三张图片就是处理成 [ B a t c h × n u m _ w i n d o w 2 , w i n d o w _ s i z e , w i n d o w _ s i z e , C ] [Batch \times num\_window^2, window\_size, window\_size, C] [Batch×num_window2,window_size,window_size,C]的图片
    在这里插入图片描述这里还有一个操作,就是在第偶数个Attention-Block中,把输入的patch_embedding进行torch.roll操作,这个操作就是循环位移
    在这里插入图片描述
    这时候就可以解释为什么说Swin-Transformer就是另一种形式的CNN
    从上面的图片中可以看到如下过程:
  • 一张图片,经过nn.Conv2d(k=patch_size, stride=patch_size),将其分割成 N 2 N^2 N2个patch_embedding
  • patch_embedding经过维度重整,从 [ B , H × W , C ] [B, H\times W, C] [B,H×W,C]变成 [ B a t c h × n u m _ w i n d o w 2 , w i n d o w _ s i z e , w i n d o w _ s i z e , C ] [Batch \times num\_window^2, window\_size, window\_size, C] [Batch×num_window2,window_size,window_size,C],然后送入nn.Linear()。这里的维度重整加上nn.Linear(),等于nn.Conv2d,可以通过写在最前面的"如何只用nn.Linear()实现nn.Conv2d的功能"看出
  • 上一步可以总结为:经过nn.Conv2d的patch_embedding继续经过若干nn.Conv2d

Swin-Transformer的位置编码

绝对位置编码

详情参考Transformer之位置编码的通俗理解
在patch_embedding过程中,依然将Token和PE相加,如上图二所示。
但是既然有了相对位置编码,为什么还要加上绝对位置编码呢?

  • 数学解释如下:

Q E + P E × K E + P E T = X E + P E × W q × [ X E + P E × W k ] T = X E + P E × W q × W k T × X E + P E T = ( X q + P E q ) × W q × W k T × ( X k + P E k ) T = X q × W q ⏞ Q u e r y × W k T × X k T ⏞ K e y ⏟ 第一项 + P E q × W q ⏞ a × W k T × X k T ⏞ K e y ⏟ 第二项 + X q × W q ⏞ Q u e r y × W k T × P E k T ⏞ b ⏟ 第三项 + P E q × W q ⏞ a × W k T × P E k T ⏞ b ⏟ 第四项 \begin{array}{ccl} Q_{E+PE} \times K_{E+PE}^T &= & X_{E + PE} \times W_q \times \Big[X_{E + PE} \times W_k \Big]^T \\ && \\ &= & X_{E + PE} \times W_q \times W_k^T \times X^T_{E + PE} \\ && \\ & = &(X_q+PE_q) \times W_q \times W_k^T \times (X_k+PE_k)^T \\ &&\\ &= &\underbrace{\overbrace{X_q \times W_q}^{Query} \times \overbrace{W_k^T \times X_k^T}^{Key}}_{第一项}+ \underbrace{ \overbrace{PE_q \times W_q}^{a} \times \overbrace{W_k^T \times X_k^T}^{Key}}_{第二项} + \underbrace{\overbrace{X_q \times W_q}^{Query} \times \overbrace{W_k^T \times PE^T_k}^{b}}_{第三项} + \underbrace{\overbrace{PE_q \times W_q}^{a} \times \overbrace{W_k^T \times PE^T_k}^{b}}_{第四项} \end{array} QE+PE×KE+PET====XE+PE×Wq×[XE+PE×Wk]TXE+PE×Wq×WkT×XE+PET(Xq+PEq)×Wq×WkT×(Xk+PEk)T第一项 Xq×Wq Query×WkT×XkT Key+第二项 PEq×Wq a×WkT×XkT Key+第三项 Xq×Wq Query×WkT×PEkT b+第四项 PEq×Wq a×WkT×PEkT b
绝对位置编码只能消去第三项和第四项中的d项,依然需要第二项中的a项,才能具有完整的偏置

  • 直觉解释如下
    如果只有相对位置编码,也就是相当于只有相对位置偏置,这个过程和只有绝对位置偏置的意义是相同的,所以只有同时具有相对位置编码和绝对位置编码,才能避免两者是等效的

相对位置编码

详情参考Transformer之位置编码的通俗理解
相对位置编码,实际上是Attention机制的偏置的位置编码:
A t t = s o f t m a x ( Q × K T D i m + r e l a t i v e _ p o s i t i o n _ b i a s ) × V Att = softmax\Big( \frac{Q \times K^T}{\sqrt{Dim}} + relative\_position\_bias\Big) \times V Att=softmax(Dim Q×KT+relative_position_bias)×V
在这里插入图片描述
这里受到CSDN图片尺寸的限制,只能发这种清晰度的,点击这里下载无损svg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/49233.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java网络编程TCP和UDP协议

一、TCP 1、服务器端 package udpTest;import javax.management.MBeanRegistrationException; import java.io.*; import java.net.ServerSocket; import java.net.Socket;public class TCPService {public static void main(String[] args) {try {byte[] bufnew byte[512];Se…

什么是离线语音识别芯片?与在线语音识别的区别

离线语音识别芯片是一种不需要联网和其他外部设备支持,‌上电即可使用的语音识别系统。‌它的应用场合相对单一,‌主要适用于智能家电、‌语音遥控器、‌智能玩具等,‌以及车载声控和一部分智能家居。‌离线语音识别芯片的特点包括小词汇量、…

Python文件写入读取,文件复制以及一维,二维,多维数据存储

基础解释 在 Python 中,文件操作的模式除了 w (只写)、 a (追加写)、 r (只读)外,还有以下几种常见模式:- r :可读可写。该文件必须已存在,写操…

生成 HTTPS 证书并配置到 Nginx 的完整步骤

步骤 1: 安装 acme.sh 如果你还没有安装 acme.sh,可以通过以下命令进行安装: curl https://get.acme.sh | sh步骤 2: 生成 HTTPS 证书 使用 acme.sh 生成 forum.selectious.fun 的证书。你可以使用 standalone 模式,这意味着 acme.sh 会在…

视觉SLAM--回环检测

文章目录 创建字典相似度计算增加字典规模 回环检测的意义:可以使 后端位姿图得到一个 全局一致估计。 视觉SLAM的主流做法: 基于外观的回环检测方法,仅 根据两幅图像的相似性确定回环检测关系。这种方法,摆脱了累计误差&…

分类损失函数 (一) torch.nn.CrossEntropyLoss()

1、交叉熵 是一种用于衡量两个概率分布之间的距离或相似性的度量方法。机器学习中,交叉熵常用于损失函数,用于评估模型的预测结果和实际标签的差异。公式: y:真是标签的概率分布,y:模型预测的概率分布 …

数据库中的内、外、左、右连接

常用的数据库连表形式: 内连接 :inner join 外连接 :outer join 左外连接 :left outer join 左连接 :left join 右外连接 right outer join 右连接: right join 全连接 full join 、union 一、内连接 内…

企业私有云的部署都有哪些方式?

如今常见的企业私有云的部署方式有自建私有云、托管私有云、虚拟私有云、混合云、容器化私有云、本地数据中心部署等。如今,企业私有云的部署呈多样化趋势,以用来满足各个企业的具体需求。以下是RAK部落小编为大家汇总的企业私有云常见的部署方式&#x…

LeetCode 58.最后一个单词的长度 C++

LeetCode 58.最后一个单词的长度 C 思路🤔: 先解决当最后字符为空格的情况,如果最后字符为空格下标就往后移动,直到不为空格才停止,然后用rfind查询空格找到的就是最后一个单词的起始位置,最后相减就是单词…

flowable执行监听器动态指定审批人在退回时产生的bug

场景: 退回产生的bug,有一个结点,本身是通过执行监听器判断上一个结点的审批人来得到这个结点的审批人。之前是通过直接的获取最新task来拿到,但是在退回场景下,最新task为退回结点,故产生错误。 解决&…

C++ 正则库与HTTP请求

正则表达式的概念和语法 用于描述和匹配字符串的工具,通过特定的语法规则,灵活的定义复杂字符串匹配条件 常用语法总结 基本字符匹配 a:匹配字符aabc:匹配字符串abc 元字符(特殊含义的字符) .:匹…

stable diffusion webui环境配置遇到的问题

环境配置步骤: conda创建一个python3.10的环境,起个名叫sdenv, 使用命令conda create -n denv python3.10进入创建好的环境在webui的路径下直接运行python launch.py会自动开始安装所需的包(可能需要梯子或者在系统配置中添加pip的国内源&am…

1Panel面板配置java运行环境及网站的详细操作教程

本篇文章主要讲解,通过1Panel面板实现java运行环境,部署网站并加载的详细教程。 日期:2024年7月21日 作者:任聪聪 独立博客:https://rccblogs.com/501.html 一、实际效果 二、详细操作 步骤一、给我的项目进行打包&am…

在jsPsych中使用Vue

jspsych 介绍 jsPsych是一个非常好用的心理学实验插件,可以用来构建心理学实验。具体的就不多介绍了,大家可以去看官网:https://www.jspsych.org/latest/ 但是大家在使用时就会发现,这个插件只能使用js绘制界面,或者…

陌陌聊天数据案例分析

目录 背景介绍和需求分析基于hive数仓实现需求开发根据聊天数据建库建表加载数据ETL数据清洗背景分析原始数据出现的问题ETL实现 需求指标统计思路需求开发 基于FineBI实现可视化报表配置流程构建可视化报表 总结 背景介绍和需求分析 陌陌是一个聊天平台,每天都会产…

不能包含中文的正则表达式

原文 1、不包含汉字[^\u4e00-\u9fa5] var r /^[^\u4e00-\u9fa5]$/ if(r.test(str)){} 2、只能包含汉字 [\u4e00-\u9fa5]

STM32自己从零开始实操10:PCB全过程

一、PCB总体分布 分布主要参考有: 方便供电布线。方便布信号线。方便接口。人体工学。 以下只能让大家看到各个模块大致分布在板子的哪一块,只能说每个人画都有自己的理由,我的理由如下。 还有很多没有表达出来的东西,我也不知…

二叉树---二叉搜索树中的众数

题目: 给你一个含重复值的二叉搜索树(BST)的根节点 root ,找出并返回 BST 中的所有 众数(即,出现频率最高的元素)。 如果树中有不止一个众数,可以按 任意顺序 返回。 假定 BST 满…

PingCAP 王琦智:下一代 RAG,tidb.ai 使用知识图谱增强 RAG 能力

导读 随着 ChatGPT 的流行,LLMs(大语言模型)再次进入人们的视野。然而,在处理特定领域查询时,大模型生成的内容往往存在信息滞后和准确性不足的问题。如何让 RAG 和向量搜索技术在实际应用中更好地满足企业需求&#…

昇思25天学习打卡营第14天|计算机视觉

昇思25天学习打卡营第14天 文章目录 昇思25天学习打卡营第14天FCN图像语义分割语义分割模型简介网络特点数据处理数据预处理数据加载训练集可视化 网络构建网络流程 训练准备导入VGG-16部分预训练权重损失函数自定义评价指标 Metrics 模型训练模型评估模型推理总结引用 打卡记录…