【论文笔记】BiFormer: Vision Transformer with Bi-Level Routing Attention

论文地址:BiFormer: Vision Transformer with Bi-Level Routing Attention

代码地址:https://github.com/rayleizhu/BiFormer

vision transformer中Attention是极其重要的模块,但是它有着非常大的缺点:计算量太大。

BiFormer提出了Bi-Level Routing Attention,在Attention计算时,只关注最重要的token,由此来降低计算量。

一、Bi-Level Routing Attention

下图是多个不同的Attention模块关注的区域,(a)是原始的attention,其他的都是稀疏的Attention结构。Bi-Level Routing Attention如下图(f)所示。

与其他的稀疏Attention结构有所不同,Bi-Level Routing Attention首先将特征图分为不同的区域(区域大小是SxS),每个区域经过线性映射,得到QKV,然后QK在每个SxS的窗口内取平均作为该区域的token(可参考代码),得到Q^{r}K^{r}(r表示region,即SxS的窗口),通过Q^{r}K^{r}的矩阵运算得到A^{r},如下式:

得到了邻接矩阵A^{r}后,取其相关性最高的k个token索引I^{r},这样就知道每个窗口与哪k个窗口相关性更高了。

得到了I^{r}后,用gather运算得到K^{g}V^{g}

最后计算Attention:

Bi-Level Routing Attention的计算过程如下图所示,其中的k就是计算相关性索引时设置的参数。

二、代码

Bi-Level Routing Attention的代码如下:

"""
Core of BiFormer, Bi-Level Routing Attention.To be refactored.author: ZHU Lei
github: https://github.com/rayleizhu
email: ray.leizhu@outlook.comThis source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
from typing import Tupleimport torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensorclass TopkRouting(nn.Module):"""differentiable topk routing with scalingArgs:qk_dim: int, feature dimension of query and keytopk: int, the 'topk'qk_scale: int or None, temperature (multiply) of softmax activationwith_param: bool, wether inorporate learnable params in routing unitdiff_routing: bool, wether make routing differentiablesoft_routing: bool, wether make output value multiplied by routing weights"""def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False):super().__init__()self.topk = topkself.qk_dim = qk_dimself.scale = qk_scale or qk_dim ** -0.5self.diff_routing = diff_routing# TODO: norm layer before/after linear?self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity()# routing activationself.routing_act = nn.Softmax(dim=-1)def forward(self, query:Tensor, key:Tensor)->Tuple[Tensor]:"""Args:q, k: (n, p^2, c) tensorReturn:r_weight, topk_index: (n, p^2, topk) tensor"""if not self.diff_routing:query, key = query.detach(), key.detach()query_hat, key_hat = self.emb(query), self.emb(key) # per-window pooling -> (n, p^2, c) attn_logit = (query_hat*self.scale) @ key_hat.transpose(-2, -1) # (n, p^2, p^2)topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1) # (n, p^2, k), (n, p^2, k)r_weight = self.routing_act(topk_attn_logit) # (n, p^2, k)return r_weight, topk_indexclass KVGather(nn.Module):def __init__(self, mul_weight='none'):super().__init__()assert mul_weight in ['none', 'soft', 'hard']self.mul_weight = mul_weightdef forward(self, r_idx:Tensor, r_weight:Tensor, kv:Tensor):"""r_idx: (n, p^2, topk) tensorr_weight: (n, p^2, topk) tensorkv: (n, p^2, w^2, c_kq+c_v)Return:(n, p^2, topk, w^2, c_kq+c_v) tensor"""# select kv according to routing indexn, p2, w2, c_kv = kv.size()topk = r_idx.size(-1)# print(r_idx.size(), r_weight.size())# FIXME: gather consumes much memory (topk times redundancy), write cuda kernel? topk_kv = torch.gather(kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1), # (n, p^2, p^2, w^2, c_kv) without mem cpydim=2,index=r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv) # (n, p^2, k, w^2, c_kv))if self.mul_weight == 'soft':topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv # (n, p^2, k, w^2, c_kv)elif self.mul_weight == 'hard':raise NotImplementedError('differentiable hard routing TBA')# else: #'none'#     topk_kv = topk_kv # do nothingreturn topk_kvclass QKVLinear(nn.Module):def __init__(self, dim, qk_dim, bias=True):super().__init__()self.dim = dimself.qk_dim = qk_dimself.qkv = nn.Linear(dim, qk_dim + qk_dim + dim, bias=bias)def forward(self, x):q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim+self.dim], dim=-1)return q, kv# q, k, v = self.qkv(x).split([self.qk_dim, self.qk_dim, self.dim], dim=-1)# return q, k, vclass BiLevelRoutingAttention(nn.Module):"""n_win: number of windows in one side (so the actual number of windows is n_win*n_win)kv_per_win: for kv_downsample_mode='ada_xxxpool' only, number of key/values per window. Similar to n_win, the actual number is kv_per_win*kv_per_win.topk: topk for window filteringparam_attention: 'qkvo'-linear for q,k,v and o, 'none': param free attentionparam_routing: extra linear for routingdiff_routing: wether to set routing differentiablesoft_routing: wether to multiply soft routing weights """def __init__(self, dim, num_heads=8, n_win=7, qk_dim=None, qk_scale=None,kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity',topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, side_dwconv=3,auto_pad=False):super().__init__()# local attention settingself.dim = dimself.n_win = n_win  # Wh, Wwself.num_heads = num_headsself.qk_dim = qk_dim or dimassert self.qk_dim % num_heads == 0 and self.dim % num_heads==0, 'qk_dim and dim must be divisible by num_heads!'self.scale = qk_scale or self.qk_dim ** -0.5################side_dwconv (i.e. LCE in ShuntedTransformer)###########self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \lambda x: torch.zeros_like(x)################ global routing setting #################self.topk = topkself.param_routing = param_routingself.diff_routing = diff_routingself.soft_routing = soft_routing# routerassert not (self.param_routing and not self.diff_routing) # cannot be with_param=True and diff_routing=Falseself.router = TopkRouting(qk_dim=self.qk_dim,qk_scale=self.scale,topk=self.topk,diff_routing=self.diff_routing,param_routing=self.param_routing)if self.soft_routing: # soft routing, always diffrentiable (if no detach)mul_weight = 'soft'elif self.diff_routing: # hard differentiable routingmul_weight = 'hard'else:  # hard non-differentiable routingmul_weight = 'none'self.kv_gather = KVGather(mul_weight=mul_weight)# qkv mapping (shared by both global routing and local attention)self.param_attention = param_attentionif self.param_attention == 'qkvo':self.qkv = QKVLinear(self.dim, self.qk_dim)self.wo = nn.Linear(dim, dim)elif self.param_attention == 'qkv':self.qkv = QKVLinear(self.dim, self.qk_dim)self.wo = nn.Identity()else:raise ValueError(f'param_attention mode {self.param_attention} is not surpported!')self.kv_downsample_mode = kv_downsample_modeself.kv_per_win = kv_per_winself.kv_downsample_ratio = kv_downsample_ratioself.kv_downsample_kenel = kv_downsample_kernelif self.kv_downsample_mode == 'ada_avgpool':assert self.kv_per_win is not Noneself.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win)elif self.kv_downsample_mode == 'ada_maxpool':assert self.kv_per_win is not Noneself.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win)elif self.kv_downsample_mode == 'maxpool':assert self.kv_downsample_ratio is not Noneself.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()elif self.kv_downsample_mode == 'avgpool':assert self.kv_downsample_ratio is not Noneself.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()elif self.kv_downsample_mode == 'identity': # no kv downsamplingself.kv_down = nn.Identity()elif self.kv_downsample_mode == 'fracpool':# assert self.kv_downsample_ratio is not None# assert self.kv_downsample_kenel is not None# TODO: fracpool# 1. kernel size should be input size dependent# 2. there is a random factor, need to avoid independent sampling for k and v raise NotImplementedError('fracpool policy is not implemented yet!')elif kv_downsample_mode == 'conv':# TODO: need to consider the case where k != v so that need two downsample modulesraise NotImplementedError('conv policy is not implemented yet!')else:raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!')# softmax for local attentionself.attn_act = nn.Softmax(dim=-1)self.auto_pad=auto_paddef forward(self, x, ret_attn_mask=False):"""x: NHWC tensorReturn:NHWC tensor"""# NOTE: use padding for semantic segmentation###################################################if self.auto_pad:N, H_in, W_in, C = x.size()pad_l = pad_t = 0pad_r = (self.n_win - W_in % self.n_win) % self.n_winpad_b = (self.n_win - H_in % self.n_win) % self.n_winx = F.pad(x, (0, 0, # dim=-1pad_l, pad_r, # dim=-2pad_t, pad_b)) # dim=-3_, H, W, _ = x.size() # padded sizeelse:N, H, W, C = x.size()assert H%self.n_win == 0 and W%self.n_win == 0 ##################################################### patchify, (n, p^2, w, w, c), keep 2d window as we need 2d pooling to reduce kv sizex = rearrange(x, "n (j h) (i w) c -> n (j i) h w c", j=self.n_win, i=self.n_win)#################qkv projection#################### q: (n, p^2, w, w, c_qk)# kv: (n, p^2, w, w, c_qk+c_v)# NOTE: separte kv if there were memory leak issue caused by gatherq, kv = self.qkv(x) # pixel-wise qkv# q_pix: (n, p^2, w^2, c_qk)# kv_pix: (n, p^2, h_kv*w_kv, c_qk+c_v)q_pix = rearrange(q, 'n p2 h w c -> n p2 (h w) c')kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w'))kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win)q_win, k_win = q.mean([2, 3]), kv[..., 0:self.qk_dim].mean([2, 3]) # window-wise qk, (n, p^2, c_qk), (n, p^2, c_qk)##################side_dwconv(lepe)################### NOTE: call contiguous to avoid gradient warning when using ddplepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', j=self.n_win, i=self.n_win).contiguous())lepe = rearrange(lepe, 'n c (j h) (i w) -> n (j h) (i w) c', j=self.n_win, i=self.n_win)############ gather q dependent k/v #################r_weight, r_idx = self.router(q_win, k_win) # both are (n, p^2, topk) tensorskv_pix_sel = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix) #(n, p^2, topk, h_kv*w_kv, c_qk+c_v)k_pix_sel, v_pix_sel = kv_pix_sel.split([self.qk_dim, self.dim], dim=-1)# kv_pix_sel: (n, p^2, topk, h_kv*w_kv, c_qk)# v_pix_sel: (n, p^2, topk, h_kv*w_kv, c_v)######### do attention as normal ####################k_pix_sel = rearrange(k_pix_sel, 'n p2 k w2 (m c) -> (n p2) m c (k w2)', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_kq//m) transpose here?v_pix_sel = rearrange(v_pix_sel, 'n p2 k w2 (m c) -> (n p2) m (k w2) c', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_v//m)q_pix = rearrange(q_pix, 'n p2 w2 (m c) -> (n p2) m w2 c', m=self.num_heads) # to BMLC tensor (n*p^2, m, w^2, c_qk//m)# param-free multihead attentionattn_weight = (q_pix * self.scale) @ k_pix_sel # (n*p^2, m, w^2, c) @ (n*p^2, m, c, topk*h_kv*w_kv) -> (n*p^2, m, w^2, topk*h_kv*w_kv)attn_weight = self.attn_act(attn_weight)out = attn_weight @ v_pix_sel # (n*p^2, m, w^2, topk*h_kv*w_kv) @ (n*p^2, m, topk*h_kv*w_kv, c) -> (n*p^2, m, w^2, c)out = rearrange(out, '(n j i) m (h w) c -> n (j h) (i w) (m c)', j=self.n_win, i=self.n_win,h=H//self.n_win, w=W//self.n_win)out = out + lepe# output linearout = self.wo(out)# NOTE: use padding for semantic segmentation# crop padded regionif self.auto_pad and (pad_r > 0 or pad_b > 0):out = out[:, :H_in, :W_in, :].contiguous()if ret_attn_mask:return out, r_weight, r_idx, attn_weightelse:return out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/578697.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Halcon颜色提取,基于MLP自动颜色提取功能

1.前言 在实际的图像处理中,经常会遇到彩色图像,使用彩色图像往往跟颜色识别有关系。但是使用RGB进行调参时又很难达到所需要的效果(异常区域过多不好处理)。 在Halcon中,halcon对颜色提取采用MLP(多层感知…

Hive 部署

一、介绍 Apache Hive是一个分布式、容错的数据仓库系统,支持大规模的分析。Hive Metastore(HMS)提供了一个中央元数据存储库,可以轻松地进行分析,以做出明智的数据驱动决策,因此它是许多数据湖架构的关键组…

C/C++ 递增/递减运算符和指针

可以将递增运算符用于指针和基本变量。本书前面介绍过。将递增运算符用于指针时。将把指针的值增加其指向的数据类型占用的字节数,这种规则适用于对指针递增和递减。 double arr[5] {1.1, 2.1, 3.1, 4.1, 5.1}; double *ptr arr; ptr; 也可以结合使用这些运算符和…

第十部分 欧拉图与哈密顿图

欧拉图: 历史背景: 哥尼斯堡七桥问题与欧拉图 问题提出后,很多人对此很感兴趣,纷纷进行试验,但在相当长的时间里,始终未能解决。而利用普通数学知识,每座桥均走一次,那这七座桥所有的…

软件架构的演进过程

软件架构的发展经历了由单体架构、垂直架构、SOA架构到微服务架构的演进过程,下面我们分别了解一下这几个架构。 一, 单体架构 一个归档包(例如war格式或者Jar格式)包含了应用所有功能的应用程序,我们通常称之为单体应用。架构单…

共模电容:又一款EMC滤波神器?|深圳比创达电子(下)

一、共模电容 1、结构特性 图7 共模电容结构示意 如图7,共模电容是在普通叠层电容基础上,结合3端电容中为降低电容ESL的优化设计,添加了一组GND;同时这组GND还有一定的屏蔽作用,可降低电极的边缘辐射。 2、电气特性…

记一次redis内存没满发生key逐出的情况。

现象: 从监控上看,redis的内存使用率最大是80%,但是发生了key evicted 分析: 原因1、可能是阿里云监控没抓取到内存100%监控数据。 阿里控制台监控监控粒度是5秒。 内存使用率的计算方法。 used_memory_human/maxmemory 原因2、…

drf之路由

一 路由Routers 对于视图集ViewSet,我们除了可以自己手动指明请求方式与动作action之间的对应关系外,还可以使用Routers来帮助我们快速实现路由信息。 REST framework提供了两个router SimpleRouterDefaultRouter 1.1 使用方法 1) 创建r…

自编码器的基本概念

这里写目录标题 全连接自编码器卷积自编码器正则自编码器:变分自编码器2. **VAE的改进:**3. **关键概念:**4. **目标函数:**5. **生成新样本:**6. **应用领域:** 全连接自编码器 自编码器是一种无监督学习模型&#x…

【c++】入门2

函数重载 函数重载:是函数的一种特殊情况,C允许在同一作用域中声明几个功能类似的同名函数,这 些同名函数的形参列表(参数个数 或 类型 或 类型顺序)不同,常用来处理实现功能类似数据类型 不同的问题。 c区分重载函数是根据参数…

搬运机器人RFID传感器CNS-RFID-01|1S的RS485(MODBUS|HS协议)通讯连接方法

搬运机器人RFID传感器CNS-RFID-01|1S支持RS485通信,可支持RS485(MODBUS RTU)协议、RS485-HS协议,广泛应用于物流仓储,立库 AGV|无人叉车|搬送机器人等领域,常用定位、驻车等,本篇重点介绍CNS-RF…

Ubuntu20.04-查看GPU的使用情况及输出详解

1. 查看GPU的使用情况 1.1 nvidia-smi # 直接在终端得到显卡的使用情况 # 不会自动刷新 nvidia-smi# 重定向到文件中 nvidia-smi > nvidia_smi_output.txt# 如果输出的内容部分是以省略号表示的,可以-q nvidia-smi -q 1.2 nvidia-smi -l # 会自动刷新&#x…

CleanMyMac X2024免费许可证及功能详细讲解

一些用户反映自己的CleanMyMac卸载不干净?你的卸载方式正确码?当你在Mac上安装使用CleanMyMac后,需要将软件卸载,你会使用怎样方法完成操作呢?小编今天主要讲解如何卸载CleanMyMac以及卸载这款软件时应该注意的事项。一…

SpringBoot 3.2.0 基于Logback定制日志框架

依赖版本 JDK 17 Spring Boot 3.2.0 工程源码:Gitee 日志门面和日志实现 日志门面(如Slf4j)就是一个标准,同JDBC一样来制定“规则”,把不同的日志系统的实现进行了具体的抽象化,只提供了统一的日志使用接…

re模块(正则)

【 一 】 re模块概述 在线测试工具 正则表达式在线测试 - 站长工具 随着正则表达式越来越普遍,Python 内置库 re 模块也支持对正则表达式使用 Python 提供了re模块可以支持正则表示表达式使用,re模块提供了9个常量、12个函数 使用方法: re…

FRP 内网穿透指南:简单上手,快速入门

最近受朋友启发,突然萌生了一个想法,那就是如何将家里闲置五六年的台式机给利用起来, 本来打算组装一个NAS存储服务器,但是硬盘实在是有点小贵,所以决定先买了一块799元的4T机械硬盘, 然后做的frp内网穿透&…

同步与互斥(三)

一、递归锁 /* 创建一个递归锁,返回它的句柄。 * 此函数内部会分配互斥量结构体 * 返回值: 返回句柄,非NULL表示成功 */ SemaphoreHandle_t xSemaphoreCreateRecursiveMutex( void );/* 释放 */ BaseType_t xSemaphoreGiveRecursive( SemaphoreHandle_t…

全自动智能四向车系统|海格里斯HEGERLS四向穿梭车机器换人 科技赋能

近年来面对用户小批量、多品种、定制化产品服务需求日渐增多,制造行业仓储库容利用率低、分拣效率低、无法快速响应等问题更加凸显!核心设备也由传统货架转变为智能仓储设备立体货架的存储方式,形成更加自动化、智能化的系统集成物流体系。其…

camunda-modeler画图入门

软件下载 camunda-modeler是camunda的工作流绘制桌面工具 5.9.0和5.18.0版本下载地址 https://storage.googleapis.com/downloads-camunda-cloud-release/camunda-modeler/5.9.0/camunda-modeler-5.9.0-win-x64.ziphttps://storage.googleapis.com/downloads-camunda-cloud-…

WAVE SUMMIT+ 2023倒计时2天,传文心一言将曝最新进展!

传文心一言将曝最新进展! 亮点一:趋势引领,“扛把子”文心一言将曝新进展亮点二:干货十足,硬核低门槛开发秘籍大放送亮点三:蓄势待发,大模型赋能产业正当时亮点四:群星闪耀&#xff…