Transformer实战-系列教程10:SwinTransformer 源码解读3(SwinTransformerBlock类)

🚩🚩🚩Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)

5、SwinTransformerBlock类

class SwinTransformerBlock(nn.Module):def extra_repr(self) -> str:return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

5.1 构造函数

SwinTransformerBlock 是 Swin Transformer 模型中的一个基本构建块。它结合了自注意力机制和多层感知机(MLP),并通过窗口划分和可选的窗口位移来实现局部注意力

def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,act_layer=nn.GELU, norm_layer=nn.LayerNorm):super().__init__()self.dim = dimself.input_resolution = input_resolutionself.num_heads = num_headsself.window_size = window_sizeself.shift_size = shift_sizeself.mlp_ratio = mlp_ratioif min(self.input_resolution) <= self.window_size:self.shift_size = 0self.window_size = min(self.input_resolution)assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"self.norm1 = norm_layer(dim)self.attn = WindowAttention(dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)if self.shift_size > 0:H, W = self.input_resolutionimg_mask = torch.zeros((1, H, W, 1))  # 1 H W 1h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))cnt = 0for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1mask_windows = window_partition(img_mask, self.window_size)mask_windows = mask_windows.view(-1, self.window_size * self.window_size)attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))else:attn_mask = Noneself.register_buffer("attn_mask", attn_mask)
  1. dim:输入特征的通道数。
  2. input_resolution:输入特征的分辨率(高度和宽度)
  3. num_heads:自注意力头的数量
  4. window_size:窗口大小,决定了注意力机制的局部范围
  5. shift_size:窗口位移的大小,用于实现错位窗口多头自注意力(SW-MSA)
  6. mlp_ratio:MLP隐层大小与输入通道数的比率
  7. qkv_bias:QKV的偏置
  8. qk_scale:QKV的缩放因子
  9. drop:丢弃率
  10. drop_path:分别控制QKV的偏差、缩放因子、丢弃率、注意力丢弃率和随机深度率
  11. norm_layer:激活层和标准化层,默认分别为 GELU 和 LayerNorm
  12. WindowAttention:窗口注意力模块
  13. Mlp:一个包含全连接层、激活函数、Dropout的模块
  14. img_mask :图像掩码,用于生成错位窗口自注意力
  15. h_slicesw_slices:水平和垂直方向上的切片,用于划分图像掩码
  16. cnt :计数器,标记不同的窗口
  17. mask_windows :图像掩码划分为窗口,并将每个窗口的掩码重塑为一维向量
  18. window_partition
  19. attn_mask :注意力掩码,用于在自注意力计算中排除窗口外的位置
  20. register_buffer:注意力掩码注册为一个模型的缓冲区

5.2 前向传播

def forward(self, x):H, W = self.input_resolutionB, L, C = x.shapeassert L == H * W, "input feature has wrong size"shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)if self.shift_size > 0:shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))else:shifted_x = xx_windows = window_partition(shifted_x, self.window_size)x_windows = x_windows.view(-1, self.window_size * self.window_size, C)attn_windows = self.attn(x_windows, mask=self.attn_mask)attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)shifted_x = window_reverse(attn_windows, self.window_size, H, W)if self.shift_size > 0:x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))else:x = shifted_xx = x.view(B, H * W, C)x = shortcut + self.drop_path(x)x = x + self.drop_path(self.mlp(self.norm2(x)))return x
  1. 原始输入: torch.Size([4, 3136, 96]),输入的是一个长度为3136的序列,每个向量的维度为96,在
    被多次调用的时候,维度也发生了变化原始输入: torch.Size([4, 784, 192])、torch.Size([4, 196, 384])、torch.Size([4, 49, 768])
  2. H,W=[ 56,56],输入分辨率中的高度和宽度
  3. B, L, C=[ 4,3136,96],当前输入的维度,批次大小、序列长度和向量的维度
  4. shortcut是用于残差
  5. norm1(x): torch.Size([4, 3136, 96]),经过一个层归一化,维度不变
  6. x.view(B, H, W, C): torch.Size([4, 56, 56, 96]),将序列重塑为(Batch_size,Height,Width,Channel)的形状
  7. shift_size用来判断是否做偏移,最开始的时候shift_size为0,window_partition用来做windows的划分,特征图为[56,56,96],默认窗口为77的大小,因为可以划分出88个窗口
  8. shifted_x: torch.Size([4, 56, 56, 96]),位移操作后的x,此处的偏移只用了torch的内置函数就可以完成,torch.roll()
  9. x_windows: torch.Size([256, 7, 7, 96]),使用窗口划分函数,划分256个窗口,每个窗口7*7个特征,每个特征96维向量
  10. x_windows: torch.Size([256, 49, 96]),将窗口重塑为一维向量,以便进行自注意力计算
  11. attn_windows: torch.Size([256, 7, 7, 96]),WindowAttention处理x,对每个窗口应用窗口注意力机制,考虑到可能的注意力掩码
  12. shifted_x: torch.Size([4, 56, 56, 96]),注意力操作后的窗口重塑回原始形状,并将它们合并回完整的特征图
  13. torch.Size([4, 56, 56, 96]),如果进行了循环位移,则执行逆向循环位移操作,以恢复原始特征图的位置
  14. torch.Size([4, 3136, 96]),特征图重塑回原始的[B, L, C]形状
  15. torch.Size([4, 3136, 96]),应用残差连接,并通过随机深度(如果设置了的话)
  16. torch.Size([4, 3136, 96]),应用第二个标准化层,然后是MLP,并再次应用随机深度,完成残差连接的最后一步。

5.3 窗口划分函数window_partition()

def window_partition(x, window_size):B, H, W, C = x.shapex = x.view(B, H // window_size, window_size, W // window_size, window_size, C)windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windows

在第1个stage中,不同的stage中会做下采样操作

  1. B, H, W, C =torch.Size([4, 56, 56, 96]),这是函数的原始输入,56*56等于原来的3136
  2. x = torch.Size([4, 8, 7, 8, 7, 96])
  3. windows = torch.Size([256, 7, 7, 96]),256的计算方法为原始数据为5656,要分出77大小为一个窗口,所以窗口数量为(56/7)(56/7)=88=64,而64需要乘上batch=4,所以就是256,向量维度96不变

而Windows Attention就是在这7*7里面的窗口进行计算

SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/678721.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PLC在物联网中位置—承上启下,与上位机下位机的关联。

谈到物联网&#xff0c;就绕不开PLC&#xff0c;本文着重介绍PLC的定义、与单片机的区分&#xff0c;价值、物联网中的位置&#xff0c;以及和上位机、下位机的关联&#xff0c;让友友们对PLC有个全面的认知。 一、什么是PLC PLC是可编程逻辑控制器&#xff08;Programmable L…

UI自动刷新大法:DataBinding数据绑定

之前我们讲了DataBinding在Activity、Fragment、RecyclerView中的基础使用&#xff0c;而那些常规使用方法里&#xff0c;每当绑定的变量发生数据变化时&#xff0c;都需要ViewDataBinding重新设值才会刷新对应UI。而DataBinding通过内部实现的观察者模式来进行自动刷新UI&…

go消息队列RabbitMQ - 订阅模式-direct

1.发布订阅 在Fanout模式中&#xff0c;一条消息&#xff0c;会被所有订阅的队列都消费。但是&#xff0c;在某些场景下&#xff0c;我们希望不同的消息被不同的队列消费。这时就要用到Direct类型的Exchange。 在Direct模型下&#xff1a; 队列与交换机的绑定&#xff0c;不能…

第 384 场 LeetCode 周赛题解

A 修改矩阵 模拟 class Solution { public:vector<vector<int>> modifiedMatrix(vector<vector<int>> &matrix) {int m matrix.size(), n matrix[0].size();vector<int> mx(n, INT32_MIN);for (int i 0; i < m; i)for (int j 0; j &l…

Java微服务学习Day1

文章目录 认识微服务服务拆分及远程调用服务拆分服务远程调用提供者与消费者 Eureka注册中心介绍构建EurekaServer注册user-serviceorder-service完成服务拉取 Ribbon负载均衡介绍原理策略饥饿加载 Nacos注册中心介绍配置分级存储负载均衡环境隔离nacos注册中心原理 认识微服务…

ChatGPT学习大纲

引言 在2023年2月份左右开始使用ChatGPT时&#xff0c;就被它强大的理解能力和应答效果所折服&#xff0c;这期间一直在断断续续的学习和使用&#xff0c;也没形成一个完整的学习过程&#xff0c;最近刚好有空&#xff0c;就寻思着好好再学习总结一下&#xff0c;故写出了ChatG…

Python : 使用python实现学生管理系统的功能,详细注释

一、学生管理系统 学生描述&#xff1a;姓名、年龄、成绩 学生管理系统功能&#xff1a;添加学生信息、删除学生信息、根据姓名修改学生信息、根据姓名查询学生信息、显示所有学生信息、退出系统 二、代码说明 1. 将每一个学生的信息放一个元组中&#xff0c;再把元组添加到列表…

java中使用Lambda表达式实现参数化方法

Lambda表达式实现参数化方法说明 Lambda表达式在Java中是一种简洁、函数式的方式来表示匿名函数。它们特别适用于那些需要一个函数作为参数的方法&#xff0c;即函数式接口。参数化方法&#xff08;通常指的是泛型方法&#xff09;是那些可以接受类型参数的方法&#xff0c;这…

2.3 Verilog 数据类型

Verilog 最常用的 2 种数据类型就是线网&#xff08;wire&#xff09;与寄存器&#xff08;reg&#xff09;&#xff0c;其余类型可以理解为这两种数据类型的扩展或辅助。 线网&#xff08;wire&#xff09; wire 类型表示硬件单元之间的物理连线&#xff0c;由其连接的器件输…

单片机基础入门:简单介绍51单片机的工作原理

在电子技术领域&#xff0c;单片机是实现智能化控制不可或缺的关键元件。它们集成了许多功能于一身&#xff0c;成为了各种电子系统的心脏。为了更好地理解单片机如何工作&#xff0c;本文将重点介绍51单片机的基本组成和工作原理。 51单片机是一种广泛使用的微控制器&#xf…

Android 车载应用之快速入门

一、Android Automotive OS 概览 车载 Android 系统也被称为 Android Automotive OS,是对原始 Android 系统的一个功能扩充版本。与手机系统一样,Android Automotive OS 源代码完全开放,第三方供应商和汽车制造商可以官方源码的基础上自行开发和拓展,无论是编程语言还是各…

CI/CD到底是啥?持续集成/持续部署概念解释

前言 大家好&#xff0c;我是chowley&#xff0c;日常工作中&#xff0c;我每天都在接触CI/CD&#xff0c;今天就给出我心中的答案。 在现代软件开发中&#xff0c;持续集成&#xff08;Continuous Integration&#xff0c;CI&#xff09;和持续部署&#xff08;Continuous D…

【UE 求职】学了虚幻引擎可以应聘哪些岗位?

目录 1 领域1.1 游戏开发领域1.2 影视和动画制作1.3 建筑和工程可视化1.4 模拟和训练1.5 其他领域 2 如何做好一份简历1. 明确简历目标2. 突出UE5相关技能3. 展示相关项目经验4. 教育背景5. 专业经验6. 软技能7. 证书和奖项8. 定制化和校对 &#x1f64b;‍♂️ 作者&#xff1…

Day41- 动态规划part09

一、打家劫舍 题目一&#xff1a;198. 打家劫舍 198. 打家劫舍 你是一个专业的小偷&#xff0c;计划偷窃沿街的房屋。每间房内都藏有一定的现金&#xff0c;影响你偷窃的唯一制约因素就是相邻的房屋装有相互连通的防盗系统&#xff0c;如果两间相邻的房屋在同一晚上被小偷闯…

使用PHPStudy搭建本地web网站并实现任意浏览器公网访问

文章目录 [toc]使用工具1. 本地搭建web网站1.1 下载phpstudy后解压并安装1.2 打开默认站点&#xff0c;测试1.3 下载静态演示站点1.4 打开站点根目录1.5 复制演示站点到站网根目录1.6 在浏览器中&#xff0c;查看演示效果。 2. 将本地web网站发布到公网2.1 安装cpolar内网穿透2…

springcloud分布式架构网上商城源码和论文

首先,论文一开始便是清楚的论述了系统的研究内容。其次,剖析系统需求分析,弄明白“做什么”,分析包括业务分析和业务流程的分析以及用例分析,更进一步明确系统的需求。然后在明白了系统的需求基础上需要进一步地设计系统,主要包罗软件架构模式、整体功能模块、数据库设计。本项…

Unity Meta Quest MR 开发(四):使用 Scene API 和 Depth API 实现深度识别和环境遮挡

文章目录 &#x1f4d5;教程说明&#x1f4d5;Scene API 实现遮挡&#x1f4d5;Scene API 实现遮挡的缺点&#x1f4d5;Depth API 实现遮挡⭐导入 Depth API⭐修改环境配置⭐添加 EnvironmentDepthOcclusion 预制体⭐给物体替换遮挡 Shader⭐取消现实手部的遮挡效果 此教程相关…

从Unity到Three.js(模型文件加载)

模型加载功能探索&#xff0c;用blender导出了个glb格式的cube进行的测试。 初接触js语法&#xff0c;回调注册的地方直接使用匿名函数总感觉脑子跟不上&#xff0c;反应不过来&#xff0c;就把加载后的回调简单封装了下&#xff0c; 官方文档是直接使用的匿名函数。 另外看官方…

带你了解软件系统架构的演变详解

在这个数字时代&#xff0c;我们身边无处不在的软件系统扮演着无比重要的角色。你曾想过背后那复杂的系统是如何演变而来的吗&#xff1f;本文将深入浅出&#xff0c;以小白的视角&#xff0c;描绘软件系统架构的绚丽蜕变历程&#xff0c;让我们一同踏上这场感性而技术的冒险之…

Peter算法小课堂—背包问题

我们已经学过好久好久的动态规划了&#xff0c;动态规划_Peter Pan was right的博客-CSDN博客 那么&#xff0c;我用一张图片来概括一下背包问题。 大家有可能比较疑惑&#xff0c;优化决策怎么优化呢&#xff1f;答案是&#xff0c;滚动数组&#xff0c;一个神秘而简单的东西…