浅谈 EMP-SSL + 代码解读:自监督对比学习的一种极简主义风

论文链接:https://arxiv.org/pdf/2304.03977.pdf

代码:https://github.com/tsb0601/EMP-SSL

其他学习链接:突破自监督学习效率极限!马毅、LeCun联合发布EMP-SSL:无需花哨trick,30个epoch即可实现SOTA


主要思想

如图,一张图片裁剪成不同的 patch,对不同的 patch 做数据增强,分别输入 encoder,得到多个 embedding,对它们求均值,得到 \bar z 作为这张图片的 embedding。最后,拉近每个 patch 的 embedding 和图片的 embedding(\bar z)之间的余弦距离;再用 Total Coding Rate(TCR) 防止坍塌(即 encoder 对所有输入都输出相同的 embedding)

图片

图片

Total Coding Rate(TCR)

公式如下:

图片

其中,det 表示求矩阵的行列式,d 是 feature vector 的 dimension,b 是 batch size

查了查该公式的含义:expand all features of Z as large as possible,即尽可能拉远矩阵中特征之间的距离。

源自 PPT 第 24 页:

https://s3.amazonaws.com/sf-web-assets-prod/wp-content/uploads/2021/06/15175515/Deep_Networks_from_First_Principles.pdf

至于为什么最大化该公式的值就可以拉远矩阵中特征之间的距离,这背后的数学原理真难啃啊 /(ㄒoㄒ)/~~


核心代码解读

数据处理

https://github.com/tsb0601/EMP-SSL/blob/main/dataset/aug.py#L116C1-L138C27

class ContrastiveLearningViewGenerator(object):def __init__(self, num_patch = 4):self.num_patch = num_patchdef __call__(self, x):normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])aug_transform = transforms.Compose([transforms.RandomResizedCrop(32,scale=(0.25, 0.25), ratio=(1,1)),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8),transforms.RandomGrayscale(p=0.2),GBlur(p=0.1),transforms.RandomApply([Solarization()], p=0.1),transforms.ToTensor(),  normalize])augmented_x = [aug_transform(x) for i in range(self.num_patch)]return augmented_x

由此看出返回的 数据 为:长度为 num_patches 个 tensor 的列表。其中,每个 tensor 的 shape 为 (B, C, H, W)。

主函数

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L148C9-L162C63

for step, (data, label) in tqdm(enumerate(dataloader)):net.zero_grad()opt.zero_grad()data = torch.cat(data, dim=0) data = data.cuda()z_proj = net(data)z_list = z_proj.chunk(num_patches, dim=0)z_avg = chunk_avg(z_proj, num_patches)# Contractive Lossloss_contract, _ = contractive_loss(z_list, z_avg)loss_TCR = cal_TCR(z_proj, criterion, num_patches)

这里要稍微注意一下几个变量的 shape:

  • data 被 cat 完后:(num_patches * B,C,H,W)
  • z_proj:(num_patches * B,C)
  • z_list:(num_patches,B,C)
  • z_avg:(B,C)

其中,chunk_avg 就是对来自同一张图片的不同 patch 的 embedding 求均值(\bar z):

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L67

def chunk_avg(x,n_chunks=2,normalize=False):x_list = x.chunk(n_chunks,dim=0)x = torch.stack(x_list,dim=0)if not normalize:return x.mean(0)else:return F.normalize(x.mean(0),dim=1)

loss

contractive_loss 就是计算每个 patch 的 embedding 和均值(\bar z)的余弦距离:

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L76

class Similarity_Loss(nn.Module):def __init__(self, ):super().__init__()passdef forward(self, z_list, z_avg):z_sim = 0num_patch = len(z_list)z_list = torch.stack(list(z_list), dim=0)z_avg = z_list.mean(dim=0)z_sim = 0for i in range(num_patch):z_sim += F.cosine_similarity(z_list[i], z_avg, dim=1).mean()z_sim = z_sim/num_patchz_sim_out = z_sim.clone().detach()return -z_sim, z_sim_out

TCR loss:最大化矩阵之间特征的距离,即拉远负样本(不是来自同一个样本的 patches)之间的距离

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L96

def cal_TCR(z, criterion, num_patches):z_list = z.chunk(num_patches,dim=0)loss = 0for i in range(num_patches):loss += criterion(z_list[i])loss = loss/num_patchesreturn loss

需要注意:函数输入的 z 是 z_proj,形状为(num_patches * B,C)。

所以,函数内部 z_list 的形状为(num_patches,B,C),即将数据分为了 num_patches 个组,每个组包含了来自不同图片里 patch 的 embedding。再分别对每个组求 TCR loss,最大化组内(不同图片的 patch)特征的距离。

所以,公式中的 Z 指的是一组来自不同图片里 patch 的 embedding,形状为(B,C)。

每个组内求 TCR loss 的代码按照公式计算,如下: 

图片

https://github.com/tsb0601/EMP-SSL/blob/main/loss.py#L76

class TotalCodingRate(nn.Module):def __init__(self, eps=0.01):super(TotalCodingRate, self).__init__()self.eps = epsdef compute_discrimn_loss(self, W):"""Discriminative Loss."""p, m = W.shape  #[d, B]I = torch.eye(p,device=W.device)scalar = p / (m * self.eps)logdet = torch.logdet(I + scalar * W.matmul(W.T))return logdet / 2.def forward(self,X):return - self.compute_discrimn_loss(X.T)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/38733.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

08 qt进程和网络编程(cs模型)

一 、qt进程 qt中进程最主要的任务就是启动额外应用程序 并且跟他们之间通信。进程类为QProcess 定义用途Header:#include qmake:QT += coreInherits:QIODevice//继承于IO设备类1.1 QProcess基本使用 第一步:创建一个QProcess对象 // process = new QProcess(this); //说明…

资讯速递 | ArkUI-X 预览版已正式开源!

OpenHarmony项目群技术指导委员会(以下简称“TSC”)-跨平台应用开发框架TSG所孵化项目 —— ArkUI-X,近期已正式开源 ,开发者基于一套主代码,就可以将在OpenHarmony上开发的精美、高性能应用同时运行在Android、iOS等其…

LNMP环境搭建wordpress以及跳转后台报404解决

基于上文配置好的LNMP环境继续搭建wordpress 目录 一.到官网下载tar.gz包,并上传到Linux上,也可以通过复制链接地址进行下载 二. 将wordpress中的所有文件移动到你nginx.conf中指定目录中 三.为wordpress配置数据库 四.到浏览器进行注册 1.刚开始…

maven编译始终提示无效的目标发行版的解决方法

摘自个人印象笔记2021-05-07:https://app.yinxiang.com/fx/55e1d5f4-aeea-446a-a768-0f1a48195f5b(图显示不完整可查看原笔记内容)1:确保IDE中的编译版本正确 在idea中,主要看项目属性中和setting的java compiler中对应的jdk版本是否正确&…

好用的安卓手机投屏到mac分享

工具推荐:scrcpy github地址:https://github.com/Genymobile/scrcpy/tree/master mac使用方式 安装环境,打开terminal,执行以下命令,没有brew的先安装brew brew install scrcpy brew install android-platform-too…

学习 Iterator 迭代器

今天看到一个面试题, 让下面解构赋值成立。 let [a,b] {a:1,b:2} 如果我们直接在浏览器输出这行代码,会直接报错,说是 {a:1,b:2} 不能迭代。 看了es6文档后,具有迭代器的就一下几种类型,没有Object类型,…

404. 左叶子之和

给定二叉树的根节点 root ,返回所有左叶子之和。 示例 1: 输入: root [3,9,20,null,null,15,7] 输出: 24 解释: 在这个二叉树中,有两个左叶子,分别是 9 和 15,所以返回 24示例 2: 输入: root [1] 输出: 0提示: 节点…

【NetCore】09-中间件

文章目录 中间件:掌控请求处理过程的关键1. 中间件1.1 中间件工作原理1.2 中间件核心对象 2.异常处理中间件:区分真异常和逻辑异常2.1 处理异常的方式2.1.1 日常错误处理--定义错误页的方法2.1.2 使用代理方法处理异常2.1.3 异常过滤器 IExceptionFilter2.1.4 特性过…

go web框架 gin-gonic源码解读02————router

go web框架 gin-gonic源码解读02————router 本来想先写context,但是发现context能简单讲讲的东西不多,就准备直接和router合在一起讲好了 router是web服务的路由,是指讲来自客户端的http请求与服务器端的处理逻辑或者资源相映射的机制。&…

react实现对数组做增删改操作自定义hook

需求 实现对数组的增删改操作。 实现 import { useState } from react;const useArray (currList) > {const [list, setList] useState(currList);// 增const addItem (item) > {setList([...list, item]);};// 删const removeItem (idx) > {const _arr [...l…

实战指南,SpringBoot + Mybatis 如何对接多数据源

系列文章目录 MyBatis缓存原理 Mybatis plugin 的使用及原理 MyBatisSpringboot 启动到SQL执行全流程 数据库操作不再困难,MyBatis动态Sql标签解析 从零开始,手把手教你搭建Spring Boot后台工程并说明 Spring框架与SpringBoot的关联与区别 Spring监听器…

轻松解决docker容器启动闪退

docker run -p 3306:3306 --name mysql8 \ -v /usr/local/mysql/log:/var/log/mysql \ -v /usr/local/mysql/data:/var/lib/mysql \ -v /usr/local/mysql/conf:/etc/mysql \ -e MYSQL_ROOT_PASSWORD666 -d mysql:8.0.32执行这个命令的时候闪退,其实这个是命令是对你…

[cv] stable diffusion——2、公式

背景: 在图像生成领域中,最常见的生成模型是GAN和VAE。然而,在2020年,提出了一种新的模型,即DDPM(Denoising Diffusion Probabilistic Model),也被称为扩散模型(Diffusi…

基于eBPF技术构建一种应用层网络管控解决方案

引言 随着网络应用的不断发展,在linux系统中对应用层网络管控的需求也日益增加,而传统的iptables、firewalld等工具难以针对应用层进行网络管控。因此需要一种创新的解决方案来提升网络应用的可管理性。 本文将探讨如何使用eBPF技术构建一种应用层网络…

【CSS】禁用元素鼠标事件(例如实现元素禁用效果)

文章目录 基本用法 基本用法 pointer-events 属性指定在什么情况下 (如果有) 某个特定的图形元素可以成为鼠标事件。实际运用中可以通过对auto 和none动态控制,来动态实现元素的禁用效果。 属性描述auto与pointer-events属性未指定时的表现效果相同,对…

【笔试题心得】排序算法总结整理

排序算法汇总 常用十大排序算法_calm_G的博客-CSDN博客 以下动图参考 十大经典排序算法 Python 版实现(附动图演示) - 知乎 冒泡排序 排序过程如下图所示: 比较相邻的元素。如果第一个比第二个大,就交换他们两个。对每一对相邻…

【LeetCode-简单】剑指 Offer 29. 顺时针打印矩阵(详解)

题目 输入一个矩阵,按照从外向里以顺时针的顺序依次打印出每一个数字。 示例 1: 输入:matrix [[1,2,3],[4,5,6],[7,8,9]] 输出:[1,2,3,6,9,8,7,4,5]示例 2: 输入:matrix [[1,2,3,4],[5,6,7,8],[9,10,1…

互联网发展历程:速度与效率,交换机的登场

互联网的演进就像一场追求速度与效率的竞赛,每一次的技术升级都为我们带来更快、更高效的网络体验。然而,在网络的初期阶段,人们面临着数据传输速度不够快的问题。一项关键的技术应运而生,那就是“交换机”。 速度不足的困境&…

CloudEvents—云原生事件规范

我们的系统中或多或少都会用到如下两类业务技术: 异步任务,用于降低接口时延或削峰,提升用户体验,降低系统并发压力;通知类RPC,用于微服务间状态变更,用户行为的联动等场景; 以上两种…

Go和Java实现解释器模式

Go和Java实现解释器模式 下面通过一个四则运算来说明解释器模式的使用。 1、解释器模式 解释器模式提供了评估语言的语法或表达式的方式,它属于行为型模式。这种模式实现了一个表达式接口,该接口 解释一个特定的上下文。这种模式被用在 SQL 解析、符…