摘要
基于对ViT在监督学习领域的表现质疑,探究自监督方法下的ViT是否具有更好的特征提取能力,进而发现:
- 自监督ViT特征包含场景布局、对象边界。这些信息可以在最后一自注意力模块中直接访问。
- 自监督ViT特征结合最近邻分类器(k-NN)分类头中表现很好,无需任何微调、线性分类器或数据增强,在ImageNet上实现了78.3%的最高精度。
- 自监督下ViT可以比卷积网络获得更多有用的语义信息。
算法框架
SSL with Knowledge Distillation
input
定义学生网络 g θ s g_{\theta_s} gθs和教师网络 g θ t g_{\theta_t} gθt两个网络,参数表示为 θ s \theta_s θs、 θ t \theta_t θt,两者体系结构相同,但参数不同。在教师网络上应用停止梯度(sg)算子。具体如下:
给定输入图像 x x x,经过两种数据增强方式局部随机裁剪(local crop)和全局随机裁剪(global views)得到 ( x 1 , x 2 ) (x_1,x_2) (x1,x2), x 1 x_1 x1(local crop)传递给学生网络 g θ s g_{\theta_s} gθs, x 2 x_2 x2(global views)传递给教师网络 g θ t g_{\theta_t} gθt。用来鼓励局部到全局的通信。
数据增强
DINO 中最核心的数据采样策略便是图像裁剪,将裁剪后的图像分为两种:
- Local views: 即局部视角,也称为small crops,指抠图面积小于原始图像的 50%
- Global views: 即全局视角,也称为large crops,指抠图面积大于原始图像的 50%
DINO 中也采用一些其它的随机增强,包括:颜色扰动(color jittering) 、高斯模糊(Gaussian blur) 、曝光增强(solarization)。
网络架构
神经网络 g g g由backbone f f f和投影头 h h h组成,故有 g = h ◦ f g = h ◦ f g=h◦f。backbone f f f为ViT或者ResNet, 可应用于下游任务。投影头 h h h为3层的MLP,隐层维度为2048(遵循L2 normalization)和一个 K K K维归一化的全连接。与标准的深度网络不同,该系统完全没有BN(batch normalization)模块。
L = L ( W ) + λ ∑ i = 1 n w i 2 L=L(W)+\lambda\sum^n_{i=1}w^2_i L=L(W)+λi=1∑nwi2
上式右侧项为L2正则化。
loss和反向传播
教师网络和学生网络的最终输出 P s P_s Ps和 P t P_t Pt都为 K K K维概率分布。概率 P P P通过网络输出 g g gsoftmax归一化得到。 τ > 0 \tau >0 τ>0是一个控制输出分布锐度的temperature参数。
P s ( x ) ( i ) = exp ( g θ s ( x ) ( i ) / τ s ) ∑ k = 1 K exp ( g θ s ( x ) ( k ) / τ s ) P_s(x)^{(i)}=\frac {\exp(g_{\theta_s}(x)^{(i)}/\tau_s)} {\sum^K_{k=1}\exp(g_{\theta_s}(x)^{(k)}/\tau_s)} Ps(x)(i)=∑k=1Kexp(gθs(x)(k)/τs)exp(gθs(x)(i)/τs)
通过最小化交叉熵损失匹配教师网络 θ t \theta_t θt学生网络 θ s \theta_s θs的参数。其中 H ( a , b ) = − a log b H(a,b)=-a\log b H(a,b)=−alogb。
min θ s H ( P t ( x ) , P s ( x ) ) \min_{\theta_s}H(P_t(x),P_s(x)) θsminH(Pt(x),Ps(x))
具体,从给定的图像,生成一个集合 V V V,集合包含两个全局视图 x 1 g x_1^g x1g和 x 2 g x_2^g x2g以及几个分辨率较小的局部视图。学生网络接收所有视图,教师网络只接收全局视图,因此鼓励“局部到全局”的通信。两个网络得到最终输出后,计算交叉熵。
min θ s ∑ x ∈ { x 1 g , x 2 g } ∑ x ′ ∈ V , x ′ ≠ x H ( P t ( x ) , P s ( x ′ ) ) \min_{\theta_s}\sum_{x\in\{x^g_1,x^g_2\}}\sum_{x'\in V , x' \not = x}H(P_t(x),P_s(x')) θsminx∈{x1g,x2g}∑x′∈V,x′=x∑H(Pt(x),Ps(x′))
学生网络通过反向传播更新参数,教师网络不参与反向传播,通过对学生网络的参数进行EMA更新,教师网络的权重更新自学生网络。 λ λ λ在训练期间,遵循余弦学习率衰减策略从0.996到1之间变化。
θ t ← λ θ t + ( 1 − λ ) θ s \theta_t←\lambda\theta_t+(1-\lambda)\theta_s θt←λθt+(1−λ)θs
Centering and Sharpening
在自监督学习中,mode collapse是指网络的学习过程中出现了多样性减少的现象。当网络学习到一组特征表示时,会出现多个输入数据映射到相同特征的情况,即模式坍塌。这种现象通常由于网络在优化过程中陷入了局部最优解,只考虑到一部分数据的特征表示,忽略了其它数据的模式和特征,从而导致了多样性缺失的现象,因此会对模型的鲁棒性产生很大的负面影响,故而引入Centering和Sharpening。
教师网络输出 g θ s g_{\theta_s} gθs后,会接一个centering操作,可以看成教师网络的输出再加上一个偏置项 c c c。
g t ( x ) ← g t ( x ) + c g_t(x)←g_t(x)+c gt(x)←gt(x)+c
偏置 c c c随着教师网络参数的EMA更新而更新,更新策略如下,其中 m > 0 m > 0 m>0是一个速率参数, B B B是批量大小。
c ← m c + ( 1 − m ) 1 B ∑ i = 1 B g θ t ( x i ) c←mc+(1-m)\frac 1 B \sum^B_{i=1}g_{\theta_t}(x_i) c←mc+(1−m)B1i=1∑Bgθt(xi)
从上述公式看出偏置 c c c实则为一个均值项,centering操作的目的是使得激活值高于平均值时为正,低于平均值时为负。由于softmax函数在处理负数时会给出较小的概率值,而在处理正数时会给出较大的概率值,因此这种操作能够防止任何一个特征占据统治地位。
Sharpening即概率 P P P中的 τ \tau τ参数。这种技巧通过在softmax函数中加入一个temperature参数,让模型将概率分布更加尖锐化。由于小差异会被夸大,可防止所有激活值相同。这个技巧和中心化操作搭配使用,可以使得激活值不断变化,从而引导学生模型更好地了解哪些特征应该变得更大。
算法伪代码
模型版本
实验
对比实验
上图为在ImageNet上应用使用不同自监督算法应用线性头和k-NN头分类,DINO在ResNet-50上达到75.3%的最先进效果。在ViT架构时,在线性分类方面比BYOL、MoCov2和SwAV高出3.5+%,在k-NN评估方面高出7.9+%。
用DINO训练的8 × 8 patch的ViT在线性分类中达到了80.1%的top-1,用k-NN分类器达到了77.4%,比SCLRv2的参数少10倍,运行时间快1.4倍。
上图比较预训练vit、resnet作为监督模型和DINO算法骨干网络下在不同数据集检测的性能。可以看到DINO优于监督学习算法。
上图做复制检测实验。在Copydays数据集“strong”子集上计算复制检测的mAP性能。对比multigrain算法,DINO的优于该算法。
上图在平均区域相似度 J m J_m Jm和平均contour-based精度 F m F_m Fm两个指标下比较现有的监督、自监督方法在DAVIS 2017视频对象分割的性能。DINO的优于其余算法
上图DINIO在不同数据集中迁移学习后进行微调实验,在top-1准确率下DINO优于监督方法。
消融实验
上图为对DINO的消融实验,可以看到缺失动量编码器(行2)即教师网络的EMA参数更新,模型将无法训练;添加Sinkhorn-Knopp批量归一化方法(行3)对算法提升效果不大;缺少MC数据增强方法(行4),效果会下降;采用MSE损失(行5),模型效果会大幅下降;最优搭配为动量编码器与多组增强和交叉熵损失(行1)。
上图比较了使用不同patch大小(16 × 16, 8 × 8和5 × 5)训练的ViT-S模型的k-NN分类性能,所有模型都训练了300个epoch。可以看到随着patch的减小,性能大大提高,但同时吞吐量会随之下降。
上图为不同MC方法下,分别在100和300epochs中达到不同top1精度的训练时间,可以看到行1中没有使用MC方法性能达到72.5%耗时45.9小时,而行4中的设置仅24小时达到74.6%。这是+2%的改进,同时只需一半的时间,但内存使用更高(15.4G与9.3G)。观察到在行1中设置中进行更多的训练无法赶上MC方法带来的性能提升,这显示了“局部到全局”增强的价值。
上图为在不同batch size下分别进行100次没有MC方法的top1精度,证实了可以小批量地将模型训练到高性能。
可视化
上图为从DINO训练的ViT-S/8的最后一块中可视化自注意力模块。
用DINO表示的ImageNet类的t-SNE可视化。对于每个类别,通过在验证集中取该类的所有图像的平均特征来获得嵌入。
投影头实验
上图为DINO添加投影头与否的计算框架,并做下列实验:
上图为对投影头模块添加BN层的实验,可以看到添加BN后效果反而会变差,故DINO系统中不使用任何BN层。
上图对有或没有l2-归一化瓶颈的情况下训练的准确性,改变了投影头中线性层的数量对准确率的影响实验。没有l2-归一化瓶颈时,增加投影头的深度DINO训练失败。说明l2归一化瓶颈稳定了深度投影头的训练,且该种情况下增加投影头的深度可以提高精度。
上图为投影头输出输出维度K对准确率的影响实验,可以看到大的输出维度可以提高性能。
ViT中使用的激活是高斯误差线性单位(GELU)。故而为了架构内的一致性,在投影头中也选择使用GELU。上图中评估了使用ReLU代替GELU的效果,观察到将激活单元改为ReLU的影响相对较小。
避免坍塌实验
由于自监督学习容易陷入特征上图研究了Centering和Sharpening的互补作用,以避免坍塌。坍缩有两种形式,无论输入是什么,模型输出沿所有维度均匀或由一个维度支配。Centering避免了一个维度支配导致的坍塌,但鼓励了均匀的输出,削尖会产生相反的效果。通过将交叉熵 H H H分解为熵 h h h和KL散度 D K L D_{KL} DKL展示这种互补性:
H ( P t , P s ) = h ( P t ) + D K L ( P t ∣ P s ) H(P_t,P_s)=h(P_t)+D_{KL}(P_t|P_s) H(Pt,Ps)=h(Pt)+DKL(Pt∣Ps)
KL等于0时,意味着学生和教师网络的概率分布一致,表示坍塌。从上图右侧看出,缺失Centering和Sharpening任何一项,都会导致特征坍塌。
从上图左侧看,0表示没有中心化和 − l o g ( 1 / K ) −log(1/K) −log(1/K)表示没有锐化,表明这两种操作导致了不同的崩溃形式。应用这两种操作可以平衡这些效果。故而temperature参数 τ \tau τ的取值很关键。
上图对 τ \tau τ的消融实验观察到:
- 需要低于0.06的temperature才能避免坍塌。当temperature高于0.06时,训练损失始终收敛到 l n ( K ) ln(K) ln(K)。
- 如果从一个较小的值开始训练,并在第一个epoch期间增加,则使用比0.06更高的温度不会崩溃。实践中,在训练的第一个30阶段对 τ \tau τ使用从0.04到0.07的线性退火。
- τ → 0 τ → 0 τ→0(极端锐化)对应于argmax操作,将导致one-hot硬分布。
reference
Caron, M. , Touvron, H. , Misra, I. , Hervé Jégou, Mairal, J. , & Bojanowski, P. , et al. (2021). Emerging properties in self-supervised vision transformers.