Transformer实战-系列教程11:SwinTransformer 源码解读4(WindowAttention类)

🚩🚩🚩Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)

6、WindowAttention类

6.1 构造函数

class WindowAttention(nn.Module):def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):super().__init__()self.dim = dimself.window_size = window_sizeself.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))coords_h = torch.arange(self.window_size[0])coords_w = torch.arange(self.window_size[1])coords = torch.stack(torch.meshgrid([coords_h, coords_w]))coords_flatten = torch.flatten(coords, 1)relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]relative_coords = relative_coords.permute(1, 2, 0).contiguous()relative_coords[:, :, 0] += self.window_size[0] - 1relative_coords[:, :, 1] += self.window_size[1] - 1relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1relative_position_index = relative_coords.sum(-1)self.register_buffer("relative_position_index", relative_position_index)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)trunc_normal_(self.relative_position_bias_table, std=.02)self.softmax = nn.Softmax(dim=-1)
  1. dim:输入特征维度
  2. window_size:窗口大小
  3. num_heads:多头注意力头数
  4. head_dim:每头注意力的头数
  5. scale :缩放因子
  6. relative_position_bias_table:相对位置偏置表,它对每个头存储不同窗口位置之间的偏置,以模拟位置信息
  7. coords_h 、coords_w、coords:窗口内每个位置的坐标
  8. coords_flatten :将坐标展平,为计算相对位置做准备
  9. 第1个relative_coords:计算窗口内每个位置相对于其他位置的坐标差
  10. 第2个relative_coords:重排坐标差的维度以符合预期的格式
  11. relative_coords[:, :, 0]、relative_coords[:, :, 1]、relative_coords[:, :, 0]:调整坐标差,使其能够映射到相对位置偏置表中的索引
  12. relative_position_index :计算每对位置之间的相对位置索引
  13. register_buffer:将相对位置索引注册为模型的缓冲区,这样它就不会在训练过程中被更新
  14. qkv :创建一个线性层,用于生成QKV
  15. attn_drop、proj、proj_drop:初始化注意力dropout、输出投影层及其dropout
  16. trunc_normal_:使用截断正态分布初始化相对位置偏置表
  17. softmax :初始化softmax层,用于计算注意力权重

6.2 前向传播

在Swin Transformer中,一共有4个stage,每次stage都分别包含W-MSA和SW-MSA,W-MSA和SW-MSA都是在执行WindowAttention模块,也就是说在一次batch的执行过程中WindowAttention的前向传播过程会执行8次。
每次stage执行后,都会进行下采样操作,这个下采样操作,长宽会减半特征图会增多。

    def forward(self, x, mask=None):B_, N, C = x.shapeqkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)q = q * self.scaleattn = (q @ k.transpose(-2, -1))relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()attn = attn + relative_position_bias.unsqueeze(0)if mask is not None:nW = mask.shape[0]attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)else:attn = self.softmax(attn)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B_, N, C)x = self.proj(x)x = self.proj_drop(x)return x
  1. B_, N, C = x.shape原始输入: torch.Size([256, 49, 96]),B_, N, C即原始输入的维度
  2. qkv = self.qkv(x).reshape...qkv: torch.Size([3, 256, 3, 49, 32]),被重塑的一个五维张量,分别代表qkv三个维度、256个窗口、3个注意力头数但是不会一直是3越往后会越多、49是一个窗口有7*7=49元素、每个头的特征维度。在之前的Transformer以及Vision Transformer中,都是用x接上各自的全连接后分别生成QKV,这这里直接一起生成了。
  3. q: torch.Size([256, 3, 49, 32]),k: torch.Size([256, 3, 49, 32]),v: torch.Size([256, 3, 49, 32]),从qkv中分解出q、k、v,而且已经包含了多头注意力机制
  4. attn: torch.Size([256, 3, 49, 49]),attn是q和k的点积
  5. relative_position_bias: torch.Size([49, 49, 3]),从相对位置偏置表中索引出每对位置之间的偏置,并重塑以匹配注意力分数的形状
  6. relative_position_bias: torch.Size([3, 49, 49]),重新排列,位置编码在Transformer中一直当成偏置加进去的,而这个位置编码是对一个窗口的,所以每一个窗口的都对应了相同的位置编码
  7. attn: torch.Size([256, 3, 49, 49]),将位置编码加到注意力分数上,到这里就算完了全部的注意力机制了
  8. attn: torch.Size([256, 3, 49, 49]),掩码加到注意力分数上,使用softmax函数归一化注意力分数,得到注意力权重,应用注意力dropout。
  9. x: torch.Size([256, 49, 96]),使用注意力权重对v向量进行重构,然后对结果进行转置和重塑
  10. x: torch.Size([256, 49, 96]),将加权的注意力输出通过一个线性投影层,应用输出dropout,这就是最后WindowAttention的输出,一共256个窗口,每个窗口有49个特征,每个特征对应96维的向量

SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/674456.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

php 函数三

一 对称加密 1.1 openssl 1.1.1 openssl_get_cipher_methods(bool $aliases false) 获取可用的加密算法。包含可用加密算法的array。 请注意:在 OpenSSL 1.1.1 版本之前,返回加密算法的拼法大小写都有; 从 OpenSSL 1.1.1 开始&#xff0c…

【机器学习】数据清洗之识别缺失点

🎈个人主页:甜美的江 🎉欢迎 👍点赞✍评论⭐收藏 🤗收录专栏:机器学习 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共同学习、交流进步…

Linux笔记之expect和bash脚本监听输出并在匹配到指定字符串时发送中断信号

Linux笔记之expect和bash脚本监听输出并在匹配到指定字符串时发送中断信号 code review! 文章目录 Linux笔记之expect和bash脚本监听输出并在匹配到指定字符串时发送中断信号1.expect2.bash 1.expect 在Expect脚本中,你可以使用expect来监听程序输出,…

Github 2024-02-08 开源项目日报 Top9

根据Github Trendings的统计,今日(2024-02-08统计)共有9个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Ruby项目1HTML项目1Python项目1Scala项目1PLpgSQL项目1Rust项目1NASL项目1C项目1TypeScript项目1非开发语言项目…

谷歌seo搜索引擎优化有什么思路?

正常做seo哪有那么多思路,其实就那么几种方法,无非就关键词,站内优化,外链,可以说万变不离其宗,但如果交给我们,你就可以实现其他的思路,或者说玩法 收录可以说是一个网站的基础&…

【Linux】vim的基本操作与配置(下)

Hello everybody!今天我们继续讲解vim的操作与配置,希望大家在看过这篇文章与上篇文章后都能够轻松上手vim! 1.补充 在上一篇文章中我们说过了,在底行模式下set nu可以显示行号。今天补充一条:set nonu可以取消行号。这两条命令大家看看就可…

10个常考的前端手写题,你全都会吗?(上)

前言 📫 大家好,我是南木元元,热爱技术和分享,欢迎大家交流,一起学习进步! 🍅 个人主页:南木元元 今天来分享一下10个常见的JavaScript手写功能。 目录 1.实现new 2.call、apply、…

[office] excel表格怎么绘制股票的CCI指标- #媒体#学习方法#笔记

excel表格怎么绘制股票的CCI指标? excel表格怎么绘制股票的CCI指标?excel表格中想要绘制一个股票cci指标,该怎么绘制呢?下面我们就来看看详细的教程,需要的朋友可以参考下 CCI指标是一种在股票,贵金属,货…

《统计学习方法:李航》笔记 从原理到实现(基于python)-- 第6章 逻辑斯谛回归与最大熵模型(2)6.2 最大熵模型

文章目录 6.2 最大熵模型6.2.1 最大熵原理6.2.3 最大熵模型的学习6.2.4 极大似然估计 《统计学习方法:李航》笔记 从原理到实现(基于python)-- 第3章 k邻近邻法 《统计学习方法:李航》笔记 从原理到实现(基于python&am…

Mysql报错:too many connections

1 问题原因 MySQL报错“too many connections”通常是由于数据库的最大连接数超过了MySQL配置的最大限制。有以下几个原因: (1)访问量过高:当MySQL服务器面对大量的并发请求时,已经建立的连接数可能会不足以处理所有的请求,从而导致连接池耗尽、连接被拒绝、出现“too …

VMware17上安装centos7.9成功后,进入linux命令行以后,运行没几分钟直接卡死,或者说非常卡

VMware17上安装centos7.9成功后,进入linux命令行以后,运行没几分钟直接卡死,或者说非常卡 解决方案:关闭windows的Hyper-V服务,重启虚拟机

Biu懂AI:Object Detection训练数据的Label格式

Bui~ 新系列博文将专注AI相关领域,想要学习高通蓝牙相关知识请查看之前的系列或关注大博主声波电波就看今朝 在CV(computer vision)中,Object detection是其中的一个核心任务,它可以在输入图像或视频中识别并框出目标。…

Rust 初体验2

变量类型 Rust 语言的变量数据类型,主要包括整型、浮点型、字符、布尔型、元组、数组、字符串、枚举、结构体和可变变量等。 fn main() { // 整型 let integer: i32 100; println!("整型: {}", integer); // 浮点型 let floating_point: f64 3.1…

15.2 Linux入门(❤❤❤❤)

15.2 Linux入门 1. Linux基础1.1 基础概念1. 操作系统2. Linux操作系统3. CentOS操作系统1.2 CentOS安装配置1. 运行要求2. 虚拟机与CentOS安装1.3 Linux目录结构1.4 Linux远程管理配置2. Linux高级操作2.1 命令:vim文本编辑器(❤❤)2.2 命令:常用文本工具(❤❤)1. echo命令<

【网页设计期末】茶文化网站

本文资源&#xff1a;https://download.csdn.net/download/weixin_47040861/88818886 1.题目要求 设计要求&#xff1a; &#xff08;1&#xff09;网站页面数量不少于4个&#xff0c;文件命名规范&#xff0c;网站结构要求层次清楚&#xff0c;目录结构清晰&#xff0c;代码…

使用ORM模型操作MySQL数据库:Python爬虫数据持久化实践

源码分享 https://docs.qq.com/sheet/DUHNQdlRUVUp5Vll2?tabBB08J2 在Python爬虫开发中&#xff0c;数据持久化是一个重要的步骤。通常&#xff0c;我们会将爬取的数据保存到数据库中。本篇博客将介绍如何使用对象关系映射&#xff08;ORM&#xff09;模型在Python中操作MySQ…

Redis的数据类型与示例演示

目录 一、KEY操作 1.1 相关命令 说明&#xff1a; 1.2示例演示 二、String类型 2.1 结构图 2.2 示例演示 三、List类型 3.1 结构图 3.2 相关命令 3.3 示例演示 四、SET类型 4.1 结构图 4.2 相关命令 4.3 示例演示 五、ZSET类型 5.1 结构图 5.2 相关命令 六、…

NumPy基础之花式索引

1 NumPy基础之花式索引 NumPy的花式索引(Fancy indexing)指ndarray数组使用整数数组进行索引。这的整数数组可以是python的列表等可迭代对象&#xff0c;也可以是NumPy数组。 花式索引&#xff0c;用整数数组的元素作为对应轴的索引&#xff0c;并且按数组元素顺序选取子集。…

负载均衡SLB

1. 什么是阿里云上的负载均衡SLB&#xff1f;它的主要功能是什么&#xff1f; 阿里云上的负载均衡SLB是一种流量分发服务&#xff0c;它的主要功能是扩展应用系统的吞吐能力和提升系统可用性。 负载均衡SLB&#xff08;Server Load Balancer&#xff09;在阿里云中是一个核心…

useEffect的4种使用情况

useeffect的用法是&#xff1a;useEffect就是指定一个副效应函数&#xff0c;组件每渲染一次&#xff0c;该函数就自动执行一次。组件首次在网页 DOM 加载后&#xff0c;副效应函数也会执行。 useEffect使用时有以下4种情况 1、不传递 useEffect不传递第二个参数会导致每次渲染…