论文信息
标题: Channel Prior Convolutional Attention for Medical Image Segmentation
论文链接: arxiv.org
代码链接: GitHub
创新点
本文提出了一种新的通道优先卷积注意力(CPCA)机制,旨在解决医学图像分割中存在的低对比度和显著器官形状变化等问题。CPCA通过在通道和空间维度上动态分配注意力权重,增强了模型对信息丰富通道和重要区域的关注能力。
方法
CPCA方法的核心在于以下几个方面:
-
动态注意力分配: 在通道和空间维度上支持动态分配注意力权重,使得模型能够自适应地关注不同的特征。
-
多尺度深度可分离卷积模块: 该模块有效提取空间关系,同时保持通道优先,增强了特征表示能力。
-
CPCANet网络结构: 基于CPCA的医学图像分割网络CPCANet被设计用于实现更高效的分割性能。
CPCA模块详解
通道优先卷积注意力(Channel Prior Convolutional Attention,CPCA)是一种新型的注意力机制,旨在提升深度学习模型在医学图像分割等任务中的性能。CPCA通过动态分配通道和空间维度的注意力权重,增强了模型对重要特征的关注能力。
CPCA结合了通道注意力和空间注意力的机制,具体实现步骤如下:
-
通道注意力(Channel Attention):
- 特征聚合: 通过全局平均池化和全局最大池化操作,生成两个不同的特征表示。
- 特征处理: 将这两个特征表示分别通过卷积层和激活函数处理,以提取通道之间的关系。
- 权重生成: 通过Sigmoid函数生成通道注意力权重,用于动态调整每个通道的重要性。
-
空间注意力(Spatial Attention):
- 空间关系捕捉: 使用多尺度深度可分离卷积模块来提取特征图中不同位置之间的关系。
- 多尺度卷积: 采用不同大小的卷积核来捕获多尺度信息,从而更好地理解特征图的空间结构。
-
整体机制:
- CPCA通过结合通道和空间注意力,动态分配注意力权重,并保持通道优先。这种结合使得网络能够更好地捕捉重要特征,提高特征的表征能力。
效果
在多个公开数据集上进行的实验表明,CPCANet相较于其他最先进的算法在分割性能上表现更佳,同时所需计算资源更少。具体来说,CPCA在处理低对比度和复杂形状变化的医学图像时,显著提高了分割精度。
实验结果
实验结果显示,CPCANet在两个主要数据集(如ACDC和ISIC2016)上均取得了优于现有方法的分割效果。通过对比分析,CPCA方法在多个指标上均表现出色,尤其是在处理复杂背景和形状变化时,展现了其优越性。
总结
本文通过提出通道优先卷积注意力(CPCA)机制,显著提升了医学图像分割的性能。CPCA不仅增强了模型对重要特征的关注能力,还通过动态分配注意力权重,优化了计算资源的使用。未来的研究可以进一步探索CPCA在其他领域的应用潜力,以及与其他深度学习技术的结合。
代码
import torch
import torch.nn.functional
import torch.nn.functional as F
from torch import nnclass Mlp(nn.Module):""" Multilayer perceptron."""def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return xdef conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):result = nn.Sequential()result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,kernel_size=kernel_size, stride=stride, padding=padding, groups=groups,bias=False))result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))return resultdef conv_bn_relu(in_channels, out_channels, kernel_size, stride, padding, groups=1):result = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,padding=padding, groups=groups)result.add_module('relu', nn.ReLU())return resultdef fuse_bn(conv_or_fc, bn):std = (bn.running_var + bn.eps).sqrt()t = bn.weight / stdt = t.reshape(-1, 1, 1, 1)if len(t) == conv_or_fc.weight.size(0):return conv_or_fc.weight * t, bn.bias - bn.running_mean * bn.weight / stdelse:repeat_times = conv_or_fc.weight.size(0) // len(t)repeated = t.repeat_interleave(repeat_times, 0)return conv_or_fc.weight * repeated, (bn.bias - bn.running_mean * bn.weight / std).repeat_interleave(repeat_times, 0)class ChannelAttention(nn.Module):def __init__(self, input_channels, internal_neurons):super(ChannelAttention, self).__init__()self.fc1 = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons, kernel_size=1, stride=1,bias=True)self.fc2 = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels, kernel_size=1, stride=1,bias=True)self.input_channels = input_channelsdef forward(self, inputs):x1 = F.adaptive_avg_pool2d(inputs, output_size=(1, 1))# print('x:', x.shape)x1 = self.fc1(x1)x1 = F.relu(x1, inplace=True)x1 = self.fc2(x1)x1 = torch.sigmoid(x1)x2 = F.adaptive_max_pool2d(inputs, output_size=(1, 1))# print('x:', x.shape)x2 = self.fc1(x2)x2 = F.relu(x2, inplace=True)x2 = self.fc2(x2)x2 = torch.sigmoid(x2)x = x1 + x2x = x.view(-1, self.input_channels, 1, 1)return xclass RepBlock(nn.Module):def __init__(self, in_channels, out_channels,channelAttention_reduce=4):super().__init__()self.C = in_channelsself.O = out_channelsassert in_channels == out_channelsself.ca = ChannelAttention(input_channels=in_channels, internal_neurons=in_channels // channelAttention_reduce)self.dconv5_5 = nn.Conv2d(in_channels, in_channels, kernel_size=5, padding=2, groups=in_channels)self.dconv1_7 = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 7), padding=(0, 3), groups=in_channels)self.dconv7_1 = nn.Conv2d(in_channels, in_channels, kernel_size=(7, 1), padding=(3, 0), groups=in_channels)self.dconv1_11 = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 11), padding=(0, 5), groups=in_channels)self.dconv11_1 = nn.Conv2d(in_channels, in_channels, kernel_size=(11, 1), padding=(5, 0), groups=in_channels)self.dconv1_21 = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 21), padding=(0, 10), groups=in_channels)self.dconv21_1 = nn.Conv2d(in_channels, in_channels, kernel_size=(21, 1), padding=(10, 0), groups=in_channels)self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 1), padding=0)self.act = nn.GELU()def forward(self, inputs):# Global Perceptroninputs = self.conv(inputs)inputs = self.act(inputs)channel_att_vec = self.ca(inputs)inputs = channel_att_vec * inputsx_init = self.dconv5_5(inputs)x_1 = self.dconv1_7(x_init)x_1 = self.dconv7_1(x_1)x_2 = self.dconv1_11(x_init)x_2 = self.dconv11_1(x_2)x_3 = self.dconv1_21(x_init)x_3 = self.dconv21_1(x_3)x = x_1 + x_2 + x_3 + x_initspatial_att = self.conv(x)out = spatial_att * inputsout = self.conv(out)return outif __name__ == "__main__":dim=64# 如果GPU可用,将模块移动到 GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 输入张量 (batch_size, height, width,channels)x = torch.randn(2,dim,40,40).to(device)# 初始化 CPCA模块block = RepBlock(dim,dim)print(block)block = block.to(device)# 前向传播output = block(x)print("输入:", x.shape)print("输出:", output.shape)
输出结果: