BYOL(NeurIPS 2020)原理解读

paper:Bootstrap your own latent: A new approach to self-supervised Learning

third-party implementation:https://github.com/open-mmlab/mmpretrain/blob/main/mmpretrain/models/selfsup/byol.py

本文的创新点

本文提出了一种新的自监督学习方法,Bootstrap Your Own Latent(BYOL),和以往需要大量负样本的对比学习方法如SimCLR不同,BYOL不依赖于负样本对。此外,和之前需要精心设计增强策略的对比方法相比,BYOL对图像增强的敏感度较低。BYOL在ImageNet上的linear evaluation取得了新的SOTA,并且在迁移学习和半监督学习的基准测试中表现优异。

方法介绍

BYOL的目标是学习一个可以用于下游任务的表示 \(y_{\theta}\)。BYOL使用online和target两个神经网络来学习。在线网络由一组权重 \(\theta\) 定义并由三个阶段组成:encoder \(f_{\theta}\)、projector \(g_{\theta}\)、predictor \(q_{\theta}\)。如图2所示。目标网络和在线网络的架构相同,但使用不同的权重 \(\xi\)。目标网络提供了用于训练在线网络的regression target,其参数 \(\xi\) 是在线网络参数 \(\theta\) 的指数移动平均值。给定一个衰减率 \(\tau \in[0,1]\),在每个训练step后,我们执行如下更新

给定一组图片 \(\mathcal{D}\),从 \(\mathcal{D}\) 中均匀采样得到一张图片 \(x\sim\mathcal{D}\),以及两组增强分布 \(\mathcal{T}\) 和 \(\mathcal{T}'\),BYOL通过对 \(x\) 分别应用增强 \(t\sim\mathcal{T}\) 和 \(t'\sim\mathcal{T}'\) 得到两个增强视图 \(v \triangleq t(x)\) 和 \(v' \triangleq t'(x)\)。从第一个增强视图 \(v\) 来看,在线网络输出一个表示 \(y \triangleq f_{\theta}(v)\) 和一个映射 \(z_{\theta} \triangleq g_{\theta}(y)\)。目标网络从第二个增强视图 \(v'\) 输出 \(y_{\xi} \triangleq f_{\xi}(v')\) 和目标映射 \(z_{\xi}' \triangleq g_{\xi}(y')\)。然后输出一个 \(z'_{\xi}\) 的预测 \(q_{\theta}(z_{\theta})\) 并对 \(q_{\theta}(z_{\theta})\) 和 \(z'_{\xi}\) 进行 \(\ell_2\) 标准化得到 \(\overline{q_\theta}\left(z_\theta\right) \triangleq q_\theta\left(z_\theta\right) /\left\|q_\theta\left(z_\theta\right)\right\|_2\) 和 \(\bar{z}_{\xi}^{\prime} \triangleq z_{\xi}^{\prime} /\left\|z_{\xi}^{\prime}\right\|_2\)。注意这个预测网络只应用在在线分支,使得在线和目标分支是非对称的结构。最后我们定义标准化预测和目标映射之间的均方根误差

然后我们再将 \(v'\) 输入在线网络,将 \(v\) 输入目标网络得到式(2)中 \(\mathcal{L}_{\theta, \xi}\) 的对称损失 \(\widetilde{\mathcal{L}}_{\theta, \xi}\)。在每个训练step,通过梯度下降根据 \(\mathcal{L}^{BYOL}_{\theta, \xi}=\mathcal{L}_{\theta, \xi}+\widetilde{\mathcal{L}}_{\theta, \xi}\) 来优化 \(\theta\),但不更新 \(\xi\)。如图2中的stop-gradient所示,BYOL的整体优化如下

在训练完成后,我们只保留 \(f_{\theta}\)。

BYOL的伪代码如下

 

实验结果

在ImageNet上的linear evaluation如下,可以看到BYOL取得了SOTA的表现。

 

作者还比较了减小batch size以及减少增强方法时BYOL和SimCLR性能。可以看到,BYOL比SimCLR性能下降的更慢,尤其是在减少数据增强时,表明BYOL对batch size和数据增强的不敏感。

 

代码解析

class BYOL(BaseSelfSupervisor):"""BYOL.Implementation of `Bootstrap Your Own Latent: A New Approach toSelf-Supervised Learning <https://arxiv.org/abs/2006.07733>`_.Args:backbone (dict): Config dict for module of backbone.neck (dict): Config dict for module of deep featuresto compact feature vectors.head (dict): Config dict for module of head functions.base_momentum (float): The base momentum coefficient for the targetnetwork. Defaults to 0.004.pretrained (str, optional): The pretrained checkpoint path, supportlocal path and remote path. Defaults to None.data_preprocessor (dict, optional): The config for preprocessinginput data. If None or no specified type, it will use"SelfSupDataPreprocessor" as type.See :class:`SelfSupDataPreprocessor` for more details.Defaults to None.init_cfg (Union[List[dict], dict], optional): Config dict for weightinitialization. Defaults to None."""def __init__(self,backbone: dict,neck: dict,head: dict,base_momentum: float = 0.004,pretrained: Optional[str] = None,data_preprocessor: Optional[dict] = None,init_cfg: Optional[Union[List[dict], dict]] = None) -> None:super().__init__(backbone=backbone,neck=neck,head=head,pretrained=pretrained,data_preprocessor=data_preprocessor,init_cfg=init_cfg)# create momentum modelself.target_net = CosineEMA(nn.Sequential(self.backbone, self.neck), momentum=base_momentum)def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample],**kwargs) -> Dict[str, torch.Tensor]:"""The forward function in training.Args:inputs (List[torch.Tensor]): The input images.data_samples (List[DataSample]): All elements requiredduring the forward function.Returns:Dict[str, torch.Tensor]: A dictionary of loss components."""assert isinstance(inputs, list)img_v1 = inputs[0]img_v2 = inputs[1]# compute online featuresproj_online_v1 = self.neck(self.backbone(img_v1))[0]proj_online_v2 = self.neck(self.backbone(img_v2))[0]# compute target featureswith torch.no_grad():# update the target netself.target_net.update_parameters(nn.Sequential(self.backbone, self.neck))proj_target_v1 = self.target_net(img_v1)[0]proj_target_v2 = self.target_net(img_v2)[0]loss_1 = self.head.loss(proj_online_v1, proj_target_v2)loss_2 = self.head.loss(proj_online_v2, proj_target_v1)losses = dict(loss=2. * (loss_1 + loss_2))return losses

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/826179.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uniapp picker 多列选择器用法

uniapp picker 多列选择器联动筛选器交互处理方法&#xff0c; uniapp 多列选择器 mode"multiSelector" 数据及筛选联动交互处理&#xff0c; 通过接口获取数据&#xff0c;根据用户选择当前列选项设置子列数据&#xff0c;实现三级联动效果&#xff0c; 本示例中处…

SEW减速机参数查询 2-2 实践

首先说说结论&#xff1a;在不和SEW官方取得沟通之前&#xff0c;你几乎无法直接通过查阅SEW官方文档得到相关减速机的所有技术参数&#xff1a;比如轴的模数和齿数&#xff0c;轴承的参数。我在周一耗费了一个上午&#xff0c;最终和SEW方面确认后才知晓相关技术参数需要凭借销…

Jenkins的安装和部署

文章目录 概述Jenkins部署项目的流程jenkins的安装启动创建容器进入容器浏览器访问8085端口 Jenkins创建项目创建example项目 概述 Jenkins&#xff1a;是一个开源的、提供友好操作界面的持续集成&#xff08;CLI&#xff09;工具&#xff0c;主要用于持续、自动构建的一些定时…

什么是Rust语言?探索安全系统编程的未来

&#x1f680; 什么是Rust语言&#xff1f;探索安全系统编程的未来 文章目录 &#x1f680; 什么是Rust语言&#xff1f;探索安全系统编程的未来摘要引言正文&#x1f4d8; Rust语言简介&#x1f31f; 发展历程&#x1f3af; Rust的技术意义和优势&#x1f4e6; Rust解决的问题…

电商技术揭秘三十:知识产权保护浅析

电商技术揭秘相关系列文章&#xff08;上&#xff09; 相关系列文章&#xff08;中&#xff09; 电商技术揭秘二十&#xff1a;能化供应链管理 电商技术揭秘二十一:智能仓储与物流优化(上) 电商技术揭秘二十二:智能仓储与物流优化(下) 电商技术揭秘二十三&#xff1a;智能…

deepinV23 Beta3安装cuda

文章目录 下载CUDA安装,以cuda11.6为例运行.run文件安装选项配置环境变量查看cuda版本重启计算机 卸载cuda deepinV23 Beta3对应的debian版本是12&#xff1a; bookworm指的是debian12&#xff0c; sid代表不稳定版。 下载CUDA 官网&#xff1a;https://developer.nvidia.com…

中华环保联合会获得国家“绿色制造体系” 第三方评价机构资格

近日&#xff0c;中华环保联合会成功获得工业和信息化部“绿色制造体系”第三方评价机构资格&#xff0c;可为企业、园区及相关机构提供全面的绿色制造体系评价服务&#xff0c;包括绿色工厂、绿色园区、绿色供应链等方面。 “绿色制造体系建设”是由工业和信息化部负责统筹推进…

redis异常:OOM command not allowed when used memory > ‘maxmemory‘

redis存储数据太多,内存溢出,导致异常 1.查看redis内存使用情况 登录redis后 info memory2.查看分配给redis的最大内存 config get maxmemory3.处理方式:拓展redis的最大内存 打开redis.conf文件,修改maxmemory 4.删掉键值重启redis后,发现删掉的数据又恢复了? redis根目录…

Midjourney是什么?Midjourney怎么用?怎么注册Midjourney账号?国内怎么使用Midjourney?多人合租Midjourney拼车

Midjourney是什么 OpenAI发布的ChatGPT4引领了聊天机器人的竞争浪潮&#xff0c;随后谷歌推出了自己的AI聊天机器人Bard&#xff0c;紧接着微软推出了Bing Chat&#xff0c;百度也推出了文心一言&#xff0c;这些聊天机器人的推出&#xff0c;标志着对话式AI技术已经达到了一个…

月球地形数据介绍(LOLA)

月球地形数据介绍 LOLA介绍LOLA数据的处理与发布数据类型和格式投影坐标系SIMPLE CYLINDRICALPOLAR STEREOGRAPHIC 数据下载与浏览 LOLA介绍 目前最新的月球地形高程数据来源于美国2009年发射的LRO探测器。 “月球勘测轨道器”(Lunar Reconnaissance Orbiter&#xff0c;LRO)…

7.2 跳跃表(skiplist)

文章目录 前言一、跳跃表——查找操作二、跳跃表——插入操作三、代码演示3.1 输出结果3.2 代码细节 四、总结&#xff1a;参考文献&#xff1a; 前言 本章内容参考海贼宝藏胡船长的数据结构与算法中的第七章——查找算法&#xff0c;侵权删。 查找的时间复杂度能从原来链表的…

线上真实案例之执行一段逻辑后报错Communications link failure

1.问题发现 在开发某个项目的一个定时任务计算经销商返利的功能时&#xff0c;有一个返利监测的调度&#xff0c;如果某一天返利计算调度失败了&#xff0c;需要重新计算&#xff0c;这个监测的调度就会重新计算某天的数据。 在UAT测试通过&#xff0c;发布生产后&#xff0c…

CSS动画(css、js动画库:各种动画效果)

第一种方法&#xff1a;文字从上到下显示动画&#xff1b; <div class"text-container"><div class"text">文字从上到下显示</div></div><style scoped> /*确保 keyframes 规则在引用它的任何选择器之前定义&#xff0c;以避…

Android开发:应用百度智能云中的身份证识别OCR实现获取个人信息的功能

百度智能云&#xff1a; 百度智能云是百度提供的公有云平台&#xff0c;于2015年正式开放运营。该平台秉承“用科技力量推动社会创新”的愿景&#xff0c;致力于将百度在云计算、大数据、人工智能的技术能力向社会输出。 百度智能云为金融、城市、医疗、客服与营销、能源、制造…

C语言数据结构之顺序表

目录 1.线性表2.顺序表2.1顺序表相关概念及结构2.2增删查改等接口的实现 3.数组相关例题 1.线性表 线性表&#xff08;linear list&#xff09;是n个具有相同特性&#xff08;数据类型相同&#xff09;的数据元素的有限序列。 线性表是一种在实际中广泛使用的数据结构&#xff…

2024年阿里云服务器明细报价整理总结

2024年阿里云服务器租用费用&#xff0c;云服务器ECS经济型e实例2核2G、3M固定带宽99元一年&#xff0c;轻量应用服务器2核2G3M带宽轻量服务器一年61元&#xff0c;ECS u1服务器2核4G5M固定带宽199元一年&#xff0c;2核4G4M带宽轻量服务器一年165元12个月&#xff0c;2核4G服务…

Zynq 7000 SoC器件的复位系统

Zynq7000 SoC器件中的复位系统包括由硬件、看门狗定时器、JTAG控制器和软件生成的复位。每个模块和系统都包括一个由复位系统驱动的复位。硬件复位由上电复位信号&#xff08;PS_POR_B&#xff09;和系统复位信号&#xff08;PS_SRST_B&#xff09;驱动。 在PS中&#xff0c;有…

JAVA基础面试题(第九篇)中! 集合与数据结构

JAVA集合和数据结构也是面试常考的点&#xff0c;内容也是比较多。 在看之前希望各位如果方便可以点赞收藏&#xff0c;给我点个关注&#xff0c;创作不易&#xff01; JAVA集合 11. HashMap 中 key 的存储索引是怎么计算的&#xff1f; 首先根据key的值计算出hashcode的值…

隧道代理的优势与劣势分析

“随着互联网的快速发展&#xff0c;网络安全已经成为一个重要的议题。为了保护个人和组织的数据&#xff0c;隧道代理技术逐渐成为网络安全的重要工具。隧道代理通过在客户端和服务器之间建立安全通道&#xff0c;加密和保护数据的传输&#xff0c;有效地防止黑客入侵和信息泄…

15-partition table (分区表)

ESP32-S3的分区表 什么是分区表&#xff1f;&#x1f914; ESP32-S3的分区表是用来确定在ESP32-S3的闪存中数据和应用程序的布局。每个ESP32-S3的闪存可以包含多个应用程序&#xff0c;以及多种不同类型的数据&#xff08;例如校准数据、文件系统数据、参数存储数据等&#x…