BYOL(NeurIPS 2020)原理解读

paper:Bootstrap your own latent: A new approach to self-supervised Learning

third-party implementation:https://github.com/open-mmlab/mmpretrain/blob/main/mmpretrain/models/selfsup/byol.py

本文的创新点

本文提出了一种新的自监督学习方法,Bootstrap Your Own Latent(BYOL),和以往需要大量负样本的对比学习方法如SimCLR不同,BYOL不依赖于负样本对。此外,和之前需要精心设计增强策略的对比方法相比,BYOL对图像增强的敏感度较低。BYOL在ImageNet上的linear evaluation取得了新的SOTA,并且在迁移学习和半监督学习的基准测试中表现优异。

方法介绍

BYOL的目标是学习一个可以用于下游任务的表示 \(y_{\theta}\)。BYOL使用online和target两个神经网络来学习。在线网络由一组权重 \(\theta\) 定义并由三个阶段组成:encoder \(f_{\theta}\)、projector \(g_{\theta}\)、predictor \(q_{\theta}\)。如图2所示。目标网络和在线网络的架构相同,但使用不同的权重 \(\xi\)。目标网络提供了用于训练在线网络的regression target,其参数 \(\xi\) 是在线网络参数 \(\theta\) 的指数移动平均值。给定一个衰减率 \(\tau \in[0,1]\),在每个训练step后,我们执行如下更新

给定一组图片 \(\mathcal{D}\),从 \(\mathcal{D}\) 中均匀采样得到一张图片 \(x\sim\mathcal{D}\),以及两组增强分布 \(\mathcal{T}\) 和 \(\mathcal{T}'\),BYOL通过对 \(x\) 分别应用增强 \(t\sim\mathcal{T}\) 和 \(t'\sim\mathcal{T}'\) 得到两个增强视图 \(v \triangleq t(x)\) 和 \(v' \triangleq t'(x)\)。从第一个增强视图 \(v\) 来看,在线网络输出一个表示 \(y \triangleq f_{\theta}(v)\) 和一个映射 \(z_{\theta} \triangleq g_{\theta}(y)\)。目标网络从第二个增强视图 \(v'\) 输出 \(y_{\xi} \triangleq f_{\xi}(v')\) 和目标映射 \(z_{\xi}' \triangleq g_{\xi}(y')\)。然后输出一个 \(z'_{\xi}\) 的预测 \(q_{\theta}(z_{\theta})\) 并对 \(q_{\theta}(z_{\theta})\) 和 \(z'_{\xi}\) 进行 \(\ell_2\) 标准化得到 \(\overline{q_\theta}\left(z_\theta\right) \triangleq q_\theta\left(z_\theta\right) /\left\|q_\theta\left(z_\theta\right)\right\|_2\) 和 \(\bar{z}_{\xi}^{\prime} \triangleq z_{\xi}^{\prime} /\left\|z_{\xi}^{\prime}\right\|_2\)。注意这个预测网络只应用在在线分支,使得在线和目标分支是非对称的结构。最后我们定义标准化预测和目标映射之间的均方根误差

然后我们再将 \(v'\) 输入在线网络,将 \(v\) 输入目标网络得到式(2)中 \(\mathcal{L}_{\theta, \xi}\) 的对称损失 \(\widetilde{\mathcal{L}}_{\theta, \xi}\)。在每个训练step,通过梯度下降根据 \(\mathcal{L}^{BYOL}_{\theta, \xi}=\mathcal{L}_{\theta, \xi}+\widetilde{\mathcal{L}}_{\theta, \xi}\) 来优化 \(\theta\),但不更新 \(\xi\)。如图2中的stop-gradient所示,BYOL的整体优化如下

在训练完成后,我们只保留 \(f_{\theta}\)。

BYOL的伪代码如下

 

实验结果

在ImageNet上的linear evaluation如下,可以看到BYOL取得了SOTA的表现。

 

作者还比较了减小batch size以及减少增强方法时BYOL和SimCLR性能。可以看到,BYOL比SimCLR性能下降的更慢,尤其是在减少数据增强时,表明BYOL对batch size和数据增强的不敏感。

 

代码解析

class BYOL(BaseSelfSupervisor):"""BYOL.Implementation of `Bootstrap Your Own Latent: A New Approach toSelf-Supervised Learning <https://arxiv.org/abs/2006.07733>`_.Args:backbone (dict): Config dict for module of backbone.neck (dict): Config dict for module of deep featuresto compact feature vectors.head (dict): Config dict for module of head functions.base_momentum (float): The base momentum coefficient for the targetnetwork. Defaults to 0.004.pretrained (str, optional): The pretrained checkpoint path, supportlocal path and remote path. Defaults to None.data_preprocessor (dict, optional): The config for preprocessinginput data. If None or no specified type, it will use"SelfSupDataPreprocessor" as type.See :class:`SelfSupDataPreprocessor` for more details.Defaults to None.init_cfg (Union[List[dict], dict], optional): Config dict for weightinitialization. Defaults to None."""def __init__(self,backbone: dict,neck: dict,head: dict,base_momentum: float = 0.004,pretrained: Optional[str] = None,data_preprocessor: Optional[dict] = None,init_cfg: Optional[Union[List[dict], dict]] = None) -> None:super().__init__(backbone=backbone,neck=neck,head=head,pretrained=pretrained,data_preprocessor=data_preprocessor,init_cfg=init_cfg)# create momentum modelself.target_net = CosineEMA(nn.Sequential(self.backbone, self.neck), momentum=base_momentum)def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample],**kwargs) -> Dict[str, torch.Tensor]:"""The forward function in training.Args:inputs (List[torch.Tensor]): The input images.data_samples (List[DataSample]): All elements requiredduring the forward function.Returns:Dict[str, torch.Tensor]: A dictionary of loss components."""assert isinstance(inputs, list)img_v1 = inputs[0]img_v2 = inputs[1]# compute online featuresproj_online_v1 = self.neck(self.backbone(img_v1))[0]proj_online_v2 = self.neck(self.backbone(img_v2))[0]# compute target featureswith torch.no_grad():# update the target netself.target_net.update_parameters(nn.Sequential(self.backbone, self.neck))proj_target_v1 = self.target_net(img_v1)[0]proj_target_v2 = self.target_net(img_v2)[0]loss_1 = self.head.loss(proj_online_v1, proj_target_v2)loss_2 = self.head.loss(proj_online_v2, proj_target_v1)losses = dict(loss=2. * (loss_1 + loss_2))return losses

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/826179.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

获取会话公钥

----------------------------------------------------举例 签到 接口开始--------------------------------------------------- 第一步&#xff1a;-----请求报文明文:{"body":{},"head":{"ywId":"GY0001"}} ODxdq2/WhHlCKoLIGNV2j…

uniapp picker 多列选择器用法

uniapp picker 多列选择器联动筛选器交互处理方法&#xff0c; uniapp 多列选择器 mode"multiSelector" 数据及筛选联动交互处理&#xff0c; 通过接口获取数据&#xff0c;根据用户选择当前列选项设置子列数据&#xff0c;实现三级联动效果&#xff0c; 本示例中处…

SEW减速机参数查询 2-2 实践

首先说说结论&#xff1a;在不和SEW官方取得沟通之前&#xff0c;你几乎无法直接通过查阅SEW官方文档得到相关减速机的所有技术参数&#xff1a;比如轴的模数和齿数&#xff0c;轴承的参数。我在周一耗费了一个上午&#xff0c;最终和SEW方面确认后才知晓相关技术参数需要凭借销…

Jenkins的安装和部署

文章目录 概述Jenkins部署项目的流程jenkins的安装启动创建容器进入容器浏览器访问8085端口 Jenkins创建项目创建example项目 概述 Jenkins&#xff1a;是一个开源的、提供友好操作界面的持续集成&#xff08;CLI&#xff09;工具&#xff0c;主要用于持续、自动构建的一些定时…

什么是Rust语言?探索安全系统编程的未来

&#x1f680; 什么是Rust语言&#xff1f;探索安全系统编程的未来 文章目录 &#x1f680; 什么是Rust语言&#xff1f;探索安全系统编程的未来摘要引言正文&#x1f4d8; Rust语言简介&#x1f31f; 发展历程&#x1f3af; Rust的技术意义和优势&#x1f4e6; Rust解决的问题…

GlobalRouting - FastRoute代码框架和功能(三)

文章目录 一、 顶层代码框架和功能(一)、总结(二)、各文件代码的概述&#xff1a;1. FastRoute\\src\\Box.h2. FastRoute\\src\\Coordinate.h3. FastRoute\\src\\DBWrapper.h4. FastRoute\\src\\FastRouteKernel.h5. FastRoute\\src\\Grid.h成员变量成员函数 6. FastRoute\\src…

python零基础入门 (9)-- 模块与包

文章目录 前言1. 什么是模块&#xff1f;1.1 模块的定义和作用1.2 内置模块和第三方模块 2. 如何使用模块&#xff1f;2.1 导入模块2.2 使用模块中的函数和变量 3. 什么是包&#xff1f;3.1 包的定义和作用3.2 包的结构和组织方式 4. 如何创建自定义模块&#xff1f;4.1 创建一…

Conmi的正确答案——ESP32获取MAC地址

ESP-IDF版本&#xff1a;v5.2.1 ESP32芯片型号&#xff1a;ESP32C3&#xff08;4M flash版本&#xff09; ESP支持的MAC地址有&#xff1a; typedef enum {ESP_MAC_WIFI_STA, /**< MAC for WiFi Station (6 bytes) */ESP_MAC_WIFI_SOFTAP, /**< MAC for WiFi Sof…

电商技术揭秘三十:知识产权保护浅析

电商技术揭秘相关系列文章&#xff08;上&#xff09; 相关系列文章&#xff08;中&#xff09; 电商技术揭秘二十&#xff1a;能化供应链管理 电商技术揭秘二十一:智能仓储与物流优化(上) 电商技术揭秘二十二:智能仓储与物流优化(下) 电商技术揭秘二十三&#xff1a;智能…

REACT+PHP课程项目血泪史

PHP php??老师让用php写后端。什么&#xff1f;写惯了java、python。这个看起来像html标签语言的东西写后端是个什么鬼&#xff0c;看起来想落后几千年的原始语言&#xff08;手动滑稽&#xff09;。 大概介绍一下&#xff0c;php主要是后端语言&#xff0c;用来连接数据库…

Day17-Java基础之综合案例

练习一&#xff1a;飞机票 需求: 机票价格按照淡季旺季、头等舱和经济舱收费、输入机票原价、月份和头等舱或经济舱。 按照如下规则计算机票价格&#xff1a;旺季&#xff08;5-10月&#xff09;头等舱9折&#xff0c;经济舱8.5折&#xff0c;淡季&#xff08;11月到来年4月…

deepinV23 Beta3安装cuda

文章目录 下载CUDA安装,以cuda11.6为例运行.run文件安装选项配置环境变量查看cuda版本重启计算机 卸载cuda deepinV23 Beta3对应的debian版本是12&#xff1a; bookworm指的是debian12&#xff0c; sid代表不稳定版。 下载CUDA 官网&#xff1a;https://developer.nvidia.com…

你为什么会成为一名程序员?

在当今数字化时代&#xff0c;程序员这一职业越来越受到人们的关注和追捧。许多人选择成为一名程序员&#xff0c;不仅是因为这个职业的前景广阔&#xff0c;还因为他们对编程和技术的兴趣。那么&#xff0c;选择成为一名程序员的原因究竟是出于兴趣还是职业发展呢&#xff1f;…

中华环保联合会获得国家“绿色制造体系” 第三方评价机构资格

近日&#xff0c;中华环保联合会成功获得工业和信息化部“绿色制造体系”第三方评价机构资格&#xff0c;可为企业、园区及相关机构提供全面的绿色制造体系评价服务&#xff0c;包括绿色工厂、绿色园区、绿色供应链等方面。 “绿色制造体系建设”是由工业和信息化部负责统筹推进…

Python3中的时间应用 (代码)

直接上python3代码 # 对时间类型的转换 from datetime import datetime import localelocale.setlocale(locale.LC_CTYPE, "chinese")# 字符串 -> datetime类型 text "2024年-4月-1日" res datetime.strptime(text, "%Y年-%m月-%d日") pr…

redis异常:OOM command not allowed when used memory > ‘maxmemory‘

redis存储数据太多,内存溢出,导致异常 1.查看redis内存使用情况 登录redis后 info memory2.查看分配给redis的最大内存 config get maxmemory3.处理方式:拓展redis的最大内存 打开redis.conf文件,修改maxmemory 4.删掉键值重启redis后,发现删掉的数据又恢复了? redis根目录…

Midjourney是什么?Midjourney怎么用?怎么注册Midjourney账号?国内怎么使用Midjourney?多人合租Midjourney拼车

Midjourney是什么 OpenAI发布的ChatGPT4引领了聊天机器人的竞争浪潮&#xff0c;随后谷歌推出了自己的AI聊天机器人Bard&#xff0c;紧接着微软推出了Bing Chat&#xff0c;百度也推出了文心一言&#xff0c;这些聊天机器人的推出&#xff0c;标志着对话式AI技术已经达到了一个…

月球地形数据介绍(LOLA)

月球地形数据介绍 LOLA介绍LOLA数据的处理与发布数据类型和格式投影坐标系SIMPLE CYLINDRICALPOLAR STEREOGRAPHIC 数据下载与浏览 LOLA介绍 目前最新的月球地形高程数据来源于美国2009年发射的LRO探测器。 “月球勘测轨道器”(Lunar Reconnaissance Orbiter&#xff0c;LRO)…

7.2 跳跃表(skiplist)

文章目录 前言一、跳跃表——查找操作二、跳跃表——插入操作三、代码演示3.1 输出结果3.2 代码细节 四、总结&#xff1a;参考文献&#xff1a; 前言 本章内容参考海贼宝藏胡船长的数据结构与算法中的第七章——查找算法&#xff0c;侵权删。 查找的时间复杂度能从原来链表的…

getopt函数使用

1. 函数介绍 用来解析命令行参数的。 函数原型 #include <unistd.h>int getopt(int argc, char * const argv[],const char *optstring);第三个参数就是设置解析的参数。 假设 o p t s t r i n g " a b : c : : " optstring "ab:c::" optstring…