论文笔记
资料
1.代码地址
https://github.com/BIT-DA/RoTTA
2.论文地址
https://arxiv.org/abs/2303.13899
3.数据集地址
coming soon
1论文摘要的翻译
测试时间自适应(TTA)旨在使预先7训练的模型适用于仅具有未标记测试数据流的测试分布。大多数以前的TTA方法已经在简单的测试数据流上取得了很大的成功,例如来自单个或多个分布的独立采样数据。然而,在自动驾驶等现实世界应用的动态场景中,这些尝试可能会失败,其中环境逐渐变化,测试数据随着时间的推移进行相关采样。在这项工作中,我们探索了这样的实际测试数据流来动态部署该模型,即实际测试时间适应(PTTA)。为此,针对PTTA中复杂的数据流,提出了一种健壮的测试时间适配(ROTTA)方法。更具体地说,我们提出了一种稳健的批归一化方案来估计归一化统计量。同时,在考虑时效性和不确定性的基础上,利用内存库对类别平衡数据进行采样。此外,为了稳定训练过程,我们开发了一种教师-学生模型的时间感知重权策略。大量的实验证明,ROTTA算法能够在相关采样数据流上实现连续的测试时间自适应。我们的方法易于实现,是快速部署的一个很好的选择。
1 介绍
面对不断变化的分布,随着误差梯度的累积,伪标记法或 熵最小值法等传统算法变得更加不可靠。此外,测试样本之间的高度相关性导致了对批量归一化统计量的错误估计和模型的崩溃。在这种分析的驱动下,适应这样的数据流将会遇到两大障碍
1)批次归一化统计中的错误估计导致测试样本的错误预测,从而导致无效的适应;
2)模型很容易或很快地对相关抽样造成的分布过度拟合。因此,这种动态情景迫切需要一种新的TTA范式来实现稳健的适应。
我们推出了一个更现实的TTA设置,在测试阶段,分布变化和相关采样同时发生。我们称这种实用的测试时间适应,或简称为PTTA。为了更清楚地了解PTTA和以前的设置之间的异同,我们在图1中将它们可视化,并在表1中进行总结。
本方法实现的大致思路:
- 首先用指数移动平均维护的全局统计来替换当前批次的错误统计。它在BatchNorm层中估计统计量是一种更稳定的方式。
- 考虑buffered样本的时效性和不确定性的情况下,用类别平衡抽样模拟了一批内存中的类独立数据。较新且不太确定的样本以更高的优先级保存在内存中。有了这批类别均衡、及时、有信心的样本,我们就可以获得当前分布的快照。
- 我们引入了一种时间感知的重加权策略,该策略考虑了记忆库中样本的时效性,并利用师生模型进行了稳健的自适应。
2论文的创新点
- 提出了一种新的更适合实际应用的测试时间自适应机制,即实际测试时间自适应(PTTA)。PTTA既考虑了分布变化,又考虑了相关抽样。
- 我们在PTTA中对现有方法的性能进行了基准测试,发现它们只考虑了问题的一个方面,导致了无效的适应。
*我们提出了一种健壮的测试时间自适应方法(ROTTA),它更全面地考虑了PTTA挑战。实施的简便性和有效性使其成为一个实用的部署选项。 - 我们在常见的TTA基准,即CIFAR-10-C和CIFAR-100C以及大规模DomainNet数据集上广泛展示了PTTA的实用性和ROTTA的有效性。ROTA获得了最先进的结果,大大超过了最佳基准(分别将CIFAR-10-C、CIFAR-100-C和DomainNet的平均分类错误分别减少了5.9%、5.5%和2.2%)。
3 Robust Test-Time Adaptation方法的概述
3.1 问题定义
给定在源域 D S = { ( X s , Y s ) } DS=\{(Xs,Ys)\} DS={(Xs,Ys)}上预先训练的参数为 θ 0 θ_0 θ0的模型 f θ 0 f_{θ_0} fθ0,所提出的实用测试时间自适应旨在使 f θ 0 f_{θ_0} fθ0适应在线未标记样本流 X 0 , X 1 , . . . , X t X_0,X_1,...,X_t X0,X1,...,Xt,其中 X t X_t Xt是分布 P t e s t P_{test} Ptest中随时间t连续变化的一批高度相关的样本。更具体地说,在模型测试的时候,随着时间的推移,测试分布 P t e s t P_{test} Ptest作为 P 0 , P 1 , … , P ∞ P_0,P_1,…,P_∞ P0,P1,…,P∞连续变化。在时间步长 t t t,我们将收到一批未标记和相关的样本来自 P t e s t P_{test} Ptest的 X t X_t Xt。接下来,将 X t X_t Xt输入到模型 f θ t f_{θ_t} fθt中,并且该模型需要使其自身适应当前的测试数据流并动态地调整 f θ t ( X t ) f_{θ_t}(X_t) fθt(Xt)。
事实上,这种设置在很大程度上是由动态场景中部署模型的实际需求驱动的。以§1中提到的自动驾驶为例,测试样本高度相关,数据分布随着天气或位置的变化而不断变化。另一个例子是智能监控的情况,相机会在一定的时间连续捕捉到更多的人,比如下班后,但在工作时间会越来越少。同时,白天和晚上的光照条件也在不断变化。
部署的模型应该在这样的动态场景中稳健地进行调整。总之,在现实世界中,分布变化和数据关联往往是同时发生的。因此,现有的TTA方法在从这样的动态场景中采样测试流时可能会变得不稳定。
ROTTA的概述如图2所示。
3.2Robust Test-Time Adaptation 描述
Robust batch normalization (RBN)
批归一化(batch normalization,BN)是一种广泛使用的训练技术,它可以加快网络的训练和收敛速度,并通过降低梯度爆炸和消失的风险来稳定训练过程。在训练时,给定特征图 F ∈ R B × C × H × W F\in\mathbb{R}^{B\times C\times H\times W} F∈RB×C×H×W作为BN层的输入,按通道方式计算平均 µ ∈ R C µ\in\mathbb{R}^{C} µ∈RC和方差 σ 2 ∈ R C σ^2\in\mathbb{R}^{C} σ2∈RC如下: μ c = 1 B H W ∑ b = 1 B ∑ h = 1 H ∑ w = 1 W F ( b , c , h , w ) , (1) σ c 2 = 1 B H W ∑ b = 1 B ∑ h = 1 H ∑ w = 1 W ( F ( b , c , h , w ) − μ c ) 2 . (2) \mu_{c}=\frac{1}{BHW}\sum_{b=1}^{B}\sum_{h=1}^{H}\sum_{w=1}^{W}F_{(b,c,h,w)} ,\text{(1)}\\\sigma_{c}^{2}=\frac{1}{BHW}\sum_{b=1}^{B}\sum_{h=1}^{H}\sum_{w=1}^{W}\left(F_{(b,c,h,w)}-\mu_{c}\right)^{2}.\text{(2)} μc=BHW1b=1∑Bh=1∑Hw=1∑WF(b,c,h,w),(1)σc2=BHW1b=1∑Bh=1∑Hw=1∑W(F(b,c,h,w)−μc)2.(2)
然后,以通道方式标准化和细化特征图,如下所示
B N ( F ( b , c , h , w ) ; μ , σ 2 ) = γ c F ( b , c , h , w ) − μ c σ c 2 + ϵ + β c , ( 3 ) BN(F_{(b,c,h,w)};\mu,\sigma^2)=\gamma_c\frac{F_{(b,c,h,w)}-\mu_c}{\sqrt{\sigma_c^2+\epsilon}}+\beta_c ,\quad(3) BN(F(b,c,h,w);μ,σ2)=γcσc2+ϵF(b,c,h,w)−μc+βc,(3)
其中 γ , β ∈ R c γ,β\in\mathbb{R}^{c} γ,β∈Rc是层中的可学习参数, ϵ > 0 ϵ\gt0 ϵ>0,是数值稳定性的常量。同时,在训练过程中,BN层维护一组全局运行均值和运行方差 ( µ s , σ s 2 ) (µ_s,σ^2_s) (µs,σs2)以供推理。
由于测试时会发生域间数据shift,导致全局统计量 ( µ s , σ s 2 ) (µ_s,σ^2_s) (µs,σs2)对测试特征归一化不准确,导致性能显著下降。为了解决上述问题,一些方法使用当前批次的统计数据进行归一化。不幸的是,当测试样本在PTTA设置下具有很高的相关性时,当前批次的统计信息也无法正确地规格化特征映射,如图c所示。具体地说,BN的性能随着数据相关性的增加而迅速降低。
基于以上分析,我们提出了一种稳健的批归一化模块,该模块维护一组全局统计量 ( µ g , σ g 2 ) (µ_g,σ^2_g) (µg,σg2)来稳健地归一化特征映射。在整个测试时间自适应之前, ( µ g , σ g 2 ) (µ_g,σ^2_g) (µg,σg2)被初始化为预训练模型的运行均值和方差 ( µ s , σ s 2 ) (µ_s,σ^2_s) (µs,σs2)。在调整模型时,我们首先用指数移动平均来更新全局统计量,即: μ g = ( 1 − α ) μ g + α μ , (4) σ g 2 = ( 1 − α ) σ g 2 + α σ 2 , (5) \mu_{g}=(1-\alpha)\mu_{g}+\alpha\mu ,\text{(4)}\\\sigma_{g}^{2}=(1-\alpha)\sigma_{g}^{2}+\alpha\sigma^{2},\text{(5)} μg=(1−α)μg+αμ,(4)σg2=(1−α)σg2+ασ2,(5),(5)其中 ( µ, σ 2 ) (µ,σ^2) (µ,σ2)是memory bank中buffer samples的统计。然后我们将特征归一化并仿射为等式(3)配合 ( µ g , σ g 2 ) (µ_g,σ^2_g) (µg,σg2)。在对测试样本进行推断时,我们直接使用 ( µ g , σ g 2 ) (µ_g,σ^2_g) (µg,σg2)来计算输出公式为Eq(3)。虽然简单,但RBN足够有效地解决了PTTA测试流上的归一化问题。
3.2.2 Category-balanced sampling with timeliness and uncer-tainty (CSTU).
在PTTA设置中,时间 t t t时候的测试样本 X t X_t Xt之间的相关性导致观察到的分布 P ^ t e s t \widehat{\mathcal{P}}_{test} P test和测试分布 P t e s t \mathcal{P_{test}} Ptest之间的偏差。具体地说,边缘标签分布 p ( y ∣ t ) p(y|t) p(y∣t)往往不同于 p ( Y ) p(Y) p(Y)。随着时间 t t t的推移,随着 X t X_t Xt的不断学习,可能会导致模型适应不可靠的分布 P ^ t e s t \widehat{\mathcal{P}}_{test} P test,从而导致无效的适应和增加模型崩溃的风险。
为了解决这个问题,我们提出了一种容量为 N N N的类别平衡memory bank M M M,该存储库在更新时考虑了样本的及时性和不确定性。特别是,我们采用测试样本的预测作为伪标签来指导 M M M的更新。同时,为了保证类别之间的平衡,我们将 M M M的容量平均分配给每个类别,并首先替换主要类别的样本(参见算法1中的第5-9行)。此外,由于测试分布的不断变化,模型中的旧样本价值有限,甚至可能削弱模型适应当前分布的能力。此外,正如所建议的那样,高不确定性的样本总是产生错误的梯度信息,这可能会阻碍模型适应。
考虑到这一点,我们将M中的每个样本附加一组启发式 ( A , U ) ({\mathcal{A}},{\mathcal{U}}) (A,U),其中 A {\mathcal{A}} A被初始化为0,并随着时间 t t t增加, A {\mathcal{A}} A是样本的存在的时间, U {\mathcal{U}} U是作为预测的熵计算的不确定性。接下来,我们结合及时性和不确定性来计算一个启发式分数,即带有及时性和不确定性的类别平衡抽样,如下: H = λ t 1 1 + exp ( − A / N ) + λ u U log C , ( 6 ) \mathcal{H}=\lambda_t\frac{1}{1+\exp(-\mathcal{A}/\mathcal{N})}+\lambda_u\frac{\mathcal{U}}{\log\mathcal{C}} ,\quad(6) H=λt1+exp(−A/N)1+λulogCU,(6),(6)其中 λ t 和 λ u λ_t和λ_u λt和λu权衡了实时性和不确定性,为了简单起见,所有实验的 λ t 和 λ u λ_t和λ_u λt和λu都设置为1.0。 C C C是类别的数量。
我们在算法1中总结了我们的抽样算法。使用CSTU,我们可以获得当前测试分布 P t e s t \mathcal{P_{test}} Ptest的健壮快照,并有效地使模型适应于它。
3.2.3 Robust training with timeliness.
实际上,在用我们的RBN替换BN层并获得CSTU抽样选择的memory bank后,我们可以直接采用广泛使用的伪标签或熵最小化技术来进行测试时间适配。然而,我们注意到,太旧或不可靠的实例仍然有机会留在M中,因为保持类别平衡是重中之重。此外,过于激进的模型更新会使不可靠的类别平衡,导致不稳定的适应。同时,分布变化引起的误差累积也使得上述方法不可行。
为了进一步降低来自旧的和不可靠的实例的误差梯度信息的风险并稳定自适应,我们使用稳健的无监督学习方法,提出了teacher-student模型,并提出了时效性重加权策略。此外,为了时间效率和稳定性,在自适应过程中只训练RBN中的仿射参数。
在时间步 t t t时,在用教师模型 f θ t T f_{θ^T_t} fθtT推断相关数据 X t X_t Xt并用 X t X_t Xt更新Memory bank M之后,我们开始更新学生模型 f θ t S f_{θ^S_t} fθtS和教师模型 f θ t T f_{θ^T_t} fθtT。首先,我们通过最小化以下损失来更新学生模型 θ t S θ^S_t θtS→ θ t + 1 S θ^S_{t+1} θt+1S的参数:
L r = 1 Ω ∑ i = 1 Ω L ( x i M , A i ; θ t T , θ t S ) , ( 7 ) \mathcal{L}_{r}=\frac{1}{\Omega}\sum_{i=1}^{\Omega}\mathcal{L}(x_{i}^{\mathcal{M}},\mathcal{A}_{i};\theta_{t}^{T},\theta_{t}^{S}) ,\quad(7) Lr=Ω1i=1∑ΩL(xiM,Ai;θtT,θtS),(7)
其中 Ω = ∣ M ∣ \mathcal{Ω}=|\mathcal{M}| Ω=∣M∣是内存块的总占用量, x i M x_{i}^{\mathcal{M}} xiM和 A i ( i = 1 , . . . ,Ω ) A_i(i=1,...,Ω) Ai(i=1,...,Ω)分别是内存库中的实例及其使用时长。随后,通过指数移动平均将教师模型更新为
θ t + 1 T = ( 1 − ν ) θ t T + ν θ t + 1 S . , ( 8 ) \theta_{t+1}^{T}=(1-\nu)\theta_{t}^{T}+\nu\theta_{t+1}^{S} . ,\quad(8) θt+1T=(1−ν)θtT+νθt+1S.,(8)
为了从内存库中计算实例 x i M x_{i}^{\mathcal{M}} xiM的损失值,时效性重新加权项计算如下
E ( A i ) = exp ( − A i / N ) 1 + exp ( − A i / N ) , ( 9 ) E(\mathcal{A}_i)=\frac{\exp(-\mathcal{A}_i/\mathcal{N})}{1+\exp(-\mathcal{A}_i/\mathcal{N})} ,\quad(9) E(Ai)=1+exp(−Ai/N)exp(−Ai/N),(9)
其中 A i A_i Ai是 x i M x_{i}^{\mathcal{M}} xiM的年龄, N N N是内存库的存储能力。然后,我们计算来自学生模型的强增广视点 x i ′ ′ x^{''}_i xi′′的软最大预测 P S ( y ∣ x ‘’ i ) P_S(y|x‘’i) PS(y∣x‘’i)和来自教师模型的弱增广视点x^'_i的软最大预测PS(y|x’i)之间的交叉熵如下:
ℓ ( x i ′ , x i ′ ′ ) = − 1 C ∑ c = 1 c p T ( c ∣ x i ′ ) log p S ( c ∣ x i ′ ′ ) . ( 10 ) \ell(x_i',x_i'')=-\frac{1}{\mathcal C}\sum_{c=1}^{c}p_{T}(c|x_{i}')\log p_{S}(c|x_{i}'') .\quad(10) ℓ(xi′,xi′′)=−C1c=1∑cpT(c∣xi′)logpS(c∣xi′′).(10)最后,配备了公式(9)和公式(10),公式的右侧。公式(7)约化为损失 L ( x i M , A i ; θ t T , θ t S ) = E ( A i ) ℓ ( x i ′ , x i ′ ′ ) . \mathcal{L}(x_i^{\mathcal{M}},\mathcal{A}_i;\theta_t^T,\theta_t^S)=E(\mathcal{A}_i)\ell(x_i',x_i'') . L(xiM,Ai;θtT,θtS)=E(Ai)ℓ(xi′,xi′′).
综上所述,由于配备了RBN、CSTU和具有时效性的健壮训练,我们的ROTA能够有效地使任何预先训练的模型适应动态场景。
4 论文实验
数据集
CIFAR10-C
CIFAR100-C
DomainNet
是迄今为止用于领域适应的最大和最难处理的数据集,由345个类别的约60万张图像组成。它由六个不同的域组成,包括剪贴画(CLP)、信息图(INF)、绘画(PNT)、快速绘制(QDR)、真实(REL)和素描(Sketch)。我们首先在六个领域中的一个领域的训练集上预训练源模型,并在其余五个领域的测试集上验证所有基线方法。
补充细节
所有实验都是在PyTorch框架下进行的。在对腐败具有稳健性的情况下,遵循前面的方法,我们从RobustBench基准中获得了预训练模型,包括用于CIFAR10→CIFAR10-C的WildResnet-28和用于CIFAR100→CIFAR100-C的ResneXt-29。然后,我们逐一更改最高严重程度为5的测试损坏,以模拟PTTA中测试分布随时间的连续变化。在大空隙下泛化的情况下,我们对DomainNet中的每个域通过标准分类损失来训练ResNet-101,并不断地使它们适应除源域之外的不同域。同时,我们利用Dirichlet分布来模拟所有数据集的相关采样测试流。优化采用学习速率为1×10−3的ADAM优化器,1弱化为RESIZE+CenterCrop。强增强是Clip、ColorJitter和RandomAffine等九种操作的组合。β=0.9.为了进行公平的比较,我们将所有方法的批处理大小设置为,将Rotta的内存库容量设置为N=64。关于超参数,我们在所有实验中采用了一组统一的ROTTA值,包括α=0.001、ν=0.001、λt=1.0、λu=1.0和δ=0.1。附录中提供了更多详细信息。
4.1和现有方法比较
在腐败情况下的健壮性。CIFAR10→CIFAR10-C和CIFAR100→CIFAR100-C的分类错误分别见表2和表3。随着时间的推移,我们改变了当前最高严重性为5的腐败类型,并同时将样本数据关联起来用于推理和适应。相同的测试流在所有比较的方法之间共享。从表2和表3可以看出,与以前的方法相比,ROTTA获得了最好的性能。此外,ROTA算法比次优算法CIFAR10→CIFAR10-C和CIFAR100→CIFAR100-C分别提高了5.9%和5.5%的性能,验证了ROTTA算法在PTTA下适应模型的有效性。
4.2消融实验
4.2.1 Effect of each component.
(1)ROTTA w/o RBN,在TEAT[70]中用测试时间BN替换RBN;
(2)ROTTA w/o CSTU,直接在测试流上适应模型;
(3)ROTTA w/o鲁棒训练(RT),直接调整模型,仅采用最小熵。
结果如表5所示
我们可以观察到所有变体都出现了显著的性能下降,证明了我们所提出的方法的每一部分对PTTA都是有效的。以一个分量为例,在没有对特征映射进行稳健归一化的情况下,ROTTA算法在CIFAR10-C和CIFAR100-C上的性能分别下降了50.2%和16.3%,证明了RBN具有足够的鲁棒性来解决相关采样数据流的归一化问题。CSTU通过维护测试分发版本的及时且可靠的快照,使ROTTA能够适应更稳定的分发版本。同时,具有时效性的稳健训练大大减少了错误的积累。在PTTA下,每个组成部分都发挥着重要的作用,以实现有效的适应。
4.2.2 Effect of the distribution changing order.
分布变化顺序的影响。为了排除固定顺序的分布变化的影响,我们对CIFAR10-C和CIFAR100-C分别进行了10种不同的变化序列的实验
如图4a和图4b所示,无论采用何种设置,ROTA都能取得优异的效果。相关采样测试流的详细结果如表6所示,ROTTA在CIFAR10C和CIFAR100-C上分别取得了4.3%和4.7%的进度。这表明ROTTA可以在长期场景中稳健而有效地适应模型,其中分布不断变化,测试流被独立或相关地采样,使其成为模型部署的良好选择。
4.2.3 Effect of Dirichlet concentration parameter δ.
我们在CIFAR100-C上改变δ的值,并将ROTTA与图4C中的其他方法进行比较。随着δ值的增加,BN、PL、TANT和COTTA的性能迅速下降,因为它们没有考虑测试样本之间日益增加的相关性。注[19]对相关采样的测试流是稳定的,但没有考虑分布的变化,导致无效的适应。同时,测试样本之间较高的相关性将使标签的传播更加准确,这就是为什么LAME的结果略有改进。最后,优异稳定的结果再次证明了ROTA的稳定性和有效性。
4.2.4 Effect of batch size.
在实际场景中,考虑到部署环境可能使用不同的测试批次大小,我们使用不同的测试批次大小值进行实验,结果如图4d所示。为了进行公平的比较,我们控制了ROTTA模型的更新频率,以便反向传播中涉及的样本数量相同。随着批量的增加,我们可以看到,除了LAME略有下降外,所有比较的方法都有显著的改善。
这是因为批次中的类别数量随着批次大小增加,导致总体相关性变得较低,但标签的传播变得更加困难。最重要的是,ROTTA在不同的批次大小上取得了最好的结果,再次证明了它在动态场景中的健壮性。
5总结
这项工作提出了一种更现实的TTA设置,即在测试阶段同时进行分布变化和相关采样,即实际测试时间适应(PTTA)。针对PTTA算法存在的问题,提出了一种针对复杂数据流的稳健测试时间自适应方法(ROTTA)。更具体地说,通过稳健的批归一化来估计一组用于特征地图归一化的稳健统计。同时,在考虑时效性和不确定性的情况下,采用内存库对测试分布进行分类均衡抽样,以获取测试分布的快照。此外,我们还提出了一种基于师生模型的时间感知重加权策略,以稳定适应过程。