【论文笔记】BiFormer: Vision Transformer with Bi-Level Routing Attention

论文地址:BiFormer: Vision Transformer with Bi-Level Routing Attention

代码地址:https://github.com/rayleizhu/BiFormer

vision transformer中Attention是极其重要的模块,但是它有着非常大的缺点:计算量太大。

BiFormer提出了Bi-Level Routing Attention,在Attention计算时,只关注最重要的token,由此来降低计算量。

一、Bi-Level Routing Attention

下图是多个不同的Attention模块关注的区域,(a)是原始的attention,其他的都是稀疏的Attention结构。Bi-Level Routing Attention如下图(f)所示。

与其他的稀疏Attention结构有所不同,Bi-Level Routing Attention首先将特征图分为不同的区域(区域大小是SxS),每个区域经过线性映射,得到QKV,然后QK在每个SxS的窗口内取平均作为该区域的token(可参考代码),得到Q^{r}K^{r}(r表示region,即SxS的窗口),通过Q^{r}K^{r}的矩阵运算得到A^{r},如下式:

得到了邻接矩阵A^{r}后,取其相关性最高的k个token索引I^{r},这样就知道每个窗口与哪k个窗口相关性更高了。

得到了I^{r}后,用gather运算得到K^{g}V^{g}

最后计算Attention:

Bi-Level Routing Attention的计算过程如下图所示,其中的k就是计算相关性索引时设置的参数。

二、代码

Bi-Level Routing Attention的代码如下:

"""
Core of BiFormer, Bi-Level Routing Attention.To be refactored.author: ZHU Lei
github: https://github.com/rayleizhu
email: ray.leizhu@outlook.comThis source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
from typing import Tupleimport torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensorclass TopkRouting(nn.Module):"""differentiable topk routing with scalingArgs:qk_dim: int, feature dimension of query and keytopk: int, the 'topk'qk_scale: int or None, temperature (multiply) of softmax activationwith_param: bool, wether inorporate learnable params in routing unitdiff_routing: bool, wether make routing differentiablesoft_routing: bool, wether make output value multiplied by routing weights"""def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False):super().__init__()self.topk = topkself.qk_dim = qk_dimself.scale = qk_scale or qk_dim ** -0.5self.diff_routing = diff_routing# TODO: norm layer before/after linear?self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity()# routing activationself.routing_act = nn.Softmax(dim=-1)def forward(self, query:Tensor, key:Tensor)->Tuple[Tensor]:"""Args:q, k: (n, p^2, c) tensorReturn:r_weight, topk_index: (n, p^2, topk) tensor"""if not self.diff_routing:query, key = query.detach(), key.detach()query_hat, key_hat = self.emb(query), self.emb(key) # per-window pooling -> (n, p^2, c) attn_logit = (query_hat*self.scale) @ key_hat.transpose(-2, -1) # (n, p^2, p^2)topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1) # (n, p^2, k), (n, p^2, k)r_weight = self.routing_act(topk_attn_logit) # (n, p^2, k)return r_weight, topk_indexclass KVGather(nn.Module):def __init__(self, mul_weight='none'):super().__init__()assert mul_weight in ['none', 'soft', 'hard']self.mul_weight = mul_weightdef forward(self, r_idx:Tensor, r_weight:Tensor, kv:Tensor):"""r_idx: (n, p^2, topk) tensorr_weight: (n, p^2, topk) tensorkv: (n, p^2, w^2, c_kq+c_v)Return:(n, p^2, topk, w^2, c_kq+c_v) tensor"""# select kv according to routing indexn, p2, w2, c_kv = kv.size()topk = r_idx.size(-1)# print(r_idx.size(), r_weight.size())# FIXME: gather consumes much memory (topk times redundancy), write cuda kernel? topk_kv = torch.gather(kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1), # (n, p^2, p^2, w^2, c_kv) without mem cpydim=2,index=r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv) # (n, p^2, k, w^2, c_kv))if self.mul_weight == 'soft':topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv # (n, p^2, k, w^2, c_kv)elif self.mul_weight == 'hard':raise NotImplementedError('differentiable hard routing TBA')# else: #'none'#     topk_kv = topk_kv # do nothingreturn topk_kvclass QKVLinear(nn.Module):def __init__(self, dim, qk_dim, bias=True):super().__init__()self.dim = dimself.qk_dim = qk_dimself.qkv = nn.Linear(dim, qk_dim + qk_dim + dim, bias=bias)def forward(self, x):q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim+self.dim], dim=-1)return q, kv# q, k, v = self.qkv(x).split([self.qk_dim, self.qk_dim, self.dim], dim=-1)# return q, k, vclass BiLevelRoutingAttention(nn.Module):"""n_win: number of windows in one side (so the actual number of windows is n_win*n_win)kv_per_win: for kv_downsample_mode='ada_xxxpool' only, number of key/values per window. Similar to n_win, the actual number is kv_per_win*kv_per_win.topk: topk for window filteringparam_attention: 'qkvo'-linear for q,k,v and o, 'none': param free attentionparam_routing: extra linear for routingdiff_routing: wether to set routing differentiablesoft_routing: wether to multiply soft routing weights """def __init__(self, dim, num_heads=8, n_win=7, qk_dim=None, qk_scale=None,kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity',topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, side_dwconv=3,auto_pad=False):super().__init__()# local attention settingself.dim = dimself.n_win = n_win  # Wh, Wwself.num_heads = num_headsself.qk_dim = qk_dim or dimassert self.qk_dim % num_heads == 0 and self.dim % num_heads==0, 'qk_dim and dim must be divisible by num_heads!'self.scale = qk_scale or self.qk_dim ** -0.5################side_dwconv (i.e. LCE in ShuntedTransformer)###########self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \lambda x: torch.zeros_like(x)################ global routing setting #################self.topk = topkself.param_routing = param_routingself.diff_routing = diff_routingself.soft_routing = soft_routing# routerassert not (self.param_routing and not self.diff_routing) # cannot be with_param=True and diff_routing=Falseself.router = TopkRouting(qk_dim=self.qk_dim,qk_scale=self.scale,topk=self.topk,diff_routing=self.diff_routing,param_routing=self.param_routing)if self.soft_routing: # soft routing, always diffrentiable (if no detach)mul_weight = 'soft'elif self.diff_routing: # hard differentiable routingmul_weight = 'hard'else:  # hard non-differentiable routingmul_weight = 'none'self.kv_gather = KVGather(mul_weight=mul_weight)# qkv mapping (shared by both global routing and local attention)self.param_attention = param_attentionif self.param_attention == 'qkvo':self.qkv = QKVLinear(self.dim, self.qk_dim)self.wo = nn.Linear(dim, dim)elif self.param_attention == 'qkv':self.qkv = QKVLinear(self.dim, self.qk_dim)self.wo = nn.Identity()else:raise ValueError(f'param_attention mode {self.param_attention} is not surpported!')self.kv_downsample_mode = kv_downsample_modeself.kv_per_win = kv_per_winself.kv_downsample_ratio = kv_downsample_ratioself.kv_downsample_kenel = kv_downsample_kernelif self.kv_downsample_mode == 'ada_avgpool':assert self.kv_per_win is not Noneself.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win)elif self.kv_downsample_mode == 'ada_maxpool':assert self.kv_per_win is not Noneself.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win)elif self.kv_downsample_mode == 'maxpool':assert self.kv_downsample_ratio is not Noneself.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()elif self.kv_downsample_mode == 'avgpool':assert self.kv_downsample_ratio is not Noneself.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()elif self.kv_downsample_mode == 'identity': # no kv downsamplingself.kv_down = nn.Identity()elif self.kv_downsample_mode == 'fracpool':# assert self.kv_downsample_ratio is not None# assert self.kv_downsample_kenel is not None# TODO: fracpool# 1. kernel size should be input size dependent# 2. there is a random factor, need to avoid independent sampling for k and v raise NotImplementedError('fracpool policy is not implemented yet!')elif kv_downsample_mode == 'conv':# TODO: need to consider the case where k != v so that need two downsample modulesraise NotImplementedError('conv policy is not implemented yet!')else:raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!')# softmax for local attentionself.attn_act = nn.Softmax(dim=-1)self.auto_pad=auto_paddef forward(self, x, ret_attn_mask=False):"""x: NHWC tensorReturn:NHWC tensor"""# NOTE: use padding for semantic segmentation###################################################if self.auto_pad:N, H_in, W_in, C = x.size()pad_l = pad_t = 0pad_r = (self.n_win - W_in % self.n_win) % self.n_winpad_b = (self.n_win - H_in % self.n_win) % self.n_winx = F.pad(x, (0, 0, # dim=-1pad_l, pad_r, # dim=-2pad_t, pad_b)) # dim=-3_, H, W, _ = x.size() # padded sizeelse:N, H, W, C = x.size()assert H%self.n_win == 0 and W%self.n_win == 0 ##################################################### patchify, (n, p^2, w, w, c), keep 2d window as we need 2d pooling to reduce kv sizex = rearrange(x, "n (j h) (i w) c -> n (j i) h w c", j=self.n_win, i=self.n_win)#################qkv projection#################### q: (n, p^2, w, w, c_qk)# kv: (n, p^2, w, w, c_qk+c_v)# NOTE: separte kv if there were memory leak issue caused by gatherq, kv = self.qkv(x) # pixel-wise qkv# q_pix: (n, p^2, w^2, c_qk)# kv_pix: (n, p^2, h_kv*w_kv, c_qk+c_v)q_pix = rearrange(q, 'n p2 h w c -> n p2 (h w) c')kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w'))kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win)q_win, k_win = q.mean([2, 3]), kv[..., 0:self.qk_dim].mean([2, 3]) # window-wise qk, (n, p^2, c_qk), (n, p^2, c_qk)##################side_dwconv(lepe)################### NOTE: call contiguous to avoid gradient warning when using ddplepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', j=self.n_win, i=self.n_win).contiguous())lepe = rearrange(lepe, 'n c (j h) (i w) -> n (j h) (i w) c', j=self.n_win, i=self.n_win)############ gather q dependent k/v #################r_weight, r_idx = self.router(q_win, k_win) # both are (n, p^2, topk) tensorskv_pix_sel = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix) #(n, p^2, topk, h_kv*w_kv, c_qk+c_v)k_pix_sel, v_pix_sel = kv_pix_sel.split([self.qk_dim, self.dim], dim=-1)# kv_pix_sel: (n, p^2, topk, h_kv*w_kv, c_qk)# v_pix_sel: (n, p^2, topk, h_kv*w_kv, c_v)######### do attention as normal ####################k_pix_sel = rearrange(k_pix_sel, 'n p2 k w2 (m c) -> (n p2) m c (k w2)', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_kq//m) transpose here?v_pix_sel = rearrange(v_pix_sel, 'n p2 k w2 (m c) -> (n p2) m (k w2) c', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_v//m)q_pix = rearrange(q_pix, 'n p2 w2 (m c) -> (n p2) m w2 c', m=self.num_heads) # to BMLC tensor (n*p^2, m, w^2, c_qk//m)# param-free multihead attentionattn_weight = (q_pix * self.scale) @ k_pix_sel # (n*p^2, m, w^2, c) @ (n*p^2, m, c, topk*h_kv*w_kv) -> (n*p^2, m, w^2, topk*h_kv*w_kv)attn_weight = self.attn_act(attn_weight)out = attn_weight @ v_pix_sel # (n*p^2, m, w^2, topk*h_kv*w_kv) @ (n*p^2, m, topk*h_kv*w_kv, c) -> (n*p^2, m, w^2, c)out = rearrange(out, '(n j i) m (h w) c -> n (j h) (i w) (m c)', j=self.n_win, i=self.n_win,h=H//self.n_win, w=W//self.n_win)out = out + lepe# output linearout = self.wo(out)# NOTE: use padding for semantic segmentation# crop padded regionif self.auto_pad and (pad_r > 0 or pad_b > 0):out = out[:, :H_in, :W_in, :].contiguous()if ret_attn_mask:return out, r_weight, r_idx, attn_weightelse:return out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/578697.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android 获取wlan0地址

要获取 Android 设备的 wlan0 接口的 IP 地址&#xff0c;可以使用以下代码&#xff1a; fun getIPAddress(interfaceName: String): String? {try {val interfaces: List<NetworkInterface> Collections.list(NetworkInterface.getNetworkInterfaces())for (intf in i…

Halcon颜色提取,基于MLP自动颜色提取功能

1.前言 在实际的图像处理中&#xff0c;经常会遇到彩色图像&#xff0c;使用彩色图像往往跟颜色识别有关系。但是使用RGB进行调参时又很难达到所需要的效果&#xff08;异常区域过多不好处理&#xff09;。 在Halcon中&#xff0c;halcon对颜色提取采用MLP&#xff08;多层感知…

Hive 部署

一、介绍 Apache Hive是一个分布式、容错的数据仓库系统&#xff0c;支持大规模的分析。Hive Metastore&#xff08;HMS&#xff09;提供了一个中央元数据存储库&#xff0c;可以轻松地进行分析&#xff0c;以做出明智的数据驱动决策&#xff0c;因此它是许多数据湖架构的关键组…

C/C++ 递增/递减运算符和指针

可以将递增运算符用于指针和基本变量。本书前面介绍过。将递增运算符用于指针时。将把指针的值增加其指向的数据类型占用的字节数&#xff0c;这种规则适用于对指针递增和递减。 double arr[5] {1.1, 2.1, 3.1, 4.1, 5.1}; double *ptr arr; ptr; 也可以结合使用这些运算符和…

PostgreSQL | 概念 | 什么是OLTPOLAP?

什么是OLTP&OLAP&#xff1f; 大白话理解&#xff1a;业务系统都可以称作OLTP&#xff0c;基于业务系统产生的数据进行数据分析和决策的都可以称为OLAP。 OLTP OLTP&#xff08; Online Transaction Processing&#xff09;在线事务处理系统 用途&#xff1a; 用于支持日…

14.Unity中序列化

非字符串类型转字节数组 //关键类&#xff1a;BitConverter//所在命名空间&#xff1a;System//主要作用&#xff1a;除字符串的其他常用类型和字节数组相互转换byte[] byte1 BitConverter.GetBytes(100); 字符串类型转字节数组 //关键类&#xff1a;Encoding//所在命名空间&…

第十部分 欧拉图与哈密顿图

欧拉图&#xff1a; 历史背景&#xff1a; 哥尼斯堡七桥问题与欧拉图 问题提出后&#xff0c;很多人对此很感兴趣&#xff0c;纷纷进行试验&#xff0c;但在相当长的时间里&#xff0c;始终未能解决。而利用普通数学知识&#xff0c;每座桥均走一次&#xff0c;那这七座桥所有的…

软件架构的演进过程

软件架构的发展经历了由单体架构、垂直架构、SOA架构到微服务架构的演进过程&#xff0c;下面我们分别了解一下这几个架构。 一, 单体架构 一个归档包&#xff08;例如war格式或者Jar格式&#xff09;包含了应用所有功能的应用程序&#xff0c;我们通常称之为单体应用。架构单…

共模电容:又一款EMC滤波神器?|深圳比创达电子(下)

一、共模电容 1、结构特性 图7 共模电容结构示意 如图7&#xff0c;共模电容是在普通叠层电容基础上&#xff0c;结合3端电容中为降低电容ESL的优化设计&#xff0c;添加了一组GND&#xff1b;同时这组GND还有一定的屏蔽作用&#xff0c;可降低电极的边缘辐射。 2、电气特性…

力扣18 四数之和 Java版本

文章目录 题目解题方法代码 题目 给你一个由 n 个整数组成的数组 nums &#xff0c;和一个目标值 target 。请你找出并返回满足下述全部条件且不重复的四元组 [nums[a], nums[b], nums[c], nums[d]] &#xff08;若两个四元组元素一一对应&#xff0c;则认为两个四元组重复&am…

记一次redis内存没满发生key逐出的情况。

现象&#xff1a; 从监控上看&#xff0c;redis的内存使用率最大是80%&#xff0c;但是发生了key evicted 分析&#xff1a; 原因1、可能是阿里云监控没抓取到内存100%监控数据。 阿里控制台监控监控粒度是5秒。 内存使用率的计算方法。 used_memory_human/maxmemory 原因2、…

drf之路由

一 路由Routers 对于视图集ViewSet&#xff0c;我们除了可以自己手动指明请求方式与动作action之间的对应关系外&#xff0c;还可以使用Routers来帮助我们快速实现路由信息。 REST framework提供了两个router SimpleRouterDefaultRouter 1.1 使用方法 1&#xff09; 创建r…

自编码器的基本概念

这里写目录标题 全连接自编码器卷积自编码器正则自编码器:变分自编码器2. **VAE的改进&#xff1a;**3. **关键概念&#xff1a;**4. **目标函数&#xff1a;**5. **生成新样本&#xff1a;**6. **应用领域&#xff1a;** 全连接自编码器 自编码器是一种无监督学习模型&#x…

【c++】入门2

函数重载 函数重载&#xff1a;是函数的一种特殊情况&#xff0c;C允许在同一作用域中声明几个功能类似的同名函数&#xff0c;这 些同名函数的形参列表(参数个数 或 类型 或 类型顺序)不同&#xff0c;常用来处理实现功能类似数据类型 不同的问题。 c区分重载函数是根据参数…

搬运机器人RFID传感器CNS-RFID-01|1S的RS485(MODBUS|HS协议)通讯连接方法

搬运机器人RFID传感器CNS-RFID-01|1S支持RS485通信&#xff0c;可支持RS485&#xff08;MODBUS RTU&#xff09;协议、RS485-HS协议&#xff0c;广泛应用于物流仓储&#xff0c;立库 AGV|无人叉车|搬送机器人等领域&#xff0c;常用定位、驻车等&#xff0c;本篇重点介绍CNS-RF…

oracle数据库什么是表的死锁,死锁的产生原因,怎么查询死锁的表信息,解决死锁的方法;给出具体业务场景与代码示例

oracle数据库中表的死锁 一、什么是表的死锁以及死锁的产生原因二、产生死锁的案例三、查询死锁信息解决死锁问题1 . 查询死锁信息2. 解决死锁问题 四、查看具体被死锁的SQL语句 一、什么是表的死锁以及死锁的产生原因 表的死锁是指在Oracle数据库中&#xff0c;两个或多个事务…

Ubuntu20.04-查看GPU的使用情况及输出详解

1. 查看GPU的使用情况 1.1 nvidia-smi # 直接在终端得到显卡的使用情况 # 不会自动刷新 nvidia-smi# 重定向到文件中 nvidia-smi > nvidia_smi_output.txt# 如果输出的内容部分是以省略号表示的&#xff0c;可以-q nvidia-smi -q 1.2 nvidia-smi -l # 会自动刷新&#x…

我想开发一款即时通讯APP请问还有市场吗?

随着科技的飞速发展&#xff0c;人们对于沟通工具的需求也在不断升级。在这个信息爆炸的时代&#xff0c;一款优质的即时通讯APP不仅能满足用户日常沟通需求&#xff0c;还能在市场中占据一席之地。本文将探讨开发一款即时通讯APP的市场前景&#xff0c;以及如何抓住市场机遇。…

CleanMyMac X2024免费许可证及功能详细讲解

一些用户反映自己的CleanMyMac卸载不干净&#xff1f;你的卸载方式正确码&#xff1f;当你在Mac上安装使用CleanMyMac后&#xff0c;需要将软件卸载&#xff0c;你会使用怎样方法完成操作呢&#xff1f;小编今天主要讲解如何卸载CleanMyMac以及卸载这款软件时应该注意的事项。一…

YACS(上海计算机学会竞赛平台)一星级题集——空心正方形

题目描述 给定一个正整数 n&#xff0c;请打印一个空心的正方形&#xff0c;它的边界由 * 构成&#xff0c;每条边都恰好有 n 个字符。 输入格式 单个正整数表示 n。 输出格式 输出一个边界为星号、内部空心的正方形。 数据范围 3≤n≤50。 样例数据 输入&#xff1a;…