「torch.cosine_smilarity() = 0」引发的关于cpu与gpu精度问题的探讨

前言:2023年11月21日下午16:00 许,本篇博客记录由「torch.cosine_smilarity()计算余弦相似度计算结果为0」现象引发的关于 CPU 与 GPU 计算精度的探索。

事情的起因是,本人在使用 torch.cosine_smilarity() 函数计算GPU上两个特征的余弦相似度时,发现得出的结果为 0,百思不得其解。首先排出特征维度的问题,然后尝试5种不同的相似度计算方法:

  • scipy.spatial.distance.cosine
  • torch.cosine_similarity
  • F.cosine_similarity
  • torch.nn.CosineSimilarity
  • 基于余弦相似度公式的torch代码

整体代码如下:

import torch
torch.set_printoptions(profile="full")
import torch.nn.functional as F
from scipy.spatial.distance import cosine
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"import clip
from PIL import Image
clip_model, processor = clip.load("ViT-L/14", device=device)
srcpath = '/newdata/SD/DEFAKE/data_test/9709_glide.png'
despath = '/newdata/SD/outputs/0_9_sd1.5.png'
src_feature = clip_model.encode_image(processor(Image.open(srcpath)).unsqueeze(0).to(device)).squeeze(0)  
des_feature = clip_model.encode_image(processor(Image.open(despath)).unsqueeze(0).to(device)).squeeze(0)sim_1 = 1 - cosine(src_feature.cpu(), des_feature.cpu())sim_2 = torch.cosine_similarity(src_feature, des_feature, dim=0).item()sim_3 = F.cosine_similarity(src_feature, des_feature, dim=0).item()cos = torch.nn.CosineSimilarity(dim=0)
sim_4 = cos(src_feature, des_feature).item()sim_5 = torch.div(torch.sum(src_feature * des_feature,0),torch.sqrt(torch.sum(torch.pow(src_feature,2),0))* torch.sqrt(torch.sum(torch.pow(des_feature,2),0))).item()print(sim_1, sim_2, sim_3, sim_4, sim_5)
# 0.5302734375 0.0 0.0 0.0 0.53076171875

发现,上述代码在CPU和GPU上运行结果不一致:

上述代码在 CPU 上的运行结果为:

0.5301393270492554
0.5301393270492554
0.5301393270492554
0.5301393270492554
0.5301393270492554

上述代码在 GPU 上的运行结果为:

0.5302734375
0.0
0.0
0.0
0.53076171875

这是一个很有意思的现象,在CPU上计算出的5种相似度结果惊人一致,而在GPU上计算出的5种相似度结果中,中间三种基于torch函数调用的方式计算结果均为0,而第一种首先将特征搬运到CPU上然后使用scipy.spatial.distance.cosine()函数的计算结果和第五种直接在GPU上使用基于余弦相似度公式的torch代码的计算结果又与CPU上计算出的结果各有不同。


然后,我把由 CLIP 预训练模型提取的两张图像的特征(维度为768维) src_featuredes_feature 换成两个随机初始化的张量 ,其余代码不变:


# import clip
# from PIL import Image
# clip_model, processor = clip.load("ViT-L/14", device=device)
# srcpath = '/newdata/SD/DEFAKE/data_test/9709_glide.png'
# despath = '/newdata/SD/outputs/0_9_sd1.5.png'
# src_feature = clip_model.encode_image(processor(Image.open(srcpath)).unsqueeze(0).to(device)).squeeze(0)  
# des_feature = clip_model.encode_image(processor(Image.open(despath)).unsqueeze(0).to(device)).squeeze(0)# 将上面代码注释掉,换为:src_featre = torch.tensor([1.0, 2.0, 3.0])
des_feature = torch.tensor([4.0, 5.0, 6.0])

可见上述示例代码在CPU和GPU上运行结果是一致的:

上述代码在 CPU 上的运行结果为:

0.9746318459510803
0.9746317863464355
0.9746317863464355
0.9746317863464355
0.9746317863464355

上述代码在 GPU 上的运行结果为:

0.9746318459510803
0.9746317863464355
0.9746317863464355
0.9746317863464355
0.9746317863464355

由上述结果可以发现,第一种基于scipy.spatial.distance.cosine()函数的计算结果与其余四组基于torch的计算结果略有不同,说明后四种方法实现的底层逻辑应该是类似的,但由于给定特征的某些不可知原因,有时会出现中间三种基于torch函数调用的方法结果为0的情况,所以保险起见,如果要使用基于torch的计算方法,首选第5种相似度计算方法,当然,时间允许的情况下,直接在CPU上使用第一种方法无疑是精度最高的计算方法。


接下来放一个时间对比图(如下),可见在GPU上使用最后一种计算方法效率最高,在CPU上使用第一种方法效率最低。

time for spicy(gpu):  0.004714250564575195, [0.5361, 0.5303, 0.5220, 0.5059, 0.5430, 0.5469, 0.5078, 0.5293, 0.5283, 0.5337]
time for torch(gpu): 0.0005323886871337891, [0.5361, 0.5308, 0.5225, 0.5063, 0.5435, 0.5469, 0.5078, 0.5298, 0.5283, 0.5332]
time for spicy(cpu):  0.009323358535766602, [0.5358, 0.5301, 0.5222, 0.5060, 0.5426, 0.5467, 0.5077, 0.5298, 0.5279, 0.5333]
time for torch(cpu): 0.0025298595428466797, [0.5358, 0.5301, 0.5222, 0.5060, 0.5426, 0.5467, 0.5077, 0.5298, 0.5279, 0.5333]

PS:鉴于第一种方法无法进行余弦相似度的批量计算(1vN计算),追求速度的话,还是选择第五种方法吧~👀 附赠批量计算方法

src_feature = clip_model.encode_image(processor(Image.open(srcpath)).unsqueeze(0).to(device))  # [1,768]
des_features = torch.stack([clip_model.encode_image(processor(Image.open(path)).unsqueeze(0).to(device)) for path in despaths]).squeeze(1)  # [N,768]
sims = torch.div(torch.sum(src_feature * des_features,1),torch.sqrt(torch.sum(torch.pow(src_feature,2),1))* torch.sqrt(torch.sum(torch.pow(des_features,2),1)))  # 长度为N的张量
sims = sims.cpu().detach().numpy().tolist()  # 转化为列表,方便计算

参考资料

  1. GPU和CPU计算上的精度差异_cpu和gpu训练结果不同-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/156826.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【迅搜02】究竟什么是搜索引擎?正式介绍XunSearch

究竟什么是搜索引擎?正式介绍XunSearch 啥?还要单独讲一下啥是搜索引擎?不就是百度、Google嘛,这玩意天天用,还轮的到你来说? 额,好吧,虽然大家天天都在用,但是我发现&am…

移远通信推出六款新型天线,为物联网客户带来更丰富的产品选择

近日,移远通信重磅推出六款新型天线,覆盖5G、非地面网络(NTN)等多种新技术,将为物联网终端等产品带来全新功能和更强大的连接性能。 移远通信COO张栋表示:“当前,物联网应用除了需要高性能的天线…

【libGDX】使用Mesh绘制三角形

1 Mesh 和 ShaderProgram 简介 1.1 创建 Mesh 1)Mesh 的构造方法 public Mesh(boolean isStatic, int maxVertices, int maxIndices, VertexAttribute... attributes) public Mesh(boolean isStatic, int maxVertices, int maxIndices, VertexAttributes attribut…

js ::after简单实战

::after的作用是在元素后面再加个XXX样式 工作中遇到了一个表格,鼠标指到单元格要有个整行编辑态的效果,下面写个简单的demo 有人可能会说了,直接修改某个单元格的hover样式不就行了嘛,问题是如果鼠标指到单元格和单元格直接的…

Android DatePicker(日期选择器)、TimePicker(时间选择器)、CalendarView(日历视图)- 简单应用

示意图&#xff1a; layout布局文件&#xff1a;xml <?xml version"1.0" encoding"utf-8"?> <ScrollView xmlns:android"http://schemas.android.com/apk/res/android"xmlns:app"http://schemas.android.com/apk/res-auto"…

实验过程中的问题记录

代码&#xff1a; if args.local_rank in [-1, 0] and eval_dataset is not None and args.eval_steps > 0 and global_step % args.eval_steps 0 :metric_cur eval_fn(args, eval_dataset, model, tokenizer, global_stepglobal_step, file_prefix"eval_")当参…

IP-guard Web系统远程命令执行漏洞说明

一、漏洞说明 近期收到反馈,IP-guard Web服务器存在远程命令执行漏洞(RCE),经过分析,确认是因为Web系统的申请审批功能使用了开源插件 flexpaper 实现文件在线预览功能,此插件存在远程代码执行漏洞。 攻击者可利用 flexpaper插件漏洞,在文件预览参数中拼接其它恶意命令…

时序预测 | Pytorch实现TCN-Transformer的时间序列预测

时序预测 | Pytorch实现TCN-Transformer的时间序列预测 目录 时序预测 | Pytorch实现TCN-Transformer的时间序列预测效果一览基本介绍程序设计 效果一览 基本介绍 基于TCN-Transformer模型的时间序列预测&#xff0c;可以用于做光伏发电功率预测&#xff0c;风速预测&#xff0…

管理体系标准

管理体系标准 什么是管理体系&#xff1f; 管理体系是组织管理其业务的相互关联部分以实现其目标的方式。这些目标可能涉及许多不同的主题&#xff0c;包括产品或服务质量、运营效率、环境绩效、工作场所的健康和安全等等。 系统的复杂程度取决于每个组织的具体情况。对于某…

Vue2+Vue3

文章目录 第 1 章&#xff1a;Vue 核心1、 Vue 简介1.官网2.介绍与描述3. Vue 的特点4. 与其它 JS 框架的关联5. Vue 周边库 2、初始Vue3、模板语法1、Vue模板语法有2大类:2、插值语法和指令语法 4、数据绑定1. 单向数据绑定2. 双向数据绑定 5、el与data的两种写法1.e1有2种写法…

社会媒体营销提问常用的ChatGPT通用提示词模板

如何制定有效的社会媒体营销策略&#xff1f; 如何选择适合的社会媒体平台进行营销&#xff1f; 如何创造有吸引力的社会媒体内容&#xff0c;提高用户参与度和分享率&#xff1f; 如何运用社交媒体广告来增加品牌曝光和用户转化&#xff1f; 如何建立和维护社交媒体账号和…

外部 prometheus监控k8s集群资源

prometheus监控k8s集群资源 一&#xff0c;通过CADvisior 监控pod的资源状态1.1 授权外边用户可以访问prometheus接口。1.2 获取token保存1.3 配置prometheus.yml 启动并查看状态1.4 Grafana 导入仪表盘 二&#xff0c;通过kube-state-metrics 监控k8s资源状态2.1 部署 kube-st…

【科技素养】蓝桥杯STEMA 科技素养组模拟练习试卷01

单选题 1、生活中&#xff0c;我们经常说“有机蔬菜”相比普通蔬菜更加健康&#xff0c;这是因为 A、它们没有使用无机肥料 B、它们是有机的 C、它们没有使用肥料 D、人们对蔬菜的错误认知 答案&#xff1a;A 2、甲乙两位工人一起在工厂工作。甲的生产速度是每小时6个鼠标…

网络运维与网络安全 学习笔记2023.11.21

网络运维与网络安全 学习笔记 第二十二天 今日目标 端口隔离原理与配置、路由原理和配置、配置多路由器静态路由 配置默认路由、VLAN间通信之路由器 端口隔离原理与配置 端口隔离概述 实现报文之间的2层隔离&#xff0c;除了使用VLAN技术以后&#xff0c;还可以使用端口隔…

c语言:十进制转任意进制

思路&#xff1a;如十进制转二进制 就是不断除二求余在除二求余&#xff0c;然后将余数从下到写出来&#xff0c;这样&#xff0c;10011100就是156的二进制 这里举例一个六进制的代码&#xff1a; #define _CRT_SECURE_NO_WARNINGS #include<stdio.h>int main() {int …

opencv-简单图像处理

图像像素存储形式  对于只有黑白颜色的灰度图&#xff0c;为单通道&#xff0c;一个像素块对应矩阵中一个数字&#xff0c;数值为0到255, 其中0表示最暗&#xff08;黑色&#xff09; &#xff0c;255表示最亮&#xff08;白色&#xff09; 对于采用RGB模式的彩色图片&#…

「MACOS限定」 如何将文件上传到GitHub仓库

介绍 本期讲解&#xff1a;如何在苹果电脑上上传文件到github远程仓库 注&#xff1a;写的很详细 方便我的朋友可以看懂操作步骤 第一步 在电脑上创建一个新目录&#xff08;文件夹&#xff09; 注&#xff1a;创建GitHub账号、新建github仓库、git下载的步骤这里就不过多赘…

118.184.158.111德迅云安全浅谈如何避免网络钓鱼攻击

随着互联网的不断发展&#xff0c;网络钓鱼攻击也越来越猖獗&#xff0c;给个人和企业带来了巨大的经济损失和安全威胁。本文对如何防范网络钓鱼攻击提出的一些小建议 希望对大家有所帮助。 1.防止XSS&#xff08;跨站脚本攻击&#xff09;攻击 XSS攻击指的是攻击者在网站中注入…

html手势密码解锁插件(附源码)

文章目录 1.设计来源1.1 界面效果 2.效果和源码2.1 动态效果2.2 源代码 源码下载 作者&#xff1a;xcLeigh 文章地址&#xff1a;https://blog.csdn.net/weixin_43151418/article/details/134534785 html手势密码解锁插件(附源码)&#xff0c;仿手机手势密码&#xff0c;拖动九…

基于SSM的网络财务管理系统设计与实现

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;采用JSP技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#x…