[2021-CVPR] Jigsaw Clustering for Unsupervised Visual Representation Learning 论文简析及关键代码简析
论文:https://arxiv.org/abs/2104.00323
代码:https://github.com/dvlab-research/JigsawClustering
总结
本文提出了一种单批次(single-batch)的自监督任务pretext task Jigsaw Cluster,相比于双批次(dual-batches)的方法降低了计算量,同时利用了图像内的信息和图像间的信息。
本文提出的任务构造的主要流程如下如图1所示,首先在一整个batch内将 nnn 张图像每张分为 m×mm\times mm×m 份图块,则共有 n×m×mn\times m\times mn×m×m 个图块。再将这些图块打乱(注意是一个batch内所有的图块进行打乱,而非某单张图像内打乱)后,再拼接为图像。
本文设计的网络(如图2所示)在backbone提取特征之后有两个分支:聚类分支和定位分支。聚类分支会完成一个有监督聚类的任务,将来自同一张原图的不同图块(已被打乱)聚集到一簇(cluster,类)。作者使用了最近比较火的对比学习来完成这个有监督聚类任务。而对于定位分支,则是要预测出图块在原图中的位置,具体是由一个分类任务来完成,损失函数直接选用交叉熵损失。
算法细节如有重叠分块、插值加池化等可见下面的原文翻译。
源码简析
以下是源码中JigClu模型的关键几步操作,笔者在进行实验后将其中信号流的形状等信息注释在代码中,希望能够帮助大家理解,或者能够为想要复现并改进本文的读者提供一些参考。
@torch.no_grad()def _batch_gather_ddp(self, images): # images是长度为4的列表,其中每个元素是形状为 (n, 3, 112, 112)的tensor"""gather images from different gpus and shuffle between them*** Only support DistributedDataParallel (DDP) model. ***"""images_gather = []for i in range(4):batch_size_this = images[i].shape[0]images_gather.append(concat_all_gather(images[i]))batch_size_all = images_gather[i].shape[0]num_gpus = batch_size_all // batch_size_thisn,c,h,w = images_gather[0].shapepermute = torch.randperm(n*4).cuda()torch.distributed.broadcast(permute, src=0)images_gather = torch.cat(images_gather, dim=0)images_gather = images_gather[permute,:,:,:]col1 = torch.cat([images_gather[0:n], images_gather[n:2*n]], dim=3)col2 = torch.cat([images_gather[2*n:3*n], images_gather[3*n:]], dim=3)images_gather = torch.cat([col1, col2], dim=2)bs = images_gather.shape[0] // num_gpusgpu_idx = torch.distributed.get_rank()return images_gather[bs*gpu_idx:bs*(gpu_idx+1)], permute, ndef forward(self, images, progress):images_gather, permute, bs_all = self._batch_gather_ddp(images) # bs=16双卡, len(images) 4, images_gather.shape (8, 3, 224, 224), permute.shape 64(即16*4), bs_all = 16# compute featuresq = self.encoder(images_gather) # bs=16双卡, q.shape (8, 2048, 2, 2) q_gather = concat_all_gather(q) # bs=16双卡, q_gather.shape (16, 2048, 2, 2) # 插值后池化,得到这个形状n,c,h,w = q_gather.shapec1,c2 = q_gather.split([1,1],dim=2) # bs=16双卡, c.shape (16, 2048, 1, 2)f1,f2 = c1.split([1,1],dim=3) # bs=16双卡, f.shape (16, 2048, 1, 1)f3,f4 = c2.split([1,1],dim=3)q_gather = torch.cat([f1,f2,f3,f4],dim=0) # bs=16双卡, q_gather.shape (64, 2048, 1, 1)q_gather = q_gather.view(n*4,-1) # bs=16双卡, q_gather.shape (64, 2048)# clustering branchlabel_clu = permute % bs_all # permute: 0-(4*bs) 之间的随机值, 取余则label_clu: 4组 0-bs之间的随机值,即同一个值label_clu值是来自同一图片的q_clu = self.encoder.fc_clu(q_gather) # bs=16双卡,q_clu.shape (64, 128) 即(4*bs, dim)q_clu = nn.functional.normalize(q_clu, dim=1)loss_clu = self.criterion_clu(q_clu, label_clu)# location branchlabel_loc = torch.LongTensor([0]*bs_all+[1]*bs_all+[2]*bs_all+[3]*bs_all).cuda()label_loc = label_loc[permute]q_loc = self.encoder.fc_loc(q_gather)loss_loc = self.criterion_loc(q_loc, label_loc)return loss_clu, loss_loc
笔者使用双卡进行实验,batchsize设为16。
源码中一些gather操作是为了适应dp或者ddp训练,对理解算法本身没有影响。
以下是笔者对原文部分进行的翻译,一些算法细节和实现细节可以从中找到,配合源码注释基本可以理解全文的算法思想。有疑惑或者异议欢迎留言讨论。
原文部分翻译
abstract
使用对比学习的无监督表示学习取得了巨大的成功,该方法将每一训练批次复制来构建对比对,使每一训练批及其扩增版本同时进行前向传播,导致额外计算。本文提出了一种新的jigsaw聚类 pretext task,该任务只需要将每个训练批次本身进行前向传播,并降低训练损失。我们的方法同时利用了图像内的和图像间的信息,极大地超越了之前的基于单训练批次(single batch based)的方法。甚至得到了与使用对比训练的方法接近的结果,而相比之下本文方法只用了一半的训练批次。
我们的方法表明多批次训练是不必要的,并为未来的单批次无监督的研究打开了大门
introduction
无监督的视觉表示学习,或者说自监督学习,是一个存在已久的问题,试图在没有人类监督信号的情况下,得到一个通用特征提取器。这个目标可以通过精心设计不带有标注的pretext task来训练特征提取器来达成。
根据pretext task的定义,大多数主流的方法分两类:图像内(intra-image)的任务和图像间(inter-image)的任务。图像内的任务,包括colorization和jigsaw puzzle,设计一种一张图像的变换,并训练一个网络学习这种变换。由于每次只有训练批次本身需要前向传播计算,所以我们将这些方法称作单批次方法(single-batch methods)。这类任务只使用了一张图片的信息就可以完成,这限制了特征提取器的学习能力。
最近几年图像间任务迅猛发展,要求网络能够辨别不同的图像。对比学习现在很流行,因为它可以降低正对的特征表示之间的距离,并扩大负对的特征表示之间的距离。为了建构正对,训练过程需要使用经过不同的数据扩增的另一批次的数据。由于每个训练批次和它的扩增过的版本要同时进行前向传播,我们将这些方法称作双批次方法(dual-batches methods)。这种方法在训练过程中大大提升了对资源的需求,如何能够设计一种有效的基于单批次的方法,达到与双批次相仿的性能仍旧是个问题。
本文中,我们提出了一个使用Jigsaw聚类(Jig-Clu)来有效训练无监督模型的框架。该方法结合了拼图和对比学习的优点,利用图像内部和图像间的信息指导特征提取。它学习更全面的表达。
该方法在训练过程中只需要一个单批,但与其他单批方法相比,结果有很大提高。它甚至可以达到类似的结果与双批次方法,但相比只有一半的训练批次。
jigsaw clustring task
在本文提出的JigClu任务中,同一批次内的每张图片被分成不同的块,它们被随机打乱在被接在一起,来形成一个新的批次用作训练。目标就是将这个被打乱的恢复为原图,如图一所示。不同于以往的Jigsaw Puzzle任务,原图分成的块是在整个批次内被打乱的,而非在单张图像内。我们需要去预测的事每个块属于哪张图片和每个块在原图中的位置。
我们使用蒙太奇(montage)图像而非单个块作为网络的输入。这个改动大幅提升了任务的难度,并为网络提供了更多的有用的信息供学习。网络需要辨识出一张图像的不同部分,并识别出它们原来的位置从而从多蒙太奇(multiple montage)输入图像中恢复原图。
这个任务使得网络能够图像内和图像间的信息,只需要通过对拼接后的图像进行前向传播,与其他对比学习的任务相比只使用了一半的训练批次。
为了恢复来自交叉图像的图块,我们设计了一个聚类分支和一个定位分支。如图二所示,具体来说,我们首先将来自拼接图像的全局特征图解耦为每个图块的表示。然后这两个分支对每个图块的特征表示进行操作。聚类分支是将这些图块分为几簇,每个簇只包含来自同一张图像的图块。另一方面,定位分支,以图像不可知的方式(image agnostic manner)预测每个图块的位置。
有了这两个分支的预测结果,JigClu问题就得以解决。聚类分支作为一个有监督聚类任务进行训练,因为我们知道图块是否来自同一张图像。定位分支可以看作是一个分类任务,其中每个图块会被分配一个标签,以此来表示其在原图中的位置。定位分支预测所有图块的这个标签。
我们的方法得到了不错的结果,是因为我们提出的任务会使模型学习到不同种类的信息。一开始,从一张拼接的图像中辨识出不同的图块迫使模型去捕捉图像内不实例级别(instance-level)的信息。这一级别的特征在其他的对比学习方法中是丢失了的。
进一步,从多个输入图像中聚类到不同的图块有助于模型在图像中学习图像级别(image-level)的特征。这时最近的一些方法得到高质量结果的关键。我们的方法保持了这一重要属性。最后,将每个图块摆放到正确的位置又要求细节的定位信息,这时之前的单批次方法考虑到的。但是在最近的一些方法中被忽略了。我们认为这种信息对于进一步提升结果来说仍旧是重要的。
performance of our method
通过我们的方法进行学习,可以产生图像内的和图像间的信息。这样综合的学习可以带来一些优势(spectrum of superiority)。首先,我们的方法在训练阶段只有一个批次,在Imagenet-1k的线性评估阶段比其他单批次方法高了2.6%。 。。。
related work
handcrafted pretext tasks
训练无监督模型的pretext task的方法有很多种。 将破坏过的图像进行恢复是一个重要主题,有with tasks of descriminating synthetic artifacts [18], colorization [20, 43], image inpainting [31], and denoising auto-encoders [37], 等。另外,许多方法通过一些变换生成persuade labels(?)来训练网络。应用包括预测两个块的关系,解决jigsaw puzzle,还有识别被替代的类。[]是一个进阶版的jigsaw puzzle,利用更复杂的方法选择图块。视频信息在训练无监督模型时也很常用。
contrastive learning
我们的方法和对比学习也高度相关,首先由[]提出,根据[]可以得到更好的性能。最近[],使用不同的扩增方法构建对比对取得了巨大的成功。尤其是,[]在pixel水平上利用图像间和图像内的信息。我们注意到训练多批次图像的对比学习方法需要大量的训练资源。通过新颖的在单批次内设计对比对,我们的工作解决了这个问题。
jigsaw clustering
本章,我们会给出本文所提出的任务的定义。我们使用一个很简单的网络,只需要对原始的骨干网络进行一点点调整。最后,我们设计了一个新颖的损失函数来更好地适应我们的聚类任务。
the jigsaw clustering task
在一个批次 X\bf{X}X=x1,x2,…,xn=x_1,x_2,\dots,x_n=x1,x2,…,xn 内,有 nnn 个随机选择的图像。每张图像 xix_ixi 被分为 m×mm\times mm×m 个图块。共有 n×m×mn\times m\times mn×m×m 个图块。所有这些图块会被随机重新排列来形成一组有蒙太奇图像X′\bf{X'}X′=x1′,x2′,…,xn′=x'_1,x'_2,\dots,x'_n=x1′,x2′,…,xn′ 形成的新的批次。每张新图同样包含 m×mm\times mm×m 个图块,这些图块来自不同的原批次 X\bf{X}X 中的图像。
任务就是对新批次 X\bf{X}X 中的这 n×m×mn\times m\times mn×m×m 个图块进行聚类为 nnn 个簇,并且对同一簇的 $ m\times m$ 个图块预测位置来恢复出 nnn 张原图,整个过程见图1。
本文提出的任务的关键是使用蒙太奇图像作为输入而不是每单独一个图块。值得注意的是,直接使用小图块作为输入会导致solution只有全局信息。此外,小尺寸的输入图像在许多应用中并不常见。仅在此处使用它们会引发pretext task和其他下游任务之间的图像分辨率差异问题。这也可能导致性能下降。而简单地直接扩展小图块将极大地提升训练资源。
我们将蒙太奇图像作为输入完美地避免了这些问题。首先,来自一个批次的输入图像与原批次有着相同的尺寸,这和最近的方法相比只消耗了一半的资源。更重要的是,为了更好地完成本任务,网络需要学习细节的图像内的特征,来辨别一张图像中的不同图块,和全局的图像间的特征来将来自同一张原图的不同图块聚集在一起。我们观察到全面特征的学习大幅加速了特征提取了的训练。更多实验结果见下一节。
在本方法中,分图像的方法是很关键的。mmm 的选择影响到任务的难度。我们的在ImageNet子集上的消融实验显示 m=2m=2m=2 时得到最好的结果。我们推测 mmm 过大会呈指数级地增加复杂度,使得网络不能高效地学习。另外,我们观察到将图像切割为不连接的图块(disjoint pathches)并不是最优的。如图3所示,随着交叉点的延伸,网络学习到更好的特征。这时可以解释的,因为某些图像的不同区域过于多样化。如果没有任何重叠的迹象,它们会给学习带来困难。第5节会有更多解释。
network design
我们为本任务设计了一个新的解耦网络。首先是特征提取器,可以是任何网络[]。然后有一个无参数的解耦网络来将特征分为 m×mm\times mm×m 个部分,对应同一个输入图像的不同的块。然后用一个MLP来嵌入每个块的特征,用作聚类任务;一个全连接层用来做定位任务。
解耦模块首先将主干的特征映射插值为边长为 mmm 的倍数的新特征映射。我们是扩大特征图而非缩小从而避免信息丢失。举个例子,比如ImageNet,输入尺寸都是224x224.如果用ResNet-50作骨干网络,则提取到的特征是空间尺寸是 7x7的。如果 m=2m=2m=2 ,我们就将特征图用双线性插值搭配8x8。这样特征图的长度就是 mmm 的倍数,我们可以使用平均池化,来对特征图进行降采样到 n×m×m×c^n\times m\times m\times \hat{c}n×m×m×c^ 。这样,一个batch的就被分解为 (n×m×m)×c^(n\times m\times m)\times \hat{c}(n×m×m)×c^ ,即有 (n×m×m)(n\times m\times m)(n×m×m) 个维度为 c^\hat{c}c^ 的向量。
然后每个向量都经过两层MLP嵌入到长度为 ccc ,来形成一组向量 Z=z1,z2,…,znmm\mathbf{Z}=z_1,z_2,\dots,z_{nmm}Z=z1,z2,…,znmm 用作聚类任务。同时, (n×m×m)×c^(n\times m\times m)\times \hat{c}(n×m×m)×c^ 的向量还会被送到一个作为分类器的全连接层,产生logits L=l1,l2,…,lnmm\mathbf{L}=l_1,l_2,\dots,l_{nmm}L=l1,l2,…,lnmm,来完成定位任务。
我们的网络是相当高效的,这个额外的解耦模块是不需要参数的。与近期的工作相比,取一批的计算方法基本相同,训练时只需取一批。这大大降低了训练成本。
loss functions
聚类分支是一个有监督聚类任务,因为 m×mm\times mm×m 个块来自同一类。有监督聚类任务很方便,我们使用对比学习来实现。我们将聚类的目标是将来自同一类的物体(块)拉到一起,将来自不同类的图块推开。我们使用余弦相似度来测量块之间的距离。这样来自同一簇的每一对块,损失函数如下:
ℓi,j=−logexp(cos(zi,zj)/τ)∑k=1nmm1k≠iexp(cos(zi,zj)/τ)\ell_{i,j}=-log\frac{exp(cos(z_i,z_j)/\tau)}{\sum_{k=1}^{nmm}\mathbb{1}_{k\neq i}exp(cos(z_i,z_j)/\tau)} ℓi,j=−log∑k=1nmm1k=iexp(cos(zi,zj)/τ)exp(cos(zi,zj)/τ)
其中 1\mathbb{1}1 表示指示函数(indicator function),τ\tauτ 是温度系数,用来平滑或者加剧距离。最终的所有来自同一簇的图块对的损失函数可写作:
Lclu=1nmm∑i(1mm−1∑j∈Ciℓi,j)\mathcal{L}_{clu}=\frac{1}{nmm}\sum_i(\frac{1}{mm-1}\sum_{j\in C_i\ell_{i,j}}) Lclu=nmm1i∑(mm−11j∈Ciℓi,j∑)
其中 CiC_iCi 表示同一簇 iii 内的图块的索引 。
定位分支被视作是一个分类任务,损失函数是简单的交叉熵损失,写作:
Lloc=CrossEntropy(L,Lgt)\mathcal{L}_{loc}=CrossEntropy(\mathbf{L,L_{gt}}) Lloc=CrossEntropy(L,Lgt)
我们提出的Ji个C路的总体损失则为:
L=αLclu+βLloc\mathcal{L}=\alpha\mathcal{L}_{clu}+\beta\mathcal{L}_{loc} L=αLclu+βLloc
在我们的实验中,α=β=1\alpha=\beta=1α=β=1 即可得到好的结果。