文章目录
- 一、DBINet
- 1.1编码器模块:ResNet50+PVT双分支结构
- 1.2解码器模块:自细化模块SR的应用
- 1.3DFM:双分支融合模块
- 1.4转换器模块:调整编码器输出至解码器中
- 1.5深度监督损失函数
- 二、GCPANet
- 2.1编码器模块:ResNet50主干
- 2.2FIA:特征交织聚合模块
- 2.2SR:自细化模块
- 2.3GCF:全局上下文流模块
- 2.4HA:头部注意模块
- 2.5整体流程
- 2.6损失函数:二元交叉熵损失
- 三、CPD解码器框架
- 3.1HA:整体注意力模块
- 3.2Aggregation:特征融合模块
- 3.3级联解码器模块
- 四、ACCoNet
- 4.1编码器模块:VGG16主干
- 4.2ACCoM:相邻上下文协调模块
- 4.3BAB:分叉聚合解码器
- 4.4网络流程
- 五、FPS-U2Net
- 5.1编码器架构:U2Net
- 5.2MAM:多级聚合模块
一、DBINet
论文:Dual Backbone Interaction Network for Burned Area Segmentation in Optical Remote Sensing Images
论文链接:ieee
代码链接:Github
1.1编码器模块:ResNet50+PVT双分支结构
- 1.以CNN为主干的编码器存在长期依赖问题(长期依赖是指当前系统的状态,可能受很长时间之前系统状态的影响,常见于RNN模型当中),使得准确性下降。
- 2.ViT虽以对长期依赖关系建模的能力闻名,但其专注于捕获全局上下文信息,通常会失去对局部空间相关性进行建模的能力。
为此提出同时将二者进行融合的新型编码器,其以 R e s N e t 50 ResNet50 ResNet50作为主编码器, P V T − t i n y PVT-tiny PVT−tiny作为辅助编码器。
class DBIkb(nn.Module)
:编码器主干代码实现。- 输入: ( B a t c h _ s i z e , 3 , 512 , 512 ) (Batch\_size,3,512,512) (Batch_size,3,512,512)。
- 输出: ( B a t c h _ s i z e , 256 , 128 , 128 ) 、 ( B a t c h _ s i z e , 512 , 64 , 64 ) 、 ( B a t c h _ s i z e , 1024 , 32 , 32 ) 、 ( B a t c h _ s i z e , 2048 , 16 , 16 ) 、 ( B a t c h _ s i z e , 64 , 16 , 16 ) (Batch\_size,256,128,128)、(Batch\_size,512,64,64)、(Batch\_size,1024,32,32)、(Batch\_size,2048,16,16)、(Batch\_size,64,16,16) (Batch_size,256,128,128)、(Batch_size,512,64,64)、(Batch_size,1024,32,32)、(Batch_size,2048,16,16)、(Batch_size,64,16,16)。
注意,最后的输出指DFM的输出,其余四个则是指 R e s N e t B l o c k i ResNet\;Block_i ResNetBlocki的输出。
class DBIBkb(nn.Module):def __init__(self, img_size=224, patch_size=4, in_chans=256, embed_dims=[64, 128, 320, 512],num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, qk_scale=None, drop_rate=0.0,attn_drop_rate=0., drop_path_rate=0.1, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2],sr_ratios=[8, 4, 2, 1], num_stages=4, F4=False):super().__init__()# ResNet-Modulesmid_ch=64self.inplanes = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)self.layer1 = self.make_layer( 64, 3, stride=1, dilation=1)self.layer2 = self.make_layer(128, 4, stride=2, dilation=1)self.layer3 = self.make_layer(256, 6, stride=2, dilation=1)self.layer4 = self.make_layer(512, 3, stride=2, dilation=1)self.CBkbs=[self.layer1, self.layer2, self.layer3, self.layer4]rout_chl=[256, 512, 1024, 2048]self.cdfm1=DFM(cur_in_ch=rout_chl[1], sup_in_ch=embed_dims[1], out_ch=rout_chl[1])self.cdfm2=DFM(cur_in_ch=rout_chl[2], sup_in_ch=embed_dims[2], out_ch=rout_chl[2])self.tdfm1=DFM(cur_in_ch=embed_dims[1], sup_in_ch=rout_chl[1], out_ch=embed_dims[1])self.tdfm2=DFM(cur_in_ch=embed_dims[2], sup_in_ch=rout_chl[2], out_ch=embed_dims[2])self.sumdfm=DFM(cur_in_ch=rout_chl[3], sup_in_ch=embed_dims[3], out_ch=mid_ch)self.cdfmx=[self.cdfm1, self.cdfm2]self.tdfmx=[self.tdfm1, self.tdfm2]# PVT-Modulesself.depths = depthsself.F4 = F4self.num_stages = num_stagesdpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rulecur = 0for i in range(1, num_stages):patch_embed = PatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),patch_size=patch_size if i == 0 else 2,in_chans=in_chans if i == 1 else embed_dims[i - 1],embed_dim=embed_dims[i])num_patches = patch_embed.num_patches if i != num_stages - 1 else patch_embed.num_patches + 1pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dims[i]))pos_drop = nn.Dropout(p=drop_rate)block = nn.ModuleList([Block(dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias,qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j],norm_layer=norm_layer, sr_ratio=sr_ratios[i])for j in range(depths[i])])cur += depths[i]setattr(self, f"patch_embed{i + 1}", patch_embed)setattr(self, f"pos_embed{i + 1}", pos_embed)setattr(self, f"pos_drop{i + 1}", pos_drop)setattr(self, f"block{i + 1}", block)trunc_normal_(pos_embed, std=.02)# init weightsself.apply(self._init_weights)#参数初始化def _init_weights(self, m):...#获取位置嵌入def _get_pos_embed(self, pos_embed, patch_embed, H, W):...#构建残差网络ResNet的一个层级,由多个Bottleneck模块堆叠而成.blocks指定了堆叠数def make_layer(self, planes, blocks, stride, dilation):...#定义前向传播过程
def forward(self, x):#存储四个ResNet模块输出的特征、最终合并的特征couts = []#获取batch_sizeB = x.shape[0]#输入阶段:卷积(7x7,步长为2,填充为3,卷积核个数为64)+BatchNormlization+ReLU+最大池化out1 = F.relu(self.bn1(self.conv1(x)), inplace=True)out1 = F.max_pool2d(out1, kernel_size=3, stride=2, padding=1)#输入ResNet Block1,得到初始卷积特征图out1out1 = self.CBkbs[0](out1)c_out = out1t_out = out1#将ResNet Bokck1的输出保存到couts列表couts.append(c_out)#num_stages=4,循环共执行三次,对应ResNet Block2+PVT Block1、DFM+ResNet Block3+PVT Block2、DFM+ResNet Block4+PVT Block3、DFM(最终融合)for i in range(1, self.num_stages):#获取PVT分支#获取patch_embed模块,用于将特征图分割为固定大小的patch并转化为一个低维向量patch_embed = getattr(self, f"patch_embed{i + 1}")#获取位置编码pos_embed = getattr(self, f"pos_embed{i + 1}")#获取位置编码后的dropout层pos_drop = getattr(self, f"pos_drop{i + 1}")#获取Transformer Block(Self-Attention+MLP模块)block = getattr(self, f"block{i + 1}")#当i==2或i==3时(2、3阶段),t_out和c_out 经过self.tdfmx[i-2](DFM)特征融合,即,融合CNN和Transformer模块提取的特征,并将其传入patch_embed.而i==1时(1阶段),直接将t_out(Transformer模块的输出)传入patch_embed,并融合.#t_out、c_out分别代表Transformer、CNN上一个阶段的输出#curt_out代表经过patch_embed处理后的输出特征图,(H,W)表示特征图的高度与宽度curt_out, (H, W) = patch_embed(self.tdfmx[i - 2](t_out, c_out) if i in [2, 3] else t_out)if i == self.num_stages - 1:#若是最后阶段,则调整位置编码,将其调整为适应当前特征图的大小(H,W),其中pos_embed移除了向量0维度的值.因为位置编码(pos_embed)通常包含一个特殊的class token的位置编码来作为图像的类别信息.但在特征提取、语义分割任务中更关注特征图中的每个位置,无需图像的整体类别信息,故只保留与特征图位置相关的编码.pos_embed = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W)else:#调整位置编码,使其匹配当前特征图的大小(H,W)pos_embed = self._get_pos_embed(pos_embed, patch_embed, H, W)#融合位置编码与特征图,并输入dropout层curt_out = pos_drop(curt_out + pos_embed)#ResNet分支#若是2、3阶段,则使用DFM模块融合c_out、t_out(CNN、Transformer提取到的特征)并送入到对应的ResNet Block(i+1)中,否则直接将原始的c_out作为输入c_out = self.CBkbs[i](self.cdfmx[i - 2](c_out, t_out) if i in [2, 3] else c_out)#PVT模块对提取的特征进行处理for blk in block:curt_out = blk(curt_out, H, W)#重塑特征图形状,[B, num_patches, embed_dim]->[[B, H, W, embed_dim],即,将一维的patch序列重塑为二维图像特征,并通过permute将其调整为[B,embed_dim,H,W]以适应卷积网络的输入.t_out = curt_out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()#保存ResNet提取的特征couts.append(c_out)#最后阶段,使用DFM融合特征并保存到couts列表中并返回ct_fuse = self.sumdfm(c_out, t_out)couts.append(ct_fuse)return couts
在运行前还需导入init_weights.py
、pvt.py
、ResNet.py
,运行结果:
if __name__ == '__main__':data=torch.randn(16,3,512,512)encoder=DBIBkb()outs=encoder(data)print('len(outs)',len(outs))for i in range(len(outs)):print('outs[%d]'%i,outs[i].shape)
1.2解码器模块:自细化模块SR的应用
解码器模块直接将输入的 f d e c f_{dec} fdec、 f c v t f_{cvt} fcvt、 f e n c f_{enc} fenc进行相加,通过连续的卷积层和自我细化过程聚合多级特征,捕捉多尺度上下文并增加特征多样性。
- f d e c f_{dec} fdec:来自上一个 d e c o d e r b l o c k decoder\;block decoderblock的特征。
- f c v t f_{cvt} fcvt:来自 c o n v e r t o r convertor convertor的特征。
- f e n c f_{enc} fenc:来自 m a i n e n c o d e r b l o c k main\;encoder\;block mainencoderblock的特征。
class Decoder(nn.Module)
:解码器模块代码。- 输入: ( B a t c h _ s i z e , C , H , W ) (Batch\_size,C,H,W) (Batch_size,C,H,W)。
- 输出: ( B a t c h _ s i z e , C , H , W ) (Batch\_size,C,H,W) (Batch_size,C,H,W)。
class Decoder(nn.Module):# Decoder Blockdef __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansionsuper().__init__()self.c = int(c2 * e) # hidden channelsself.cv1 = Conv(c1, 2 * self.c, 1, 1)self.cv2 = Conv(2 * self.c, c2, 1) # optional act=FReLU(c2)self.m = nn.Sequential(*(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)))#使用了自细化模块SRself.sr=SRM(c2)def forward(self, x):a, b = self.cv1(x).split((self.c, self.c), 1)out=self.cv2(torch.cat((self.m(a), b), 1))out=self.sr(out)return out
在DBINet前向传播中,每个解码器的输出(除了融合DFM模块)都通过转换器转换、采样操作再送入解码器中。相关维度变化:
- 来自编码器的输入:
- R e s N e t B l o c k 1 ResNet\;Block1 ResNetBlock1:输出 [ 1 , 256 , 128 , 128 ] [1, 256, 128, 128] [1,256,128,128]——> [ 1 , 64 , 128 , 128 ] [1, 64, 128, 128] [1,64,128,128],输入解码器1。
- R e s N e t B l o c k 2 ResNet\;Block2 ResNetBlock2:输出 [ 1 , 512 , 64 , 64 ] [1, 512, 64, 64] [1,512,64,64]——> [ 1 , 64 , 64 , 64 ] [1, 64, 64, 64] [1,64,64,64],输入解码器2。
- R e s N e t B l o c k 3 ResNet\;Block3 ResNetBlock3:输出 [ 1 , 1024 , 32 , 32 ] [1, 1024, 32, 32] [1,1024,32,32]——> [ 1 , 64 , 32 , 32 ] [1, 64, 32, 32] [1,64,32,32],输入解码器3。
- R e s N e t B l o c k 4 ResNet\;Block4 ResNetBlock4:输出 [ 1 , 2048 , 16 , 16 ] [1, 2048, 16, 16] [1,2048,16,16]——> [ 1 , 64 , 16 , 16 ] [1, 64, 16, 16] [1,64,16,16],输入解码器4。
- 融合 D F M DFM DFM:输出 [ 1 , 64 , 16 , 16 ] [1, 64, 16, 16] [1,64,16,16],直接输入解码器。
- 来自转换器+采样操作的输入:
- 解码器1: [ 1 , 64 , 128 , 128 ] [1, 64, 128, 128] [1,64,128,128]。
- 解码器2: [ 1 , 64 , 64 , 64 ] [1, 64, 64, 64] [1,64,64,64]。
- 解码器3: [ 1 , 64 , 32 , 32 ] [1, 64, 32, 32] [1,64,32,32]。
- 解码器4: [ 1 , 64 , 16 , 16 ] [1, 64, 16, 16] [1,64,16,16]。
- 来自上一解码器的输入:
- 解码器4的输出: [ 1 , 64 , 16 , 16 ] [1, 64, 16, 16] [1,64,16,16],上采样为 [ 1 , 64 , 32 , 32 ] [1, 64, 32, 32] [1,64,32,32],作为解码器3的输入 。
- 解码器3的输出: [ 1 , 64 , 32 , 32 ] [1, 64, 32, 32] [1,64,32,32],上采样为 [ 1 , 64 , 64 , 64 ] [1, 64, 64, 64] [1,64,64,64],作为解码器3的输入。
- 解码器2的输出: [ 1 , 64 , 64 , 64 ] [1, 64, 64, 64] [1,64,64,64],上采样为 [ 1 , 64 , 128 , 128 ] [1, 64, 128, 128] [1,64,128,128],作为解码器3的输入。
- 解码器1的输出: [ 1 , 64 , 128 , 128 ] [1, 64, 128, 128] [1,64,128,128]。
class DBINet
中与解码器相关代码如下:
class DBINet(nn.Module):def __init__(self, **kwargs):super(DBINet, self).__init__() ...# Decoderself.decoder4 = Decoder(mid_ch, mid_ch)self.decoder3 = Decoder(mid_ch, mid_ch)self.decoder2 = Decoder(mid_ch, mid_ch)self.decoder1 = Decoder(mid_ch, mid_ch)#out_ch=1,mid_ch=64self.dside1 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, stride=1, padding=1)self.dside2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, stride=1, padding=1)self.dside3 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, stride=1, padding=1)self.dside4 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, stride=1, padding=1)# initialise weights...def forward(self, inputs):H, W = inputs.size(2), inputs.size(3)#H, W = 512,512... #编码器、转换器代码# decoderup4=self.decoder4(ca14 + c4 + c5) #输入[B, 64, 16, 16],输出[B, 64, 16, 16]up3=self.decoder3(ca13 + c3 + self.upsample2(up4))#输入[B, 64, 32, 32],输出[B, 64, 32, 32]up2=self.decoder2(ca12 + c2 + self.upsample2(up3))#输入[B, 64, 64, 64],输出[B, 64, 64, 64]up1=self.decoder1(ca1 + c1 + self.upsample2(up2))#输入[B, 64, 128, 128],输出[B, 64, 128, 128]#卷积操作d1=self.dside1(up1)#输入[B, 64, 128, 128],输出[B, 1, 128, 128]d2=self.dside2(up2)#输入[B, 64, 64, 64],输出[B, 1, 64, 64]d3=self.dside3(up3)#输入[B, 64, 32, 32],输出[B, 1, 32, 32]d4=self.dside4(up4)#输入[B, 64, 16, 16],输出[B, 1, 16, 16]#使用线性插值将其尺寸还原为(512,512)S1 = F.interpolate(d1, size=(H, W), mode='bilinear', align_corners=True)#输出[B, 1, 512, 512]S2 = F.interpolate(d2, size=(H, W), mode='bilinear', align_corners=True)#输出[B, 1, 512, 512]S3 = F.interpolate(d3, size=(H, W), mode='bilinear', align_corners=True)#输出[B, 1, 512, 512]S4 = F.interpolate(d4, size=(H, W), mode='bilinear', align_corners=True)#输出[B, 1, 512, 512]return S1,S2,S3,S4, torch.sigmoid(S1), torch.sigmoid(S2), torch.sigmoid(S3), torch.sigmoid(S4)
直接在DBINet中测试模型的输出大小:
if __name__ == '__main__':model = DBINet()data = torch.randn(1,3,512,512)out = model(data)for i in out:print(i.shape)
1.3DFM:双分支融合模块
双功能融合模块(DFM)用于特征融合。
class DFM(nn.Module):def __init__(self, cur_in_ch, sup_in_ch, out_ch, mid_ch=64):super(DFM, self).__init__()self.cur_cv = Conv(cur_in_ch, mid_ch) self.sup_cv = Conv(sup_in_ch, mid_ch)self.fuse1 = Conv(mid_ch*2, mid_ch)self.fuse2 = Conv(mid_ch, out_ch)def forward(self, x_cur, x_sup):x_cur1=self.cur_cv(x_cur)x_sup1=self.sup_cv(x_sup)xfuse=torch.cat([x_cur1, x_sup1], dim=1)x=self.fuse1(xfuse) + x_cur1x=self.fuse2(x)return x
DFM模块在DBINet中共调用五次,其中两次用于主编码器( R e s N e t ResNet ResNet)、两次用于辅助编码器( P V T PVT PVT)、一次用于融合主编码器与辅助编码器的输出特征。
输入图像尺寸为 ( 3 , 512 , 512 ) (3,512,512) (3,512,512),故 P V T B l o c k 1 、 P V T B l o c k 2 、 P V T B l o c k 3 PVT\;Block1、PVT\;Block2、PVT\;Block3 PVTBlock1、PVTBlock2、PVTBlock3的输出尺寸依次为 [ 1 , 128 , 64 , 64 ] 、 [ 1 , 320 , 32 , 32 ] 、 [ 1 , 512 , 16 , 16 ] [1, 128, 64, 64]、[1, 320, 32, 32]、[1, 512, 16, 16] [1,128,64,64]、[1,320,32,32]、[1,512,16,16]。上图中有:
- 1. P i P_i Pi:第i个阶段中patch个数。
- 2. C i C_i Ci:第i个阶段中输出特征图的通道个数。
在DBINet中仅用到前三个阶段的 P V T B l o c k PVT\;Block PVTBlock,因此 D F M DFM DFM模块有:
- t d f m [ 0 ] tdfm[0] tdfm[0]:位于PVT主干中的DFM模块,输出尺寸与上一PVT模块的输入尺寸保持一致。
- 来自 P V T B l o c k 1 PVT\;Block1 PVTBlock1的输入: [ 1 , 128 , 64 , 64 ] [1, 128, 64, 64] [1,128,64,64]
- 来自 R e s N e t B l o c k 2 ResNet\;Block2 ResNetBlock2的输入: [ 1 , 512 , 64 , 64 ] [1, 512, 64, 64] [1,512,64,64]
- 输出: [ 1 , 128 , 64 , 64 ] [1, 128, 64, 64] [1,128,64,64]
- t d f m [ 1 ] tdfm[1] tdfm[1]:位于PVT主干中的DFM模块,输出尺寸与上一PVT模块的输入尺寸保持一致。
- 来自 P V T B l o c k 2 PVT\;Block2 PVTBlock2的输入: [ 1 , 320 , 32 , 32 ] [1, 320, 32, 32] [1,320,32,32]
- 来自 R e s N e t B l o c k 3 ResNet\;Block3 ResNetBlock3的输入: [ 1 , 1024 , 32 , 32 ] [1, 1024, 32, 32] [1,1024,32,32]
- 输出: [ 1 , 320 , 32 , 32 ] [1, 320, 32, 32] [1,320,32,32]
其中, t d f m [ 1 ] tdfm[1] tdfm[1]的输出还需通过 P V T B l o c k 3 PVT\;Block3 PVTBlock3得到 [ 1 , 512 , 16 , 16 ] [1, 512, 16, 16] [1,512,16,16]。
在DBINet中仅用到前三个阶段的 P V T B l o c k PVT\;Block PVTBlock,因此 D F M DFM DFM模块有:
- c d f m [ 0 ] cdfm[0] cdfm[0]:位于ResNet主干中的DFM模块,输出尺寸与上一ResNet模块的输入尺寸保持一致。
- 来自 R e s N e t B l o c k 2 ResNet\;Block2 ResNetBlock2的输入: [ 1 , 512 , 64 , 64 ] [1, 512, 64, 64] [1,512,64,64]
- 来自 P V T B l o c k 1 PVT\;Block1 PVTBlock1的输入: [ 1 , 128 , 64 , 64 ] [1, 128, 64, 64] [1,128,64,64]
- 输出: [ 1 , 128 , 64 , 64 ] [1, 128, 64, 64] [1,128,64,64]
- c d f m [ 1 ] cdfm[1] cdfm[1]:位于ResNet主干中的DFM模块,输出尺寸与上一ResNet模块的输入尺寸保持一致。
- 来自 R e s N e t B l o c k 3 ResNet\;Block3 ResNetBlock3的输入: [ 1 , 1024 , 32 , 32 ] [1, 1024, 32, 32] [1,1024,32,32]
- 来自 P V T B l o c k 2 PVT\;Block2 PVTBlock2的输入: [ 1 , 320 , 32 , 32 ] [1, 320, 32, 32] [1,320,32,32]
- 输出: [ 1 , 320 , 32 , 32 ] [1, 320, 32, 32] [1,320,32,32]
其中, c d f m [ 1 ] cdfm[1] cdfm[1]的输出还需通过 R e s N e t B l o c k 4 ResNet\;Block4 ResNetBlock4得到 [ 1 , 2048 , 16 , 16 ] [1, 2048, 16, 16] [1,2048,16,16]。最后将主编码器、辅助编码器的输出进行融合,并将维度降低到 64 64 64。该DFM定义代码为:
#cur_in_ch(主分支通道数)=2048,sup_in_ch(辅助分支通道数)=512,out_ch(输出特征图通道数)=64
self.sumdfm = DFM(cur_in_ch=rout_chl[3], sup_in_ch=embed_dims[3], out_ch=mid_ch)
- s u m d f m sumdfm sumdfm:将主编码器与辅助编码器的输出融合,H、W与主编码器输出保持一致,但维度降低。
- 来自 P V T B l o c k 3 PVT\;Block3 PVTBlock3的输入: [ 1 , 512 , 16 , 16 ] [1, 512, 16, 16] [1,512,16,16]
- 来自 R e s N e t B l o c k 4 ResNet\;Block4 ResNetBlock4的输入: [ 1 , 2048 , 16 , 16 ] [1, 2048, 16, 16] [1,2048,16,16]
- 输出: [ 1 , 64 , 16 , 16 ] [1, 64, 16,16] [1,64,16,16]
1.4转换器模块:调整编码器输出至解码器中
融合来自主编码器的多尺度粗略特征,并为解码器生成调整后的特征作为输入。转换器的编码器与解码器之间通过 C B A M CBAM CBAM模块衔接,其通过通道和空间注意力机制捕获突出的空间位置和边缘信息,底部则是一个空洞卷积模块。
class Convertor(nn.Module):# Convertordef __init__(self, in_ch=64, out_ch=64):super(Convertor, self).__init__()mid_ch = in_chself.rebnconv1 = REBNCONV(mid_ch, mid_ch, dirate=1)self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)self.cbam1 = CBAM(mid_ch, mid_ch)self.cbam2 = CBAM(mid_ch, mid_ch)self.cbam3 = CBAM(mid_ch, mid_ch)self.cbam4 = CBAM(mid_ch, mid_ch)self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)def forward(self, x1, x2, x3, x4):hx1 = self.rebnconv1(x1)hx = self.pool1(hx1)hx2 = self.rebnconv2(hx + x2)hx = self.pool2(hx2)hx3 = self.rebnconv3(hx + x3)hx = self.pool3(hx3)hx4 = self.rebnconv4(hx + x4)hx5 = self.rebnconv5(hx4)hx4d = self.rebnconv4d(torch.cat((hx5, self.cbam4(hx4)), 1))hx4dup = _upsample_like(hx4d, hx3)hx3d = self.rebnconv3d(torch.cat((hx4dup, self.cbam3(hx3)), 1))hx3dup = _upsample_like(hx3d, hx2)hx2d = self.rebnconv2d(torch.cat((hx3dup, self.cbam2(hx2)), 1))hx2dup = _upsample_like(hx2d, hx1)hx1d = self.rebnconv1d(torch.cat((hx2dup, self.cbam1(hx1)), 1))return hx1d, hx2d, hx3d, hx4d
- 来自编码器的输入: [ 1 , 64 , 128 , 128 ] 、 [ 1 , 64 , 64 , 64 ] 、 [ 1 , 64 , 32 , 32 ] 、 1 , 64 , 16 , 16 ] [1, 64, 128, 128]、[1, 64, 64, 64]、[1, 64, 32, 32]、1, 64, 16, 16] [1,64,128,128]、[1,64,64,64]、[1,64,32,32]、1,64,16,16]。
- 转换器输出: [ 1 , 64 , 128 , 128 ] 、 [ 1 , 64 , 64 , 64 ] 、 [ 1 , 64 , 32 , 32 ] 、 [ 1 , 64 , 16 , 16 ] [1, 64, 128, 128]、[1, 64, 64, 64]、[1, 64, 32, 32]、[1, 64, 16, 16] [1,64,128,128]、[1,64,64,64]、[1,64,32,32]、[1,64,16,16]。
1.5深度监督损失函数
研究中使用BCE损失(像素级损失)+IoU损失(地图级损失)作为损失函数:
- S 1 S_1 S1:最终显著性特征图。
- S 2 − S 4 S_2-S_4 S2−S4:阶段显著性特征图。
DBINet/EORSSD_train.py
中对深度监督损失函数的使用方式为:
#...
CE = torch.nn.BCEWithLogitsLoss()
MSE = torch.nn.MSELoss()
IOU = pytorch_iou.IOU(size_average = True)
#...#得到四张显著性特征图及对应的sigmoid输出(s1是最终输出的特征图)
s1,s2,s3,s4, s1_sig,s2_sig,s3_sig,s4_sig= model(images)loss1 = CE(s1, gts) + IOU(s1_sig, gts)
loss2 = CE(s2, gts) + IOU(s2_sig, gts)
loss3 = CE(s3, gts) + IOU(s3_sig, gts)
loss4 = CE(s4, gts) + IOU(s4_sig, gts)
#不同阶段的显著性图有不同的权重
total_loss = loss1 + loss2/2 + loss3/4 +loss4/8
running_loss += total_loss.data.item()
二、GCPANet
论文:Global Context-Aware Progressive Aggregation Network for Salient Object Detection(用于显著目标检测的全局上下文感知渐进式聚合网络)
论文链接:Global Context-Aware Progressive Aggregation Network for Salient Object Detection
论文代码:Github
博客链接:论文阅读(十八):Global Context-Aware Progressive Aggregation Network for Salient Object Detection
以下均假设输入图像大小为 [ 3 , 512 , 512 ] [3,512,512] [3,512,512]。
2.1编码器模块:ResNet50主干
编码器使用 R e s N e t − 50 ResNet-50 ResNet−50来提取多级特征,解码器组件则逐步集成多级特征,并以有监督的方式生成显著图。编码器在class GCPANet(nn.Module)
中的使用:
class GCPANet(nn.Module):def __init__(self, cfg):super(GCPANet, self).__init__()...self.bkbone = ResNet()...self.initialize()def forward(self, x):#通过ResNet50编码器获取五张特征图out1, out2, out3, out4, out5_ = self.bkbone(x)...return out2, out3, out4, out5#初始化参数def initialize(self):...
R e s N e t 50 ResNet50 ResNet50需要定义残差块来实现,其残差块结构如下:
具体实现代码:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as Fdef weight_init(module):for n, m in module.named_children():print('initialize: '+n)if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')if m.bias is not None:nn.init.zeros_(m.bias)elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.ones_(m.weight)if m.bias is not None:nn.init.zeros_(m.bias)elif isinstance(m, nn.Linear):nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')if m.bias is not None:nn.init.zeros_(m.bias)else:m.initialize()class Bottleneck(nn.Module):def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):#inplanes:输入通道数;planes:输出通道数;stride:步幅;downsample:下采样层;dilation:膨胀系数super(Bottleneck, self).__init__()#1×1卷积self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)self.bn1 = nn.BatchNorm2d(planes)#3×3卷积self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=(3*dilation-1)//2, bias=False, dilation=dilation)self.bn2 = nn.BatchNorm2d(planes)#1×1卷积self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(planes*4)#下采样(若步幅不为1或输入通道数与目标通道数不匹配,则进行下采样)self.downsample = downsampledef forward(self, x):residual = x#1×1卷积out = F.relu(self.bn1(self.conv1(x)), inplace=True)#3×3卷积out = F.relu(self.bn2(self.conv2(out)), inplace=True)#1×1卷积out = self.bn3(self.conv3(out))#若不能直接将x与特征残差连接,则需下采样if self.downsample is not None:residual = self.downsample(x)#残差连接return F.relu(out+residual, inplace=True)
在此基础上以 R e s N e t 50 ResNet50 ResNet50为主干进行特征提取, R e s N e t 50 ResNet50 ResNet50共包含四个 B l o c k Block Block结构,每个 B l o c k Block Block中分别有3、4、6、3个 B o t t l e n e c k Bottleneck Bottleneck。整体结构如下:
class ResNet(nn.Module):def __init__(self):super(ResNet, self).__init__()#跟踪输入通道数self.inplanes = 64#conv1:7×7大小、输入通道3(RGB图像)、输出通道64、步长2、填充3self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)#第一个残差层,对应conv_2self.layer1 = self.make_layer( 64, 3, stride=1, dilation=1)#第二个残差层,对应conv_3self.layer2 = self.make_layer(128, 4, stride=2, dilation=1)#第三个残差层,对应conv_4self.layer3 = self.make_layer(256, 6, stride=2, dilation=1)#第四个残差层,对应conv_5self.layer4 = self.make_layer(512, 3, stride=2, dilation=1)#权重初始化self.initialize()def make_layer(self, planes, blocks, stride, dilation):downsample = None#若步幅不为1或输入通道数与目标通道数不匹配,则进行下采样if stride != 1 or self.inplanes != planes*4:#使用1×1卷积和批量归一化进行下采样downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes*4, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes*4))#添加第一个残差块,使用Bottleneck结构(输入通道数、输出通道数、步长、下采样模块、膨胀系数)layers = [Bottleneck(self.inplanes, planes, stride, downsample, dilation=dilation)]#更新通道数,为原先四倍self.inplanes = planes*4#循环添加残差块for _ in range(1, blocks):layers.append(Bottleneck(self.inplanes, planes, dilation=dilation))return nn.Sequential(*layers)def forward(self, x):#conv1,输出为112×112out1 = F.relu(self.bn1(self.conv1(x)), inplace=True)#conv2_x,输出为56×56out1 = F.max_pool2d(out1, kernel_size=3, stride=2, padding=1)out2 = self.layer1(out1)#conv_3,输出为28×28out3 = self.layer2(out2)#conv_4,输出为14×14out4 = self.layer3(out3)#conv_5,输出为7×7out5 = self.layer4(out4)return out1, out2, out3, out4, out5def initialize(self):#加载预训练模型的权重,允许部分权重匹配(strict=False)self.load_state_dict(torch.load('resnet50-19c8e357.pth'), strict=False)
注意,论文中编码器使用的编码器 R e s N e t 50 ResNet50 ResNet50架构已有训练好的文件,测试时应删除代码self.load_state_dict(torch.load('resnet50-19c8e357.pth'), strict=False)
、self.initialize()
。进行测试:
if __name__ == '__main__':data = torch.randn(1,3,512,512)Encoder = ResNet()out = Encoder(data)for i in out:print(i.shape)
2.2FIA:特征交织聚合模块
FIA模块用于融合低级特征(包含较多的细节信息)、高级特征(包含较多的语义信息)、上下文特征(不同显著对象或成分之间的关系非常有用),最后返回具有全局感知的区分性和综合性特征。
- f l f_l fl:低级特征。
- f h f_h fh:高级特征。
- f g f_g fg:上下文特征。
""" Feature Interweaved Aggregation Module """
class FAM(nn.Module):def __init__(self, in_channel_left, in_channel_down, in_channel_right):#接受左、下、右三个方向的输入通道数(对应低级特征、高级特征、全局特征)super(FAM, self).__init__()#对低级特征f_l进行卷积、归一化self.conv0 = nn.Conv2d(in_channel_left, 256, kernel_size=3, stride=1, padding=1)self.bn0 = nn.BatchNorm2d(256)#对高级特征f_h进行卷积、归一化self.conv1 = nn.Conv2d(in_channel_down, 256, kernel_size=3, stride=1, padding=1)self.bn1 = nn.BatchNorm2d(256)#对全局特征f_g进行卷积、归一化self.conv2 = nn.Conv2d(in_channel_right, 256, kernel_size=3, stride=1, padding=1)self.bn2 = nn.BatchNorm2d(256)self.conv_d1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)self.conv_d2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)self.conv_l = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)self.conv3 = nn.Conv2d(256*3, 256, kernel_size=3, stride=1, padding=1)self.bn3 = nn.BatchNorm2d(256)def forward(self, left, down, right):#依次将低级特征f_l、高级特征f_h、全局特征f_g卷积、归一化、ReLU激活,并压缩到256通道left = F.relu(self.bn0(self.conv0(left)), inplace=True)down = F.relu(self.bn1(self.conv1(down)), inplace=True)right = F.relu(self.bn2(self.conv2(right)), inplace=True) #256#上采样高级特征图down_1 = self.conv_d1(down)#对left特征图卷积,得到分割掩码w1w1 = self.conv_l(left)#检查高级特征图和低级特征图的空间维度,不匹配则使用线性插值调整高级特征图的大小.将分割掩码w1与高级特征图相乘并使用ReLU激活函数,得到f_{hl}if down.size()[2:] != left.size()[2:]:down_ = F.interpolate(down, size=left.size()[2:], mode='bilinear')z1 = F.relu(w1 * down_, inplace=True)else:z1 = F.relu(w1 * down, inplace=True)#将上采样后的高级特征图调整至与低级特征图相同的维度if down_1.size()[2:] != left.size()[2:]:down_1 = F.interpolate(down_1, size=left.size()[2:], mode='bilinear')#将高级特征图与低级特征图相乘得到f_{lh}z2 = F.relu(down_1 * left, inplace=True)#上采样全局特征图down_2 = self.conv_d2(right)if down_2.size()[2:] != left.size()[2:]:down_2 = F.interpolate(down_2, size=left.size()[2:], mode='bilinear')#将全局特征图与低级特征图相乘得到f_{gl}z3 = F.relu(down_2 * left, inplace=True)#将三个结果catout = torch.cat((z1, z2, z3), dim=1)#输入卷积层运算并返回return F.relu(self.bn3(self.conv3(out)), inplace=True)def initialize(self):weight_init(self)
2.2SR:自细化模块
将得到的特征图通过乘法和加法运算进一步细化和增强。
""" Self Refinement Module """
class SRM(nn.Module):def __init__(self, in_channel):super(SRM, self).__init__()self.conv1 = nn.Conv2d(in_channel, 256, kernel_size=3, stride=1, padding=1)self.bn1 = nn.BatchNorm2d(256)self.conv2 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)def forward(self, x):#先将输入特征压缩为256通道大小,再分别通过Batch Normalization、ReLU层out1 = F.relu(self.bn1(self.conv1(x)), inplace=True)#经过卷积运算转为512通道out2 = self.conv2(out1)#将前256通道作为权重,后256通道作为偏置0w, b = out2[:, :256, :, :], out2[:, 256:, :, :]#加权结合out1、w、b,并应用ReLU激活函数得到输出return F.relu(w * out1 + b, inplace=True)def initialize(self):weight_init(self)
2.3GCF:全局上下文流模块
用于从 R e s N e t 50 ResNet50 ResNet50提取的最高级特征图中捕获全局上下文信息。
class CA(nn.Module):def __init__(self, in_channel_left, in_channel_down):#in_channel_left:f_{top}通道数;in_channel_down:f_{gap}通道数super(CA, self).__init__()self.conv0 = nn.Conv2d(in_channel_left, 256, kernel_size=1, stride=1, padding=0)self.bn0 = nn.BatchNorm2d(256)self.conv1 = nn.Conv2d(in_channel_down, 256, kernel_size=1, stride=1, padding=0)self.conv2 = nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)def forward(self, left, down):#对f_{top}进行Conv+Batch Normlization+ReLUleft = F.relu(self.bn0(self.conv0(left)), inplace=True)#平均池化,减少空间维度(H、W下降)down = down.mean(dim=(2,3), keepdim=True)#卷积+激活down = F.relu(self.conv1(down), inplace=True)#将输出值归一化到0-1之间down = torch.sigmoid(self.conv2(down))return left * downdef initialize(self):weight_init(self)
2.4HA:头部注意模块
用于去掉 R e s N e t 50 ResNet50 ResNet50提取的最高级特征图中对于显著性目标检测冗余的信息。
class SA(nn.Module):def __init__(self, in_channel_left, in_channel_down):super(SA, self).__init__()self.conv0 = nn.Conv2d(in_channel_left, 256, kernel_size=3, stride=1, padding=1)self.bn0 = nn.BatchNorm2d(256)self.conv2 = nn.Conv2d(in_channel_down, 512, kernel_size=3, stride=1, padding=1)def forward(self, left, down):#left、down都是由ResNet提取的特征#与SR模块相同操作left = F.relu(self.bn0(self.conv0(left)), inplace=True) #256 channelsdown_1 = self.conv2(down)#检查down_1的空间尺寸是否与left相同.如果不同,则使用双线性插值调整down_1的尺寸.if down_1.size()[2:] != left.size()[2:]:down_1 = F.interpolate(down_1, size=left.size()[2:], mode='bilinear')#与SR模块相同,分别获取权重w、bw,b = down_1[:,:256,:,:], down_1[:,256:,:,:]#得到F1return F.relu(w*left+b, inplace=True)def initialize(self):weight_init(self)
2.5整体流程
class GCPANet(nn.Module):def __init__(self, cfg):super(GCPANet, self).__init__()self.cfg = cfg#ResNet50:进行特征提取self.bkbone = ResNet()#GCF:初始化多个通道注意力模块(CA)、空间注意力模块(SA)用于特征加权self.ca45 = CA(2048, 2048)self.ca35 = CA(2048, 2048)self.ca25 = CA(2048, 2048)self.ca55 = CA(256, 2048)self.sa55 = SA(2048, 2048)#FIA:初始化特征交织聚合模块,用于处理不同层次的特征self.fam45 = FAM(1024, 256, 256)self.fam34 = FAM( 512, 256, 256)self.fam23 = FAM( 256, 256, 256)#SR:初始化自细化模块,用于对特征进行处理和提升self.srm5 = SRM(256)self.srm4 = SRM(256)self.srm3 = SRM(256)self.srm2 = SRM(256)#四个卷积层,将特征图(256通道)映射为单通道输出self.linear5 = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1)self.linear4 = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1)self.linear3 = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1)self.linear2 = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1)#初始化权重self.initialize()def forward(self, x):#使用骨干网络ResNet提取多层次特征out1, out2, out3, out4, out5_ = self.bkbone(x)# GCFout4_a = self.ca45(out5_, out5_)out3_a = self.ca35(out5_, out5_)out2_a = self.ca25(out5_, out5_)# HAout5_a = self.sa55(out5_, out5_)out5 = self.ca55(out5_a, out5_)#FIA+SRout5 = self.srm5(out5)out4 = self.srm4(self.fam45(out4, out5, out4_a))out3 = self.srm3(self.fam34(out3, out4, out3_a))out2 = self.srm2(self.fam23(out2, out3, out2_a))#将四个阶段SR模块的输出线性插值,得到与原始图像有相同大小的特征图out5 = F.interpolate(self.linear5(out5), size=x.size()[2:], mode='bilinear')out4 = F.interpolate(self.linear4(out4), size=x.size()[2:], mode='bilinear')out3 = F.interpolate(self.linear3(out3), size=x.size()[2:], mode='bilinear')out2 = F.interpolate(self.linear2(out2), size=x.size()[2:], mode='bilinear')#返回四张特征图return out2, out3, out4, out5def initialize(self):if self.cfg.snapshot:try:self.load_state_dict(torch.load(self.cfg.snapshot))except:print("Warning: please check the snapshot file:", self.cfg.snapshot)passelse:weight_init(self)
2.6损失函数:二元交叉熵损失
import torch.nn.functional as F
#获取模型输出
out2, out3, out4, out5 = net(image)
#计算各个特征图对应的损失值
loss2 = F.binary_cross_entropy_with_logits(out2, mask)
loss3 = F.binary_cross_entropy_with_logits(out3, mask)
loss4 = F.binary_cross_entropy_with_logits(out4, mask)
loss5 = F.binary_cross_entropy_with_logits(out5, mask)
#根据权重计算综合损失
loss = loss2*1 + loss3*0.8 + loss4*0.6 + loss5*0.4
三、CPD解码器框架
论文:Cascaded Partial Decoder for Fast and Accurate Salient Object Detection(用于实现快速准确显著性目标检测的级联部分解码器)
论文链接:Cascaded Partial Decoder for Fast and Accurate Salient Object Detection
论文代码:Github
博客链接:论文阅读(二十一):Cascaded Partial Decoder for Fast and Accurate Salient Object Detection
3.1HA:整体注意力模块
显著对象的边缘信息可能会被初始显著性图过滤掉,而使得复杂场景中的某些对象难以精确分割。整体注意力模块(HA,Holistic Attention Module)可用于提高显著性图的有效性,有助于分割整个突出对象,并完善更精确的边界。
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.parameter import Parameterimport numpy as np
import scipy.stats as stdef gkern(kernlen=16, nsig=3):interval = (2*nsig+1.)/kernlenx = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1)kern1d = np.diff(st.norm.cdf(x))kernel_raw = np.sqrt(np.outer(kern1d, kern1d))kernel = kernel_raw/kernel_raw.sum()return kerneldef min_max_norm(in_):max_ = in_.max(3)[0].max(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_)min_ = in_.min(3)[0].min(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_)in_ = in_ - min_return in_.div(max_-min_+1e-8)class HA(nn.Module):# holistic attention moduledef __init__(self):super(HA, self).__init__()#使用31×31的高斯核,标准差为4gaussian_kernel = np.float32(gkern(31, 4))#增加维度,使之能够卷积运算gaussian_kernel = gaussian_kernel[np.newaxis, np.newaxis, ...]#转为张量,并作为可学习参数self.gaussian_kernel = Parameter(torch.from_numpy(gaussian_kernel))def forward(self, attention, x):#通过卷积运算获取注意力图soft_attention = F.conv2d(attention, self.gaussian_kernel, padding=15)#最小-最大归一化soft_attention = min_max_norm(soft_attention)#将特征图与注意力图运算得到加权后的特征图x = torch.mul(x, soft_attention.max(attention))return x
3.2Aggregation:特征融合模块
特征融合模块可将输入的三个特征进行融合:
class aggregation(nn.Module):# dense aggregation, it can be replaced by other aggregation model, such as DSS, amulet, and so on.# used after MSFdef __init__(self, channel):super(aggregation, self).__init__()self.relu = nn.ReLU(True)self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1)self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1)self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1)self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1)self.conv_upsample5 = BasicConv2d(2*channel, 2*channel, 3, padding=1)self.conv_concat2 = BasicConv2d(2*channel, 2*channel, 3, padding=1)self.conv_concat3 = BasicConv2d(3*channel, 3*channel, 3, padding=1)self.conv4 = BasicConv2d(3*channel, 3*channel, 3, padding=1)self.conv5 = nn.Conv2d(3*channel, 1, 1)def forward(self, x1, x2, x3):x1_1 = x1#将x1上采样后与x2相乘得到x2_1x2_1 = self.conv_upsample1(self.upsample(x1)) * x2#将x1、x2上采样后于x3相乘得到x3_1x3_1 = self.conv_upsample2(self.upsample(self.upsample(x1))) \* self.conv_upsample3(self.upsample(x2)) * x3#将x1_1上采样后与x2_1拼接得到x2_2 x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1)x2_2 = self.conv_concat2(x2_2)#将x2_2上采样后与x3_1拼接得到x3_2x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1)x3_2 = self.conv_concat3(x3_2)#双线性插值上采样8倍得到最终输出x = self.conv4(x3_2)# N,96,H//8,W//8x = self.conv5(x)return x
3.3级联解码器模块
级联编码器通过丢弃了浅层的高分辨率特征以进行加速(只利用了编码器的后三个卷积块的输出特征,没有使用前两个),并直接利用生成的显著性图来循环优化深层的特征,这有效地抑制了特征中的干扰项,并显著提高了其表示能力。
模型中的RFB模块不再赘述,论文Adjacent Context Coordination Network for Salient Object Detection in Optical Remote Sensing Images中有对级联解码器的具体使用。级联解码器整体代码:
class CPD_VGG(nn.Module):def __init__(self, channel=32):super(CPD_VGG, self).__init__()self.vgg = B2_VGG()self.rfb3_1 = RFB(256, channel)self.rfb4_1 = RFB(512, channel)self.rfb5_1 = RFB(512, channel)self.agg1 = aggregation(channel)self.rfb3_2 = RFB(256, channel)self.rfb4_2 = RFB(512, channel)self.rfb5_2 = RFB(512, channel)self.agg2 = aggregation(channel)self.HA = HA()self.upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)def forward(self, x):x1 = self.vgg.conv1(x)x2 = self.vgg.conv2(x1)x3 = self.vgg.conv3(x2)x3_1 = x3x4_1 = self.vgg.conv4_1(x3_1)x5_1 = self.vgg.conv5_1(x4_1)x3_1 = self.rfb3_1(x3_1)x4_1 = self.rfb4_1(x4_1)x5_1 = self.rfb5_1(x5_1)#特征融合aggregation模块attention = self.agg1(x5_1, x4_1, x3_1)#整体注意力模块,得到注意力分支的运算结果x3_2 = self.HA(attention.sigmoid(), x3)x4_2 = self.vgg.conv4_2(x3_2)x5_2 = self.vgg.conv5_2(x4_2)x3_2 = self.rfb3_2(x3_2)x4_2 = self.rfb4_2(x4_2)x5_2 = self.rfb5_2(x5_2)#特征融合aggregation模块detection = self.agg2(x5_2, x4_2, x3_2)return self.upsample(attention), self.upsample(detection)
四、ACCoNet
论文:Adjacent Context Coordination Network for Salient Object Detection in Optical Remote Sensing Images
论文链接:Adjacent Context Coordination Network for Salient Object Detection in Optical Remote Sensing Images
论文代码:Github
论文博客:Adjacent Context Coordination Network for Salient Object Detection in Optical Remote Sensing Images
相邻上下文协调网络ACCoNet(Adjacent Context Coordination Network)是用于光学遥感影像的显著性目标检测模型。核心思想是全面探索相邻特征中包含的上下文信息,扩大特征交互的覆盖范围,并提高普通解码器块的上下文捕获能力。
4.1编码器模块:VGG16主干
使用VGG16作为编码器网络,其中最后的最大池化层和三个全连接层被截断。VGG16模型结构如下:
将VGG16分为五个块作为编码器网络,采用每个块最后一个卷积层输出的特征图作为ACCoM模块的输入。编码器网络的输入大小为256×256×3,代码如下:
import torch
import torch.nn as nnclass VGG(nn.Module):# pooling layer at the front of blockdef __init__(self, mode = 'rgb'):super(VGG, self).__init__()conv1 = nn.Sequential()conv1.add_module('conv1_1', nn.Conv2d(3, 64, 3, 1, 1))conv1.add_module('bn1_1', nn.BatchNorm2d(64))conv1.add_module('relu1_1', nn.ReLU(inplace=True))conv1.add_module('conv1_2', nn.Conv2d(64, 64, 3, 1, 1))conv1.add_module('bn1_2', nn.BatchNorm2d(64))conv1.add_module('relu1_2', nn.ReLU(inplace=True))self.conv1 = conv1conv2 = nn.Sequential()conv2.add_module('pool1', nn.MaxPool2d(2, stride=2))conv2.add_module('conv2_1', nn.Conv2d(64, 128, 3, 1, 1))conv2.add_module('bn2_1', nn.BatchNorm2d(128))conv2.add_module('relu2_1', nn.ReLU())conv2.add_module('conv2_2', nn.Conv2d(128, 128, 3, 1, 1))conv2.add_module('bn2_2', nn.BatchNorm2d(128))conv2.add_module('relu2_2', nn.ReLU())self.conv2 = conv2conv3 = nn.Sequential()conv3.add_module('pool2', nn.MaxPool2d(2, stride=2))conv3.add_module('conv3_1', nn.Conv2d(128, 256, 3, 1, 1))conv3.add_module('bn3_1', nn.BatchNorm2d(256))conv3.add_module('relu3_1', nn.ReLU())conv3.add_module('conv3_2', nn.Conv2d(256, 256, 3, 1, 1))conv3.add_module('bn3_2', nn.BatchNorm2d(256))conv3.add_module('relu3_2', nn.ReLU())conv3.add_module('conv3_3', nn.Conv2d(256, 256, 3, 1, 1))conv3.add_module('bn3_3', nn.BatchNorm2d(256))conv3.add_module('relu3_3', nn.ReLU())self.conv3 = conv3conv4 = nn.Sequential()conv4.add_module('pool3_1', nn.MaxPool2d(2, stride=2))conv4.add_module('conv4_1', nn.Conv2d(256, 512, 3, 1, 1))conv4.add_module('bn4_1', nn.BatchNorm2d(512))conv4.add_module('relu4_1', nn.ReLU())conv4.add_module('conv4_2', nn.Conv2d(512, 512, 3, 1, 1))conv4.add_module('bn4_2', nn.BatchNorm2d(512))conv4.add_module('relu4_2', nn.ReLU())conv4.add_module('conv4_3', nn.Conv2d(512, 512, 3, 1, 1))conv4.add_module('bn4_3', nn.BatchNorm2d(512))conv4.add_module('relu4_3', nn.ReLU())self.conv4 = conv4conv5 = nn.Sequential()conv5.add_module('pool4', nn.MaxPool2d(2, stride=2))conv5.add_module('conv5_1', nn.Conv2d(512, 512, 3, 1, 1))conv5.add_module('bn5_1', nn.BatchNorm2d(512))conv5.add_module('relu5_1', nn.ReLU())conv5.add_module('conv5_2', nn.Conv2d(512, 512, 3, 1, 1))conv5.add_module('bn5_2', nn.BatchNorm2d(512))conv5.add_module('relu5_2', nn.ReLU())conv5.add_module('conv5_3', nn.Conv2d(512, 512, 3, 1, 1))conv5.add_module('bn5_2', nn.BatchNorm2d(512))conv5.add_module('relu5_3', nn.ReLU())self.conv5 = conv5pre_train = torch.load('/home/lgy/20210206_ORSI_SOD/model/vgg16-397923af.pth')self._initialize_weights(pre_train)def forward(self, x):#得到五张特征图x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = self.conv4(x)x = self.conv5(x)return x
4.2ACCoM:相邻上下文协调模块
相邻上下文协调模块ACCoM(Adjacent Context Coordination Module),协调当前、前一个和后续区块的跨尺度特征。
- 本地分支:并行使用具有不同卷积内核的卷积层,并配备注意力机制(CA、SA),以在一个特征级别中捕获本地和全局内容。
- 两个相邻分支:在相邻级别的特征之间引入特征交互来捕获跨级别上下文互补信息,即,两个相邻分支使用空间注意力机制(SA)从不同级别特征中提取上下文信息。
之后,ACCoM将这些信息传输到解码器。
【 A C C o M − 1 ACCoM-1 ACCoM−1】
class ACCoM1(nn.Module):def __init__(self, cur_channel):super(ACCoM1, self).__init__()self.relu = nn.ReLU(True)#本地分支##膨胀卷积self.cur_b1 = BasicConv2d(cur_channel, cur_channel, 3, padding=1, dilation=1)self.cur_b2 = BasicConv2d(cur_channel, cur_channel, 3, padding=2, dilation=2)self.cur_b3 = BasicConv2d(cur_channel, cur_channel, 3, padding=3, dilation=3)self.cur_b4 = BasicConv2d(cur_channel, cur_channel, 3, padding=4, dilation=4)self.cur_all = BasicConv2d(4*cur_channel, cur_channel, 3, padding=1)self.cur_all_ca = ChannelAttention(cur_channel)self.cur_all_sa = SpatialAttention()# latter convself.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.lat_sa = SpatialAttention()def forward(self, x_cur, x_lat):#本地分支##使用膨胀卷积得到四个感受野的特征x_cur_1 = self.cur_b1(x_cur)x_cur_2 = self.cur_b2(x_cur)x_cur_3 = self.cur_b3(x_cur)x_cur_4 = self.cur_b4(x_cur)##将不同感受野的特征融合x_cur_all = self.cur_all(torch.cat((x_cur_1, x_cur_2, x_cur_3, x_cur_4), 1))##CA模块生成注意力图,再与本地分支相乘,增加通道特征cur_all_ca = x_cur_all.mul(self.cur_all_ca(x_cur_all))##SA模块生成注意力图,再与CA模块的输出相乘,增加空间特征cur_all_sa = x_cur_all.mul(self.cur_all_sa(cur_all_ca))#后续分支##上采样x_lat = self.upsample2(x_lat)##SA模块生成注意力图,再与本地分支相乘,增加空间特征lat_sa = x_cur_all.mul(self.lat_sa(x_lat))##融合原始分支、本地分支、SA增强后的本地分支x_LocAndGlo = cur_all_sa + lat_sa + x_curreturn x_LocAndGlo
【 A C C o M − 2 、 3 、 4 ACCoM-2、3、4 ACCoM−2、3、4】
class ACCoM(nn.Module):def __init__(self, cur_channel):super(ACCoM, self).__init__()self.relu = nn.ReLU(True)# current convself.cur_b1 = BasicConv2d(cur_channel, cur_channel, 3, padding=1, dilation=1)self.cur_b2 = BasicConv2d(cur_channel, cur_channel, 3, padding=2, dilation=2)self.cur_b3 = BasicConv2d(cur_channel, cur_channel, 3, padding=3, dilation=3)self.cur_b4 = BasicConv2d(cur_channel, cur_channel, 3, padding=4, dilation=4)self.cur_all = BasicConv2d(4 * cur_channel, cur_channel, 3, padding=1)self.cur_all_ca = ChannelAttention(cur_channel)self.cur_all_sa = SpatialAttention()# previous convself.downsample2 = nn.MaxPool2d(2, stride=2)self.pre_sa = SpatialAttention()# latter convself.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.lat_sa = SpatialAttention()def forward(self, x_pre, x_cur, x_lat):#本地分支x_cur_1 = self.cur_b1(x_cur)x_cur_2 = self.cur_b2(x_cur)x_cur_3 = self.cur_b3(x_cur)x_cur_4 = self.cur_b4(x_cur)x_cur_all = self.cur_all(torch.cat((x_cur_1, x_cur_2, x_cur_3, x_cur_4), 1))cur_all_ca = x_cur_all.mul(self.cur_all_ca(x_cur_all))cur_all_sa = x_cur_all.mul(self.cur_all_sa(cur_all_ca))#前序分支##下采样x_pre = self.downsample2(x_pre)##SA模块生成注意力图,再与本地分支相乘,增加空间特征pre_sa = x_cur_all.mul(self.pre_sa(x_pre))#后序分支x_lat = self.upsample2(x_lat)lat_sa = x_cur_all.mul(self.lat_sa(x_lat))#分支融合x_LocAndGlo = cur_all_sa + pre_sa + lat_sa + x_curreturn x_LocAndGlo
【 A C C o M − 5 ACCoM-5 ACCoM−5】
class ACCoM5(nn.Module):def __init__(self, cur_channel):super(ACCoM5, self).__init__()self.relu = nn.ReLU(True)# current convself.cur_b1 = BasicConv2d(cur_channel, cur_channel, 3, padding=1, dilation=1)self.cur_b2 = BasicConv2d(cur_channel, cur_channel, 3, padding=2, dilation=2)self.cur_b3 = BasicConv2d(cur_channel, cur_channel, 3, padding=3, dilation=3)self.cur_b4 = BasicConv2d(cur_channel, cur_channel, 3, padding=4, dilation=4)self.cur_all = BasicConv2d(4*cur_channel, cur_channel, 3, padding=1)self.cur_all_ca = ChannelAttention(cur_channel)self.cur_all_sa = SpatialAttention()# previous convself.downsample2 = nn.MaxPool2d(2, stride=2)self.pre_sa = SpatialAttention()def forward(self, x_pre, x_cur):#本地分支x_cur_1 = self.cur_b1(x_cur)x_cur_2 = self.cur_b2(x_cur)x_cur_3 = self.cur_b3(x_cur)x_cur_4 = self.cur_b4(x_cur)x_cur_all = self.cur_all(torch.cat((x_cur_1, x_cur_2, x_cur_3, x_cur_4), 1))cur_all_ca = x_cur_all.mul(self.cur_all_ca(x_cur_all))cur_all_sa = x_cur_all.mul(self.cur_all_sa(cur_all_ca))#前序分支x_pre = self.downsample2(x_pre)pre_sa = x_cur_all.mul(self.pre_sa(x_pre))#分支融合x_LocAndGlo = cur_all_sa + pre_sa + x_curreturn x_LocAndGlo
4.3BAB:分叉聚合解码器
BAB设计思想来源于级联部分解码器,只是在此基础上使用了两个注意力分支,用于处理当前 A C C o M ACCoM ACCoM模块和上一 B A B BAB BAB模块的输出,最后推断出显著对象的掩码。其中,在BAB内部使用两个膨胀卷积,通过扩大感受野的方式来捕捉合并分支 f a c c o m t f^t_{accom} faccomt中的上下文信息。相关参数:
class BAB_Decoder(nn.Module):def __init__(self, channel_1=1024, channel_2=512, channel_3=256, dilation_1=3, dilation_2=2):super(BAB_Decoder, self).__init__()self.conv1 = BasicConv2d(channel_1, channel_2, 3, padding=1)self.conv1_Dila = BasicConv2d(channel_2, channel_2, 3, padding=dilation_1, dilation=dilation_1)self.conv2 = BasicConv2d(channel_2, channel_2, 3, padding=1)self.conv2_Dila = BasicConv2d(channel_2, channel_2, 3, padding=dilation_2, dilation=dilation_2)self.conv3 = BasicConv2d(channel_2, channel_2, 3, padding=1)self.conv_fuse = BasicConv2d(channel_2*3, channel_3, 3, padding=1)def forward(self, x):x1 = self.conv1(x)#膨胀卷积x1_dila = self.conv1_Dila(x1)x2 = self.conv2(x1)#膨胀卷积x2_dila = self.conv2_Dila(x2)x3 = self.conv3(x2)#融合两个膨胀卷积+正常卷积的结果,再次卷积并返回特征图x_fuse = self.conv_fuse(torch.cat((x1_dila, x2_dila, x3), 1))return x_fuse
4.4网络流程
class ACCoNet_VGG(nn.Module):def __init__(self, channel=32):super(ACCoNet_VGG, self).__init__()#Backbone modelself.vgg = VGG('rgb')self.ACCoM5 = ACCoM5(512)self.ACCoM4 = ACCoM(512)self.ACCoM3 = ACCoM(256)self.ACCoM2 = ACCoM(128)self.ACCoM1 = ACCoM1(64)# self.agg2_rgbd = aggregation(channel)self.decoder_rgb = decoder(512)self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.sigmoid = nn.Sigmoid()def forward(self, x_rgb):#通过VGG得到五种级别特征图x1_rgb = self.vgg.conv1(x_rgb)x2_rgb = self.vgg.conv2(x1_rgb)x3_rgb = self.vgg.conv3(x2_rgb)x4_rgb = self.vgg.conv4(x3_rgb)x5_rgb = self.vgg.conv5(x4_rgb)#将特征图输入五个ACCoM模块中,得到对应输出x5_ACCoM = self.ACCoM5(x4_rgb, x5_rgb)x4_ACCoM = self.ACCoM4(x3_rgb, x4_rgb, x5_rgb)x3_ACCoM = self.ACCoM3(x2_rgb, x3_rgb, x4_rgb)x2_ACCoM = self.ACCoM2(x1_rgb, x2_rgb, x3_rgb)x1_ACCoM = self.ACCoM1(x1_rgb, x2_rgb)#将ACCoM模块输出输入到解码器中s1, s2, s3, s4, s5 = self.decoder_rgb(x5_ACCoM, x4_ACCoM, x3_ACCoM, x2_ACCoM, x1_ACCoM)#解码器结果上采样得到特征图s3 = self.upsample2(s3)s4 = self.upsample4(s4)s5 = self.upsample8(s5)return s1, s2, s3, s4, s5, self.sigmoid(s1), self.sigmoid(s2), self.sigmoid(s3), self.sigmoid(s4), self.sigmoid(s5)
五、FPS-U2Net
论文题目:FPS-U2Net: Combining U2Net and multi-level aggregation architecture for fire point segmentation in remote sensing images(FPS-UNet:结合 UNet 和多级聚合架构,用于遥感图像中的火点分割)
论文链接:FPS-U2Net: Combining U2Net and multi-level aggregation architecture for fire point segmentation in remote sensing images
代码链接:Github
博客链接:论文学习(十):FPS-U2Net: Combining U2Net and multi-level aggregation architecture for fire point segmentation in remote sensing images
5.1编码器架构:U2Net
F P S − U 2 N e t FPS-U^2Net FPS−U2Net网络模型如上,其编码器采用 U 2 N e t U^2Net U2Net架构, U 2 N e t U^2Net U2Net模型的核心即为 R S U − L RSU-L RSU−L。此处不列出代码,太长了。
5.2MAM:多级聚合模块
多级聚合模块位于同一阶段的编码器和解码器之间,以聚合相邻的多尺度特征并捕获更丰富的上下文信息。应当是参考了ACCoNet中的ACCoM模块。
#MAM1
class MAM1(nn.Module):### Multi-level Aggregation Moduledef __init__(self, in_ch, out_ch):super(MAM1, self).__init__()# latter convself.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.lat_sa = SpatialAttention()self.cbam=CBAM(in_ch)self.fuse=Bottleneck(in_ch, out_ch)def forward(self, x_cur, x_lat):# latter convx_lat = self.upsample2(x_lat)lat_sa = x_cur.mul(self.lat_sa(x_lat))x=self.cbam(x_cur)x=x+lat_sa+x_curx_LocAndGlo = self.fuse(x)return x_LocAndGlo
#MAM2-4
class MAM(nn.Module):### Multi-level Aggregation Moduledef __init__(self, in_ch, out_ch):super(MAM, self).__init__()# previous convself.downsample2 = nn.MaxPool2d(2, stride=2)self.pre_sa = SpatialAttention()# latter convself.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.lat_sa = SpatialAttention()self.cbam=CBAM(in_ch)self.fuse=Bottleneck(in_ch, out_ch)def forward(self, x_pre, x_cur, x_lat):# previois convx_pre = self.downsample2(x_pre)pre_sa = x_cur.mul(self.pre_sa(x_pre))# latter convx_lat = self.upsample2(x_lat)lat_sa = x_cur.mul(self.lat_sa(x_lat))x=self.cbam(x_cur)x=x+pre_sa+lat_sa+x_curx_LocAndGlo = self.fuse(x)return x_LocAndGlo
#MAM5
class MAM5(nn.Module):### Multi-level Aggregation Moduledef __init__(self, in_ch, out_ch):super(MAM5, self).__init__()# previous convself.downsample2 = nn.MaxPool2d(2, stride=2)self.pre_sa = SpatialAttention()self.cbam=CBAM(in_ch)self.fuse=Bottleneck(in_ch, out_ch)def forward(self, x_pre, x_cur):# previois convx_pre = self.downsample2(x_pre)pre_sa = x_cur.mul(self.pre_sa(x_pre))x=self.cbam(x_cur)x=x+pre_sa+x_curx_LocAndGlo = self.fuse(x)return x_LocAndGlo