PyTorch 中的距离函数深度解析:掌握向量间的距离和相似度计算

目录

Pytorch中Distance functions详解

pairwise_distance

用途

用法

参数

数学理论公式

示例代码

cosine_similarity

用途

用法

参数

数学理论

示例代码 

输出结果

pdist

用途

用法

参数

数学理论

示例代码

总结 


Pytorch中Distance functions详解

pairwise_distance

torch.nn.functional.pairwise_distance 是 PyTorch 中的一个函数,用于计算两组向量之间的成对距离。这个函数广泛应用于机器学习和深度学习中,尤其是在处理距离相关的任务,如聚类、相似度计算等。

用途

  • 计算两组向量间的成对距离,常用于度量向量间的相似性或差异性。
  • 用于机器学习中的距离度量,如k-最近邻 (k-NN)、聚类等。

用法

torch.nn.functional.pairwise_distance(x1, x2, p=2.0, eps=1e-6, keepdim=False)

 

  • x1, x2: 输入的两组向量,必须有相同的维度。
  • p: 距离计算的幂指数,默认为2,即欧几里得距离。
  • eps: 一个小的数值,用于保证数值稳定性。
  • keepdim: 是否保持输出的维度。

参数

  • x1: 第一组向量的张量。
  • x2: 第二组向量的张量。
  • p: 距离度量的幂指数,默认为2(欧几里得距离)。
  • eps: 避免除零错误的小数,默认为1e-6。
  • keepdim: 在输出中保持原始输入的维度结构。

数学理论公式

对于向量 x1_{i}​ 和 x2_{i}pairwise_distance 计算的是 p 范数下的距离:

d(x1_{i},x2_{i})=(\sum_{j}|x1_{ij}-x2_{ij}|^{p}+eps)^{\frac{1}{p}}

 其中,x1_{ij} 和 x2_{ij} 分别是x1_{i}x1_{i}x2_{i} ,的第j个元素。

示例代码

import torch
import torch.nn.functional as F# 定义两个向量组
x1 = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32)
x2 = torch.tensor([[1, 3, 5], [2, 4, 6]], dtype=torch.float32)# 计算成对距离
dist = F.pairwise_distance(x1, x2, p=2)# 输出结果  tensor([2.2361, 2.4495]) 这里,输出的是每一对向量之间的欧几里得距离。print(dist)

cosine_similarity

torch.nn.functional.cosine_similarity 是 PyTorch 中的一个函数,用于计算两个张量之间的余弦相似度。这个函数在机器学习和深度学习领域中非常有用,尤其是在处理文本、图像或任何类型的特征向量时,用于度量它们之间的相似性。

用途

  • 计算两个向量或向量组之间的余弦相似度。
  • 广泛应用于自然语言处理、计算机视觉、推荐系统等领域。

用法

torch.nn.functional.cosine_similarity(x1, x2, dim=1, eps=1e-8)
  • x1, x2: 输入的两个张量,必须能够广播到相同的形状。
  • dim: 计算相似度的维度。
  • eps: 避免除零错误的小数值。

参数

  • x1 (Tensor): 第一个输入张量。
  • x2 (Tensor): 第二个输入张量。
  • dim (int, 可选): 计算相似度的维度,默认为1。
  • eps (float, 可选): 用于避免除零的小数值,默认为1e-8。

数学理论

余弦相似度的计算公式为:

similarity = \frac{x1}{max(||x1||_{2},\varepsilon )\times max(||x2||_{2},\varepsilon )}

 

  • x1⋅x2 表示两个张量的点积。
  • ||x1||_{2} 和 ||x2||_{2}​ 分别是 x1 和 x2 的2范数。
  • ε 是一个小的数值,用来保证除数不为零。

示例代码 

import torch
import torch.nn.functional as F# 随机生成两个张量
input1 = torch.randn(100, 128)
input2 = torch.randn(100, 128)# 计算余弦相似度
output = F.cosine_similarity(input1, input2)# 打印结果
print(output)

输出结果

此代码将计算 input1input2 每行之间的余弦相似度,并输出一个长度为100的张量,每个元素对应于两个输入张量相应行的余弦相似度值。由于输入是随机生成的,输出也会随机变化。

pdist

torch.nn.functional.pdist 是 PyTorch 中的一个函数,它用于计算输入张量中每对行向量之间的 p 范数距离。此函数在统计分析、机器学习和数据科学中非常有用,尤其是在涉及距离度量和空间关系的场景中。

用途

  • 计算给定张量中每对行向量之间的距离。
  • 应用于聚类分析、多维缩放和其他需要距离度量的算法。

用法

torch.nn.functional.pdist(input, p=2)
  • input: 输入张量,其形状为 N×M,其中 N 是行数,M 是列数(特征数)。
  • p: 用于计算的 p 范数,默认为 2,即欧几里得距离。

参数

  • input (Tensor): 形状为 N×M 的输入张量。
  • p (float): p 范数的值,用于计算向量对之间的距离。可取值为 0 到 ∞ 之间的任何实数。

数学理论

对于输入张量的每一对行向量 x_{i}x_{j}pdist 计算它们之间的 p 范数距离:​d(x_{i},x_{j})=(\sum_{k}|x_{ik}-x_{jk}|^{p})^{\frac{1}{p}} 其中,x_{ik}​ 和 x_{jk} 分别是 x_{i} 和x_{j}的第 k 个元素。

示例代码

import torch
import torch.nn.functional as F# 定义输入张量
input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32)# 计算 p 范数距离
distances = F.pdist(input_tensor, p=2)# 输出结果 tensor([5.1962, 10.3923, 5.1962]) 这里,输出的是输入张量中每一对行向量之间的欧几里得距离。print(distances)

总结 

本文解析了 PyTorch 中三个关键的距离函数:pairwise_distancecosine_similaritypdist。这些函数在深度学习和机器学习中非常重要,用于计算向量之间的距离和相似度,从而支持各种算法如聚类、k-最近邻、特征相似度度量等。每个函数都有其特定的应用场景和数学原理。pairwise_distance 计算两组向量间的成对欧几里得距离,cosine_similarity 计算两个张量间的余弦相似度,而 pdist 则计算一个张量内各行向量间的 p 范数距离。通过这些函数,我们能有效地分析和处理数据,特别是在高维空间中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/632771.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JVM工作原理与实战(十九):运行时数据区-方法区

专栏导航 JVM工作原理与实战 RabbitMQ入门指南 从零开始了解大数据 目录 专栏导航 前言 一、运行时数据区 二、方法区 1.方法区介绍 2.方法区在Java虚拟机的实现 3.类的元信息 4.运行时常量池 5.字符串常量池 6.静态变量的存储 总结 前言 JVM作为Java程序的运行环境…

App分发测试平台:改变应用开发的游戏规则

App分发测试平台是一个提供应用开发者上传、测试、分享和发布应用的在线服务平台。它为开发者提供了一个高效的测试环境,并为用户提供了一个方便的应用获取渠道。其中,测试环节尤为关键,因为它能够确保应用在上线之前达到预期的功能和性能。 …

问题:Feem无法发送信息OR无法连接(手机端无法发给电脑端)

目录 前言 问题分析 资源、链接 其他问题 前言 需要在小米手机、华为平板、Dell电脑之间传输文件,试过安装破解的华为电脑管家、小米的MIUI文件传输等,均无果。(小米“远程管理”ftp传输倒是可以,但速度太慢了,且…

Java JVM 堆、栈、方法区详解

目录 1. 栈 2. 堆 3. 方法区 4. 本地方法栈 5. 程序计数器 首先来看一下JVM运行时数据区有哪些。 1. 栈 在介绍JVM栈之前,先了解一下 栈帧 概念。 栈帧:一个栈帧随着一个方法的调用开始而创建,这个方法调用完成而销毁。栈帧内存放者方…

ROS学习笔记8——实现ROS通信时的常用命令

机器人系统中启动的节点少则几个,多则十几个、几十个,不同的节点名称各异,通信时使用话题、服务、消息、参数等等都各不相同,一个显而易见的问题是: 当需要自定义节点和其他某个已经存在的节点通信时,如何获取对方的话…

gitgud.io+Sapphire注册账号教程

gitgud.io是一个仓库,地址 https://gitgud.io/,点进去之后会看到注册页面。 意思是需要通过注册这个Sapphire账户来登录。点击右边的Sapphire,就跳转到Sapphire的登陆页面,点击创建新账号,就进入注册页面。&#xff0…

SpiderFlow爬虫平台漏洞利用分析(CVE-2024-0195)

1. 漏洞介绍 SpiderFlow爬虫平台项目中spider-flow-web\src\main\java\org\spiderflow\controller\FunctionController.java文件的FunctionService.saveFunction函数调用了saveFunction函数,该调用了自定义函数validScript,该函数中用户能够控制 functi…

Spring | Spring中的Bean--下

Spring中的Bean: 4.Bean的生命周期5.Bean的配装配式 ( 添加Bean到IOC容器的方式 依赖注入的方式 )5.1 基于XML的配置5.2 基于Annotation (注解) 的装配 (更常用)5.3 自动装配 4.Bean的生命周期 Spring容器可以管理 singleton作用域的Bean的生命周期,在此…

go语言(七)----slice的声明方式

1、声明方式一 //声明一个slice1是一个切片,但是并没有给slice分配空间var slice1 []intslice1 make([]int,3)2、声明方式二 声明一个slice切片,同时给slice分配空间,3个空间,初始化值是0var slice1 []int make([]int,3)3、声…

ICCV2023 | PTUnifier+:通过Soft Prompts(软提示)统一医学视觉语言预训练

论文标题:Towards Unifying Medical Vision-and-Language Pre-training via Soft Prompts 代码:https://github.com/zhjohnchan/ptunifier Fusion-encoder type和Dual-encoder type。前者在多模态任务中具有优势,因为模态之间有充分的相互…

从临床和科研场景分析ChatGPT在医疗健康领域的应用可行性

2023年4月发表在Journal Medical Systems的文献《Evaluating the Feasibility of ChatGPT in Healthcare: An Analysis of Multiple Clinical and Research Scenarios》(评估 ChatGPT 在医疗健康领域的可行性:对多种临床和研究场景的分析)介绍…

IPv6自动隧道---6to4中继

6to4中继 普通IPv6网络需要与6to4网络通过IPv4网络互通,这可以通过6to4中继路由器方式实现。所谓6to4中继,就是通过6to4隧道转发的IPv6报文的目的地址不是6to4地址,但转发的下一跳是6to4地址,该下一跳为路由器我们称之为6to4中继。隧道的IPv4目的地址依然从下一跳的6to4地…

PPT 编辑模式滚动页面不居中

PPT 编辑模式滚动页面不居中 目标:编辑模式下适应窗口大小、切换页面居中显示 调整视图大小,编辑模式通过Ctrl 鼠标滚轮 或 在视图菜单中点击适应窗口大小。 2. 翻页异常,调整视图大小后,PPT翻页但内容不居中或滚动&#xff0c…

『MySQL快速上手』-⑩-索引特性

文章目录 1.索引的作用2.索引的理解建立测试表插入多条记录查看结果 2.1 MySQL与磁盘交互的基本单位2.1 为何IO交互要是 Page2.3 理解单个Page2.4 理解多个Page2.5 页目录2.6 单页情况2.7 多页情况2.8 B vs B2.9 聚簇索引 vs 非聚簇索引非聚簇索引聚簇索引 3.索引操作3.1 创建主…

pytest + allure(windows)安装

背景 软硬件环境: windows11,已安装anaconda,python,pycharm用途:使用pytest allure 生成报告allure 依赖java,点击查看java安装教程 allure 下载与安装 从 allure下载网址下载最新版本.zip文件 放在自…

基于YOLOv8深度学习的葡萄簇目标检测系统【python源码+Pyqt5界面+数据集+训练代码】目标检测、深度学习实战

《博主简介》 小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~ 👍感谢小伙伴们点赞、关注! 《------往期经典推…

【llm 微调code-llama 训练自己的数据集 一个小案例】

这也是一个通用的方案,使用peft微调LLM。 准备自己的数据集 根据情况改就行了,jsonl格式,三个字段:context, answer, question import pandas as pd import random import jsondata pd.read_csv(dataset.csv) train_data data…

pyspark 笔记:窗口函数window

窗口函数相关的概念和基本规范可以见:pyspark笔记:over-CSDN博客 1 创建Pyspark dataFrame from pyspark.sql.window import Window import pyspark.sql.functions as F employee_salary [("Ali", "Sales", 8000),("Bob&qu…

USACO介绍 报名流程 成绩查询方式详解(文末有备赛资料)

USACO美国计算机奥林匹克活动 2023-2024新赛季的时间线安排是怎么样的? 2023-2024USACO竞赛时间 一般来说,USACO竞赛时间在12月-3月期间,每月都有一场比赛每次3-5小时,并在规定时间内完成3-4道题。23-24年USACO竞赛时间安排如下&a…

uniapp h5 生成 ubuntu桌面程序 并运行方法

uniapp h5 生成 ubuntu桌面程序 并运行方法,在window环境下开发,发布到ubuntu桌面,并运行 1、安装Nodejs 安装包官方下载地址:https://www.nodejs.com.cn/ 安装完后cmd,如图,即安装成功 2、通过Nodejs安装 electron…