论文地址:BiFormer: Vision Transformer with Bi-Level Routing Attention
代码地址:https://github.com/rayleizhu/BiFormer
vision transformer中Attention是极其重要的模块,但是它有着非常大的缺点:计算量太大。
BiFormer提出了Bi-Level Routing Attention,在Attention计算时,只关注最重要的token,由此来降低计算量。
一、Bi-Level Routing Attention
下图是多个不同的Attention模块关注的区域,(a)是原始的attention,其他的都是稀疏的Attention结构。Bi-Level Routing Attention如下图(f)所示。
与其他的稀疏Attention结构有所不同,Bi-Level Routing Attention首先将特征图分为不同的区域(区域大小是SxS),每个区域经过线性映射,得到QKV,然后QK在每个SxS的窗口内取平均作为该区域的token(可参考代码),得到(r表示region,即SxS的窗口),通过的矩阵运算得到,如下式:
得到了邻接矩阵后,取其相关性最高的k个token索引,这样就知道每个窗口与哪k个窗口相关性更高了。
得到了后,用gather运算得到和:
最后计算Attention:
Bi-Level Routing Attention的计算过程如下图所示,其中的k就是计算相关性索引时设置的参数。
二、代码
Bi-Level Routing Attention的代码如下:
"""
Core of BiFormer, Bi-Level Routing Attention.To be refactored.author: ZHU Lei
github: https://github.com/rayleizhu
email: ray.leizhu@outlook.comThis source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
from typing import Tupleimport torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensorclass TopkRouting(nn.Module):"""differentiable topk routing with scalingArgs:qk_dim: int, feature dimension of query and keytopk: int, the 'topk'qk_scale: int or None, temperature (multiply) of softmax activationwith_param: bool, wether inorporate learnable params in routing unitdiff_routing: bool, wether make routing differentiablesoft_routing: bool, wether make output value multiplied by routing weights"""def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False):super().__init__()self.topk = topkself.qk_dim = qk_dimself.scale = qk_scale or qk_dim ** -0.5self.diff_routing = diff_routing# TODO: norm layer before/after linear?self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity()# routing activationself.routing_act = nn.Softmax(dim=-1)def forward(self, query:Tensor, key:Tensor)->Tuple[Tensor]:"""Args:q, k: (n, p^2, c) tensorReturn:r_weight, topk_index: (n, p^2, topk) tensor"""if not self.diff_routing:query, key = query.detach(), key.detach()query_hat, key_hat = self.emb(query), self.emb(key) # per-window pooling -> (n, p^2, c) attn_logit = (query_hat*self.scale) @ key_hat.transpose(-2, -1) # (n, p^2, p^2)topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1) # (n, p^2, k), (n, p^2, k)r_weight = self.routing_act(topk_attn_logit) # (n, p^2, k)return r_weight, topk_indexclass KVGather(nn.Module):def __init__(self, mul_weight='none'):super().__init__()assert mul_weight in ['none', 'soft', 'hard']self.mul_weight = mul_weightdef forward(self, r_idx:Tensor, r_weight:Tensor, kv:Tensor):"""r_idx: (n, p^2, topk) tensorr_weight: (n, p^2, topk) tensorkv: (n, p^2, w^2, c_kq+c_v)Return:(n, p^2, topk, w^2, c_kq+c_v) tensor"""# select kv according to routing indexn, p2, w2, c_kv = kv.size()topk = r_idx.size(-1)# print(r_idx.size(), r_weight.size())# FIXME: gather consumes much memory (topk times redundancy), write cuda kernel? topk_kv = torch.gather(kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1), # (n, p^2, p^2, w^2, c_kv) without mem cpydim=2,index=r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv) # (n, p^2, k, w^2, c_kv))if self.mul_weight == 'soft':topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv # (n, p^2, k, w^2, c_kv)elif self.mul_weight == 'hard':raise NotImplementedError('differentiable hard routing TBA')# else: #'none'# topk_kv = topk_kv # do nothingreturn topk_kvclass QKVLinear(nn.Module):def __init__(self, dim, qk_dim, bias=True):super().__init__()self.dim = dimself.qk_dim = qk_dimself.qkv = nn.Linear(dim, qk_dim + qk_dim + dim, bias=bias)def forward(self, x):q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim+self.dim], dim=-1)return q, kv# q, k, v = self.qkv(x).split([self.qk_dim, self.qk_dim, self.dim], dim=-1)# return q, k, vclass BiLevelRoutingAttention(nn.Module):"""n_win: number of windows in one side (so the actual number of windows is n_win*n_win)kv_per_win: for kv_downsample_mode='ada_xxxpool' only, number of key/values per window. Similar to n_win, the actual number is kv_per_win*kv_per_win.topk: topk for window filteringparam_attention: 'qkvo'-linear for q,k,v and o, 'none': param free attentionparam_routing: extra linear for routingdiff_routing: wether to set routing differentiablesoft_routing: wether to multiply soft routing weights """def __init__(self, dim, num_heads=8, n_win=7, qk_dim=None, qk_scale=None,kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity',topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, side_dwconv=3,auto_pad=False):super().__init__()# local attention settingself.dim = dimself.n_win = n_win # Wh, Wwself.num_heads = num_headsself.qk_dim = qk_dim or dimassert self.qk_dim % num_heads == 0 and self.dim % num_heads==0, 'qk_dim and dim must be divisible by num_heads!'self.scale = qk_scale or self.qk_dim ** -0.5################side_dwconv (i.e. LCE in ShuntedTransformer)###########self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \lambda x: torch.zeros_like(x)################ global routing setting #################self.topk = topkself.param_routing = param_routingself.diff_routing = diff_routingself.soft_routing = soft_routing# routerassert not (self.param_routing and not self.diff_routing) # cannot be with_param=True and diff_routing=Falseself.router = TopkRouting(qk_dim=self.qk_dim,qk_scale=self.scale,topk=self.topk,diff_routing=self.diff_routing,param_routing=self.param_routing)if self.soft_routing: # soft routing, always diffrentiable (if no detach)mul_weight = 'soft'elif self.diff_routing: # hard differentiable routingmul_weight = 'hard'else: # hard non-differentiable routingmul_weight = 'none'self.kv_gather = KVGather(mul_weight=mul_weight)# qkv mapping (shared by both global routing and local attention)self.param_attention = param_attentionif self.param_attention == 'qkvo':self.qkv = QKVLinear(self.dim, self.qk_dim)self.wo = nn.Linear(dim, dim)elif self.param_attention == 'qkv':self.qkv = QKVLinear(self.dim, self.qk_dim)self.wo = nn.Identity()else:raise ValueError(f'param_attention mode {self.param_attention} is not surpported!')self.kv_downsample_mode = kv_downsample_modeself.kv_per_win = kv_per_winself.kv_downsample_ratio = kv_downsample_ratioself.kv_downsample_kenel = kv_downsample_kernelif self.kv_downsample_mode == 'ada_avgpool':assert self.kv_per_win is not Noneself.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win)elif self.kv_downsample_mode == 'ada_maxpool':assert self.kv_per_win is not Noneself.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win)elif self.kv_downsample_mode == 'maxpool':assert self.kv_downsample_ratio is not Noneself.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()elif self.kv_downsample_mode == 'avgpool':assert self.kv_downsample_ratio is not Noneself.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()elif self.kv_downsample_mode == 'identity': # no kv downsamplingself.kv_down = nn.Identity()elif self.kv_downsample_mode == 'fracpool':# assert self.kv_downsample_ratio is not None# assert self.kv_downsample_kenel is not None# TODO: fracpool# 1. kernel size should be input size dependent# 2. there is a random factor, need to avoid independent sampling for k and v raise NotImplementedError('fracpool policy is not implemented yet!')elif kv_downsample_mode == 'conv':# TODO: need to consider the case where k != v so that need two downsample modulesraise NotImplementedError('conv policy is not implemented yet!')else:raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!')# softmax for local attentionself.attn_act = nn.Softmax(dim=-1)self.auto_pad=auto_paddef forward(self, x, ret_attn_mask=False):"""x: NHWC tensorReturn:NHWC tensor"""# NOTE: use padding for semantic segmentation###################################################if self.auto_pad:N, H_in, W_in, C = x.size()pad_l = pad_t = 0pad_r = (self.n_win - W_in % self.n_win) % self.n_winpad_b = (self.n_win - H_in % self.n_win) % self.n_winx = F.pad(x, (0, 0, # dim=-1pad_l, pad_r, # dim=-2pad_t, pad_b)) # dim=-3_, H, W, _ = x.size() # padded sizeelse:N, H, W, C = x.size()assert H%self.n_win == 0 and W%self.n_win == 0 ##################################################### patchify, (n, p^2, w, w, c), keep 2d window as we need 2d pooling to reduce kv sizex = rearrange(x, "n (j h) (i w) c -> n (j i) h w c", j=self.n_win, i=self.n_win)#################qkv projection#################### q: (n, p^2, w, w, c_qk)# kv: (n, p^2, w, w, c_qk+c_v)# NOTE: separte kv if there were memory leak issue caused by gatherq, kv = self.qkv(x) # pixel-wise qkv# q_pix: (n, p^2, w^2, c_qk)# kv_pix: (n, p^2, h_kv*w_kv, c_qk+c_v)q_pix = rearrange(q, 'n p2 h w c -> n p2 (h w) c')kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w'))kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win)q_win, k_win = q.mean([2, 3]), kv[..., 0:self.qk_dim].mean([2, 3]) # window-wise qk, (n, p^2, c_qk), (n, p^2, c_qk)##################side_dwconv(lepe)################### NOTE: call contiguous to avoid gradient warning when using ddplepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', j=self.n_win, i=self.n_win).contiguous())lepe = rearrange(lepe, 'n c (j h) (i w) -> n (j h) (i w) c', j=self.n_win, i=self.n_win)############ gather q dependent k/v #################r_weight, r_idx = self.router(q_win, k_win) # both are (n, p^2, topk) tensorskv_pix_sel = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix) #(n, p^2, topk, h_kv*w_kv, c_qk+c_v)k_pix_sel, v_pix_sel = kv_pix_sel.split([self.qk_dim, self.dim], dim=-1)# kv_pix_sel: (n, p^2, topk, h_kv*w_kv, c_qk)# v_pix_sel: (n, p^2, topk, h_kv*w_kv, c_v)######### do attention as normal ####################k_pix_sel = rearrange(k_pix_sel, 'n p2 k w2 (m c) -> (n p2) m c (k w2)', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_kq//m) transpose here?v_pix_sel = rearrange(v_pix_sel, 'n p2 k w2 (m c) -> (n p2) m (k w2) c', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_v//m)q_pix = rearrange(q_pix, 'n p2 w2 (m c) -> (n p2) m w2 c', m=self.num_heads) # to BMLC tensor (n*p^2, m, w^2, c_qk//m)# param-free multihead attentionattn_weight = (q_pix * self.scale) @ k_pix_sel # (n*p^2, m, w^2, c) @ (n*p^2, m, c, topk*h_kv*w_kv) -> (n*p^2, m, w^2, topk*h_kv*w_kv)attn_weight = self.attn_act(attn_weight)out = attn_weight @ v_pix_sel # (n*p^2, m, w^2, topk*h_kv*w_kv) @ (n*p^2, m, topk*h_kv*w_kv, c) -> (n*p^2, m, w^2, c)out = rearrange(out, '(n j i) m (h w) c -> n (j h) (i w) (m c)', j=self.n_win, i=self.n_win,h=H//self.n_win, w=W//self.n_win)out = out + lepe# output linearout = self.wo(out)# NOTE: use padding for semantic segmentation# crop padded regionif self.auto_pad and (pad_r > 0 or pad_b > 0):out = out[:, :H_in, :W_in, :].contiguous()if ret_attn_mask:return out, r_weight, r_idx, attn_weightelse:return out