计算病理学(computational pathology)下的深度学习方法需要手动注释大型 WSI 数据集,并且通常存在领域适应性和可解释性较差的问题。作者报告了一种可解释的弱监督深度学习方法,只需要WSI级标签。将该方法命名为聚类约束注意力多实例学习 (CLAM,clustering-constrained-attention multiple-instance learning),它使用注意力来识别具有高诊断价值的子区域,以准确对整个WSI进行分类,并在已识别的代表性区域上进行实例级聚类以约束和细化特征空间。通过将 CLAM 应用于肾细胞癌和非小细胞肺癌的亚型分类以及淋巴结转移的检测,表明它可用于定位 WSI 上的形态特征,其性能优于标准弱监督分类算法。
来自:Data-efficient and weakly supervised computational pathology on whole-slide images, Nature Biomedical Engineering, 2021
工程地址:https://github.com/mahmoodlab/CLAM
目录
- CLAM概述
- 方法
- Instance-level clustering
- Smooth SVM loss
- 训练细节
CLAM概述
- 图1a:分割后,我们可以从WSI中提取patches。
- 图1b:patches被预训练的CNN编码成特征表示,在训练和推理过程中,每个WSI中提取的patch作为特征向量传递给CLAM。使用注意力网络将patch信息聚合为WSI级表示,用于最终的诊断预测。
- 图1c:对于每个类,注意力网络对WSI中的每个patch进行排名,根据其对WSI诊断的重要性分配注意力分数(左)。注意力pooling根据每个patch的注意力得分对其进行加权,并将patch级别的特征总结为WSI级别的表示(右下)。在训练过程中,给定GT标签,强参与(红色)和弱参与(蓝色)patch可以额外用作代表性样本以监督聚类层,聚类层学习丰富的patch级特征空间,可在不同类别的正实例和负实例之间分离(右上)。
- 图1d:注意力得分可以可视化为热图,以识别ROI(解释用于诊断的重要形态学)。
方法
CLAM是一个高通量的深度学习工具箱,旨在解决计算病理学中的弱监督分类任务,其中训练集中的每个WSI是具有已知WSI级别的单个数据点,但对于WSI中的任何像素或patch都没有类别特定的信息或注释。CLAM建立在MIL框架之上,该框架将每个WSI(称为bag)视为由许多(多达数十万)较小的区域或patch(称为instance)组成的集合。MIL框架通常将其范围限制在一个正类和一个负类的二元分类问题上,并基于这样的假设:如果至少有一个patch属于正类,那么整个WSI应该被分类为正类(阳性),而如果所有patch都属于负类,则WSI应该被分类为负类(阴性)。这一假设体现在max-pooling聚合函数上,它简单地使用正类预测概率最高的patch进行WSI级预测,这也使得MIL不适合多类分类问题。
除了Max-pooling之外,虽然可以使用其他聚合函数,但它们依然不能提供简单、直观的模型可解释性机制。相比之下,CLAM通常适用于多类别分类,它是围绕可训练和可解释的基于注意力的pooling函数构建的,从patch级表示中聚合每个类别的WSI级表示。在多分类注意力pooling设计中,注意力网络预测了一个多类分类问题中对应于 N N N个类别的 N N N个不同的注意力分数集。这使得网络能够明确地了解哪些形态学特征应该被视为每个类的积极证据(类相关的特征)和消极证据(非信息性的,缺乏类定义的特征),并总结WSI级表示。
具体来说,对于表示为 K K K个实例(patch)的WSI,我们将对应于第 k k k个patch的实例级嵌入表示为 z k z_{k} zk。在CLAM中,第一个全连接层 W 1 ∈ R 512 × 1024 W_{1}\in\R^{512\times 1024} W1∈R512×1024进一步将每个固定的patch级表示 z k ∈ R 1024 z_{k}\in\R^{1024} zk∈R1024压缩为 h k ∈ R 512 h_{k}\in\R^{512} hk∈R512。注意力网络由几个堆叠的全连接层组成;如果将注意力网络的前两层 U a ∈ R 256 × 512 U_{a}\in\R^{256\times 512} Ua∈R256×512+ V a ∈ R 256 × 512 V_{a}\in\R^{256\times 512} Va∈R256×512和 W 1 W_{1} W1共同视为所有类共享的注意力主干的一部分,注意力网络将分为 N N N个平行分支: W a , 1 , . . . , W a , N ∈ R 1 × 256 W_{a,1},...,W_{a,N}\in\R^{1\times 256} Wa,1,...,Wa,N∈R1×256。同样, N N N个并行独立分类器 W c , 1 , . . . , W c , N W_{c,1},...,W_{c,N} Wc,1,...,Wc,N对每个特定类的WSI表示进行评分。
因此,第 i i i类的第 k k k个patch的注意力分数记为 a i , k a_{i,k} ai,k,并且根据第 i i i类注意力分数聚合WSI表示记为 h s l i d e , i ∈ R 512 h_{slide,i}\in\R^{512} hslide,i∈R512:
分类层 W c , i W_{c,i} Wc,i给出相应的非归一化WSI级分数 s s l i d e , i s_{slide,i} sslide,i: s s l i d e , i = W c , i h s l i d e , i s_{slide,i}=W_{c,i}h_{slide,i} sslide,i=Wc,ihslide,i。我们在模型的注意力主干的每一层后使用dropout( P = 0.25 P=0.25 P=0.25)进行正则化。
对于推理,通过对WSI级预测分数应用softmax函数来计算每个类的预测概率分布。
Instance-level clustering
为了进一步鼓励学习特定于类的特征,我们在训练期间加入一个额外的二值聚类目标。对于 N N N个类中的每一个,在第一个层 W 1 W_{1} W1之后加一个全连接层。将第 i i i个类对应的聚类层权重记为 W i n s t , i ∈ R 2 × 512 W_{inst,i}\in\R^{2\times 512} Winst,i∈R2×512,则第 k k k个patch预测的聚类分数为 p i , k p_{i,k} pi,k: p i , k = W i n s t , i h k p_{i,k}=W_{inst,i}h_{k} pi,k=Winst,ihk。
鉴于我们无法访问patch级标签,我们使用注意力网络的输出在每次训练迭代中为每张WSI生成伪标签,以监督聚类。聚类中只优化最强参与和最弱参与的区域。为了避免混淆,对于给定的WSI,对于GT标签 Y ∈ { 1 , . . . , N } Y\in\left\{1,...,N\right\} Y∈{1,...,N},我们将GT类别对应的注意力分支 W a , Y W_{a,Y} Wa,Y称为"in-the-class",其余的 N − 1 N-1 N−1个注意力分支称为"out-the-class"。如果将in-the-class的注意力分数的排序列表(升序)表示为 a ~ Y , 1 , . . . , a ~ Y , K \widetilde{a}_{Y,1},...,\widetilde{a}_{Y,K} a Y,1,...,a Y,K,我们将注意力得分最低的 B B B个patch分配给负类标签 ( y Y , b = 0 ) (y_{Y,b}=0) (yY,b=0),其中, 1 ≤ b ≤ B 1\leq b\leq B 1≤b≤B。注意力得分最高的 B B B个patch分配给正类标签 ( y Y , b = 1 ) (y_{Y,b}=1) (yY,b=1),其中, B + 1 ≤ b ≤ 2 B B+1\leq b\leq 2B B+1≤b≤2B。直观地说,由于在训练过程中每个注意力分支都受到WSI级别标签的监督,因此高注意力分数的 B B B个patch被期望成为 Y Y Y类别的强参与阳性证据,而低注意分数的 B B B个patch被期望成为 Y Y Y类别的强参与阴性证据。聚类任务可以直观地解释为约束patch级特征空间 h k h_k hk,使每个类别的强参与特征证据与其阴性证据线性可分。
对于癌症亚型问题,所有类别通常被认为是互斥的(也就是说,它们不能出现在同一张WSI中),因为将in-the-class注意力分支中最受关注和最不受关注的片段分别聚类为正证据和负证据,因此对N−1个out-the-class注意力分支施加额外的监督是有意义的。也就是说,给定GT的WSI标签 Y Y Y,任取类别 i i i不属于 Y Y Y,如果我们假设WSI上的patch都不属于 i i i类,那么注意力得分最高的 B B B个patch就不能成为 i i i类的正证据(由于互斥性)。
因此,除了对从in-the-class注意力分支中选择的 2 B 2B 2B个patch进行聚类外,还将所有out-the-class注意力分支中最受关注的前 B B B个patch分配为负聚类标签,因为它们被认为是假阳性证据。另一方面,如果互斥性假设不成立(例如,癌症与非癌症问题,其中一张WSI可以包含来自肿瘤组织和正常组织的patch),那么就不会监督来自out-the-class分支的高注意力的patch的聚类,因为我们不知道它们是否为假阳性。
实例级聚类算法如下:
Smooth SVM loss
对于实例级聚类任务,我们使用平滑的top-1 SVM loss,它是基于多分类SVM loss的,神经网络模型输出一个预测分数向量 s s s,其中 s s s中的每个条目对应于模型对单个类的预测。给定所有GT标签 y ∈ { 1 , . . . , N } y\in\left\{1,...,N\right\} y∈{1,...,N},多类别SVM loss对分类器进行线性惩罚,仅当该差值大于指定的裕度 α α α时,对GT类的预测分数与其余类的最高预测分数之间的差值进行惩罚。Smooth变体(公式5)在多分类SVM损失中加入了温度标度 τ τ τ,它已被证明具有非稀疏梯度的无限可微性,并且在有效实现算法时适用于深度神经网络的优化。平滑支持向量机损失可以看作是广泛使用的交叉熵分类损失的一种推广,适用于不同的边界有限值选择和不同的温度尺度。
经验表明,当数据标签有噪声或数据有限时,向损失函数引入margin可以减少过拟合。在训练过程中,创建的用于监督实例级聚类任务的伪标签必然是有噪声的。也就是说,强参与的patch可能不一定对应于GT类,同样,弱参与的patch也不能保证是该类的负证据。因此,代替广泛使用的交叉熵损失,将二进制top-1平滑SVM损失应用于网络聚类层的输出。在所有的实验中, α α α和 τ τ τ都被设置为1.0。
训练细节
在训练过程中,WSI被随机采样。每张WSI的多项采样概率与GT类的频率成反比(来自代表性不足的类的WSI相对于其他类更有可能被采样),以减轻训练集中的类不平衡。注意力模块的权重参数随机初始化,并使用WSI标签和模型其余部分端到端训练,总的损失是WSI级损失 L s l i d e L_{slide} Lslide和instance-level损失 L p a t c h L_{patch} Lpatch之和。
为了计算 L s l i d e L_{slide} Lslide,使用标准交叉熵损失将 s s l i d e s_{slide} sslide与真实的WSI级标签进行比较,为了计算 L p a t c h L_{patch} Lpatch,使用二元Smooth SVM损失将每个采样patch的实例级聚类预测分数 p k p_k pk与相应的伪聚类标签进行比较(回想一下,对于非亚型问题,从in-the-class分支中总共采样了 2 B 2B 2B个patch。而对于亚型问题,从in-the-class分支中采样 2 B 2B 2B个patch,从 N − 1 N−1 N−1个out-the-class注意力分支各采样 B B B个patch)。
数据集摘要见补充表8: