文章目录
- 1.介绍
- 2.原理
- 3.代码
- 4.SE模块的应用
论文:Squeeze-and-Excitation Networks
论文链接:Squeeze-and-Excitation Networks
代码链接:Github
1.介绍
卷积算子使网络能够在每一层的局部感受野中融合空间(spatial)和通道(channel)信息来构造信息特征。本文将重点放在通道(channel)关系上,提出SE(Squeeze-and-Excitation Block)模块,其显式建模通道之间的相互依赖性,自适应的重新校准通道方向上的特征响应,来提高所提取特征的质量。将SE模块堆叠在一起,就形成了SENet(Squeeze-and-Excitation Networks)。
通俗来说,SENet的核心在于通过网络根据损失函数学习特征权重,使得特征图中有效通道的权重变大,无效或效果小的通道权重变小的方式训练模型达到更好的结果。而SE(Squeeze-and-Excitation Block)模块是一个子结构,可嵌入其他模型当中。
2.原理
给定输入 x x x,其经一系列卷积操作(定义为 F t r ( ⋅ ; θ ) F_{tr}(·;θ) Ftr(⋅;θ))后得到通道数为 c w c_w cw的特征,其形状为 ( C , H , W ) (C,H,W) (C,H,W)。 之后通过三种运算来实现SE模块的功能:
【1. S q u e e z e Squeeze Squeeze操作】
卷积核只能关注到局部感受野的空间信息,感受野区域之外的信息无法利用,这使得输出特征图就很难获得足够的信息来提取通道之间的关系。 S q u e e z e Squeeze Squeeze操作,定义为 F s q ( ⋅ ) F_{sq}(·) Fsq(⋅),顺着空间维度来进行特征压缩,将每个二维的特征通道变成一个实数,这个实数某种程度上具有全局的感受野,并且输出的维度和输入的特征通道数相匹配。这一操作通过全局平均池化实现:
F s q ( c 2 ) = 1 H × W ∑ i = 1 H ∑ j = 1 W c 2 ( i , j ) F_{sq}(c_2)=\frac{1}{H×W}\sum^{H}_{i=1}\sum^{W}_{j=1}c_2(i,j) Fsq(c2)=H×W1i=1∑Hj=1∑Wc2(i,j)
特征图经过 F s q ( ) F_{sq}() Fsq()运算后得到全局统计向量,形状为 ( 1 , 1 , c 2 ) (1,1,c_2) (1,1,c2)。此时一个像素值代表一个通道,从而屏蔽掉空间上的分布信息,更好的利用通道间的相关性。
【2. E x c i t a t i o n Excitation Excitation操作】
E x c i t a t i o n Excitation Excitation操作,定义为 F e x ( ⋅ ; w ) F_{ex}(·;w) Fex(⋅;w),用于捕获通道之间的依赖关系。这里使用了神经网络的门机制,即使用两个全连接层+两个激活函数组成的结构输出和输入与特征同样数目的权重值,也就是每个特征通道的权重系数。并且,为了限制模型复杂度和辅助泛化,在构造全连接层时对通道 c 2 c_2 c2进行了降维处理,降维比例为 r r r。计算公式:
其中, W 1 ∈ R C r × C , W 2 ∈ R C r × C W_1∈R^{\frac{C}{r}}×C,W_2∈R^{\frac{C}{r}}×C W1∈RrC×C,W2∈RrC×C,两个激活函数依次为 R e L U 、 s i g m o i d ReLU、sigmoid ReLU、sigmoid。原理图:
【3. S c a l e Scale Scale操作】
S c a l e Scale Scale操作定义为 F s c a l e ( ⋅ , ⋅ ) F_{scale}(·,·) Fscale(⋅,⋅),用于将前面得到的注意力权重加权到每个通道的特征上。论文中通过逐通道乘以权重系数,即在在通道维度上引入attention机制来实现。如下图所示:
不同颜色代表不同通道的重要程度。
3.代码
import torch.nn as nnclass SELayer(nn.Module):def __init__(self, channel, reduction=16):#channel:输入通道数;reduction:缩减比率super(SELayer, self).__init__()#1.Squeezeself.avg_pool = nn.AdaptiveAvgPool2d(1)#2.Excitationself.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)#3.Scalereturn x * y.expand_as(x)
4.SE模块的应用
例如,可将SE模块集成在残差块中:
以此形成集成后的ResNet网络: