对比损失的PyTorch实现详解

对比损失的PyTorch实现详解

本文以SiT代码中对比损失的实现为例作介绍。

论文:https://arxiv.org/abs/2104.03602
代码:https://github.com/Sara-Ahmed/SiT

对比损失简介

作为一种经典的自监督损失,对比损失就是对一张原图像做不同的图像扩增方法,得到来自同一原图的两张输入图像,由于图像扩增不会改变图像本身的语义,因此,认为这两张来自同一原图的输入图像的特征表示应该越相似越好(通常用余弦相似度来进行距离测度),而来自不同原图像的输入图像应该越远离越好。来自同一原图的输入图像可做正样本,同一个batch内的不同输入图像可用作负样本。如下图所示(粗箭头向上表示相似度越高越好,向下表示越低越好)。
在这里插入图片描述

论文中的公式

lcontrxi,xj(W)=esim(SiTcontr(xi),SiTcontr(xj))/τ∑k=1,k≠i2Nesim(SiTcontr(xi),SiTcontr(xk))/τ(1)l^{x_i,x_j}_{contr}(W)=\frac{e^{sim(SiT_{contr}(x_i),SiT_{contr}(x_j))/\tau}}{\sum_{k=1,k\ne i}^{2N}e^{sim(SiT_{contr}(x_i),SiT_{contr}(x_k))/\tau}} \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (1) lcontrxi,xj(W)=k=1,k=i2Nesim(SiTcontr(xi),SiTcontr(xk))/τesim(SiTcontr(xi),SiTcontr(xj))/τ                  (1)

L=−1N∑j=1Nloglxj,xjˉ(W)(2)\mathcal{L}=-\frac{1}{N}\sum_{j=1}^Nlogl^{x_j,x_{\bar{j}}}(W) \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (2) L=N1j=1Nloglxj,xjˉ(W)                  (2)

SiT论文中的对比损失公式如上所示。其中xix_ixixjx_jxj分别表示两个不同的输入图像,sim(⋅,⋅)sim(\cdot,\cdot)sim(,)表示余弦相似度,即归一化之后的点积,τ\tauτ是超参数温度,xjx_jxjxjˉx_{\bar{j}}xjˉ是来自同一原图的两种不同数据增强的输入图像, SiTcontr(⋅)SiT_{contr}(\cdot)SiTcontr() 表示从对比头中得到的图像表示,没看过原文的话,就直接理解为输入图像经过一系列神经网络,得到一个dimdimdim 维度的特征向量作为图像的特征表示,网络不是本文的重点,重点是怎样根据得到的特征向量计算对比损失

与最近很火的infoNCE对比损失基本一样,只是写法不同。

代码实现

class ContrastiveLoss(nn.Module):def __init__(self, batch_size, device='cuda', temperature=0.5):super().__init__()self.batch_size = batch_sizeself.register_buffer("temperature", torch.tensor(temperature).to(device))			# 超参数 温度self.register_buffer("negatives_mask", (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool).to(device)).float())		# 主对角线为0,其余位置全为1的mask矩阵def forward(self, emb_i, emb_j):		# emb_i, emb_j 是来自同一图像的两种不同的预处理方法得到z_i = F.normalize(emb_i, dim=1)     # (bs, dim)  --->  (bs, dim)z_j = F.normalize(emb_j, dim=1)     # (bs, dim)  --->  (bs, dim)representations = torch.cat([z_i, z_j], dim=0)          # repre: (2*bs, dim)similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)      # simi_mat: (2*bs, 2*bs)sim_ij = torch.diag(similarity_matrix, self.batch_size)         # bssim_ji = torch.diag(similarity_matrix, -self.batch_size)        # bspositives = torch.cat([sim_ij, sim_ji], dim=0)                  # 2*bsnominator = torch.exp(positives / self.temperature)             # 2*bsdenominator = self.negatives_mask * torch.exp(similarity_matrix / self.temperature)             # 2*bs, 2*bsloss_partial = -torch.log(nominator / torch.sum(denominator, dim=1))        # 2*bsloss = torch.sum(loss_partial) / (2 * self.batch_size)return loss

以下是SiT论文的对比损失代码实现,笔者已经将debug过程中得到的张量形状在注释中标注了出来,供大家参考,其中dim是得到的特征向量的维度,bs是批尺寸batch size。

笔者简单画了一张similarity_matrix的图示来说明整个过程。本图以bs==4为例,a,b,c,da,b,c,da,b,c,d分别代表同一个batch内的不同样本,下表0和1表示两种不同的图像扩增方法。图中每个方格则是对应行列的图像特征(dim维的向量)表示计算相似度的结果值。

在这里插入图片描述

  1. emb_i,emb_j 是来自同一图像的两种不同的预处理方法得到的输入图像的特征表示。首先是通过F.normalize()emb_iemb_j进行归一化。

  2. 然后将二者拼接起来的到维度为2*bs的representations。再将representations分别转换为列向量和行向量计算相似度矩阵similarity_matrix(见图)。

  3. 在通过偏移的对角线(图中蓝线)的到sim_ijsim_ji,并拼接的到positives。请注意蓝线对应的行列坐标,分别是a0,a1a_0,a_1a0,a1b0,b1b_0,b_1b0,b1等,即蓝线对应的网格即是来自同一张原图的不同处理的输入图像。这在损失的设计中即是我们的正样本。

  4. 然后nominator(分子)即可根据公式计算的到。

  5. 而在计算denominator时需注意要乘上self.negatives_mask。该变量在__init__中定义,是对2*bs的方针对角阵取反,即主对角线全是0,其余位置全是1 。这是为了在负样本中屏蔽自己与自己的相似度结果(图中红线),即使得similarity_matrix的主对角钱全为0。因为自己与自己的相似度肯定是1,加入到计算中没有意义。

  6. 再到后面loss_partial的计算(第22行)其实是计算出公式(1),torch.sum()计算的是(1)中分母上的∑\sum符号。

  7. 第23行就是计算公式(2),其中与公式相比分母上多了除了个2,是因为本实现为了方便将similarity_matrix的维度扩展为2*bs。即相当于将公式(2)中的lcontrxj,xjˉl_{contr}^{x_j,x_{\bar{j}}}lcontrxj,xjˉlcontrxjˉ,xjl_{contr}^{x_{\bar{j}},x_j}lcontrxjˉ,xj 分别计算了一遍。所以要多除个2。

自行验证

大家可以将上面的ContrastiveLoss类复制到自己的测试的文件中,并构造几个输入进行测试,打印中间结果,验证自己是否真正地理解了对比损失的代码实现计算过程。

loss_func = losses.ContrastiveLoss(batch_size=4)
emb_i = torch.rand(4, 512).cuda()
emb_j = torch.rand(4, 512).cuda()loss_contra = loss_func(emb_i, emb_j)
print(loss_contra)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532773.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

android 融云浏览大图,融云 Android sdk kit 头像昵称更新机制

先申明笔者的实现方式不是唯一 也不一定是最优化的方案 如果您看到此篇博文 有不同看法 或者 更好的优化 更高的效率 欢迎在评论发表意见 融云官网点我融云头像机制相关视频详解首先跟大家说一下 kit 跟 lib 的头像机制 kit 是已经包含融云已经给开发者定制好的界面 诸如 会话界…

Linux中的awk、sed、grep及正则表达式详解

Linux中的awk、sed、grep及正则表达式详解 简介 awk、sed和grep是Linux中文本操作的三大利器。 其中awk适用于取列,sed适用于取行,grep适用于过滤。 正则表达式 首先我们来介绍一下正则表达式,正则表达式(regular expression)描述了一种…

android聚焦时如何给控件加边框,edittext设置获得焦点时的边框颜色

第一步:为了更好的比较,准备两个一模一样的EditText(当Activity启动时,焦点会在第一个EditText上,如果你不希望这样只需要写一个高度和宽带为0的EditText即可避免,这里就不这么做了),代码如下:a…

xargs 命令教程

xargs 命令教程 转自:http://www.ruanyifeng.com/blog/2019/08/xargs-tutorial.html 作者: 阮一峰 日期: 2019年8月 8日 xargs是 Unix 系统的一个很有用的命令,但是常常被忽视,很多人不了解它的用法。 本文介绍如…

android strictmode有什么作用,Android 性能优化 之 StrictMode

8种机械键盘轴体对比本人程序员,要买一个写代码的键盘,请问红轴和茶轴怎么选?StrictMode概述StrictMode 是用来检测程序中违例情况的开发者工具。使用StrictMode,系统检测出主线程违例的情况会做出相应的反应,如日志打…

curl 的用法指南

curl 的用法指南 转自:http://www.ruanyifeng.com/blog/2019/09/curl-reference.html 作者: 阮一峰 日期: 2019年9月 5日 简介 curl 是常用的命令行工具,用来请求 Web 服务器。它的名字就是客户端(client&#xf…

怎么在html显示已登录状态,jQuery Ajax 实现在html页面实时显示用户登录状态

当网站是全静态的html页面时,而又希望网站会员在登录之后并在所有页面头部显示登录状态,如用户名等,如果未登录就是未登录状态,下面给大家来分享实现的方法。一、在html静态页面中加入div,并指定ID如:二、新…

xpwifi热点设置android,教你在XP电脑中开启设置WiFi热点使用的步骤

对于系统中网络的连接问题是最重要的,那在处理不同的错误的情况中,对于无线网络的设置也就是我们说的WiFi的使用也是会遇到问题的,那在操作的时候对于电脑中是怎么实现设置WiFi热点的的,对于这个问题今天小编就来跟大家分享一下教…

C/C++ 指针详解

指针详解 参考视频:https://www.bilibili.com/video/BV1bo4y1Z7xf/,感谢Bilibilifengmuzi2003的搬运翻译及后续勘误,也感谢已故原作者Harsha Suryanarayana的讲解,RIP。 学习完之后,回看找特定的知识点,善…

android双联动列表,Android Fragment实现列表和内容联动

在平板上经常能看到这种的情况:左边是一个列表,右边是列表项对应的内容,当点击某一个列表时,右边内容区也会随之改变。下面使用fragment简单的demo:思路:在mainactivity定义一个回调接口,并在列…

android模拟器太卡,安卓模拟器安装之后太卡怎么解决

用安卓模拟器玩游戏原理就是在电脑上安装了一部手机,如果你的电脑配置不是非常高,能不卡顿吗?遇到卡顿怎么解决?1、安装最新版本的显卡驱动。逍遥模拟器对于显卡的性能要求很高,因此升级至最新版本的显卡驱动,是确保逍遥模拟器流…

编程环境中Runtime(运行时)的三个含义

编程环境中Runtime(运行时)的三个含义 转自:https://www.zhihu.com/question/20607178 知乎答主doodlewind 三个含义 实际上编程语境中的 runtime 至少有三个含义,分别是: 指「程序运行的时候」,即程序…

非常不错的一款html5【404页面】,不含js脚本可以左右摆动,原生JavaScript实现日历功能代码实例(无引用Jq)...

这篇文章主要介绍了原生JavaScript实现日历功能代码实例(无引用Jq),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下成品显示,可左右切换月份html 代码移动端日历日一二三四五六css代码*{margin: 0;pa…

计算机应用与基础实践怎么考,自考计算机基础应用科目笔试和实践性考试怎么考...

自考计算机基础应用科目笔试和实践性考试怎么考? 报考自考的考生有些专业的考生会在自己的课程科目中发现计算机基础应用不仅有理论知识考试还有实践性考试,那么自考计算机基础应用科目的笔试和实践性考试怎么考?自考计算机基础应用科目笔试怎…

14 [虚拟化] 虚存抽象;Linux进程的地址空间

14 [虚拟化] 虚存抽象;Linux进程的地址空间 南京大学操作系统课蒋炎岩老师网络课程笔记。 视频:https://www.bilibili.com/video/BV1N741177F5?p14 讲义:http://jyywiki.cn/OS/2021/slides/10.slides#/ 本讲概述 程序 状态机;…

Ubuntu 18.04 安装OpenCV C++

Ubuntu 18.04 安装OpenCV C 构建并安装 仅构建核心模块 # 更新并安装依赖 # 更新并安装依赖 sudo apt update && sudo apt install -y cmake g wget unzip# 下载并解压包 wget -O opencv.zip https://github.com/opencv/opencv/archive/master.zip unzip opencv.zip…

html计算x的y,HTML5画布:旋转时计算x,y点

我开发了一个HTML5 Canvas应用程序,它涉及到读取一个xml文件,该文件描述了需要在画布上绘制的箭头,直形和其他形状的位置。的XML布局的HTML5画布:旋转时计算x,y点实施例:如果对象被旋转它涉及计算一个点的位…

(2021) 20 [虚拟化] 进程调度

(2021) 20 [虚拟化] 进程调度 南京大学操作系统课蒋炎岩老师网络课程笔记。 视频:https://www.bilibili.com/video/BV1HN41197Ko?p20 讲义:http://jyywiki.cn/OS/2021/slides/11.slides#/ 背景 — 机制与策略分离 机制:一个通用的、可定制…

局域网中计算机网络密码查看,Win10怎么查看电脑上已知的wifi网络密码

方法一:网络和共享中心查询1、在Windows 10桌面最左下角的【Windwos开始图标上右键】,在弹出的菜单中点击打开【网络连接】,如下图所示。2、在打开的网络连接设置中,双击已经连接的【无线网络名称】,在弹出的【WLAN状态…

(2021) 22 [持久化] 1-Bit的存储

(2021) 22 [持久化] 1-Bit的存储 南京大学操作系统课蒋炎岩老师网络课程笔记。 视频:https://www.bilibili.com/video/BV1HN41197Ko?p22 讲义:http://jyywiki.cn/OS/2021/slides/12.slides#/ 背景 回顾 操作系统是什么?一组对象 一组API…