SwinTransformer的相对位置索引的原理以及源码分析

文章目录

  • 1. 理论分析
  • 2. 完整代码


引用:参考博客链接


1. 理论分析

根据论文中提供的公式可知是在 Q Q Q K K K进行匹配并除以 d \sqrt d d 后加上了相对位置偏执 B B B

A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q K T d + B ) V \begin{aligned} &Attention(Q,K,V) = Softmax(\frac{QK^T}{\sqrt d}+B)V \end{aligned} Attention(Q,K,V)=Softmax(d QKT+B)V

    如下图,假设输入的feature map高宽都为2,那么首先我们可以构建出每个像素的绝对位置(左下方的矩阵),对于每个像素的绝对位置是使用行号和列号表示的。比如蓝色的像素对应的是第0行第0列所以绝对位置索引是(0,0),接下来再看看相对位置索引。首先看下蓝色的像素,在蓝色像素使用q与所有像素k进行匹配过程中,是以蓝色像素为参考点。然后用蓝色像素的绝对位置索引与其他位置索引进行相减,就得到其他位置相对蓝色像素的相对位置索引。例如黄色像素的绝对位置索引是(0,1),则它相对蓝色像素的相对位置索引为 ( 0 , 0 ) − ( 0 , 1 ) = ( 0 , − 1 ) (0,0)−(0,1)=(0,−1) (0,0)(0,1)=(0,1) 。那么同理可以得到其他位置相对蓝色像素的相对位置索引矩阵。同样,也能得到相对黄色,红色以及绿色像素的相对位置索引矩阵。接下来将每个相对位置索引矩阵按行展平,并拼接在一起可以得到下面的4x4矩阵 。

在这里插入图片描述

对应源代码为:

import torch
import torch.nn as  nn
from timm.models.layers import trunc_normal_window_size = [2,2]
num_heads = 3# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 绝对位置索引coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2

    请注意,这里描述的一直是相对位置索引,并不是相对位置偏执参数。因为后面我们会根据相对位置索引去取对应的参数。比如说黄色像素是在蓝色像素的右边,所以相对蓝色像素的相对位置索引为(0,−1)。绿色像素是在红色像素的右边,所以相对红色像素的相对位置索引为(0,−1)。可以发现这两者的相对位置索引都是(0,−1),所以他们使用的相对位置偏执参数都是一样的。

    其实讲到这基本已经讲完了,但在源码中作者为了方便把二维索引给转成了一维索引。具体这么转的呢,有人肯定想到,简单啊直接把行、列索引相加不就变一维了吗?比如上面的相对位置索引中有(0,−1)和(−1,0)在二维的相对位置索引中明显是代表不同的位置,但如果简单相加都等于-1那不就出问题了吗?接下来我们看看源码中是怎么做的。首先在原始的相对位置索引上加上 ( M − 1 ) (M-1) (M1) ( M M M为窗口的大小,在本示例中 M M M=2),加上之后索引中就不会有负数了。
    
在这里插入图片描述

对应源代码为:

relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1

    
接着将所有的行标都乘上2M-1。

在这里插入图片描述

    
对应源代码为:

relative_coords[:, :, 0] *= 2 * window_size[1] - 1

最后将行标和列标进行相加。这样即保证了相对位置关系,而且不会出现上述0 + ( − 1 ) = ( − 1 ) + 0 0+(-1)=(-1)+00+(−1)=(−1)+0的问题了,是不是很神奇。

在这里插入图片描述

relative_position_index = relative_coords.sum(-1)

    刚刚上面也说了,之前计算的是相对位置索引,并不是相对位置偏执参数。真正使用到的可训练参数 B ^ \hat{B} B^ 是保存在relative position bias table表里的,这个表的长度是等于 ( 2 M − 1 ) × ( 2 M − 1 ) (2M−1)×(2M−1) (2M1)×(2M1)的。那么上述公式中的相对位置偏执参数 B B B是根据上面的相对位置索引表根据查relative position bias table表得到的,如下图所示。

在这里插入图片描述

对应源代码为:

'''
relative_position_bias_table:其shape=((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
'''
relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
trunc_normal_(relative_position_bias_table, std=.02)'''
index:虽然shape=(window_size[0] * window_size[1], window_size[0] * window_size[1]),但是只有(2 * window_size[0] - 1) * (2 * window_size[1] - 1)个不同的元素。作为索引,正好能一一对应relative_position_bias_table中的元素
'''
index = relative_position_index.view(-1)
relative_position_bias = relative_position_bias_table[index] # index的每一个不同的元素对应relative_position_bias_table中一个值relative_position_bias = relative_position_bias.view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

2. 完整代码

import torch
import torch.nn as  nn
from timm.models.layers import trunc_normal_window_size = [2,2]
num_heads = 3# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # # 绝对位置索引 # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords_temp = relative_coords.numpy()
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1relative_coords[:, :, 0] *= 2 * window_size[1] - 1relative_position_index = relative_coords.sum(-1)relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
trunc_normal_(relative_position_bias_table, std=.02)print('relative_position_bias_table:',relative_position_bias_table.shape)
print(relative_position_index.shape)index = relative_position_index.view(-1)
print('index:',index.shape)
relative_position_bias = relative_position_bias_table[index]
print('relative_position_bias_Noreshape:',relative_position_bias.shape)relative_position_bias = relative_position_bias.view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1)relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Wwprint('relative_position_bias:',relative_position_bias.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/40772.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

绝了,华为伸缩摄像头如何突破影像边界?

自华为Pura70 Ultra超聚光伸缩镜头诞生以来,备受大家的关注,听说这颗镜头打破了传统手机的摄像头体积与镜头的设计,为我们带来了不一样的拍照体验。 智能手机飞速发展的今天,影像功能已经成为我们衡量一款手机性能的重要指标。想…

MySQL中mycat与mha应用

目录 一.Mycat代理服务器 1.Mycat应用场景 2.mycat安装目录结构说明 3.Mycat的常用配置文件 4.Mycat日志 5.mycat 实现读写分离 二.MySQL高可用 1.原理过程 2.MHA软件 3.实现MHA 一.Mycat代理服务器 1.Mycat应用场景 Mycat适用的场景很丰富,以下是几个典型…

沪上繁花:上海电信的5G-A之跃

2024年6月18日下午,在上海举行的3GPP RAN第104次会议上,3GPP正式宣布R18标准冻结。R18是无线网络面向5G-A的第一个版本,其成功冻结正式宣布了5G发展迎来新机遇,5G-A商用已进入全新的发展阶段。 在5G-A滚滚而来的时代洪流中&#x…

C#实战|账号管理系统:通用登录窗体的实现。

哈喽,你好啊,我是雷工! 本节记录登录窗体的实现方法,比较有通用性,所有的项目登录窗体实现基本都是这个实现思路。 一通百通,以下为学习笔记。 01 登录窗体的逻辑 用户在登录窗输入账号和密码,如果输入账号和密码信息正确,点击【登录】按钮,则跳转显示主窗体,同时在固…

上海外贸建站公司wordpress模板推荐

Sora索啦高端制造业wordpress主题 红色高端制造业wordpress主题,适合外贸企业出海建独立站的wordpress模板。 https://www.jianzhanpress.com/?p5885 Yamal外贸独立站wordpress主题 绿色的亚马尔Yamal外贸独立站wordpress模板,适用于外贸公司建独立站…

Redis 中 Set 和 Zset 类型

目录 1.Set类型 1.1 Set集合 1.2 普通命令 1.3 集合操作 1.4 内部编码 1.5 使用场景 2.Zset类型 2.1 Zset有序集合 2.2 普通命令 2.3 集合间操作 2.4 内部编码 2.5 使用场景 1.Set类型 1.1 Set集合 集合类型也是保存多个字符串类型的元素,但是和列表类型不同的是&…

【Go】excelize库实现excel导入导出封装(四),导出时自定义某一列或多列的单元格样式

大家好,这里是符华~ 查看前三篇: 【Go】excelize库实现excel导入导出封装(一),自定义导出样式、隔行背景色、自适应行高、动态导出指定列、动态更改表头 【Go】excelize库实现excel导入导出封装(二&…

WY-35A4T三相电压继电器 导轨安装 约瑟JOSEF

功能简述 WY系列电压继电器是带延时功能的数字式交流电压继电器。 可用于发电机,变压器和输电线的继电保护装置中,作为过电压或欠电压闭锁的动作元件 LCD实时显示当前输入电压值 额定输入电压Un:100VAC、200VAC、400VAC产品满足电磁兼容四级标准 产品…

VBA初学:零件成本统计之一(任务汇总)

经过前期一年多对金蝶K3生产任务流程和操作的改造和优化,现在总算可以将零件加工各个环节的成本进行归集了。 原本想写存储过程,通过直接SQL报表做到K3中去的,但财务原本就是用EXCEL,可以方便调整和保存,加上还有一部分…

便携式气象站:探索自然的智慧伙伴

在探索自然奥秘、追求科学真理的道路上,气象数据始终是我们不可或缺的指引。然而,传统的气象站往往庞大而笨重,难以在偏远地区或移动环境中灵活部署。 便携式气象站,顾名思义,是一种小巧轻便、易于携带和安装的气象观测…

由于找不到xinput1 3.dll无法继续执行重新安装程序

如果您的计算机提示无法找到xinput1_3.dll文件,这可能表明您的计算机存在问题。在这种情况下,您需要立即对xinput1_3.dll文件进行修复,否则您的某些程序将无法启动。以下是解决无法找到xinput1_3.dll文件的方法。 一、关于xinput1_3.dll文件的…

Elasticsearch 实现 Word、PDF,TXT 文件的全文内容提取与检索

文章目录 一、安装软件:1.通过docker安装好Es、kibana安装kibana:2.安装原文检索与分词插件:之后我们可以通过doc命令查看下载的镜像以及运行的状态:二、创建管道pipeline名称为attachment二、创建索引映射:用于存放上传文件的信息三、SpringBoot整合对于原文检索1、导入依赖…

安全及应用(更新)

一、账号安全 1.1系统帐号清理 #查看/sbin/nologin结尾的文件并统计 [rootrootlocalhost ~]# grep /sbin/nologin$ /etc/passwd |wc -l 40#查看apache登录的shell [rootrootlocalhost ~]# grep apache /etc/passwd apache:x:48:48:Apache:/usr/share/httpd:/sbin/nologin#改变…

Android增量更新----java版

一、背景 开发过程中,随着apk包越来越大,全量更新会使得耗时,同时浪费流量,为了节省时间,使用增量更新解决。网上很多文章都不是很清楚,没有手把手教学,使得很多初学者,摸不着头脑&a…

边缘概率密度、条件概率密度、边缘分布函数、联合分布函数关系

目录 二维随机变量及其分布离散型随机变量连续型随机变量边缘分布边缘概率密度举例边缘概率密度 条件概率密度边缘概率密度与条件概率密度的区别边缘概率密度条件概率密度举个具体例子 参考资料 二维随机变量及其分布 离散型随机变量 把所有的概率,都理解成不同质量…

逻辑图框架图等结构图类图的高效制作方式不妨进来看看

**逻辑图框架图等结构图类图的高效制作方式不妨进来看看** 基于我们每天都在处理大量的数据和信息。为了更清晰地理解和传达这些信息,结构图、逻辑图和框架图等可视化工具变得越来越重要。然而,如何高效地制作这些图表并确保其准确性和易读性呢&#xf…

Windows密码凭证获取

Windows HASH HASH简介 hash ,一般翻译做散列,或音译为哈希,所谓哈希,就是使用一种加密函数进行计算后的结果。这个 加密函数对一个任意长度的字符串数据进行一次数学加密函数运算,然后返回一个固定长度的字符串。…

服装购物商城系统小程序-计算机毕业设计源码35058

摘要 服装购物商城系统小程序,依托Spring Boot框架的强大支持,为用户呈现了一个功能丰富、体验流畅的在线购物平台。该系统不仅涵盖了商品展示、用户注册登录、购物车管理、订单处理、支付集成等核心购物流程,还引入了个性化推荐算法&#xf…

Jmeter使用JSON Extractor提取多个变量

1.当正则不好使时,用json extractor 2.提取多个值时,默认值必填,否则读不到变量

Java | Leetcode Java题解之第212题单词搜索II

题目&#xff1a; 题解&#xff1a; class Solution {int[][] dirs {{1, 0}, {-1, 0}, {0, 1}, {0, -1}};public List<String> findWords(char[][] board, String[] words) {Trie trie new Trie();for (String word : words) {trie.insert(word);}Set<String> a…