详细解析Barlow Twins:自监督学习中的创新方法

首先先简单了解一下机器学习中,主要有三种学习范式:监督学习、无监督学习和自监督学习:

  • 监督学习:依赖带标签的数据,通过输入输出映射关系进行训练。
  • 无监督学习:不依赖标签,关注数据的内在结构和模式。
  • 自监督学习:利用数据本身生成标签,通过预训练任务学习有效的特征表示。

Barlow Twins

Barlow Twins是一种基于信息论的自监督学习方法,其目标是减少神经元之间的冗余。该方法要求神经元对数据增强具有不变性,但彼此独立。

在实际训练中,通过反向传播(backpropagation)调整神经网络的参数,使得交叉相关矩阵的对角线元素尽可能大,而非对角线元素尽可能小——接近单位矩阵,从而达到上述目标。

1 例子

假设我们有一张图片 X X X ,经过两个不同的数据增强得到图像 Y A Y^A YA Y B Y^B YB ,其再通过相同的神经网络得到特征表示 Z A Z^A ZA Z B Z^B ZB (假设有RGB三维)。由于是同一张图片, Z A Z^A ZA 的蓝色与 Z B Z^B ZB 的蓝色应该相似(红绿同理),同时为了最大限度减少冗余,我们希望特征彼此本身不同(即 Z A Z^A ZA 中的蓝绿红彼此不同) —— 对数据增强保持不变,但独立于其他

image-20240530211739303

数学上描述即为:计算特征表示 Z A Z^A ZA Z B Z^B ZB 的交叉相关矩阵,目标为使该矩阵接近单位矩阵。

这张图展示了Barlow Twins方法的主要流程。具体步骤如下:

  1. 数据增强
    • 从输入图像 X X X 出发,使用不同的数据增强变换 T T T 生成两组扭曲图像 Y A Y^A YA Y B Y^B YB。这些变换包括随机裁剪、翻转、颜色抖动等。
  2. 特征提取
    • 将扭曲图像 Y A Y^A YA Y B Y^B YB 输入相同的神经网络 f θ f_\theta fθ,生成对应的特征表示 Z A Z^A ZA Z B Z^B ZB
  3. 计算交叉相关矩阵
    • 计算特征表示 Z A Z^A ZA Z B Z^B ZB交叉相关矩阵。目标是使该矩阵接近单位矩阵,从而:
      • 对角线元素:希望在不同数据增强下,相同神经元的特征表示具有高度相关性(接近1)。
      • 非对角线元素:希望不同神经元之间没有冗余(接近0)。

2 Loss计算

交叉相关矩阵 C i j C_{ij} Cij​ 的计算

衡量了不同增强视图下神经元之间的相关性
C i j = ∑ b z b , i A z b , j B ∑ b ( z b , i A ) 2 ∑ b ( z b , j B ) 2 C_{ij} = \frac{\sum_b z^A_{b,i} z^B_{b,j}}{\sqrt{\sum_b (z^A_{b,i})^2} \sqrt{\sum_b (z^B_{b,j})^2}} Cij=b(zb,iA)2 b(zb,jB)2 bzb,iAzb,jB

  • z b , i A z^A_{b,i} zb,iA z b , j B z^B_{b,j} zb,jB 分别表示第 b b b 个样本在增强视图 A A A B B B 中第 i i i 和第 j j j 个神经元的特征表示。
损失函数 L B T \mathcal{L}_{BT} LBT

L B T = ∑ i ( 1 − C i i ) 2 + λ ∑ i ∑ j ≠ i C i j 2 \mathcal{L}_{BT} = \sum_i (1 - C_{ii})^2 + \lambda \sum_i \sum_{j \neq i} C_{ij}^2 LBT=i(1Cii)2+λij=iCij2

  • 不变性项:
    ∑ i ( 1 − C i i ) 2 \sum_i (1 - C_{ii})^2 i(1Cii)2 这个部分希望对角线上的元素 C i i C_{ii} Cii 尽可能接近1,表示在不同增强视图下,相同神经元的特征表示高度相关。

  • 冗余减少项:
    λ ∑ i ∑ j ≠ i C i j 2 \lambda \sum_i \sum_{j \neq i} C_{ij}^2 λij=iCij2 这个部分希望非对角线上的元素 C i j C_{ij} Cij 尽可能接近0,表示不同神经元之间没有冗余。系数 λ \lambda λ 是一个超参数,用来平衡这两个项的权重。

整个Barlow Twins的关键即损失函数:

返回方阵非对角线元素的扁平(一维)视图函数:

  1. x.flatten()[:-1]:首先,将方阵x扁平化(即将其转换为一维数组),然后删除最后一个元素。扁平化后的数组中,最后一个元素是方阵的最后一个对角线元素。

  2. .view(n - 1, m + 1):然后,将扁平化后的数组重新塑形为一个(n - 1, m + 1)的矩阵。这个矩阵的每一行都包含了原方阵的一行元素。

  3. [:, 1:]:接着,删除矩阵的第一列。这一列包含了原方阵的剩余所有对角线元素。

  4. .flatten():最后,再次将矩阵扁平化。这样,得到的就是一个包含了原方阵所有非对角线元素的一维数组。

def off_diagonal(x):'''返回方阵非对角线元素的扁平(一维)视图'''n, m = x.shapeassert n == mreturn x.flatten()[:-1].view(n - 1, m + 1)[:, 1:].flatten()

barlow_loss计算函数:

def barlow_loss(z1, z2, bn, lambd):'''返回一对特征的Barlow Twins的loss:param z1:第一个输入特征:param z2:第二个输入特征:param bn:应用于 z1 和 z2 的 nn.BatchNorm1d 层:param lambd:权衡超参数 lambda'''# 批量归一化z1_norm = bn(z1)z2_norm = bn(z2)batch_size = z1.size(0)# 计算 z1 和 z2 的协方差矩阵c = torch.mm(z1_norm, z2_norm.t()) / batch_size# lossc_diff = (c - torch.eye(c.size(0), device=c.device)).pow(2)c_diff = off_diagonal(c_diff).mul_(lambd)loss = c_diff.sum()return loss

3 整体流程

整体流程的伪代码如下:

# 训练循环
for x in loader:  # 加载一个批次包含N个样本# 对每个样本生成两个随机增强版本y_a, y_b = augment(x)  # augment函数生成数据增强版本# 计算表征z_a = f(y_a)  # NxDz_b = f(y_b)  # NxD# 沿批次维度标准化表征z_a_norm = (z_a - z_a.mean(dim=0)) / z_a.std(dim=0)  # NxDz_b_norm = (z_b - z_b.mean(dim=0)) / z_b.std(dim=0)  # NxD# 计算交叉相关矩阵c = torch.mm(z_a_norm.T, z_b_norm) / N  # DxD# 计算损失c_diff = (c - torch.eye(D, device=c.device)).pow(2)  # DxD# 将非对角线元素乘以lambdaoff_diagonal(c_diff).mul_(lambda_off_diag)loss = c_diff.sum()# 优化步骤optimizer.zero_grad()loss.backward()optimizer.step()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/845159.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

pikachu靶场(unsafe upfileupload(文件上传)通关教程)

目录 client check 1.在桌面新建一个文本文档 2.保存为.png格式 3.打开网站 4.按照图中操作 5.点击forward 6.访问 MIME type 1.新建一个php文件,里面写上 2.上传文件,就是我们保存的文件 3.打开抓包工具,点击开始上传 4.修改Conen…

ADC数模转换器

一、ADC(Analog-Digital Converter)模拟-数字转换器 1、ADC可以将引脚上连续变化的模拟电压转换为内存中存储的数字变量,建立模拟电路到数字电路的桥梁 2、12位逐次逼近型ADC,1us转换时间 3、输入电压范围:0~3.3V&a…

【K8s】专题四(2):Kubernetes 控制器之 Deployment

以下内容均来自个人笔记并重新梳理,如有错误欢迎指正!如果对您有帮助,烦请点赞、关注、转发!欢迎扫码关注个人公众号! 目录 一、基本介绍 二、工作原理 三、相关特性 四、资源清单(示例) 五…

【Linux】多线程——线程概念|进程VS线程|线程控制

> 作者:დ旧言~ > 座右铭:松树千年终是朽,槿花一日自为荣。 > 目标:理解【Linux】多线程——线程概念|进程VS线程|线程控制 > 毒鸡汤:有些事情,总是不明白,所以我不会坚持。早安! &…

产品上市新闻稿怎么写?纯干货

一个产品的上市,想要达到一个非常好的宣传效果,前期的预热造势是必不可少的,投放产品上市新闻稿到权威专业的媒体,潜移默化去影响用户的心智,产品上市新闻稿怎么写?接下来伯乐网络传媒就来给大家分享一下&a…

重生之 SpringBoot3 入门保姆级学习(10、日志基础与使用)

重生之 SpringBoot3 入门保姆级学习(10、日志基础使用) 3.1 日志基础3.2 使用日志3.2.1 基础使用3.2.2 调整日志级别3.2.3 带参数的日志 3.1 日志基础 SpringBoot 默认使用 SLF4j(Simple Logging Facade for Java)和 Logback 实现…

码蹄集部分题目(2024OJ赛17期;二分+差分+ST表+单调队列+单调栈)

1🐋🐋小码哥处理订单(钻石;二分差分) 时间限制:1秒 占用内存:128M 🐟题目描述 🐟题目思路 【码蹄集进阶塔全题解07】算法基础:二分 MT2070 – MT2079_哔哩…

Element ui 快速入门(基础知识点)

element ui官网 前言: 在当今时代,我们在编写计算机程序时,不仅仅是写几个增删改查的简单功能,为了满足广大用户对页面美观的需求,为了让程序员们写一些功能更简便,提高团队协作效率,所以eleme…

python操作mongodb底层封装并验证pymongo是否应该关闭连接

一、pymongo简介 github地址:https://github.com/mongodb/mongo-python-driver mymongo安装命令:pip install pymongo4.7.2 mymongo接口文档:PyMongo 4.7.2 Documentation PyMongo发行版包含Python与MongoDB数据库交互的工具。bson包是用…

【Python】解决Python报错:AttributeError: ‘int‘ object has no attribute ‘xxx‘

🧑 博主简介:阿里巴巴嵌入式技术专家,深耕嵌入式人工智能领域,具备多年的嵌入式硬件产品研发管理经验。 📒 博客介绍:分享嵌入式开发领域的相关知识、经验、思考和感悟,欢迎关注。提供嵌入式方向…

RLC防孤岛保护装置如何工作的?

什么是RLC防孤岛保护装置? 孤岛保护装置是电力系统中一道强大的守护利器,它以敏锐的感知和迅速的反应,守护着电网的平稳运行。当电网遭遇故障或意外脱离主网时,孤岛保护装置如同一位机警的守门人,立刻做出决断&#xf…

Go微服务: 基于Docker搭建Kong网关环境

概述 在当今的微服务架构中,API网关扮演着至关重要的角色,它作为系统的统一入口负责处理所有内外部请求,实现路由转发、负载均衡、安全控制、限流熔断等多种功能Kong,作为一个开源、高性能、可扩展的API网关,凭借其强…

【机器学习】探索未来科技的前沿:人工智能、机器学习与大模型

文章目录 引言一、人工智能:从概念到现实1.1 人工智能的定义1.2 人工智能的发展历史1.3 人工智能的分类1.4 人工智能的应用 二、机器学习:人工智能的核心技术2.1 机器学习的定义2.2 机器学习的分类2.3 机器学习的实现原理2.4 机器学习的应用2.5 机器学习…

在PostGIS中检查孤线(Find isolated lines in PostGIS)

场景 在PostGIS中有一张线要素表,需要检查该表中的孤线,并且进行自动纠正的计算。 其中孤线定义为两端端点都不在任何其他线的顶点上。 本文介绍在PostGIS中的线要素点,通过函数计算指定线要素表中的孤线,并计算最接近的纠偏位置。 In PostGIS, there is a table of line …

GPT-4o(OpenAI最新推出的大模型)

简介:最近,GPT-4o横空出世。对GPT-4o这一人工智能技术进行评价,包括版本间的对比分析、GPT-4o的技术能力以及个人感受等。 方向一:对比分析 GPT-4o(OpenAI最新推出的大模型)与GPT-4之间的主要区别体现在响应…

268 基于matlab的模拟双滑块连杆机构运动

基于matlab的模拟双滑块连杆机构运动,并绘制运动动画,连杆轨迹可视化输出,并输出杆件质心轨迹、角速度、速度变化曲线。可定义杆长、滑块速度,滑块初始位置等参数。程序已调通,可直接运行。 268 双滑块连杆机构运动 连…

Github单个文件或者单个文件夹下载插件

有时候我们在github上备份了一些资料,比如pdf,ppt,md之类的,需要用到的时候只要某个文件即可,又不要把整个仓库的zip包下载下来,毕竟有时文件太多,下载慢,我们也不需要所有资料,那么就可以使用到…

i-am-a-bot:一款基于多个大语言模型的验证码系统安全评估工具

关于i-am-a-bot i-am-a-bot是一款基于多个大语言模型的验证码安全评估工具,该工具提供了一个使用了多模态大语言模型(LLM)的自动化解决方案,可以帮助广大研究人员测试各种类型验证码机制的安全性。 从底层上看,i-am-a…

renren-fast-vue启动报错

问题描述 拉取人人开源vue项目启动失败 报错信息 版本信息 序号名称版本号1node14.21.3 启动方案 1.拉取项目 git clone https://gitee.com/renrenio/renren-fast-vue.git 2.执行安装依赖命令 npm install 3.此时报错 chromedriver2.27.2 install: node install.js 4.手动…

安装与使用ChatTTS文本转语音模型

非常自然的文本转语音(Text To Speech)TTS,支持中英文混读,还可以穿插笑声,听起来很真实自然。 1、有哪些优点 对话式 TTS: ChatTTS针对对话式任务进行了优化,实现了自然流畅的语音合成,同时支持多说话人。细粒度控制…