【深度学习】pytorch计算KL散度、kl_div

使用pytorch进行KL散度计算,可以使用pytorch的kl_div函数

假设y为真实分布,x为预测分布。

import torch
import torch.nn.functional as F# 定义两组数据
tensor1 = torch.tensor([[0.1, 0.2, 0.3, 0.2, 0.2],[0.2, 0.1, 0.2, 0.3, 0.2],[0.2, 0.3, 0.1, 0.2, 0.2],[0.2, 0.2, 0.3, 0.1, 0.2],[0.2, 0.2, 0.2, 0.3, 0.1],[0.1, 0.2, 0.2, 0.2, 0.3],[0.3, 0.2, 0.1, 0.2, 0.2],[0.2, 0.3, 0.2, 0.1, 0.2],[0.1, 0.2, 0.2, 0.3, 0.2],[0.2, 0.1, 0.3, 0.2, 0.2],[0.2, 0.3, 0.2, 0.2, 0.1],[0.1, 0.1, 0.2, 0.3, 0.3],[0.3, 0.2, 0.2, 0.1, 0.2],[0.2, 0.3, 0.1, 0.2, 0.2],[0.1, 0.3, 0.2, 0.2, 0.2],[0.2, 0.2, 0.1, 0.3, 0.2]])tensor2 = torch.tensor([[0.2, 0.1, 0.3, 0.2, 0.2],[0.3, 0.2, 0.2, 0.1, 0.2],[0.2, 0.3, 0.2, 0.2, 0.1],[0.1, 0.2, 0.3, 0.2, 0.2],[0.2, 0.2, 0.1, 0.2, 0.3],[0.3, 0.2, 0.2, 0.3, 0.0],[0.2, 0.3, 0.1, 0.2, 0.2],[0.1, 0.2, 0.2, 0.3, 0.2],[0.2, 0.1, 0.3, 0.2, 0.2],[0.2, 0.3, 0.2, 0.1, 0.2],[0.1, 0.2, 0.3, 0.2, 0.2],[0.2, 0.3, 0.2, 0.2, 0.1],[0.2, 0.1, 0.2, 0.3, 0.2],[0.3, 0.2, 0.2, 0.1, 0.2],[0.2, 0.2, 0.3, 0.2, 0.1],[0.1, 0.3, 0.2, 0.2, 0.2]])# 计算两组张量之间的 KL 散度
logp_x = F.log_softmax(tensor1, dim=-1)
p_y = F.softmax(tensor2, dim=-1)kl_divergence = F.kl_div(logp_x, p_y, reduction='batchmean')
kl_sum = F.kl_div(logp_x, p_y, reduction='sum')
print("KL散度(batchmean)值为:", kl_divergence.item())
print("KL散度(sum)值为:", kl_sum.item())

打印结果:

KL散度(batchmean)值为: 0.00508523266762495
KL散度(sum)值为: 0.0813637226819992  

其中kl_div接收三个参数,第一个为预测分布,第二个为真实分布,第三个为reduction。(其实还有其他参数,只是基本用不到)

这里有一些细节需要注意,第一个参数与第二个参数都要进行softmax(dim=-1),目的是使两个概率分布的所有值之和都为1,若不进行此操作,如果x或y概率分布所有值的和大于1,则可能会使计算的KL为负数。

softmax接收一个参数dim,dim=-1表示在最后一维进行softmax操作。

除此之外,第一个参数还要进行log()操作(至于为什么,大概是为了方便pytorch的代码组织,pytorch定义的损失函数都调用handle_torch_function函数,方便权重控制等),才能得到正确结果。还有说是因为要用y指导x,所以求x的对数概率,y的概率

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/803742.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot使用Zxing生成二维码

SpringBoot使用Zxing生成二维码 什么是Zxing具体实现1. 在pom文件中导入依赖2. 二维码生成工具类3. 控制层和服务层4. 前端 总结参考 什么是Zxing ZXing,一个支持在图像中解码和生成条形码(如二维码、PDF 417、EAN、UPC、Aztec、Data Matrix、Codabar)的库。ZXing(…

spring boot admin搭建,监控springboot程序运行状况

新建一个spring boot web项目&#xff0c;添加以下依赖 <dependency><groupId>de.codecentric</groupId><artifactId>spring-boot-admin-starter-server</artifactId><version>2.3.0</version></dependency> <dependency&…

微信小程序自定义tabbar,页面切换存在闪动【解决方案】

需求&#xff1a; 自定义tabbar&#xff0c;在需要的几个主页面都加入这么一个组件&#xff0c;但是有个情况&#xff1b;而组件中使用照片&#xff08;svg或png&#xff09;和文字;在切换tabbar的时候&#xff0c;跳转相应的页面&#xff0c;运行到真机或是模拟器&#xff0c;…

Docker搭建Nginx+keepalived高可用负载均衡服务器

一、背景 1.nginx高可用 在生产环境下&#xff0c;Nginx作为流量的入口&#xff0c;如果Nginx不能正常工作或服务器宕机&#xff0c;将导致整个微服务架构的不可用。所以负责负载均衡、反向代理的服务&#xff08;Nginx&#xff09;为了提高处理性能&#xff0c;高可用&#…

python-pytorch NLP中处理中文的步骤0.5.002

python-pytorch NLP中处理中文的步骤0.5.001 1. 导入包2. 准备停用词3. 把需要处理的文本切词4. 将切的词放入list中5. 获取vocab、vocab_size6. 获取word_to_idx、idx_to_word7. 告一段落8. 其他&#xff08;1800的停用词&#xff09; 1. 导入包 import jieba import torch i…

Go —— channel (二)

一个空的 channel 会产生哪些问题 读写nil管道均会阻塞触发死锁。关闭的管道仍然可以读取数据&#xff0c;向关闭的管道写数据会触发panic。 问&#xff1a;如果有多个协程同时读取一个channel&#xff0c;channel会如何选择消费者 channel 会按照维护的 recvq 等待读消息的…

239. 奇偶游戏(带权值并查集,邻域并查集,《算法竞赛进阶指南》)

239. 奇偶游戏 - AcWing题库 小 A 和小 B 在玩一个游戏。 首先&#xff0c;小 A 写了一个由 0 和 1 组成的序列 S&#xff0c;长度为 N。 然后&#xff0c;小 B 向小 A 提出了 M 个问题。 在每个问题中&#xff0c;小 B 指定两个数 l 和 r&#xff0c;小 A 回答 S[l∼r] 中…

苍穹外卖11(Apache ECharts前端统计,营业额统计,用户统计,订单统计,销量排名Top10)

目录 一、Apache ECharts【前端】 1. 介绍 2. 入门案例 二、营业额统计 1. 需求分析和设计 1 产品原型 2 业务规则 3 接口设计 2. 代码开发 3. 功能测试 三、用户统计 1. 需求分析和设计 1 产品原型 2 业务规则 3 接口设计 2. 代码开发 3. 功能测试 四、订单统…

0.开篇:SSM+Spring Boot导学

1. 为什么要使用框架 Spring是一个轻量级Java开发框架&#xff0c;最早有Rod Johnson创建&#xff0c;目的是为了解决企业级应用开发的业务逻辑层和其他各层的耦合问题。 几乎当下所有企业级JavaEE开发都离不开SSM&#xff08;Spring SpringMVC MyBatis&#xff09;Spring B…

什么是企业邮箱?如何选择合适的企业邮箱?

企业邮箱和个人邮箱不同&#xff0c;它的邮箱后缀是企业自己的域名。企业邮箱供应商一般都提供手机app、桌面端、web浏览器访问等邮箱使用途径。那么什么是企业邮箱&#xff1f;如何选择合适的企业邮箱&#xff1f;好用的企业邮箱应具备无缝迁移、协作、多邮箱管理等功能。 企…

【hive】单节点搭建hadoop和hive

一、背景 需要使用hive远程debug&#xff0c;尝试使用无hadoop部署hive方式一直失败&#xff0c;无果&#xff0c;还是使用有hadoop方式。最终查看linux内存占用6GB&#xff0c;还在后台运行docker的mysql(bitnami/mysql:8.0)&#xff0c;基本满意。 版本选择&#xff1a; &a…

全面深入学习Java中的字符串类

二、字符串类 &#xff08;一&#xff09;String String&#xff1a;字符串 concat() --- 在末尾追加字符串&#xff0c;返回新的字符串 substring(int begindex) --- 从指定下标处截取到字符串末尾&#xff0c;并返回新的字符串 substring(int begindex,endindex) --- 从开始…

OceanBase 中一个关于 NOT IN 子查询的 SQL 优化案例

通过一个案例了解 not in 对 NULL 值敏感的处理逻辑和优化方法。 作者&#xff1a;胡呈清&#xff0c;爱可生 DBA 团队成员&#xff0c;擅长故障分析、性能优化&#xff0c;个人博客&#xff1a;[简书 | 轻松的鱼]&#xff0c;欢迎讨论。 爱可生开源社区出品&#xff0c;原创内…

Linux中账号登陆报错access denied

“Access denied” 是一个权限拒绝的错误提示&#xff0c;意味着用户无法获得所请求资源的访问权限。出现 “Access denied” 错误的原因可以有多种可能性&#xff0c;包括以下几种常见原因&#xff1a; 错误的用户名或密码&#xff1a;输入的用户名或密码不正确&#xff0c;导…

机器学习—1.快速入门

机器学习步骤 确定与问题相关的输入&#xff08;明确输入&#xff09;收集与问题相关的数据&#xff08;数据准备&#xff0c;学&#xff09;分析预测结果的类型&#xff08;分类&#xff1f;回归&#xff1f;是判断题还是应用题&#xff09;根据预测记过的类型&#xff0c;选…

http添加SSL证书后打开变成另外一个网站是怎么回事

当在使用http的网站上添加了SSL证书后&#xff0c;如果打开该网站时出现了另外一个网站&#xff0c;可能是由以下几种情况引起的&#xff1a; 错误的证书配置 证书配置可能存在错误&#xff0c;导致SSL连接时服务器返回了错误的证书&#xff0c;或者证书与网站域名不匹配。这…

MySQL-系统及自定义变量

详情系统变量信息参考MySQL官方文档 系统变量分类&#xff1a; 全局系统变量&#xff08;global&#xff09; 全局系统变量针对于所有会话&#xff08;连接&#xff09;有效&#xff0c;但 不能跨重启 会话系统变量&#xff08;session&#xff09; 仅针对当前连接有效&am…

STM32-模数转化器

ADC(Analog-to-Digital Converter) 指模数转换器。是指将连续变化的模拟信号转换 为离散的数字信号的器件。 ADC相关参数说明&#xff1a; 分辨率&#xff1a; 分辨率以二进制&#xff08;或十进制&#xff09;数的位数来表示&#xff0c;一般有 8 位、10 位、12 位、16 位…

Transformer模型-decoder解码器,target mask目标掩码的简明介绍

今天介绍transformer模型的decoder解码器&#xff0c;target mask目标掩码 背景 解码器层是对前面文章中提到的子层的包装器。它接受位置嵌入的目标序列&#xff0c;并将它们通过带掩码的多头注意力机制传递。使用掩码是为了防止解码器查看序列中的下一个标记。它迫使模型仅使用…

WPF 多语言切换及ResourceDictionary的Source路径填写

WPF 多语言切换 1. 添加资源字典 新增两个资源字典&#xff0c;里面分别存储不同语言的文本 <ResourceDictionary xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/winfx/2006/xaml" xmlns:s…