论文笔记
资料
1.代码地址
https://github.com/SakurajimaMaiii/TSD
2.论文地址
https://arxiv.org/abs/2303.10902
3.数据集地址
论文摘要的翻译
TTA在接收训练分布外的测试域样本时对深度神经网络进行自适应。在这样设置下,模型只能访问在线未标记的测试样本和训练域上的预训练模型。由于源域和目标域之间的域差距,我们首先将TTA作为一个特征修正问题来解决。之后,我们根据对齐和一致性两个方面来讨论测试时间特征的修正。对于测试时间特征一致性,我们提出了一种测试时间自蒸馏策略,以确保当前批次和所有先前批次的表示之间的一致性。对于测试时间特征对齐,我们提出了一种记忆的空间局部聚类策略,以对齐即将到来的批次的邻域样本之间的表示。为了解决常见的噪声标签问题,我们提出了熵和一致性滤波器来选择和丢弃可能的噪声标签。为了证明我们的方法的可扩展性和有效性,我们在四个领域泛化基准和四个具有不同骨干的医学图像分割任务上进行了实验。实验结果表明,我们的方法不仅稳定地提高了基线,而且优于现有的最先进的测试时间自适应方法。
1背景
当训练和测试数据从同一分布中采样时,深度学习在计算机视觉任务中取得了巨大成功。然而,在现实世界的应用中,当从不同的分布中收集训练(源)数据和测试(目标)数据时,通常会出现性能下降,即域偏移。在实践中,测试样本可能会遇到不同类型的变化或损坏。深度学习模型对这些变化或损坏很敏感,这可能会导致性能下降。
为了解决这个具有挑战性但实际的问题,已经提出了各种工作来在TTA策略。测试时间训练(TTT)在训练和测试阶段使用自我监督任务来调整模型。这种范式在训练和测试阶段都依赖于额外的模型修改,这在现实世界中是不可行和不可扩展的。
为了解决上述问题,本文将TTA作为一个表示修正问题来处理。在TTA的测试阶段,访问的模型已经学习了专门用于源域的特征表示,并且由于域差异较大,可能会生成目标域的表示不准确。所以有必要校正目标域的特征表示。为了实现目标域的更好表示,我们利用常用的表征质量量度,这些量度可以概括为特征对齐和均匀性。对齐是指相似的图像应该具有相似的表示,而均匀性是指不同类别的图像应该在潜在空间中尽可能均匀地分布。以前关于TTA的大多数工作都可以从所提出的表示修正的角度进行归纳。之前的工作没有一种方法同时从表示对齐和一致性来解决TTA问题。在本文中,我们发现了这一局限性,并提出了一种新的方法,从这两个特性中校正特征表示。我们将TTA中的两个性质公式化为测试时间特征一致性和测试时间特征对齐。
- Test Time Feature Uniformity
根据特征一致性的观点,我们希望来自不同类别的测试图像的表示应该尽可能均匀地分布。为了更好地处理目标域中的所有样本,我们建议为每个到达的测试样本引入历史时间信息。建立了一个内存库来存储所有到达样本的特征表示和logit,以维护来自先前数据的有用信息。然后,我们使用内存库中的logits和特征来计算每个类的伪原型。之后,为了保证当前批次样本的一致性,基于原型的分类和模型预测(线性分类器的输出)的预测分布应该相似,即一类当前图像的特征分布应该与同一类以前所有图像的特征分配一致。这可以减少错误分类的异常值样本的偏差,以形成更均匀的潜在空间。
2论文的创新点
- 我们从特征对齐和一致性的角度提出了测试时间自适应的新视角。所提出的测试时间特征一致性促进了当前批次样本的表示以及所有先前样本的一致性。测试时间特征对齐根据测试样本在潜在空间中的邻居来操纵测试样本的表示,以基于伪标签来对齐表示。
- 为了解决TTA中的在线设置和噪声标签问题,我们提出了两种互补的策略:用于测试时间特征一致性的无监督自蒸馏和用于测试时间特性对齐的记忆空间局部聚类。我们还提出了熵滤波器和一致性滤波器,以进一步减轻噪声标签的影响。
- 实验表明,我们提出的方法在域泛化基准和医学图像分割基准上都优于现有的测试时自适应方法。
3 论文方法的概述
图1展示了方法的总体流程。我们将在本节中描述问题设置和方法的详细信息。
3.1 准备工作
在测试时间自适应(TTA)中,我们只能在线获得目标域的未标记图像,并在源域上预训练模型。使用源域上的标准经验风险最小化来训练源模型,例如图像分类任务的交叉熵损失。给定在 D s {\mathcal D_s} Ds上训练的模型,我们的目标是使用未标记的目标数据 { x i } ∈ D t , i ∈ { 1 … N } \{x_i\}∈\mathcal D_t,i∈\{1…N\} {xi}∈Dt,i∈{1…N}来适配该模型,其中 x i x_i xi表示目标域 D t D_t Dt的第 i i i个图像, N N N表示目标图像的数量, D s D_s Ds表示源域。在测试过程中,我们使用在源域 D s D_s Ds上训练的源模型参数初始化模型 g = f ◦ h g=f◦h g=f◦h,其中 f f f表示主干, h h h表示线性分类头。图像 x i x_i xi的模型g的输出表示为 p i = g ( x i ) ∈ R C p_i=g(x_i)\in\mathbb{R}^C pi=g(xi)∈RC,其中C是类的数量。
3.2 测试时自蒸馏方法
在自适应过程中,给定一批未标记的测试样本,我们可以通过预先训练的模型生成图像预测 z i = f ( x i ) z_i=f(x_i) zi=f(xi)、 logits p i = h ( z i ) \text{logits }p_i=h(z_i) logits pi=h(zi)和伪标签$
\hat{y}_i=\arg\max p_i$。然后,我们维护一个memory bank KaTeX parse error: Undefined control sequence: \machcal at position 1: \̲m̲a̲c̲h̲c̲a̲l̲ ̲B={(z_i,p_i)}来存储图像预测 z i 和 l o g i t s p i z_i和logits p_i zi和logitspi。用线性分类器的权重初始化mermory bank。当目标样本 x i x_i xi到来时,对于每个图像,我们将图像预测 z i 和 l o g i t s p i z_i和logits p_i zi和logitspi添加到存储库中。为了建立当前样本和所有先前样本之间的关系,应为每个类生成伪原型。k类的原型可以公式化为
c k = ∑ i z i 1 [ y ^ i = k ] ∑ i 1 [ y ^ i = k ] (1) c_k=\frac{\sum_iz_i\mathbb{1}[\hat{y}_i=k]}{\sum_i\mathbb{1}[\hat{y}_i=k]}\text {(1)} ck=∑i1[y^i=k]∑izi1[y^i=k](1)其中1(·)是一个指示符函数,如果参数为true,则输出值1,否则输出值0。然而,一些伪标签可能被分配给错误的类,导致不正确的原型计算。我们使用香农熵滤波器来过滤噪声标签。对于预测 p i p_i pi,其熵可以计算为 H ( p i ) = − ∑ σ ( p i ) log σ ( p i ) H\left(p_{i}\right)=-\sum\sigma(p_{i})\log\sigma(p_{i}) H(pi)=−∑σ(pi)logσ(pi),其中 σ σ σ表示softmax运算。我们的目标是用高熵过滤不可靠的特征或预测,因为较低的熵通常意味着较高的准确性。具体而言,对于每个类,将忽略存储器组中具有前M个最高熵的图像嵌入。之后,我们使用过滤嵌入来计算原型,如等式所示1并将基于原型的分类输出定义为与类k的原型的特征相似性上的softmax: y i k = exp ( sin ( z i , c k ) ) ∑ k ′ = 1 C exp ( sin ( z i , c k ′ ) ) , ( 2 ) y_i^k=\frac{\exp\left(\sin(z_i,c_k)\right)}{\sum_{k^{\prime}=1}^C\exp\left(\sin(z_i,c_{k^{\prime}})\right)},\quad(2) yik=∑k′=1Cexp(sin(zi,ck′))exp(sin(zi,ck)),(2)其中 sin ( z i , c k ) \sin(z_i,c_k) sin(zi,ck)表示 z i 和 c k z_i和c_k zi和ck之间的余弦相似性。网络g的基于原型的分类结果yi和输出pi对于相同的输入应该共享相似的分布。因此,维持均匀性的损失建议为 L i ( p i , y i ) = − σ ( p i ) log y i . ( 3 ) \mathcal{L}_i(p_i,y_i)=-\sigma(p_i)\log y_i.\quad(3) Li(pi,yi)=−σ(pi)logyi.(3)
请注意, p i p_i pi是一个软伪标签,而不是硬伪标签。使用软标签的原因是软标签通常提供更多信息。通过使用所提出的测试时间自蒸馏,该网络可以映射当前样本的均匀性,以提高表示质量。
尽管我们在计算原型时使用熵过滤器来删除有噪声的标签,但仍然存在一些不可避免的错误预测。我们建议,为了获得可靠的样本,线性全连接层和基于原型的分类器的输出应该相似。因此,我们采用一致性过滤器来识别错误预测。特别是,如果线性分类器和基于原型的分类器在对logits执行argmax后产生相同的预测,即相同的结果,我们假设这个样本是可靠的。该策略可以使用图像xi的滤波器掩模来实现,如下所示 M i = 1 [ arg max p i = arg max y i ] . ( 4 ) \mathcal{M}_i=\mathbb{1}[\arg\max p_i=\arg\max y_i].\quad(4) Mi=1[argmaxpi=argmaxyi].(4)通过进行一致性过滤,我们进一步过滤不可靠的样本,无监督的自蒸馏损失可以公式化如下
L t s d = ∑ i L i ∗ M i ∑ i M i . ( 5 ) \mathcal{L}_{tsd}=\frac{\sum_i\mathcal{L}_i*\mathcal{M}_i}{\sum_i\mathcal{M}_i}.\quad(5) Ltsd=∑iMi∑iLi∗Mi.(5)
3.3 Memorized Spatial Local Clustering
如前所述,属于同一类的特征应在潜在空间中对齐。然而,由于目标域和源域之间的域间隙,这种情况在TTA中可能会有所不同。我们鼓励使用K近邻特征,而不是所有特征,以减少噪声标签的影响。一个简单的策略是在一批样本中添加一致性正则化。然而,历史时间信息被忽略,对齐效果较差。此外,有一个简单的解决方案,如果我们只使用一批样本进行对齐,该模型可以很容易地将所有图像映射到某个类。为了解决这些问题,我们将空间局部聚类和内存库连接起来。我们从检索图像x的存储体中的K个最近特征开始。基于我们的假设,图像x的logits应该与潜在空间中其最近邻的logits对齐。为了实现这一点,我们根据图像x的图像嵌入与其邻居之间的距离来对齐两种logits。公式如下 L m s l c = 1 K ∑ j = 1 K sin ( z , z j ) ( σ ( p ) − σ ( p j ) ) 2 , ( 6 ) \mathcal{L}_{mslc}=\frac{1}{K}\sum_{j=1}^{K}\sin(z,z_{j})(\sigma(p)-\sigma(p_{j}))^{2},\quad(6) Lmslc=K1j=1∑Ksin(z,zj)(σ(p)−σ(pj))2,(6)
其中 s i m ( z , z j ) sim(z,z_j) sim(z,zj)表示%z%和%z_j%之间的余弦相似性。 z j j = 1 K {z_j}^K_{j=1} zjj=1K表示memory bank B \mathcal B B中 z z z的 K K K个最近的图像嵌入, p j p_j pj表示相应的对数。如果 z j z_j zj和 z z z在特征空间中很接近,即 s i m ( z , z j ) sim(z,z_j) sim(z,zj)很大,则该目标函数将推动 p j 和 p p_j和p pj和p接近。我们分离了 s i m ( z , z j ) sim(z,z_j) sim(z,zj)的梯度,即 s i m ( z , z j ) sim(z,z_j) sim(z,zj)将被视为常数,以避免模型输出恒定结果而不考虑不同样本的琐碎解决方案。
3.4 优化目标函数
L = L t s d + λ L m s l c , ( 7 ) \mathcal{L}=\mathcal{L}_{tsd}+\lambda\mathcal{L}_{mslc},\quad(7) L=Ltsd+λLmslc,(7)其中λ是平衡不同损失函数的权衡参数。在我们的实现中,我们使用余弦相似度作为相似度度量。具体来说,我们定义 s i m ( x , y ) = x T y / ∣ ∣ x ∣ ∣ ∣ ∣ y ∣ ∣ sim(x,y)=x^Ty/||x||||y|| sim(x,y)=xTy/∣∣x∣∣∣∣y∣∣。在测试阶段,自适应以在线方式执行。具体来说,当在时间点T接收到图像 x T x_T xT时,模型状态会使用从最后一张图像 x T − 1 x_{T-1} xT−1更新的参数进行初始化。该模型在接收到新样本 x T x_T xT后产生预测 p T = g ( x T ) p_T=g(x_T) pT=g(xT),并使用方程7仅用一步梯度下降来更新模型。
4 论文实验
4.1 数据设置
- 数据集
- PACS
包含9991个示例和7个类别,这些示例和类别来自4个领域:艺术、漫画、照片和草图。 - OfficeHome
由4个领域组成:艺术、剪贴画、产品和真实,其中包括15588张图片和65个类别。 - VLCS
包括四个域:Caltech101、LabelMe、SUN09和VOC2007,包括10729个图像和5个类别。 - DomainNet
一个大规模的数据集有六个域 d ∈ { 剪贴画、信息图、绘画、快速绘制、真实、草图 } d∈\{剪贴画、信息图、绘画、快速绘制、真实、草图\} d∈{剪贴画、信息图、绘画、快速绘制、真实、草图},有586575幅图像和345个类。
- PACS
- 模型。
在主要实验中,我们在配备BN层的ResNet-18/50上评估了不同的方法,BN在关于DA和DG的文献中得到了广泛的应用。此外,我们在不同的主干上测试了我们的算法,包括视觉变换器(ViT-B/16)、ResNeXt-50(32×4d)、EfficientNet(B4)和MLP混合器(混合器-L16)。 - 实现
实施对于源训练,我们选择一个域作为目标域,其他域作为源域。我们将源域中的所有图像拆分为80%/20%,用于训练和验证。我们使用Adam优化器,学习率为5e-5。除了ViT-B/16和MLP Mixer,当我们使用ImageNet-21K预训练权重时,所有模型都使用ImageNet-1K预训练权重进行初始化。除了ViT-B/16和MLP Mixer之外,我们使用所有模型的torchvision实现;其他模型我们使用timm库中的实现。
对于测试时间自适应,我们使用Adam优化器[26]并将批大小设置为128。我们根据经验设置了权衡参数 λ = 0.1 λ=0.1 λ=0.1。所有可训练层都会更新,我们的方法不需要特殊选择。我们在所有实现中都使用PyTorch。我们报告了整个目标域的准确性以供评估。对于所有实验,我们报告了具有不同权重初始化、随机种子和数据分割的三次重复的平均值。
超参数搜索和模型选择。我们使用训练域验证进行源模型训练。我们在验证集中选择精度最高的模型。对于超参数搜索,我们在 { 1 e − 3 , 1 e − 4 , 1 e − 5 , 1 e − 6 } \{1e-3,1e-4,1e-5,1e-6\} {1e−3,1e−4,1e−5,1e−6}中搜索学习率 l r lr lr,特征滤波器的超参数 M ∈ { 1 , 5 , 20 , 50100 , N A } M∈\{1,5,20,50100,NA\} M∈{1,5,20,50100,NA},其中NA表示没有熵滤波器。我们强调,在访问测试样本之前,应选择TTA设置中的所有超参数。我们在训练域验证集上进行超参数搜索。 - 基线
我们将我们的方法与经验风险最小化(ERM)、Tent、T3A、SHOT-IM、ETA、测试时间批归一化(BN)、拉普拉斯调整最大似然估计(LAME)和伪标签(PL)进行了比较。
4.2 Comparative Study
- Comparison with TTA methods
表1和表2显示了四个不同数据集在ResNet-18/50上的结果。从表1和表2中可以看出,我们的方法通常达到最先进的性能。具体而言,所提出的方法提高了ERM的基线,每个数据集分别提高了4.8%、1.3%、0.5%、2.53%。其他测试时间自适应方法并没有像我们的方法那样稳定地改善基线。我们还可视化了不同方法在整个适应过程中的准确性变化,如图2所示。图2中的“批号”表示模型更新了多少批图像。我们可以看到,我们的方法可以更快地适应数据,并在目标域上实现更高的精度。
- 与DG/SFDA方法的比较
上述实验主要关注测试时间自适应,旨在在测试阶段对模型进行自适应。人们很自然地会问:与领域泛化或无源领域自适应方法相比,我们的方法怎么样?为了回答这个问题,我们首先将我们的方法与一些最近的领域泛化或无源领域自适应方法进行了比较,例如PACS和DomainNet数据集上的SWAD、PCL、DNA和F-mix。从表3中可以看出,我们的方法在领域泛化方面优于现有方法。此外,结合SWAD,我们使用ResNet50骨干网实现了令人印象深刻的91%的准确率。
我们还报告了具有挑战性的DomainNet数据集的结果。结果列于表4中。可以看出,我们的方法优于最先进的DG方法,如SWAD和DNA。此外,我们的方法可以显著改善SWAD,其性能优于当前最先进的SFDA方法。请注意,在线测试时间调整在现实世界中更灵活,因为SFDA以离线方式调整测试数据,这比TTA需要更多的训练循环和资源。
5总结
重点研究了测试时间自适应,并提出了一个新的视角,即测试时间自适应可以被视为一个特征修正问题。此外,特征修正包括两部分:测试时特征均匀性和测试时特征对齐。从特征均匀性的角度来看,我们提出了测试时间自蒸馏,以使目标特征在自适应中尽可能均匀。为了对齐同一类的特征,我们提出了记忆空间局部聚类,以鼓励潜在空间中特征表示之间的距离与伪对数对齐。大量实验证明,我们的方法不仅普遍提高了ERM基线,而且在四个领域泛化基准上优于现有的TTA或SFDA方法。此外,我们的方法可以应用于不同的主干。为了将我们的模型扩展到现实世界的应用中,我们在四个跨域医学图像分割任务上验证了我们的方法。实验结果表明,我们的方法有效、灵活、可扩展。