InfoNCE Loss公式及源码理解

InfoNCE Loss公式及源码理解–从交叉熵损失谈起

当谈论到信息论中的损失函数时,InfoNCE(Noise Contrastive Estimation)和交叉熵损失都是两个关键的概念。它们不仅在衡量概率分布之间的差异方面发挥着重要作用,而且在深度学习的自监督学习领域扮演着重要角色。虽然它们的形式和应用环境有所不同,但是我们可以发现它们之间存在着微妙的联系。

交叉熵损失作为衡量两个概率分布之间距离的指标,在分类任务和神经网络训练中广泛使用。而InfoNCE Loss,则是针对自监督学习任务中特征学习的一种损失函数。它通过比较正样本和负样本的相似性来学习模型参数,从而提高特征的区分度。

在这篇博客中,我们将深入探讨交叉熵损失和InfoNCE之间的联系,探究它们在信息论和深度学习中的联系与异同。我们将分析两者的数学形式、应用领域以及它们之间可能的内在关系,以期对这两个重要概念有更深入的理解。

InfoNCE

InfoNCE Loss(Noise Contrastive Estimation Loss)是一种用于自监督学习的损失函数,通常用于学习特征表示或者表征学习。它基于信息论的思想,通过对比正样本和负样本的相似性来学习模型参数。

公式介绍

InfoNCE Loss的公式如下:
InfoNCE Loss = − 1 N ∑ i = 1 N log ⁡ ( exp ⁡ ( q i ⋅ k i + τ ) ∑ j = 1 N exp ⁡ ( q i ⋅ k j − τ ) ) \text{InfoNCE Loss} = -\frac{1}{N} \sum_{i=1}^{N} \log \left( \frac{\exp \left( \frac{q_i \cdot k_{i^+}}{\tau} \right)}{\sum_{j=1}^{N} \exp \left( \frac{q_i \cdot k_{j^-}}{\tau} \right)} \right) InfoNCE Loss=N1i=1Nlog j=1Nexp(τqikj)exp(τqiki+)
其中:

  • N N N是样本的数量
  • q i q_i qi是查询样本 i i i的编码向量
  • k i + k_{i+} ki+是与查询样本 i i i相对应的正样本的编码向量
  • k i − k_{i-} ki是与查询样本 i i i不对应的负样本的编码向量
  • τ \tau τ是温度系数,用于调节相似度得分的分布,后面会详细讨论

算法思想

从INfoNCE的公式中我们可以发现,分子只包含一对正样本,分母则包含一个batch下的 N N N个所有样本,即1个与 q i q_i qi对应的正样本和 ( N − 1 ) (N-1) (N1)个负样本,那么上述公式我们也可以简化为下述形式:
InfoNCE Loss = − 1 N ∑ i = 1 N log ⁡ A + A + + B − \text{InfoNCE Loss} = -\frac{1}{N} \sum_{i=1}^{N} \log\frac{A_+}{A_++B_-} InfoNCE Loss=N1i=1NlogA++BA+
首先,分式部分一定是介于(0,1)之间的,而log在(0,1)之间是单增的且函数值小于0
在损失优化过程中,我们希望达成的结果是 A + A_+ A+尽可能大,也就是正样本之间的距离尽可能尽,其实也隐含着与负样本之间的相似度尽可能低,距离尽可能远。从公式上来看,我们在最小化loss的过程中,需要让公式接近0,也就是让log内部的分式接近1,要达到这个效果,应该使 A > > B A>>B A>>B,可以发现跟我们的训练思路是吻合的,这就达到了对于查询向量而言,推近它和正样本之间的距离,拉远它和负样本的距离

写到这里,基本上把InfoNCE的公式以及公式背后的主要思想讲清楚了,下面就要说Cross Entropy Loss跟它的关系了,其实主要还是InfoNCELoss代码是基于交叉熵损失实现的,看不明白交叉熵损失的代码逻辑也看不懂InfoNCELoss了

Cross Entropy Loss

交叉熵损失是衡量两个概率分布之间差异的一种指标。在分类问题中,我们通常有一个真实的概率分布 P P P(通常是一个独热编码向量,代表了样本的真实标签分布),和一个模型预测的概率分布 Q Q Q。交叉熵损失用于衡量这两个概率分布之间的差异。

其数学公式为:
CrossEntropy ( P , Q ) = − ∑ i P ( i ) ⋅ log ⁡ ( Q ( i ) ) \text{CrossEntropy}(P, Q) = - \sum_i P(i) \cdot \log(Q(i)) CrossEntropy(P,Q)=iP(i)log(Q(i))

  • P ( i ) P(i) P(i) 是真实标签的概率分布,代表了样本属于类别 i i i的概率
  • Q ( i ) Q(i) Q(i)是模型预测的概率分布,代表了模型对样本属于类别 i i i的预测概率
  • l o g log log 是自然对数函数。

交叉熵损失的含义和主要思想是在真实分布和模型预测分布之间衡量误差。当模型的预测与真实情况相符时,交叉熵损失会趋近于0。换句话说,交叉熵损失函数的优化目标是使得模型的预测概率分布尽可能地接近真实标签的概率分布,以最小化误差。

在深度学习中,交叉熵损失通常用作分类任务中的损失函数,在训练过程中用来衡量模型预测与真实标签之间的差异,并通过反向传播来优化模型参数。

结合上述解释,下面来看一下交叉熵损失的代码

'''创建原始数据样例
x:3row x 4col的张量,表示数据中包含三条数据,每条数据预测四个类别
y:3d张量,与三条数据对应;每个元素属于0-3,与四个类别对应'''# 1.创建原始数据
x=torch.rand((3,4))
y=torch.tensor([3,0,2])# 2.计算x_sfm=softmax(x),求出归一化后的每个类别概率值
softmax_func=nn.Softmax()
x_sfm=softmax_func(x)# 3.计算log(x_sfm),由于原来的概率值位于0-1,取对数后一定是负值
# 概率值越大,取对数后的绝对值越小,符合我们的损失目标
x_log=torch.log(x_sfm)# ls = nn.LogSoftmax(dim=1)# 也可以使用nn.LogSoftmax()进行测试,二者结果一致
# print(ls(x))# 4.最后使用nn.NLLLoss求损失
# 思路,按照交叉熵的计算过程,将真值与经过LogSoftmax后的预测值求和取平均
index=range(len(x))
loss=x_log[index,y]
print(abs(sum(loss)/len(x)))

从代码中可以很好理解交叉熵如何发挥作用,并且也能理解交叉熵的真值标签为啥只是一维张量

InfoNCE loss 代码

import torch
import torch.nn.functional as Fdef approx_infoNCE_loss(q, k):# 计算query和key的相似度得分similarity_scores = torch.matmul(q, k.t())  # 矩阵乘法计算相似度得分# 计算相似度得分的温度参数temperature = 0.07# 计算logitslogits = similarity_scores / temperature# 构建labels(假设有N个样本)N = q.size(0)labels = torch.arange(N).to(logits.device)# 计算交叉熵损失loss = F.cross_entropy(logits, labels)return loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/158066.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

做进销存什么软件好用

进销存软件是企业管理库存、采购和销售等环节的信息化管理系统,对于企业的运营和管理具有重要的意义。在选择进销存软件时,需要考虑以下因素: 功能需求:选择能够满足企业实际需求的进销存软件。例如,系统是否支持商品…

边缘计算系统设计与实践

💂 个人网站:【 海拥】【神级代码资源网站】【办公神器】🤟 基于Web端打造的:👉轻量化工具创作平台💅 想寻找共同学习交流的小伙伴,请点击【全栈技术交流群】 随着物联网、大数据和人工智能等技术的快速发展…

使用低代码可视化开发平台快速搭建应用

目录 一、JNPF可视化平台介绍 二、搭建JNPF可视化平台 【表单设计】 【报表设计】 【流程设计】 【代码生成器】 三、使用JNPF可视化平台 1.前后端分离: 2.多数据源: 3.预置功能: 4.私有化部署: 四、总结 可视化低代码…

【云原生】Spring Cloud Alibaba 之 Gateway 服务网关实战开发

目录 一、什么是网关 ⛅网关的实现原理 二、Gateway 与 Zuul 的区别? 三、Gateway 服务网关 快速入门 ⛄需求 ⏳项目搭建 ✅启动测试 四、Gateway 断言工厂 五、Gateway 过滤器 ⛽过滤器工厂 ♨️全局过滤器 六、源码地址 ⛵小结 一、什么是网关 Spri…

打包项目报错:程序包javax.servlet不存在

背景: WebService项目在没有配置Tomcat的情况下重新打包,由于是直接导入别人写好的项目,没有配置其他环境,所以报错程序包javax.servlet不存在 解决方法: 找到servlet-api.jar包,导入到现有项目的SDK 重…

Java,数据结构与集合源码,数据结构概述

目录 数据结构概念: 数据结构的研究对象: 研究对象一,数据间逻辑关系: 研究对象二,数据的存储结构(或物理结构): 研究对象三:运算结构 数据结构的相关介绍&#xff…

LeetCode [中等] 49. 字母异位词分组

给你一个字符串数组,请你将 字母异位词 组合在一起。可以按任意顺序返回结果列表。 字母异位词 是由重新排列源单词的所有字母得到的一个新单词。 思路: 利用字符数组的排序,将字符串数组中的每个字符串转换为字符数组,进行排序…

BGP的基础知识

BGP——边界网关协议 IGP——内部网关协议——OSPF、RIP、ISIS EGP——外部网关协议——EGP、BGP 边界网关协议BGP是一种实现自治系统AS之间的路由可达,并选择最佳路由的路径矢量路由协议。目前在IPV4环境下主要使用BGPV4,目前市场上也存在BGPV4&…

使用ExLlamaV2量化并运行EXL2模型

量化大型语言模型(llm)是减少这些模型大小和加快推理速度的最流行的方法。在这些技术中,GPTQ在gpu上提供了惊人的性能。与非量化模型相比,该方法使用的VRAM几乎减少了3倍,同时提供了相似的精度水平和更快的生成速度。 ExLlamaV2是一个旨在从…

SpringBoot : ch04 整合数据源

前言 Spring Boot 是当今最流行的 Java 开发框架之一,它以简洁、高效的特点帮助开发者快速构建稳健的应用程序。在实际项目中,涉及到数据库操作的需求时,我们需要对数据源进行整合。本文将重点介绍如何在 Spring Boot 中整合数据源&#xff…

NX二次开发UF_CAM_PREF_ask_integer_value 函数介绍

文章作者:里海 来源网站:https://blog.csdn.net/WangPaiFeiXingYuan UF_CAM_PREF_ask_integer_value Defined in: uf_cam_prefs.h int UF_CAM_PREF_ask_integer_value(UF_CAM_PREF_t pref, int * value ) overview 概述 This function provides the …

如何看待程序员领域内的“内卷”现象?

要搞清楚这个问题,我首先就来阐释一下“内卷”的概念。 内卷本身是从一个学术名词演化为网络流行词的,本是指文化模式因达到某种最终形态,既无法保持稳定也不能转化为更高级的新形态,而只能在这种文化模式内部无限变得复杂的现象。…

TVS瞬态抑制二极管的工作原理和特点?|深圳比创达电子EMC

TVS二极管一般是用来防止端口瞬间的电压冲击造成后级电路的损坏。防止端口瞬间的电压冲击造成后级电路的损坏。有单向与双向之分,单向TVS一般应用于直流供电电路,双向TVS应用于交流供电电路。 TVS产品的额定瞬态功率应大于电路中可能出现的最大瞬态浪涌…

【C++】const与类(const修饰函数的三种位置)

目录 const基本介绍 正文 前: 中: 后: 拷贝构造使用const 目录 const基本介绍 正文 前: 中: 后: 拷贝构造使用const const基本介绍 const 是 C 中的修饰符,用于声明常量或表示不可修改的对象、函数或成员函数。 我们已经了解了const基本用法,我们先进行…

【 OpenGauss源码学习 —— (hash_search)】

列存储(hash_search) 概述hash_search 函数hash_search_with_hash_value 函数calc_bucket 函数get_hash_entry 函数 补充知识 声明:本文的部分内容参考了他人的文章。在编写过程中,我们尊重他人的知识产权和学术成果,力…

层层剥开Android14升级后异常弹框的神秘面纱

本篇文章将会通过研究源码的方式给您讲述Android系统升级到Android14后出现的两个异常弹框并给出消除它们的方案。闲话少叙,我们开始。 问题描述 在Android 14升级后,出现两个弹窗的异常情况。这里是异常的截图: 接下来,我们对这…

第一个Maven项目

(一)准备工作 1、从官网下载压缩包:apache-maven-3.5.4-bin,然后解压到D盘没有中文的目录。 2、配置环境变量: 在左下角win打开“设置”,搜索“高级系统设置”,点击“高级”,点击“环境变量”&…

酷开科技OS——Coolita,让智能大屏走向国际

10月23日,2023中国—东盟视听传播论坛在南宁举行。作为第五届中国—东盟视听周重要活动之一,本次论坛以“共享新成果、共创新视听、共建新家园”为主题。来自中国和东盟的300余名专家学者、业界代表通过主旨演讲、主题发言、圆桌对话等方式进行深入探讨&…

自学成为android framework高手需要准备哪些装备-千里马车载车机系统开发学习

背景 hi,粉丝朋友们: 大家好!经常有很多学员买课同学都会问到需要准备哪些装备,我也回答了很多学员了,今天就搞一篇文章来统一说明一下,告诉一下大家如果你想从一个framework新手变成一个framework开发的高…

计算机网络实用工具之fping

简介 fping是一个类似ping的程序,它使用互联网控制消息协议(ICMP)回显请求来确定目标主机是否正在响应。fping与ping的不同之处在于,您可以在命令行上指定任意数量的目标,或者指定一个包含要ping的目标列表的文件。fp…