模块出处
[ECCV 24] [link] [code] Context-Guided Spatial Feature Reconstruction for Efficient Semantic Segmentation
模块名称
Rectangular self-Calibration Attention (RCA)
模块作用
空间注意力
模块结构
模块代码
import torch
import torch.nn as nn
import torch.nn.functional as Fclass RCA(nn.Module):def __init__(self, inp, kernel_size=1, ratio=1, band_kernel_size=11, dw_size=(1,1), padding=(0,0), stride=1, square_kernel_size=3, relu=True):super(RCA, self).__init__()self.dwconv_hw = nn.Conv2d(inp, inp, square_kernel_size, padding=square_kernel_size//2, groups=inp)self.pool_h = nn.AdaptiveAvgPool2d((None, 1))self.pool_w = nn.AdaptiveAvgPool2d((1, None))gc = inp // ratioself.excite = nn.Sequential(nn.Conv2d(inp, gc, kernel_size=(1, band_kernel_size), padding=(0, band_kernel_size//2), groups=gc),nn.BatchNorm2d(gc),nn.ReLU(inplace=True),nn.Conv2d(gc, inp, kernel_size=(band_kernel_size, 1), padding=(band_kernel_size//2, 0), groups=gc),nn.Sigmoid())def sge(self, x):x_h = self.pool_h(x)x_w = self.pool_w(x)x_gather = x_h + x_wge = self.excite(x_gather)return gedef forward(self, x):loc = self.dwconv_hw(x)att = self.sge(x)out = att*locreturn outif __name__ == '__main__':x = torch.randn([3, 256, 40, 40])rca = RCA(inp=256)out = rca(x)print(out.shape) # 3, 256, 40, 40
原文表述
矩形自我校准注意力(RCA)明确地模拟矩形区域并校准注意力形状。