论文:https://arxiv.org/abs/2212.04497
代码:GitHub - Amshaker/unetr_plus_plus: UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation
机构:Mohamed Bin Zayed University of Artificial Intelligence1, University of California Merced2, Google Research3, Linkoping University4
UNETR++作者和UNETR居然完全不沾边来着,继续找思路所以主要写写方法部分,别的部分简略一点.....!感觉挺有收获的!
摘要
由于Transformer模型的成功,最近的工作研究了它们在三维医学分割任务中的适用性。在Transformer模型中,自注意力机制是努力获取远程依赖关系的主要构建块之一。然而,自注意运算具有二次复杂度,这被证明是一个计算瓶颈,特别是在体积医学成像中,其中输入是三维的,有许多切片。在本文中,我们提出了一种名为unetr++的三维医学图像分割方法,该方法既提供了高质量的分割mask,又在参数、计算成本和推理速度方面具有效率。我们设计的核心是引入一种新的高效成对注意(efficient paired attention, EPA)块,该块使用基于空间和通道注意的一对相互依赖的分支有效地学习空间和通道方面的判别特征。我们的空间注意公式是有效的,具有相对于输入序列长度的线性复杂性(linear complexity)。为了实现空间分支和以通道为中心的分支之间的通信,我们共享查询(query)和键映射(key mapping)功能的权重,这些功能提供了互补的好处(配对关注),同时还减少了整体网络参数。我们对Synapse、BTCV、ACDC、BRaTs和Decathlon-Lung这五个基准进行了广泛的评估,揭示了我们在效率和准确性方面的贡献的有效性。在Synapse上,我们的UNETR++设置了一个新的最先进的骰子得分为87.2%,同时与文献中最好的方法相比,在参数和FLOPs方面都降低了71%以上,效率显著。
背景
早期基于CNN的网络受限于他们的感受野,但是基于transformer的方法计算成本高
后面也冒出了一些混合方法,一些用基于transformer的encoder和卷积的decoder,另外一些设计编码器和解码器子网的混合块。但是这些网络主要关注于提高分割进度,这反过来又大大增加了模型在参数和FLOPs的大小,导致鲁棒性不理想。我们认为这是由于他们低效的self-attention的设计,在体数据分割中显露出更大的问题。此外,这些现有的方法没有捕捉到空间和通道特征之间的显式依赖关系,这可以提高分割质量。在这项工作中,我们的目标是在一个统一的框架中同时提高分割精度和模型效率。
贡献
1)我们提出了一种高效的混合分层结构用于三维医学图像分割,命名为unetr++,力求在参数、FLOPs和推理速度方面实现更好的分割精度和效率。基于最近的UNETR框架[13],我们提出的UNETR++分层方法引入了一种新的高效的对注意力(EPA)块,该块通过在两个分支中应用空间和通道注意力有效地捕获丰富的相互依赖的空间和通道特征。我们在EPA中的空间注意将键和值投射到一个固定的低维空间,使自注意计算相对于输入令牌的数量呈线性。另一方面,我们的通道注意通过在通道维度中执行查询和键之间的点积操作来强调通道特征映射之间的依赖关系。此外,为了捕获空间和通道特征之间的强相关性,查询和键的权重在分支之间共享,这也有助于控制网络参数的数量。相反,值的权重保持独立,以强制学习两个分支中的互补特征。
2)我们通过在五个基准上进行全面实验来验证我们的UNETR++方法:
Synapse[19]、BTCV[19]、ACDC[1]、BRaTs[24] Decathlon-Lungs[30]。定性和定量结果都证明了UNETR++的有效性,与文献中已有的方法相比,在分割精度和模型效率方面都有更好的表现。
相关工作
CNN-based Segmentation Methods
unet,多尺度三维全卷积,nnunet
金字塔[35]、大核[26]、扩展卷积[6]和可变形卷积[20]等方法,在基于cnn的框架内编码整体上下文信息
Transformers-based Segmentation Methods
ViT, 1d-embedding,shifted windows for 2D
Hybrid Segmentation Methods
TransFuse[34]提出了一种带有BiFusion模块的并行cnn-Transformer架构,用于融合编码器中的多级特征。
MedT[31]在自注意中引入了门控的位置敏感轴向注意机制来控制编码器中的位置嵌入信息,而解码器中的ConvNet模块产生分割模型。
TransUNet[5]结合了Transformer和U-Net架构,其中Transformer对来自卷积特征的嵌入图像补丁进行编码,解码器将上采样编码特征与高分辨率CNN特征相结合进行定位。
Ds-transunet[21]采用双尺度编码器 Swin-transformer[22]处理多尺度输入,并通过自注意编码来自不同语义尺度的局部和全局特征表示。
UNETR,三维混合模型, 该模型将变压器的远程空间依赖关系与CNN的感应偏置结合成“u形”编码器结构。其参数量是nnunet的2.5倍,但是如果nnFormer要在UNETR的基础上获得了更好的性能,需要进一步增加了1.6X参数和2.8Xflop。 UNETR:用于三维医学图像分割的Transformer-CSDN博客
nnFormer, 该方法适应swing - unet[3]架构。在这里,卷积层将输入扫描转换成三维patches,并引入基于体积的自关注模块来构建分层特征金字塔。nnFormer在取得良好性能的同时,其计算复杂度明显高于UNETR和其他混合方法。
我们认为上述混合方法难以有效捕获特征通道之间的相互依赖关系,以获得丰富的特征表示,既编码空间信息,也编码通道间的特征依赖关系。
方法
我们首先确定了我们要设计混合框架的两个理想属性:
1)Efficient Global Attention 高效的全局注意力:
在体积医学分割的情况下,计算上是昂贵的,并且在混合设计中交织窗口关注和卷积组件时变得更加成问题。与这些方法不同的是,我们认为跨特征通道计算自关注而不是计算体积维度,有望将相对于体积维度的复杂性从二次型降低到线性型。此外,通过将键和值的空间矩阵投影到较低维空间中,可以有效地学习空间注意信息。
2) Enriched Spatial-channel Feature Representation丰富的空间通道特征表示:
现有的混合体医学图像分割方法大多是通过注意力计算来捕获空间特征,而忽略了以编码不同通道特征映射之间相互依赖关系的形式来获取通道信息。
整体框架
我们的UNETR++框架基于最近推出的UNETR[13],在编码器和解码器之间使用跳过连接,然后是卷积块(ConvBlocks)来生成预测掩码。
我们的unetr++采用分层设计,而不是在整个编码器中使用固定的特征分辨率,其中特征的分辨率在每个阶段逐渐降低两倍。在我们的UNETR++框架中,编码器有四个阶段,其中第一阶段包括Patch embedding,将体积输入划分为3D补丁,然后是我们新颖的高效成对注意(EPA)块。
Patch embedding
UNETR++的这个部分和 UNETR挺像的呢,但是有点好奇的是为什么UNETR里面用的直接是P,而没有分为P1,P2,P3这样,到时候看看代码其中P1,P2,P3是否不同好了
把3D输入 x∈R HxWxD 变成不重叠的补丁 xu∈R Nx(P1,P2,P3),其中P1,P2,P3是每个patch的分辨率, N=H/P1 x W/P2 xD/P3,是序列长度。
然后,将这些补丁投影到C通道维度,得到的特征图尺寸为 H/P1 x W/P2 xD/P3 x C
对于每个剩余的编码器阶段,我们使用非重叠卷积的下采样层将分辨率降低两倍,然后是EPA块。
在我们提出的unetr++框架中,每个EPA块包括两个注意模块,通过使用共享关键字查询方案对空间和通道维度的信息进行编码,有效地学习丰富的空间通道特征表示。
在我们提出的unetr++框架中,每个EPA块包括两个注意模块,通过使用共享keys-queries方案对空间和通道维度的信息进行编码,有效地学习丰富的空间通道特征表示。编码器级通过skip-connection 与解码器级连接以合并不同分辨率的输出。这可以恢复下采样操作期间丢失的空间信息,从而预测更精确的输出。与编码器类似,解码器也包括四个阶段,其中每个解码器阶段包括一个上采样层,使用反卷积将特征图的分辨率提高两倍,然后是EPA块(最后一个解码器除外)。Channel 的数量在每两个解码器阶段之间减少2倍。因此,最后一个解码器的输出与卷积特征映射融合,以恢复空间信息并增强特征表示。然后将结果输出馈送到3x3x3和1x1x1个卷积块中以生成voxel-wise的最终掩码预测。
Efficient PairedAttention Block
空间注意模块将自注意的复杂度从二次型降低到线性型。另一方面,通道注意模块有效地学习了通道特征映射之间的相互依赖关系。EPA块基于两个注意模块之间的共享keys-queries查询方案,以相互通知,以产生更好和更有效的特征表示。这可能是由于通过共享keys-queries来学习互补特性,但使用不同的value layer。
如图所示,输入特征映射x被馈送到EPA块的通道和空间注意模块。
Q和K线性层的权值是在两个注意模块之间共享的,每个注意模块使用不同的V层。两个注意模块计算为:
其中,^X s和^X c分别表示空间和通道注意图。SA为空间注意模块,CA为通道注意模块。Qshared、Kshared、Vspatial和Vchannel分别是共享查询、共享键、空间值层和通道值层的矩阵。!就是这里的QK都是共享的但是值做单独注意
Spatial attention
我们用这个模块把获取空间信息的复杂度从O(n^2)降低到O(np) (所以到底和原先的相比怎么降的呢🤔),其中n为记号的个数,p为投影向量的维数,其中p << n。
给定shape为 HW DXC的归一化张量x,我们使用三个线性层计算Qshared, Kshared和vspace投1影,收益率Qshared = WQX, Kshared=WKX, vspace =WVX,其中,WQ、WK、WV分别为Qshared、Kshared、Vspatial的投影权值。
1)Kshared和Vspatial层从HWD XC投影到形状为p C的低维矩阵中。(坏了我怎么记得是把channel压瘪,我再回去看看先)
2)其次,通过将Qshared层乘以投影Kshared的转置来计算空间注意图,然后使用softmax来度量每个特征与其他空间特征之间的相似性。
3)这些相似度乘以投影的vspace层,生成shapeHWDxC的最终空间注意图。空间注意的定义如下:
(我记忆中的空间注意力是CBAM的这个↓
)
Channel attention
该模块通过在通道值层和通道注意图之间的通道维度中应用点积运算来捕获特征通道之间的相互依赖关系。
利用空间注意模块相同的Qshared和Kshared,计算通道的值层,利用线性层学习互补特征,得到Vchannel = WVX,维数为 HWDxC,其中wv为Vchannel的投影权值。
定义如下
式中,Vchannel、Qshared、Kshared分别表示通道值层、共享查询、共享键,d为每个向量的大小。
最后,我们对两个关注模块的输出进行和融合,并通过卷积块对其进行变换,以获得丰富的特征表示。EPA块的最终输出^X为:
其中,^X s和^X c表示空间和通道注意图,Conv1和Conv3分别为1x1x1和3x3x3卷积块。
dbq我到时候在琢磨一下CBAM的通道和空间注意力和这个什么关系好了,感觉不太一样
损失函数
soft dice loss + cross-entropy loss
式中,I为类数;V为体素数;Yv;i和Pv;i分别表示类i在体素v处的真实情况和输出概率。
实验
数据集
Synapse 多器官CT分割
BTCV 多器官CT分割
ACDC 心脏自动诊断
BraTS 脑肿瘤分割
Decathlon-Lung
实现细节
Pytorch v1.10.1, MONAI库(可恶这个也用的是那个库,我有空直接进行一个学!)
硬件:A100 40GB GPU
1k epochs
learning rate :0.01 , weight decay :3e^5.
评估指标
Dice Similarity Coefficient (DSC
95% Hausdorff Distance (HD95
结果
Synapse
BTCV
ACDC
BRATS
Lungs
展望
为了观察UNETR++的潜在局限性,我们分析了Synapse的不同异常情况。虽然我们的预测比现有的方法更好,更接近真实情况,但我们发现,在一些情况下,我们的模型和现有的方法一样,难以分割某些器官。当一些切片中器官的几何形状异常(由细边界描绘)时,我们的模型和现有的模型很难准确地分割它们。原因可能是与正常样本相比,具有这种异常形状的训练样本的可用性有限。我们计划在预处理阶段应用几何数据增强技术来解决这个问题。