MoCo v1(CVPR 2020)原理与代码解读

paper:Momentum Contrast for Unsupervised Visual Representation Learning

official implementation:https://github.com/facebookresearch/moco

背景

最近的一些研究提出使用对比损失相关的方法进行无监督视觉表征学习并取得了不错的结果。尽管是受到不同motivation的启发,这些方法都可以看做是在构建一个动态字典。字典中的"keys"(tokens)从数据(图片或图片的patch)中采样并用一个编码器encoder网络来表示。无监督学习训练encoder来执行字典查找:一个encoded "query"应该与它匹配的key相似,而与其它的key不同。学习过程表述为最小化对比损失的过程。

存在的问题

从构建动态字典的角度来看,作者假设构建的字典应该具备两个特点:

  1. large即字典要足够大
  2. 在训练期间字典要保持一致性

从直觉上来说,一个更大的字典可以更好地对连续的、高维的视觉空间进行采样。而字典中的键应该由相同或相似的编码器表示,以便它们与query的比较是一致的。然而,一些使用对比损失的现有方法受限于这两个方面中的一个(具体将在后续的方法介绍中讨论)。

本文的创新点

本文提出了动量对比(Momentum Contrast,MoCo)作为一种构建大型和一致的字典的方法,用于对比损失的无监督学习,如图1所示。

作者维护了一个数据样本的队列作为字典,当前mini-batch的encoded representation进队,队列中最老的表示出队。队列将字典大小和batch size进行解耦从而使得字典可以非常大。此外由于字典的key来源于之前若干个mini-batch,作者提出了一个缓慢变化的key encoder,具体实现为query encoder的基于动量的移动平均值,从而保持一致性。 

无监督学习的一个主要目的是得到一个预训练表示,通过微调可以tranfer到下游任务中。作者通过实验表明,在7个与检测和分割相关的下游任务中,MoCo无监督预训练可以超过ImageNet有监督预训练。

方法介绍

Contrastive Learning as Dictionary Look-up

对比学习可以用来为字典查找任务训练一个编码器。对于一个encoded query \(q\) 和一组encoded样本 \(\{k_0,k_1,k_2,...\}\),后者是字典的keys。假设字典中有一个单独的key(表示为 \(k_+\))与 \(q\) 匹配,对比损失作为一个函数,当 \(q\) 和的positive key \(k_+\) 相似并与所有其它的key(被认为是 \(q\) 的negative keys)不相似时对比损失的值很小。用点积来表示相似性,本文采用了对比损失的一种形式,InfoNCE,如下

其中 \(\tau\) 是是温度超参,结果对一个正样本和 \(K\) 个负样本求和。从直觉上来说,这个损失是一个 \((K+1)\) 类基于softmax分类器的log损失,这个分类器试图将 \(q\) 分为 \(k_+\) 类。对比损失还有其它形式,比如margin-based loss和NCE loss的一些变种。

对比损失作为无监督的目标函数用来训练encoder network来表示queries和keys。一般来说,query representation是 \(q=f_q(x_q)\) 其中 \(f_q\) 是encoder网络,\(x_q\) 是一个query样本(同样,\(k=f_k(x_k)\))。初始化取决于具体的代理任务,输入 \(x_q\) 和 \(x_k\) 可以是图像、patches、或包含一组patches的context。网络 \(f_q\) 和 \(f_k\) 可以是相同的、部分共享的、或不同的。

Momentum Contrast

Dictionary as a queue

本文方法的核心是维护一个数据样本的队列作为字典,这使得我们可以重用前面mini-batch中的encoded keys,队列的引入将字典大小与batch大小进行了解耦,我们的字典可以比普通的batch size大得多,并且可以灵活独立的作为一个超参来设置。

字典中的样本被逐步替换掉,当前mini-batch进入队列,而队列中最老的mini-batch被删除。字典总是表示所有数据的一个采样子集,而维护字典的额外计算是可控的。此外删除最早的mini-batch也是有好处的,因为它的encoded keys是最老的,与最新的编码最不一致。

Momentum update

使用队列可以使字典更大,但也使得通过反向传播更新key encoder变得困难(梯度应该传播到队列中的所有样本)。一个天真的解决方法是忽略key encoder \(f_k\) 的梯度直接拷贝query encoder \(f_q\),但这种解决方案在实验中得到的结果很差,作者推测这是由于快速变化的encoder减少了key representation的一致性导致的。因此提出了动量更新来解决这个问题。

我们将 \(f_k\) 的参数表示为 \(\theta_k\),\(f_q\) 的参数表示为 \(\theta_q\),然后通过下式更新 \(\theta_k\)

其中 \(m\in[0,1)\) 是动量系数,只有参数 \(\theta_q\) 通过反向传播更新,式(2)中的动量更新使得 \(\theta_k\) 比 \(\theta_q\) 的更新更平滑。因此,尽管队列中的keys是通过不同的encoder编码的(不同的mini-batch),这些encoder之间的差异非常小。后续实验表明,一个更大的动量(例如 \(m=0.999\))比更小的动量(例如 \(m=0.9\))表现得更好,表明一个缓慢更新的key encoder是使用队列的核心。

Relations to previous mechanisms

MoCo是使用对比损失的一种机制,作者将其与其它两种机制进行了对比,如图2所示,它们在字典大小和一致性上表现出不同的属性。

图2(a)是通过反向传播进行end-to-end更新的一种机制,它使用当前mini-batch中的样本作为字典,因此key的编码是一致的(通过相同的一组encoder参数)。但是字典的大小和mini-batch的大小耦合,受限于GPU的内存。同时也受到大mini-batch优化问题的挑战。

另外一种机制是采用memory bank,如图2(b)所示。memory back包含了数据集中所有样本的representation,每个mini-batch的字典是从memory bank中随机采样得到的,且没有反向传播,因此字典的size可以很大。但是,memory bank中一个样本的表示在它最后一次被看到时就更新了,因此采样的keys是过去一个epoch中不同step的encoder得到的,从而缺乏了一致性。

Pretext Task

对比学习可以使用不同的代理任务,由于本文的重点不是设计一个新的代理任务,本文遵循instance discrimination任务使用了一个简单的代理任务。如果一个query和一个key来源于同一张图像,则将它们视为positive pair,否则视为negative pair。我们对同一张图像进行两次随机数据增强得到一个postive pair,queries和keys分别由各自的encoder \(f_q\) 和 \(f_k\) 进行编码,encoder可以是任何的卷积网络。

MoCo的伪代码如下所示,对当前的mini-batch,我们对postive pair分别进行编码得到queries和对应的keys,负样本来源于队列。

Shuffling BN

编码器 \(f_q\) 和 \(f_k\) 中都使用了BN,作者在实验中发现使用BN会阻止模型学习好的表示,模型似乎“欺骗”了代理任务并很容易地找到了一种low-loss的解决方法。这可能是样本之间的batch内的通信(BN引起的)泄露了信息。

作者通过shuffle BN来解决这个问题。具体训练是在多个GPU上进行的,每个GPU独立的对样本执行BN。对于key encoder \(f_k\),在将当前mini-batch分配到不同GPU之前打乱样本顺序(并在编码之后还原顺序),query encoder \(f_q\) 不进行打乱顺序。这保证了用于计算query和对应的positve key的统计信息来自于不同的子集,有效解决了欺骗问题。

代码解析

下面是官方实现,基本上和文章中的伪代码一致,没有什么难以理解的地方。其中encoder_k的参数更新顺序和伪代码不一样,伪代码是f_q和f_k分别forward,然后f_q的loss反向传播,更新f_q的参数,最后f_k进行动量更新。而代码中是f_q先forward,然后f_k更新参数,接着f_k进行forward,最后再根据反向传播更新f_q。

另外,这里包含了MoCo v2的代码,主要的区别就是v2借鉴SimCLR的做法,在encoder的avg pooling层后多加了一层projection layer,即一个MLP。

# Copyright (c) Meta Platforms, Inc. and affiliates.# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.import torch
import torch.nn as nnclass MoCo(nn.Module):"""Build a MoCo model with: a query encoder, a key encoder, and a queuehttps://arxiv.org/abs/1911.05722"""def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False):"""dim: feature dimension (default: 128)K: queue size; number of negative keys (default: 65536)m: moco momentum of updating key encoder (default: 0.999)T: softmax temperature (default: 0.07)"""super(MoCo, self).__init__()self.K = Kself.m = mself.T = T# create the encoders# num_classes is the output fc dimensionself.encoder_q = base_encoder(num_classes=dim)self.encoder_k = base_encoder(num_classes=dim)if mlp:  # hack: brute-force replacementdim_mlp = self.encoder_q.fc.weight.shape[1]  # 2048self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc)self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc)for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):param_k.data.copy_(param_q.data)  # initializeparam_k.requires_grad = False  # not update by gradient# create the queueself.register_buffer("queue", torch.randn(dim, K))# 将张量或缓冲区注册为 nn.Module 的一部分,但不会被视为模型的可学习参数。# 通常情况下,这用于存储模型中的固定参数或状态,例如均值、方差等,这些参数在训练过程中不会被更新。self.queue = nn.functional.normalize(self.queue, dim=0)self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))@torch.no_grad()def _momentum_update_key_encoder(self):"""Momentum update of the key encoder"""for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)@torch.no_grad()def _dequeue_and_enqueue(self, keys):# gather keys before updating queuekeys = concat_all_gather(keys)batch_size = keys.shape[0]ptr = int(self.queue_ptr)assert self.K % batch_size == 0  # for simplicity# replace the keys at ptr (dequeue and enqueue)self.queue[:, ptr: ptr + batch_size] = keys.Tptr = (ptr + batch_size) % self.K  # move pointerself.queue_ptr[0] = ptr@torch.no_grad()def _batch_shuffle_ddp(self, x):"""Batch shuffle, for making use of BatchNorm.*** Only support DistributedDataParallel (DDP) model. ***"""# gather from all gpusbatch_size_this = x.shape[0]x_gather = concat_all_gather(x)batch_size_all = x_gather.shape[0]num_gpus = batch_size_all // batch_size_this# random shuffle indexidx_shuffle = torch.randperm(batch_size_all).cuda()# 打乱索引顺序,比如batch_size_all=8, idx_shuffle=[1,3,5,2,0,4,7,6]# broadcast to all gpustorch.distributed.broadcast(idx_shuffle, src=0)# 将生成的随机索引序列从GPU 0(src=0)广播到所有其他的GPU设备上,以便在分布式训练时,每个GPU都能够获得相同的随机索引序列,以保持数据的同步性。# index for restoringidx_unshuffle = torch.argsort(idx_shuffle)  # tensor([4, 0, 3, 1, 5, 2, 7, 6])# shuffled index for this gpugpu_idx = torch.distributed.get_rank()idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]return x_gather[idx_this], idx_unshuffle@torch.no_grad()def _batch_unshuffle_ddp(self, x, idx_unshuffle):"""Undo batch shuffle.*** Only support DistributedDataParallel (DDP) model. ***"""# gather from all gpusbatch_size_this = x.shape[0]x_gather = concat_all_gather(x)batch_size_all = x_gather.shape[0]num_gpus = batch_size_all // batch_size_this# restored index for this gpugpu_idx = torch.distributed.get_rank()idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]return x_gather[idx_this]def forward(self, im_q, im_k):"""Input:im_q: a batch of query imagesim_k: a batch of key imagesOutput:logits, targets"""# compute query featuresq = self.encoder_q(im_q)  # queries: NxCq = nn.functional.normalize(q, dim=1)# compute key featureswith torch.no_grad():  # no gradient to keysself._momentum_update_key_encoder()  # update the key encoder# 和论文中伪代码的顺序不一样,论文中encoder_k是先forward后更新参数,这里是先更新参数后forward# shuffle for making use of BNim_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)k = self.encoder_k(im_k)  # keys: NxCk = nn.functional.normalize(k, dim=1)# undo shufflek = self._batch_unshuffle_ddp(k, idx_unshuffle)# compute logits# Einstein sum is more intuitive# positive logits: Nx1l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)# negative logits: NxKl_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])# logits: Nx(1+K)logits = torch.cat([l_pos, l_neg], dim=1)# apply temperaturelogits /= self.T# labels: positive key indicatorslabels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()# dequeue and enqueueself._dequeue_and_enqueue(k)return logits, labels# utils
@torch.no_grad()
def concat_all_gather(tensor):"""Performs all_gather operation on the provided tensors.*** Warning ***: torch.distributed.all_gather has no gradient."""tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]torch.distributed.all_gather(tensors_gather, tensor, async_op=False)output = torch.cat(tensors_gather, dim=0)return output

实验结果

无监督模型的常见评估方法是将训练好的encoder的权重freeze,后面接一层全连接层和softmax,然后在目标数据上只训练全连接层,最后在测试集上评估得到的模型效果。下面是MoCo和之前的无监督模型的结果对比,可以看到MoCo取得了最优的结果。

无监督模型的另一个作用是当做下游任务的预训练权重。在VOC目标检测任务上和监督预训练的对比如下,可以看到MoCo比监督预训练权重的效果更好。

 

下面是在COCO数据的目标检测任务和实例分割任务上与随机初始化权重、监督预训练权重的结果对比

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/814007.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【DM8】临时表空间

临时表分类 事务级临时表会话级临时表 临时表,可以像普通表一样插入、更新和删除数据临时表的DML操作产生较少的redo日志临时表支持创建索引,以提高查询性能在一个会话或事务结束之后,数据将自动从临时表中删除不同的用户访问相同的临时表&a…

吴恩达深度学习 (week3,4)

文章目录 一、神经网络概述二、神经网络的表示三、神经网络的输出四、多个例子的向量化五、向量化实现的解释六、深度学习激活函数七、激活函数导数八、神经网络的梯度下降法九、深度学习随机初始化十、上述学习总结1、第一题2、第二题3、第三题4、第四题5、第五题6、第六题7、…

关于Transformer的面试题

文章目录 前言1、Transformer模型1. 1 基本要点1.2 提升 2、BN和LN的区别和联系2.1 基本要点2.2 扩展 3、PreNorm和PostNorm的区别[链接](https://www.zhihu.com/question/519668254)4、Multi-head self-attention中为什么要用三个不同的矩阵 前言 Transformer大模型的一些问题…

【GEE实践应用】哨兵1号和2号数据叠加

目录 1.数据叠加代码 2.代码逐句解释 1.数据叠加代码 var geometry table; //table是我们提前导入的矢量数据 // 加载Sentinel-2影像 var sentinel2 ee.ImageCollection("COPERNICUS/S2").filterBounds(geometry) // geometry是你感兴趣区域的几何对象.filte…

Redis:发布和订阅

文章目录 一、介绍二、发布订阅命令 一、介绍 Redis的发布和订阅功能是一种消息通信模式,发送者(pub)发送消息,订阅者(sub)接收消息。这种功能使得消息发送者和接收者不需要直接建立连接,而是通…

商标没有去注册有哪些不好的影响!

有些商家咨询普推知产老杨,商标没有去注册有哪些不好的影响,其实对企业来说还有许多实际不利的影响,有时代价比注册一个商标要大很多。 想的商标名称没去注册商标,如果别人抢注拿下商标注册证,那就会涉及侵权&#xf…

工厂方法模式:解锁灵活的对象创建策略

在软件设计中,工厂方法模式是一种非常实用的创建型设计模式,它不仅提升了系统的灵活性,还简化了对象的创建过程。本文将详细探讨工厂方法模式的核心概念、实现方式、应用场景以及与其他设计模式的对比,旨在提供一份全面且实用的指…

磁悬浮鼓风机市场规模不断增长 我国行业发展面临挑战

磁悬浮鼓风机市场规模不断增长 我国行业发展面临挑战 磁悬浮鼓风机又称磁悬浮高速离心鼓风机,指基于磁悬浮技术制成的气体输送设备。磁悬浮鼓风机综合性能优良,属于高效节能磁悬浮动力装备,在众多领域需求旺盛。未来随着国家节能环保政策逐渐…

阿里云优惠口令2024最新

2024年阿里云域名优惠口令,com域名续费优惠口令“com批量注册更享优惠”,cn域名续费优惠口令“cn注册多个价格更优”,cn域名注册优惠口令“互联网上的中国标识”,阿里云优惠口令是域名专属的优惠码,可用于域名注册、续…

01—JavaScript概述

一、初识Javascript JavaScript一种直译式脚本语言,是一种动态类型、弱类型、基于原型的语言,内置支持类型。它的解释器被称为JavaScript引擎,为浏览器的一部分,广泛用于客户端的脚本语言,最早是在 HTML(标…

jsoncpp 编译和使用

原文链接: jsoncpp的编译和使用 jsoncpp 编译出库文件 1.从github仓库下载 2.下载 cmake 工具 3.生成VS项目 4.编译得到需要的库文件 jsoncpp 的使用 查看原文

基于Springboot的自习室预订系统

基于SpringbootVue的自习室预订系统的设计与实现 开发语言:Java数据库:MySQL技术:SpringbootMybatis工具:IDEA、Maven、Navicat 系统展示 用户登录页 网站首页 公告信息 留言反馈 后台管理 学生信息管理 公告信息管理 留言…

入门:多层感知器Multiple-Layer Perceiver, MLP

本文将简单介绍多层感知器(MLP)的基本概念、原理和应用。MLP是一种前馈人工神经网络,由多层节点组成,每层节点通过权重和偏置与下一层节点相连。MLP在许多领域都有广泛的应用,如分类、回归、自然语言处理等。 本文将分…

SRNIC、选择性重传、伸缩性、连接扩展性、RoCEv2优化(六)

参考论文SRDMA(A Scalable Architecture for RDMA NICs ):https://download.csdn.net/download/zz2633105/89101822 借此,对论文内容总结、加以思考和额外猜想,如有侵权,请联系删除。 如有描述不当之处&…

04异常Lambda算法正则

异常 异常是什么? 异常是代码在编译或者执行的过程中可能出现的错误。避免异常的出现,同时处理可能出现的异常,让代码更稳健。 异常分为几类? 编译时异常、运行时异常。编译时异常:没有继承RuntimeExcpetion的异常…

Linux: 工具: tshark 抓到了收方向的ESP明文包?

根据这个描述,看着是正常的, 抓到包之后,可以方便的分析问题,省去在wireshark里解码的问题。 经过调查发现是内核将ESP解开之后,如果是tunnel模式,内核又重新将skb丢给了interface去做处理。这样tshark/tcp…

Java基础(三)--常用工具类

文章目录 第三章、常用工具类一、Java异常1、什么是异常2、异常处理3、常见的异常类型4、throws5、throw6、自定义异常7、异常链 二、包装类1、包装类2、字符串与基本数据类型转换3、包装类的比较 三、String 类1、创建String对象的方法2、String的常用方法3、字符串的存储4、字…

360安全卫士去除广告方法

大安全时代,360 安全卫士为您提供全面安全服务,电脑端下载: https://urlqh.cn/orQqc 在当今数字化时代,网络安全已成为人们日常生活中的重要关切。在这片浩瀚的网络海洋中,360安全卫士犹如一座坚不可摧的灯塔&#xf…

基于微信公众号,搭建一套简单的电商支付环境(下)-- 微信公众号的对接

一、接着上文 上文把部署情况介绍了,侧重于网络及代理,本文选择把微信公众号的对接实现介绍一下。 还是那句话,微信官方的文档已非常详细,这里先摘抄一些重要的概念。 其次,待对接微信公众号的接口众多,…

Qt | 视频播放器(multimedia、multimediawidgets)

QT +=multimedia 通俗解释: 此代码行告诉编译器在构建应用程序时包含多媒体库。这意味着您的应用程序将能够播放和显示音频和视频文件。 使用分步说明构建模型: 创建一个新的 Qt 项目。 在 .pro 文件中添加以下行: QT += multimedia 导入必要的多媒体头文件: #include &l…