论文地址:https://arxiv.org/pdf/2403.07705
源码地址:https://github.com/jiaw-z/DKT-Stereo
概述
通过在合成数据上预训练的模型在未见领域上表现出强大的鲁棒性。然而,在现实世界场景中对这些模型进行微调时,其领域泛化能力可能会严重下降。本文探讨了在不损害模型对未见领域泛化能力的前提下,如何微调立体匹配网络。研究动机来源于比较真实标签(GT)与伪标签(PL)在微调过程中的差异:GT会退化,但PL能够保持领域泛化能力。通过实验发现,GT与PL之间的差异包含了有价值的信息,这些信息可以在微调过程中对网络进行正则化。文章还提出了一种框架,该框架包括一个冻结的教师网络、一个指数移动平均(EMA)教师网络和一个学生网络。核心思想是利用EMA教师网络来衡量学生网络学到的内容,并动态改进GT和PL以进行微调。作者将该框架与最先进的网络集成,并在多个真实世界数据集上评估了其有效性。本文的贡献如下:
- 首次尝试解决微调立体匹配网络时领域泛化能力下降的问题。我们基于真实标注和伪标注之间的差异将像素分为一致和不一致区域,并展示了它们在微调期间的不同作用。我们进一步分析了它们的作用,确定了导致领域泛化能力下降的两个主要原因:在没有足够正则化的情况下学习新知识和过度拟合真实标注细节。
- 提出了F&E模块来解决这两个原因,过滤掉不一致区域以避免正则化不足,并在一致区域集成视差以防止过度拟合真实标注细节。
- 引入了通过结合指数移动平均教师来动态调整不同区域的方法,实现了在保留领域泛化能力和学习目标域知识之间的平衡。
- 开发了DKT微调框架,可以轻松应用于现有网络,显著提高了它们对未见领域的鲁棒性,并同时实现了有竞争力的目标域性能。
方法
定义
文中将像素划分为三类区域:
一致区域 X c ( τ ) X_c(\tau) Xc(τ):伪标签 D ^ ( x i ) \hat{D}(x_i) D^(xi)与真实标签 D ( x i ) D^(x_i) D(xi)差异小于阈值 τ \tau τ的区域
X c ( τ ) = x ∣ ∣ D ^ ( x i ) − D ( x i ) ∣ < τ X_c(\tau) = {x \mid |\hat{D}(x_i) - D^(x_i)| < \tau} Xc(τ)=x∣∣D^(xi)−D(xi)∣<τ
该区域代表GT与PL高度对齐。
不一致区域 X i n c ( τ ) X_{inc}(\tau) Xinc(τ):GT与PL差异大于等于 τ \tau τ的区域
X i n c ( τ ) = x ∣ ∣ D ^ ( x i ) − D ∗ ( x i ) ∣ ≥ τ X_{inc}(\tau) = {x \mid |\hat{D}(x_i) - D^*(x_i)| \geq \tau} Xinc(τ)=x∣∣D^(xi)−D∗(xi)∣≥τ
网络在该区域可能遇到预训练未见的新挑战。
无效区域 X i n v a l i d X_{invalid} Xinvalid:因GT稀疏性导致无标注的区域
关键发现
GT微调的问题:
- 不一致区域:网络学习新知识但缺乏正则化,导致域泛化能力下降。
- 一致区域:网络可能过拟合GT的细节。
PL的优势:
- 一致区域:使用 X c ( 3 ) X_c(3) Xc(3)的PL微调可保留域泛化能力。
- 无效区域:PL在无标注区域的预测能提升泛化能力。
联合训练:直接联合GT和PL效果不佳,但通过Filter and Ensemble (F&E)模块动态优化标签后,可平衡目标域性能和模型原有泛化能力。
DKT Framework
冻结教师:
生成初始伪标签,微调过程中参数冻结,保留预训练模型的原始知识。
EMA教师:
通过学生网络权重动态更新
θ T ′ = m ⋅ θ T ′ + ( 1 − m ) ⋅ θ S ( m ∈ [ 0 , 1 ] ) \theta_{T'} = m \cdot \theta_{T'} + (1-m) \cdot \theta_S \quad (m \in [0,1]) θT′=m⋅θT′+(1−m)⋅θS(m∈[0,1])
EMA教师模型可以量化Student已掌握的知识,作为区域划分依据来衡量一致/不一致区域。
学生模型:
使用改进后的GT和PL进行训练,最终用于推理,通过动态调整学习区域防止过拟合。
F&E模块:
F&E-GT(处理真实标注)
区域划分:基于阈值τ(默认τ=3)将GT划分为:
1)不一致区域( X i n c X_{inc} Xinc):|GT - EMA预测| ≥ τ。
2)一致区域( X c X_c Xc):|GT - EMA预测| < τ
动态处理机制:
1) X i n c X_{inc} Xinc区域:以概率 p = 1 − ∣ X i n c ∣ ∣ X v a l i d ∣ p=1-\frac{|X_{inc}|}{|X_{valid}|} p=1−∣Xvalid∣∣Xinc∣随机保留,减少高难度区域对学习过程的干扰。
2) X c X_c Xc区域:通过随机权重α对GT和EMA预测进行线性插值:
D ˉ c ∗ = α ⋅ D ∗ + ( 1 − α ) ⋅ D ^ T ′ ( α ∼ U ( 0 , 1 ) ) \bar{D}^*_c = \alpha \cdot D^* + (1-\alpha) \cdot \hat{D}^{T'} \quad (\alpha \sim U(0,1)) Dˉc∗=α⋅D∗+(1−α)⋅D^T′(α∼U(0,1))
此外,将其限制输出与GT的偏差在±1像素内,且添加细粒度扰动防止细节过拟合。
F&E-PL(处理伪标签)
区域筛选:通过掩码 M ^ = ∣ D ^ T − D ^ T ′ ∣ < τ \hat{M} = \vert \hat{D}^{T} - \hat{D}^{T'} \vert < \tau M^=∣D^T−D^T′∣<τ 过滤不一致区域。
精度提升机制:在一致区域使用随机权重β集成两个Teacher的预测,渐进式提升PL质量:
D T = β ⋅ D ^ T + ( 1 − β ) ⋅ D ^ T ′ ( β ∼ U ( 0 , 1 ) ) D_T = \beta \cdot \hat{D}T + (1-\beta) \cdot \hat{D}{T'} \quad (\beta \sim U(0,1)) DT=β⋅D^T+(1−β)⋅D^T′(β∼U(0,1))
训练策略
最终损失函数结合改进后的GT和PL监督:
L = L d i s p ( D ^ , D ˉ ∗ , M ∗ ) + λ L d i s p ( D ^ , D ˉ T , M ^ ) L = L_{disp}(\hat{D}, \bar{D}^* ,M^*) + \lambda L_{disp}(\hat{D}, \bar{D}^T, \hat{M}) L=Ldisp(D^,Dˉ∗,M∗)+λLdisp(D^,DˉT,M^)
EMA重置机制:每5k步将EMA Teacher权重重置为当前Student, 使区域划分随学习进度动态更新。
实验