[2021-CVPR] Jigsaw Clustering for Unsupervised Visual Representation Learning 论文简析及关键代码简析

[2021-CVPR] Jigsaw Clustering for Unsupervised Visual Representation Learning 论文简析及关键代码简析

论文:https://arxiv.org/abs/2104.00323

代码:https://github.com/dvlab-research/JigsawClustering

总结

本文提出了一种单批次(single-batch)的自监督任务pretext task Jigsaw Cluster,相比于双批次(dual-batches)的方法降低了计算量,同时利用了图像内的信息和图像间的信息。

本文提出的任务构造的主要流程如下如图1所示,首先在一整个batch内将 nnn 张图像每张分为 m×mm\times mm×m 份图块,则共有 n×m×mn\times m\times mn×m×m 个图块。再将这些图块打乱(注意是一个batch内所有的图块进行打乱,而非某单张图像内打乱)后,再拼接为图像。

本文设计的网络(如图2所示)在backbone提取特征之后有两个分支:聚类分支和定位分支。聚类分支会完成一个有监督聚类的任务,将来自同一张原图的不同图块(已被打乱)聚集到一簇(cluster,类)。作者使用了最近比较火的对比学习来完成这个有监督聚类任务。而对于定位分支,则是要预测出图块在原图中的位置,具体是由一个分类任务来完成,损失函数直接选用交叉熵损失。

算法细节如有重叠分块、插值加池化等可见下面的原文翻译。

源码简析

以下是源码中JigClu模型的关键几步操作,笔者在进行实验后将其中信号流的形状等信息注释在代码中,希望能够帮助大家理解,或者能够为想要复现并改进本文的读者提供一些参考。

    @torch.no_grad()def _batch_gather_ddp(self, images):        # images是长度为4的列表,其中每个元素是形状为 (n, 3, 112, 112)的tensor"""gather images from different gpus and shuffle between them*** Only support DistributedDataParallel (DDP) model. ***"""images_gather = []for i in range(4):batch_size_this = images[i].shape[0]images_gather.append(concat_all_gather(images[i]))batch_size_all = images_gather[i].shape[0]num_gpus = batch_size_all // batch_size_thisn,c,h,w = images_gather[0].shapepermute = torch.randperm(n*4).cuda()torch.distributed.broadcast(permute, src=0)images_gather = torch.cat(images_gather, dim=0)images_gather = images_gather[permute,:,:,:]col1 = torch.cat([images_gather[0:n], images_gather[n:2*n]], dim=3)col2 = torch.cat([images_gather[2*n:3*n], images_gather[3*n:]], dim=3)images_gather = torch.cat([col1, col2], dim=2)bs = images_gather.shape[0] // num_gpusgpu_idx = torch.distributed.get_rank()return images_gather[bs*gpu_idx:bs*(gpu_idx+1)], permute, ndef forward(self, images, progress):images_gather, permute, bs_all = self._batch_gather_ddp(images)     # bs=16双卡,  len(images) 4, images_gather.shape  (8, 3, 224, 224), permute.shape 64(即16*4), bs_all = 16# compute featuresq = self.encoder(images_gather)         #  bs=16双卡, q.shape  (8, 2048, 2, 2) q_gather = concat_all_gather(q)         #  bs=16双卡, q_gather.shape   (16, 2048, 2, 2)     # 插值后池化,得到这个形状n,c,h,w = q_gather.shapec1,c2 = q_gather.split([1,1],dim=2)         # bs=16双卡, c.shape (16, 2048, 1, 2)f1,f2 = c1.split([1,1],dim=3)               # bs=16双卡, f.shape (16, 2048, 1, 1)f3,f4 = c2.split([1,1],dim=3)q_gather = torch.cat([f1,f2,f3,f4],dim=0)   # bs=16双卡, q_gather.shape (64, 2048, 1, 1)q_gather = q_gather.view(n*4,-1)            # bs=16双卡, q_gather.shape (64, 2048)# clustering branchlabel_clu = permute % bs_all            # permute: 0-(4*bs) 之间的随机值, 取余则label_clu: 4组 0-bs之间的随机值,即同一个值label_clu值是来自同一图片的q_clu = self.encoder.fc_clu(q_gather)       # bs=16双卡,q_clu.shape (64, 128) 即(4*bs, dim)q_clu = nn.functional.normalize(q_clu, dim=1)loss_clu = self.criterion_clu(q_clu, label_clu)# location branchlabel_loc = torch.LongTensor([0]*bs_all+[1]*bs_all+[2]*bs_all+[3]*bs_all).cuda()label_loc = label_loc[permute]q_loc = self.encoder.fc_loc(q_gather)loss_loc = self.criterion_loc(q_loc, label_loc)return loss_clu, loss_loc

笔者使用双卡进行实验,batchsize设为16。

源码中一些gather操作是为了适应dp或者ddp训练,对理解算法本身没有影响。

以下是笔者对原文部分进行的翻译,一些算法细节和实现细节可以从中找到,配合源码注释基本可以理解全文的算法思想。有疑惑或者异议欢迎留言讨论。

原文部分翻译

abstract

使用对比学习的无监督表示学习取得了巨大的成功,该方法将每一训练批次复制来构建对比对,使每一训练批及其扩增版本同时进行前向传播,导致额外计算。本文提出了一种新的jigsaw聚类 pretext task,该任务只需要将每个训练批次本身进行前向传播,并降低训练损失。我们的方法同时利用了图像内的和图像间的信息,极大地超越了之前的基于单训练批次(single batch based)的方法。甚至得到了与使用对比训练的方法接近的结果,而相比之下本文方法只用了一半的训练批次。
我们的方法表明多批次训练是不必要的,并为未来的单批次无监督的研究打开了大门

在这里插入图片描述

introduction

无监督的视觉表示学习,或者说自监督学习,是一个存在已久的问题,试图在没有人类监督信号的情况下,得到一个通用特征提取器。这个目标可以通过精心设计不带有标注的pretext task来训练特征提取器来达成。

根据pretext task的定义,大多数主流的方法分两类:图像内(intra-image)的任务和图像间(inter-image)的任务。图像内的任务,包括colorization和jigsaw puzzle,设计一种一张图像的变换,并训练一个网络学习这种变换。由于每次只有训练批次本身需要前向传播计算,所以我们将这些方法称作单批次方法(single-batch methods)。这类任务只使用了一张图片的信息就可以完成,这限制了特征提取器的学习能力。

最近几年图像间任务迅猛发展,要求网络能够辨别不同的图像。对比学习现在很流行,因为它可以降低正对的特征表示之间的距离,并扩大负对的特征表示之间的距离。为了建构正对,训练过程需要使用经过不同的数据扩增的另一批次的数据。由于每个训练批次和它的扩增过的版本要同时进行前向传播,我们将这些方法称作双批次方法(dual-batches methods)。这种方法在训练过程中大大提升了对资源的需求,如何能够设计一种有效的基于单批次的方法,达到与双批次相仿的性能仍旧是个问题。

本文中,我们提出了一个使用Jigsaw聚类(Jig-Clu)来有效训练无监督模型的框架。该方法结合了拼图和对比学习的优点,利用图像内部和图像间的信息指导特征提取。它学习更全面的表达。

该方法在训练过程中只需要一个单批,但与其他单批方法相比,结果有很大提高。它甚至可以达到类似的结果与双批次方法,但相比只有一半的训练批次。

jigsaw clustring task

在本文提出的JigClu任务中,同一批次内的每张图片被分成不同的块,它们被随机打乱在被接在一起,来形成一个新的批次用作训练。目标就是将这个被打乱的恢复为原图,如图一所示。不同于以往的Jigsaw Puzzle任务,原图分成的块是在整个批次内被打乱的,而非在单张图像内。我们需要去预测的事每个块属于哪张图片和每个块在原图中的位置

我们使用蒙太奇(montage)图像而非单个块作为网络的输入。这个改动大幅提升了任务的难度,并为网络提供了更多的有用的信息供学习。网络需要辨识出一张图像的不同部分,并识别出它们原来的位置从而从多蒙太奇(multiple montage)输入图像中恢复原图。

这个任务使得网络能够图像内和图像间的信息,只需要通过对拼接后的图像进行前向传播,与其他对比学习的任务相比只使用了一半的训练批次。

为了恢复来自交叉图像的图块,我们设计了一个聚类分支和一个定位分支。如图二所示,具体来说,我们首先将来自拼接图像的全局特征图解耦为每个图块的表示。然后这两个分支对每个图块的特征表示进行操作。聚类分支是将这些图块分为几簇,每个簇只包含来自同一张图像的图块。另一方面,定位分支,以图像不可知的方式(image agnostic manner)预测每个图块的位置。

有了这两个分支的预测结果,JigClu问题就得以解决。聚类分支作为一个有监督聚类任务进行训练,因为我们知道图块是否来自同一张图像。定位分支可以看作是一个分类任务,其中每个图块会被分配一个标签,以此来表示其在原图中的位置。定位分支预测所有图块的这个标签。

我们的方法得到了不错的结果,是因为我们提出的任务会使模型学习到不同种类的信息。一开始,从一张拼接的图像中辨识出不同的图块迫使模型去捕捉图像内不实例级别(instance-level)的信息。这一级别的特征在其他的对比学习方法中是丢失了的。

进一步,从多个输入图像中聚类到不同的图块有助于模型在图像中学习图像级别(image-level)的特征。这时最近的一些方法得到高质量结果的关键。我们的方法保持了这一重要属性。最后,将每个图块摆放到正确的位置又要求细节的定位信息,这时之前的单批次方法考虑到的。但是在最近的一些方法中被忽略了。我们认为这种信息对于进一步提升结果来说仍旧是重要的。

performance of our method

通过我们的方法进行学习,可以产生图像内的和图像间的信息。这样综合的学习可以带来一些优势(spectrum of superiority)。首先,我们的方法在训练阶段只有一个批次,在Imagenet-1k的线性评估阶段比其他单批次方法高了2.6%。 。。。

related work

handcrafted pretext tasks

训练无监督模型的pretext task的方法有很多种。 将破坏过的图像进行恢复是一个重要主题,有with tasks of descriminating synthetic artifacts [18], colorization [20, 43], image inpainting [31], and denoising auto-encoders [37], 等。另外,许多方法通过一些变换生成persuade labels(?)来训练网络。应用包括预测两个块的关系,解决jigsaw puzzle,还有识别被替代的类。[]是一个进阶版的jigsaw puzzle,利用更复杂的方法选择图块。视频信息在训练无监督模型时也很常用。

contrastive learning

我们的方法和对比学习也高度相关,首先由[]提出,根据[]可以得到更好的性能。最近[],使用不同的扩增方法构建对比对取得了巨大的成功。尤其是,[]在pixel水平上利用图像间和图像内的信息。我们注意到训练多批次图像的对比学习方法需要大量的训练资源。通过新颖的在单批次内设计对比对,我们的工作解决了这个问题。

jigsaw clustering

本章,我们会给出本文所提出的任务的定义。我们使用一个很简单的网络,只需要对原始的骨干网络进行一点点调整。最后,我们设计了一个新颖的损失函数来更好地适应我们的聚类任务。

在这里插入图片描述

the jigsaw clustering task

在一个批次 X\bf{X}X=x1,x2,…,xn=x_1,x_2,\dots,x_n=x1,x2,,xn 内,有 nnn 个随机选择的图像。每张图像 xix_ixi 被分为 m×mm\times mm×m 个图块。共有 n×m×mn\times m\times mn×m×m 个图块。所有这些图块会被随机重新排列来形成一组有蒙太奇图像X′\bf{X'}X=x1′,x2′,…,xn′=x'_1,x'_2,\dots,x'_n=x1,x2,,xn 形成的新的批次。每张新图同样包含 m×mm\times mm×m 个图块,这些图块来自不同的原批次 X\bf{X}X 中的图像。

任务就是对新批次 X\bf{X}X 中的这 n×m×mn\times m\times mn×m×m 个图块进行聚类为 nnn 个簇,并且对同一簇的 $ m\times m$ 个图块预测位置来恢复出 nnn 张原图,整个过程见图1。

本文提出的任务的关键是使用蒙太奇图像作为输入而不是每单独一个图块。值得注意的是,直接使用小图块作为输入会导致solution只有全局信息。此外,小尺寸的输入图像在许多应用中并不常见。仅在此处使用它们会引发pretext task和其他下游任务之间的图像分辨率差异问题。这也可能导致性能下降。而简单地直接扩展小图块将极大地提升训练资源。

我们将蒙太奇图像作为输入完美地避免了这些问题。首先,来自一个批次的输入图像与原批次有着相同的尺寸,这和最近的方法相比只消耗了一半的资源。更重要的是,为了更好地完成本任务,网络需要学习细节的图像内的特征,来辨别一张图像中的不同图块,和全局的图像间的特征来将来自同一张原图的不同图块聚集在一起。我们观察到全面特征的学习大幅加速了特征提取了的训练。更多实验结果见下一节。

在本方法中,分图像的方法是很关键的。mmm 的选择影响到任务的难度。我们的在ImageNet子集上的消融实验显示 m=2m=2m=2 时得到最好的结果。我们推测 mmm 过大会呈指数级地增加复杂度,使得网络不能高效地学习。另外,我们观察到将图像切割为不连接的图块(disjoint pathches)并不是最优的。如图3所示,随着交叉点的延伸,网络学习到更好的特征。这时可以解释的,因为某些图像的不同区域过于多样化。如果没有任何重叠的迹象,它们会给学习带来困难。第5节会有更多解释。

network design

我们为本任务设计了一个新的解耦网络。首先是特征提取器,可以是任何网络[]。然后有一个无参数的解耦网络来将特征分为 m×mm\times mm×m 个部分,对应同一个输入图像的不同的块。然后用一个MLP来嵌入每个块的特征,用作聚类任务;一个全连接层用来做定位任务。

解耦模块首先将主干的特征映射插值为边长为 mmm 的倍数的新特征映射。我们是扩大特征图而非缩小从而避免信息丢失。举个例子,比如ImageNet,输入尺寸都是224x224.如果用ResNet-50作骨干网络,则提取到的特征是空间尺寸是 7x7的。如果 m=2m=2m=2 ,我们就将特征图用双线性插值搭配8x8。这样特征图的长度就是 mmm 的倍数,我们可以使用平均池化,来对特征图进行降采样到 n×m×m×c^n\times m\times m\times \hat{c}n×m×m×c^ 。这样,一个batch的就被分解为 (n×m×m)×c^(n\times m\times m)\times \hat{c}(n×m×m)×c^ ,即有 (n×m×m)(n\times m\times m)(n×m×m) 个维度为 c^\hat{c}c^ 的向量。

然后每个向量都经过两层MLP嵌入到长度为 ccc ,来形成一组向量 Z=z1,z2,…,znmm\mathbf{Z}=z_1,z_2,\dots,z_{nmm}Z=z1,z2,,znmm 用作聚类任务。同时, (n×m×m)×c^(n\times m\times m)\times \hat{c}(n×m×m)×c^ 的向量还会被送到一个作为分类器的全连接层,产生logits L=l1,l2,…,lnmm\mathbf{L}=l_1,l_2,\dots,l_{nmm}L=l1,l2,,lnmm,来完成定位任务。

我们的网络是相当高效的,这个额外的解耦模块是不需要参数的。与近期的工作相比,取一批的计算方法基本相同,训练时只需取一批。这大大降低了训练成本。

loss functions

聚类分支是一个有监督聚类任务,因为 m×mm\times mm×m 个块来自同一类。有监督聚类任务很方便,我们使用对比学习来实现。我们将聚类的目标是将来自同一类的物体(块)拉到一起,将来自不同类的图块推开。我们使用余弦相似度来测量块之间的距离。这样来自同一簇的每一对块,损失函数如下:
ℓi,j=−logexp(cos(zi,zj)/τ)∑k=1nmm1k≠iexp(cos(zi,zj)/τ)\ell_{i,j}=-log\frac{exp(cos(z_i,z_j)/\tau)}{\sum_{k=1}^{nmm}\mathbb{1}_{k\neq i}exp(cos(z_i,z_j)/\tau)} i,j=logk=1nmm1k=iexp(cos(zi,zj)/τ)exp(cos(zi,zj)/τ)
其中 1\mathbb{1}1 表示指示函数(indicator function),τ\tauτ 是温度系数,用来平滑或者加剧距离。最终的所有来自同一簇的图块对的损失函数可写作:
Lclu=1nmm∑i(1mm−1∑j∈Ciℓi,j)\mathcal{L}_{clu}=\frac{1}{nmm}\sum_i(\frac{1}{mm-1}\sum_{j\in C_i\ell_{i,j}}) Lclu=nmm1i(mm11jCii,j)
其中 CiC_iCi 表示同一簇 iii 内的图块的索引 。

定位分支被视作是一个分类任务,损失函数是简单的交叉熵损失,写作:
Lloc=CrossEntropy(L,Lgt)\mathcal{L}_{loc}=CrossEntropy(\mathbf{L,L_{gt}}) Lloc=CrossEntropy(L,Lgt)
我们提出的Ji个C路的总体损失则为:
L=αLclu+βLloc\mathcal{L}=\alpha\mathcal{L}_{clu}+\beta\mathcal{L}_{loc} L=αLclu+βLloc
在我们的实验中,α=β=1\alpha=\beta=1α=β=1 即可得到好的结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532852.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

matlab legend 分块,matlab legend 分块!

matlab legend 分块!(2013-03-26 18:07:38)%%%压差clc;clear all;figure(55);set (gcf,Position,[116 123 275 210],color,w);P[25 26 27 28 29 30 31 32 33 34 35];%理论q0.00006*pi*28*P*10^(6)*0.03^3/(12*0.028448*5);q1110.00006*pi*28*P*10^(6)*0.03^3/(12*0.…

利用opencv-python绘制多边形框或(半透明)区域填充(可用于分割任务mask可视化)

利用opencv-python绘制多边形框或(半透明)区域填充(可用于分割任务mask可视化) 本文主要就少opencv中两个函数polylines和fillPoly分别用于绘制多边形框或区域填充,并会会以常见用途分割任务mask(还是笔者…

Positional Encodings in ViTs 近期各视觉Transformer中的位置编码方法总结及代码解析 1

Positional Encodings in ViTs 近期各视觉Transformer中的位置编码方法总结及代码解析 最近CV领域的Vision Transformer将在NLP领域的Transormer结果借鉴过来,屠杀了各大CV榜单。对其做各种改进的顶会论文也是层出不穷,本文将聚焦于各种最新的视觉trans…

mysql 分析查询语句,MySQL教程之SQL语句分析查询优化

怎么获取有功能问题的SQL1、经过用户反应获取存在功能问题的SQL2、经过慢查询日志获取功能问题的SQL3、实时获取存在功能问题的SQL运用慢查询日志获取有功能问题的SQL首要介绍下慢查询相关的参数1、slow_query_log 发动定制记载慢查询日志设置的办法,能够经过MySQL指…

树莓派摄像头基础配置及测试

树莓派摄像头基础配置 step 1 硬件连接 硬件连接,注意不要接反了,排线蓝色一段朝向网口的方向。(笔者的设备是树莓派4B) step 2 安装raspi-config 安装 raspi-config raspi-config在raspbian中是预装的,而在kali、…

使用百度云智能SDK和树莓派搭建简易的人脸识别系统 Python语言版

硬件 树莓派4B一个CSI摄像头一个 笔者使用的是树莓派4B和CSI摄像头,但是树莓派3和USB摄像头等相似设备均可。 百度云智能设置 Step 1 登录 百度云智能 网址https://cloud.baidu.com/ 首先登录百度账号,与百度云、百度贴吧等互通,可直接…

xp搭建 php环境,windows xp 下 LAMP环境搭建

1. apache安装步骤如下图在浏览器中输入:localhost,出现下面页面说明已成功安装apache。2. mysql安装如下图显示在运行里面输入cmd ,然后连接测试mysql ,如图所示:3. php安装(1)将php压缩包解压到安装路径中的php目录…

C++中的虚函数(表)实现机制以及用C语言对其进行的模拟实现

C中的虚函数(表)实现机制以及用C语言对其进行的模拟实现 声明:本文非博主原创,转自https://blog.twofei.com/496/,博主读后受益良多,特地转载,一是希望好文能有更多人看到,二是为了日后自己查阅。 前言 …

C++中数组和指针的关系(区别)详解

C中数组和指针的关系(区别)详解 本文转自:http://c.biancheng.net/view/1472.html 博主在阅读后将文中几个知识点提出来放在前面: 没有方括号和下标的数组名称实际上代表数组的起始地址,这意味着数组名称实际上就是…

安装php独立环境,0507-php独立环境的安装与配置 Web程序 - 贪吃蛇学院-专业IT技术平台...

1.在一个纯英文目录下新建三个文件夹2.安装apache(选择好版本)过程中该填的按格式填好,其余的只更改安装目录即可如果报错1901是安装版本的问题。检查:安装完成后localhost打开为It works!添加到电脑属性环境变量:3.将php文件解压文档放到AMP…

linux中PATH变量-详细介绍

转自:https://blog.csdn.net/haozhepeng/article/details/100584451 转载者勘误 原文最后提到的 echo 命令对于环境变量的修改无影响。这是肯定的,echo 命令相当于只是一个打印的函数(比如 Python 中的 print)。这里要修改环境变…

php assert eval,代码执行函数之一句话木马

前言大家好,我是阿里斯,一名IT行业小白。非常抱歉,昨天的内容出现瑕疵比较多,今天重新整理后再次发出,修改并添加了细节,另增加了常见的命令执行函数如果哪里不足,还请各位表哥指出。eval和asse…

显卡、显卡驱动、CUDA、CUDA Toolkit、cuDNN 梳理

显卡、显卡驱动、CUDA、CUDA Toolkit、cuDNN 梳理 转自:https://www.cnblogs.com/marsggbo/p/11838823.html#nvccnvidia-smi GPU型号含义 显卡: 简单理解这个就是我们前面说的GPU,尤其指NVIDIA公司生产的GPU系列,因为后面介绍的…

VS Code的Error: Running the contributed command: ‘_workbench.downloadResource‘ failed解决

VS Code的Error: Running the contributed command: _workbench.downloadResource failed解决 转自:https://blog.csdn.net/ibless/article/details/118610776 1 问题描述 此前,本人参考网上教程在VS Code中配置了“Remote SSH”插件(比如这…

Oracle闪回报错,oracle 闪回区满了,ORA-19815

oracle 闪回区满了,查看日志报错:ORA-19815,命令行输入:sqlplus / as sysdbastartup mount //如果你的数据库出现了无法连接的情况时,可以加上这句select file_type, percent_space_used as used,percent_space_rec…

[2021-ICCV] MUSIQ Multi-scale Image Quality Transformer 论文简析

[2021-ICCV] MUSIQ: Multi-scale Image Quality Transformer 论文简析 论文:https://arxiv.org/abs/2108.05997 代码:https://github.com/google-research/google-research/tree/master/musiq 概述 当前SOTA的IQA(图像质量评估&#xff0…

安装oracle不动了,windows2008安装ORACLE到2%不动的问题 | 信春哥,系统稳,闭眼上线不回滚!...

最近又有网友遇到在windows2008服务器上安装ORACLE软件时到2%就卡住不动的问题,下面是该网友的描述:oralce 11g r2 windows server 2008 R2安装到最后一步复制数据文件时卡到2% 不走了内存一直飙升求解决这个问题前段时间也有人遇到过,但是他…

手把手教你入门Git --- Git使用指南(Linux)

手把手教你入门Git — Git使用指南(Linux) 系统:ubuntu 18.04 LTS 本文所有git命令操作实验具有连续性,git小白完全可以从头到尾跟着本文所有给出的命令走一遍,就会对git有一个初步的了解,应当能做到会用并…

php数据关系图,如何利用navicat查看数据表的ER关系图

文章背景:(相关推荐:navicat)由于工作需要,现在要分析一个数据库,然后查看各个表之间的关系,所以需要查看表与表之间的关系图,专业术语叫做ER关系图。默认情况下,Navicat显示的界面是这样的&…

Linux中g++与gcc的区别

转自:https://blog.csdn.net/bit_clearoff/article/details/53965514 Windows中我们常用vs来编译编写好的C和C代码;vs把编辑器,编译器和调试器等工具都集成在这一款工具中,在Linux下我们能用什么工具来编译所编写好的代码呢&#…