(表征学习论文阅读)A Simple Framework for Contrastive Learning of Visual Representations

Chen T, Kornblith S, Norouzi M, et al. A simple framework for contrastive learning of visual representations[C]//International conference on machine learning. PMLR, 2020: 1597-1607.

1. 前言

本文作者为了了解对比学习是如何学习到有效的表征,对本文所提出的三大组件进行了全面的研究:

  1. 各种数据增强手段的组合在表征学习中起到了重要作用;
  2. 在表征和对比损失之间引入非线性变换能够有效提高表征质量;
  3. 对比学习相较于监督学习需要更大的batch size和更多的训练步数。

在没有人类标注或者监督的情况下学习数据的有效表征是一个长期存在的难题,目前的主要工作可以分为两类:

  1. 基于生成模型的方法
    例如VQ-VAE,MAE,BERT
  2. 基于判别模型的方法
    例如MoCo,CLIP

2. 方法

本文提出了一个框架SimCLR,通过最大化同一数据的不同数据增强处理后的两个视角之间的相似度来学习有效表征。
在这里插入图片描述

  1. 如图所示,本文首先将数据 x x x进行两个不同的增强,这里作者使用了三种简单的数据增强方法:随机裁剪后再调整到原始大小、随机颜色失真、高斯模糊。
  2. f ( ∙ ) f(\bullet) f()代表编码器,这里作者使用的是同一个编码器来对两个视角数据进行编码
  3. 最后编码器输出的结果通过非线性变换 g ( ∙ ) g(\bullet) g()得到 z i z_i zi z j z_j zj,两个向量构成了一组正例,进行相似度计算,也就是简单的单位向量内积计算出余弦相似度。目标就是最大化两者的余弦相似度。同时,一个batch中其他的数据构成了负例,最小化与负例的相似度。注意最终训练完成的编码器我们是需要舍弃掉非线性变换的。
    本文使用的损失函数就是最基本的InfoNCE损失,具体可以参考我的另一篇讲解InfoNCE的博文。
    在这里插入图片描述
    在这里插入图片描述

3. 代码

这里仅提供文章提到的两个点的代码:

  1. 数据增强
    高斯模糊
import numpy as np
import torch
from torch import nn
from torchvision.transforms import transformsnp.random.seed(0)class GaussianBlur(object):"""blur a single image on CPU"""def __init__(self, kernel_size):radias = kernel_size // 2kernel_size = radias * 2 + 1self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1),stride=1, padding=0, bias=False, groups=3)self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size),stride=1, padding=0, bias=False, groups=3)self.k = kernel_sizeself.r = radiasself.blur = nn.Sequential(nn.ReflectionPad2d(radias),self.blur_h,self.blur_v)self.pil_to_tensor = transforms.ToTensor()self.tensor_to_pil = transforms.ToPILImage()def __call__(self, img):img = self.pil_to_tensor(img).unsqueeze(0)sigma = np.random.uniform(0.1, 2.0)x = np.arange(-self.r, self.r + 1)x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))x = x / x.sum()x = torch.from_numpy(x).view(1, -1).repeat(3, 1)self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))with torch.no_grad():img = self.blur(img)img = img.squeeze()img = self.tensor_to_pil(img)return img

组合各类增强手段

class ContrastiveLearningDataset:def __init__(self, root_folder=r"D:\pyproject\representation_learning\data"):self.root_folder = root_folder@staticmethoddef get_simclr_pipeline_transform(size, s=1):"""Return a set of data augmentation transformations as described in the SimCLR paper."""color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size),transforms.RandomHorizontalFlip(),transforms.RandomApply([color_jitter], p=0.8),transforms.RandomGrayscale(p=0.2),GaussianBlur(kernel_size=int(0.1 * size)),transforms.ToTensor()])return data_transformsdef get_dataset(self, name, n_views):valid_datasets = {'cifar10': lambda: datasets.CIFAR10(self.root_folder, train=True,transform=ContrastiveLearningViewGenerator(self.get_simclr_pipeline_transform(32),n_views),download=True),'stl10': lambda: datasets.STL10(self.root_folder, split='unlabeled',transform=ContrastiveLearningViewGenerator(self.get_simclr_pipeline_transform(96),n_views),download=True)}try:dataset_fn = valid_datasets[name]except KeyError:raise InvalidDatasetSelection()else:return dataset_fn()
  1. 非线性变换
class ResNetSimCLR(nn.Module):def __init__(self, base_model, out_dim):super(ResNetSimCLR, self).__init__()self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),"resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}self.backbone = self._get_basemodel(base_model)dim_mlp = self.backbone.fc.in_features# add mlp projection head# 修改resnet最后一层的全连接层即可self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)def _get_basemodel(self, model_name):try:model = self.resnet_dict[model_name]except KeyError:raise InvalidBackboneError("Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")else:return modeldef forward(self, x):return self.backbone(x)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/800341.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode题练习与总结:螺旋矩阵Ⅱ--59

一、题目描述 给你一个正整数 n ,生成一个包含 1 到 n^2 所有元素,且元素按顺时针顺序螺旋排列的 n x n 正方形矩阵 matrix 。 示例 1: 输入:n 3 输出:[[1,2,3],[8,9,4],[7,6,5]]示例 2: 输入&#xff1…

VMware启动显示“打开虚拟机时出错: 获取该虚拟机的所有权失败”

提示框(忘截图了)里提示目录C:\Users\mosep\Documents\Virtual Machines\VM-Win10 x64\中的某个文件(在我这里好像是VM-Win10 x64.vmx,VM-Win10 x64是我给虚拟机取的名字)在被使用中。 找到这个目录,删除.…

【面试题】如何在亿级别用户中检查用户名是否存在?

前言 不知道大家有没有留意过,在使用一些app或者网站注册的时候,提示你用户名已经被占用了,比如我们熟知的《英雄联盟》有些人不知道取啥名字,干脆就叫“不知道取啥名”。 但是有这样困惑的可不止他一个,于是就出现了“…

如何从应用商店Microsoft Store免费下载安装HEVC视频扩展插件

在电脑上打开一张HEIC类型的图片提示缺少HEVC解码器,无法打开查看,现象如下: 这种情况一般会提示我们需要下载安装HEVC解码器,点击“立即下载并安装”会跳转到应用商店,但是我们发现需要付费7元才能下载安装 免费安装…

6. Z 字形变换(Java)

目录 题目描述:输入:输出:代码实现: 题目描述: 将一个给定字符串 s 根据给定的行数 numRows ,以从上往下、从左到右进行 Z 字形排列。 比如输入字符串为 “PAYPALISHIRING” 行数为 3 时,排列如…

mac | Windows 本地部署 Seata2.0.0,Nacos 作为配置中心、注册中心,MySQL 存储信息

1、本人环境介绍 系统 macOS sonama 14.1.1 MySQL 8.2.0 (官方默认是5.7版本) Seata 2.0.0 Nacos 2.2.3 2、下载&数据库初始化 默认你已经有 Nacos、MySQL,如果没有 Nacos 请参考我的文章 : Docker 部署 Nacos(单机…

订阅edk2社区邮件列表

给社区发邮件步骤 UEFI订阅邮件列表 开发者订阅邮箱 develedk2.groups.io | Home 点击Join This Group,按照步骤填写自己邮箱地址(该地址是edk2,发送邮件到该邮箱的地址) 自己邮箱确认就可以自动收到邮件了 比如:

虚拟串口工具vspd.exe的使用

关于vspd虚拟串口工具的获取: 工具下载 (1、 虚拟串口工具官方下载链接 2、通过本文资源下载)工具按照步骤(过于简单,此处省略) 关于vspd虚拟串口工具的使用: 打开软件,如下&…

Ethernet 汇总

Ethernet系统 硬件最小系统 CPU:可以是复杂的芯片,也可以是小的单片机DMA:用于减轻CPU负担,搬运数据系统Memory<->FIFOMAC:可以集成在芯片里面,用于CPU和PHY之间的通信MII:接口用于MAC和PHY的通信,包括控制MDIO和数据DataPHY:模拟器件,最底层,数据收发源头软件…

本地电脑渲染不行怎么解决?自助式渲染助你渲染无忧

有时候&#xff0c;即使购买了昂贵的新电脑&#xff0c;我们也可能会遇到渲染速度缓慢、画质不佳或渲染失败等问题。这些问题可能由多种因素引起。针对该问题&#xff0c;为大家推荐了自助式的渲染&#xff0c;解决你本地电脑渲染不佳问题。 电脑渲染不行原因 新电脑渲染效果不…

为什么企业推广需要品牌故事?媒介盒子分享

从时代来看&#xff0c;我们正处“信息超载的商业时代”&#xff0c;品牌传播面临着“产品同质化”和“信息超载化”的困境。近日小米SU7的出圈除了汽车本身的话题度外&#xff0c;离不开小米的品牌故事、创始人雷军的话题等等。今天媒介盒子就来和大家聊聊&#xff1a;为什么企…

postgresql发布和订阅

一、发布订阅介绍 发布和订阅使用了pg的逻辑复制的功能&#xff0c;通过发布端创建publication与表绑定&#xff0c;订阅端创建subscription同时会在发布端创建逻辑复制槽实现逻辑复制功能 逻辑复制基于 发布&#xff08;Publication&#xff09; 与 订阅&#xff08;Subscri…

【go】模板展示不同k8s命名空间的deployment

gin模板展示k8s命名空间的资源 这里学习如何在前端单页面&#xff0c;调用后端接口展示k8s的资源 技术栈 后端 -> go -> gin -> gin模板前端 -> gin模板 -> html jsk8s -> k8s-go-client &#xff0c;基本资源(deployment等) 环境 go 1.19k8s 1.23go m…

Centos7 安装GitLab

安装环境: 虚拟机:Centos7 最小安装 4核8G 下载GitLab 本次实验下载的是 gitlab-ce-14.1.0-ce.0.el7.x86_64.rpm 官网截图 清华源截图 安装包下载地址(官网;下载CE版本,EE是收费版本):https://packages.gitlab.com/gitlab/gitlab-ce国内镜像源下载地址(清华源):htt…

Linux函数学习 fork

1、Linux fork 函数 pid_t fork(void); pid_t &#xff1a; 对于子进程&#xff0c;返回0 pid_t &#xff1a; 对于父进程进程&#xff0c;返回子进程进程号 int pipe(int pipefd[2]); pipefd[0] 为读取管道 pipefd[1] 为写入管道 返回值&#xff1a;-1失败 0 成功 2、函…

springboot实现上传文件接口(简单版)

使用springboot实现一个最简单版本的上传文件接口 private String uploadPath "C:/imageFiles";RequestMapping(value "/upload", method RequestMethod.POST)private Result upload( RequestParam("modelName") String modelName,RequestPar…

HTML5+CSS3+JS小实例:圣诞按钮

实例:圣诞按钮 技术栈:HTML+CSS+JS 效果: 源码: 【HTML】 <!DOCTYPE html> <html lang="zh-CN"><head><meta charset="UTF-8"><meta name="viewport" content="width=device-width, initial-scale=1.0&…

【4月最新】低至50/年,4G 618/3年 云服务器价格即将回调 ,搭建网站 博客 Linux练习 比虚拟机方便 附阿里云 京东云 腾讯云对比表

更新日期&#xff1a;4月8日&#xff08;半年档 价格回调&#xff0c;京东云采购季持续进行&#xff09; 本文纯原创&#xff0c;侵权必究 《最新对比表》已更新在文章头部—腾讯云文档&#xff0c;文章具有时效性&#xff0c;请以腾讯文档为准&#xff01; 【腾讯文档实时更…

CorelDRAW2024全网最详细独家讲解新版本新功能

各位粉丝大家好&#xff0c;为了让大家更深入的了解CorelDRAW2024新版的各项新功能&#xff0c;我们独家邀请到了Corel中国专家名师张苏老师&#xff0c;策划并录制30分钟全中文讲解栏目&#xff01;干货满满&#xff0c;全程演示&#xff0c;一览CorelDRAW2024新版的各项新功能…

rabbitmq的介绍和交换机类型

rabbitmq的介绍和交换机类型 1.流程 首先先介绍一个简单的一个消息推送到接收的流程&#xff0c;提供一个简单的图 黄色的圈圈就是我们的消息推送服务&#xff0c;将消息推送到 中间方框里面也就是 rabbitMq的服务器&#xff0c;然后经过服务器里面的交换机、队列等各种关系…