浅谈 EMP-SSL + 代码解读:自监督对比学习的一种极简主义风

论文链接:https://arxiv.org/pdf/2304.03977.pdf

代码:https://github.com/tsb0601/EMP-SSL

其他学习链接:突破自监督学习效率极限!马毅、LeCun联合发布EMP-SSL:无需花哨trick,30个epoch即可实现SOTA


主要思想

如图,一张图片裁剪成不同的 patch,对不同的 patch 做数据增强,分别输入 encoder,得到多个 embedding,对它们求均值,得到 \bar z 作为这张图片的 embedding。最后,拉近每个 patch 的 embedding 和图片的 embedding(\bar z)之间的余弦距离;再用 Total Coding Rate(TCR) 防止坍塌(即 encoder 对所有输入都输出相同的 embedding)

图片

图片

Total Coding Rate(TCR)

公式如下:

图片

其中,det 表示求矩阵的行列式,d 是 feature vector 的 dimension,b 是 batch size

查了查该公式的含义:expand all features of Z as large as possible,即尽可能拉远矩阵中特征之间的距离。

源自 PPT 第 24 页:

https://s3.amazonaws.com/sf-web-assets-prod/wp-content/uploads/2021/06/15175515/Deep_Networks_from_First_Principles.pdf

至于为什么最大化该公式的值就可以拉远矩阵中特征之间的距离,这背后的数学原理真难啃啊 /(ㄒoㄒ)/~~


核心代码解读

数据处理

https://github.com/tsb0601/EMP-SSL/blob/main/dataset/aug.py#L116C1-L138C27

class ContrastiveLearningViewGenerator(object):def __init__(self, num_patch = 4):self.num_patch = num_patchdef __call__(self, x):normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])aug_transform = transforms.Compose([transforms.RandomResizedCrop(32,scale=(0.25, 0.25), ratio=(1,1)),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8),transforms.RandomGrayscale(p=0.2),GBlur(p=0.1),transforms.RandomApply([Solarization()], p=0.1),transforms.ToTensor(),  normalize])augmented_x = [aug_transform(x) for i in range(self.num_patch)]return augmented_x

由此看出返回的 数据 为:长度为 num_patches 个 tensor 的列表。其中,每个 tensor 的 shape 为 (B, C, H, W)。

主函数

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L148C9-L162C63

for step, (data, label) in tqdm(enumerate(dataloader)):net.zero_grad()opt.zero_grad()data = torch.cat(data, dim=0) data = data.cuda()z_proj = net(data)z_list = z_proj.chunk(num_patches, dim=0)z_avg = chunk_avg(z_proj, num_patches)# Contractive Lossloss_contract, _ = contractive_loss(z_list, z_avg)loss_TCR = cal_TCR(z_proj, criterion, num_patches)

这里要稍微注意一下几个变量的 shape:

  • data 被 cat 完后:(num_patches * B,C,H,W)
  • z_proj:(num_patches * B,C)
  • z_list:(num_patches,B,C)
  • z_avg:(B,C)

其中,chunk_avg 就是对来自同一张图片的不同 patch 的 embedding 求均值(\bar z):

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L67

def chunk_avg(x,n_chunks=2,normalize=False):x_list = x.chunk(n_chunks,dim=0)x = torch.stack(x_list,dim=0)if not normalize:return x.mean(0)else:return F.normalize(x.mean(0),dim=1)

loss

contractive_loss 就是计算每个 patch 的 embedding 和均值(\bar z)的余弦距离:

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L76

class Similarity_Loss(nn.Module):def __init__(self, ):super().__init__()passdef forward(self, z_list, z_avg):z_sim = 0num_patch = len(z_list)z_list = torch.stack(list(z_list), dim=0)z_avg = z_list.mean(dim=0)z_sim = 0for i in range(num_patch):z_sim += F.cosine_similarity(z_list[i], z_avg, dim=1).mean()z_sim = z_sim/num_patchz_sim_out = z_sim.clone().detach()return -z_sim, z_sim_out

TCR loss:最大化矩阵之间特征的距离,即拉远负样本(不是来自同一个样本的 patches)之间的距离

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L96

def cal_TCR(z, criterion, num_patches):z_list = z.chunk(num_patches,dim=0)loss = 0for i in range(num_patches):loss += criterion(z_list[i])loss = loss/num_patchesreturn loss

需要注意:函数输入的 z 是 z_proj,形状为(num_patches * B,C)。

所以,函数内部 z_list 的形状为(num_patches,B,C),即将数据分为了 num_patches 个组,每个组包含了来自不同图片里 patch 的 embedding。再分别对每个组求 TCR loss,最大化组内(不同图片的 patch)特征的距离。

所以,公式中的 Z 指的是一组来自不同图片里 patch 的 embedding,形状为(B,C)。

每个组内求 TCR loss 的代码按照公式计算,如下: 

图片

https://github.com/tsb0601/EMP-SSL/blob/main/loss.py#L76

class TotalCodingRate(nn.Module):def __init__(self, eps=0.01):super(TotalCodingRate, self).__init__()self.eps = epsdef compute_discrimn_loss(self, W):"""Discriminative Loss."""p, m = W.shape  #[d, B]I = torch.eye(p,device=W.device)scalar = p / (m * self.eps)logdet = torch.logdet(I + scalar * W.matmul(W.T))return logdet / 2.def forward(self,X):return - self.compute_discrimn_loss(X.T)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/38733.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

资讯速递 | ArkUI-X 预览版已正式开源!

OpenHarmony项目群技术指导委员会(以下简称“TSC”)-跨平台应用开发框架TSG所孵化项目 —— ArkUI-X,近期已正式开源 ,开发者基于一套主代码,就可以将在OpenHarmony上开发的精美、高性能应用同时运行在Android、iOS等其…

LNMP环境搭建wordpress以及跳转后台报404解决

基于上文配置好的LNMP环境继续搭建wordpress 目录 一.到官网下载tar.gz包,并上传到Linux上,也可以通过复制链接地址进行下载 二. 将wordpress中的所有文件移动到你nginx.conf中指定目录中 三.为wordpress配置数据库 四.到浏览器进行注册 1.刚开始…

好用的安卓手机投屏到mac分享

工具推荐:scrcpy github地址:https://github.com/Genymobile/scrcpy/tree/master mac使用方式 安装环境,打开terminal,执行以下命令,没有brew的先安装brew brew install scrcpy brew install android-platform-too…

学习 Iterator 迭代器

今天看到一个面试题, 让下面解构赋值成立。 let [a,b] {a:1,b:2} 如果我们直接在浏览器输出这行代码,会直接报错,说是 {a:1,b:2} 不能迭代。 看了es6文档后,具有迭代器的就一下几种类型,没有Object类型,…

404. 左叶子之和

给定二叉树的根节点 root ,返回所有左叶子之和。 示例 1: 输入: root [3,9,20,null,null,15,7] 输出: 24 解释: 在这个二叉树中,有两个左叶子,分别是 9 和 15,所以返回 24示例 2: 输入: root [1] 输出: 0提示: 节点…

【NetCore】09-中间件

文章目录 中间件:掌控请求处理过程的关键1. 中间件1.1 中间件工作原理1.2 中间件核心对象 2.异常处理中间件:区分真异常和逻辑异常2.1 处理异常的方式2.1.1 日常错误处理--定义错误页的方法2.1.2 使用代理方法处理异常2.1.3 异常过滤器 IExceptionFilter2.1.4 特性过…

react实现对数组做增删改操作自定义hook

需求 实现对数组的增删改操作。 实现 import { useState } from react;const useArray (currList) > {const [list, setList] useState(currList);// 增const addItem (item) > {setList([...list, item]);};// 删const removeItem (idx) > {const _arr [...l…

实战指南,SpringBoot + Mybatis 如何对接多数据源

系列文章目录 MyBatis缓存原理 Mybatis plugin 的使用及原理 MyBatisSpringboot 启动到SQL执行全流程 数据库操作不再困难,MyBatis动态Sql标签解析 从零开始,手把手教你搭建Spring Boot后台工程并说明 Spring框架与SpringBoot的关联与区别 Spring监听器…

基于eBPF技术构建一种应用层网络管控解决方案

引言 随着网络应用的不断发展,在linux系统中对应用层网络管控的需求也日益增加,而传统的iptables、firewalld等工具难以针对应用层进行网络管控。因此需要一种创新的解决方案来提升网络应用的可管理性。 本文将探讨如何使用eBPF技术构建一种应用层网络…

【CSS】禁用元素鼠标事件(例如实现元素禁用效果)

文章目录 基本用法 基本用法 pointer-events 属性指定在什么情况下 (如果有) 某个特定的图形元素可以成为鼠标事件。实际运用中可以通过对auto 和none动态控制,来动态实现元素的禁用效果。 属性描述auto与pointer-events属性未指定时的表现效果相同,对…

【笔试题心得】排序算法总结整理

排序算法汇总 常用十大排序算法_calm_G的博客-CSDN博客 以下动图参考 十大经典排序算法 Python 版实现(附动图演示) - 知乎 冒泡排序 排序过程如下图所示: 比较相邻的元素。如果第一个比第二个大,就交换他们两个。对每一对相邻…

【LeetCode-简单】剑指 Offer 29. 顺时针打印矩阵(详解)

题目 输入一个矩阵,按照从外向里以顺时针的顺序依次打印出每一个数字。 示例 1: 输入:matrix [[1,2,3],[4,5,6],[7,8,9]] 输出:[1,2,3,6,9,8,7,4,5]示例 2: 输入:matrix [[1,2,3,4],[5,6,7,8],[9,10,1…

互联网发展历程:速度与效率,交换机的登场

互联网的演进就像一场追求速度与效率的竞赛,每一次的技术升级都为我们带来更快、更高效的网络体验。然而,在网络的初期阶段,人们面临着数据传输速度不够快的问题。一项关键的技术应运而生,那就是“交换机”。 速度不足的困境&…

规划性和可扩展性,助力企业全面预算管理的推进

对于当今社会经济市场的不稳定状况和不断变化的消费者行为,企业业务也从未像今天这样不可预测过。面对变化和变革,企业需要具备规划性的预测能力,才能使得自身在竞争中保持领先地位。那些具备前瞻性的企业都尝试在现阶段通过更好的规划不断提…

基于Mysqlrouter+MHA+keepalived实现高可用半同步 MySQL Cluster项目

目录 项目名称: 基于Mysqlrouter MHA keepalived实现半同步主从复制MySQL Cluster MySQL Cluster: 项目架构图: 项目环境: 项目环境安装包: 项目描述: 项目IP地址规划: 项目步骤: 一…

windows11下配置vscode中c/c++环境

本文默认已经下载且安装好vscode,主要是解决环境变量配置以及编译task、launch文件的问题。 自己尝试过许多博客,最后还是通过这种方法配置成功了。 Linux(ubuntu 20.04)配置vscode可以直接跳转到配置task、launch文件,不需要下载mingw与配…

localhost:8080 is already in use

报错原因:本机的8080端口号已经被占用。因为机器的空闲端口号是随机分配的,而idea默认启动的端口号是8080,所以是存在这种情况。 对于这个问题,我们只需要重启idea或者修改项目的启动端口号即可。 更推荐第二种。对于修改项目启动端口号&…

ZDH-wemock模块

本次介绍基于版本v5.1.1 目录 项目源码 预览地址 安装包下载地址 wemock模块 wemock模块前端 配置首页 配置mock wemock服务 下载地址 打包 运行 效果展示 项目源码 zdh_web: https://github.com/zhaoyachao/zdh_web zdh_mock: https://github.com/zhaoyachao/z…

TCGA数据下载推荐:R语言easyTCGA包

#使用easyTCGA获取数据 #清空 rm(listls()) gc() # 安装bioconductor上面的R包 options(BioC_mirror"https://mirrors.tuna.tsinghua.edu.cn/bioconductor") if(!require("BiocManager")) install.packages("BiocManager") if(!require("TC…

怎样让音频速度变慢?请跟随以下方法进行操作

怎样让音频速度变慢?在会议录音过程中,经常会遇到主讲人语速过快,导致我们无法清晰听到对方说的内容。如果我们能够减慢音频速度,就能更好地记录对方的讲话内容。此外,在听到快速播放的外语或方言时,我们也…