模块出处
[MICCAI 22] [link] [code] Lesion-aware Dynamic Kernel for Polyp Segmentation
模块名称
Efficient Self-Attention (ESA)
模块作用
高效自注意力
模块结构
模块思想
Self Attention操作在具有优秀的长距离建模能力的同时,也有着较高的计算与内存成本,因此需要进行优化。本文的ESA模块中的Query直接通过原特征Reshape得到,而Key与Value则通过Pyramid Pooling操作得到。
模块代码
import torch.nn.functional as F
import torch.nn as nn
import torch
from einops import rearrangeclass PPM(nn.Module):def __init__(self, pooling_sizes=(1, 3, 5)):super().__init__()self.layer = nn.ModuleList([nn.AdaptiveAvgPool2d(output_size=(size,size)) for size in pooling_sizes])def forward(self, feat):b, c, h, w = feat.shapeoutput = [layer(feat).view(b, c, -1) for layer in self.layer]output = torch.cat(output, dim=-1)return outputclass ESA(nn.Module):def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):super().__init__()inner_dim = dim_head * headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.attend = nn.Softmax(dim=-1)self.to_qkv = nn.Conv2d(dim, inner_dim * 3, kernel_size=1, stride=1, padding=0, bias=False)self.ppm = PPM(pooling_sizes=(1, 3, 5))self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x):# input x (b, c, h, w)b, c, h, w = x.shapeq, k, v = self.to_qkv(x).chunk(3, dim=1) # q/k/v shape: (b, inner_dim, h, w)q = rearrange(q, 'b (head d) h w -> b head (h w) d', head=self.heads) # q shape: (b, head, n_q, d)k, v = self.ppm(k), self.ppm(v) # k/v shape: (b, inner_dim, n_kv)k = rearrange(k, 'b (head d) n -> b head n d', head=self.heads) # k shape: (b, head, n_kv, d)v = rearrange(v, 'b (head d) n -> b head n d', head=self.heads) # v shape: (b, head, n_kv, d)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale # shape: (b, head, n_q, n_kv)attn = self.attend(dots)out = torch.matmul(attn, v) # shape: (b, head, n_q, d)out = rearrange(out, 'b head n d -> b n (head d)')return self.to_out(out)if __name__ == '__main__':x = torch.randn([1, 256, 11, 11])esa = ESA(dim=256)out = esa(x)print(out.shape) # 1, 121, 256