PyTorch交叉熵理解

PyTorch 中的交叉熵损失

CrossEntropyLoss

PyTorch 中使用CrossEntropyLoss 计算交叉熵损失,常用于分类任务。交叉熵损失衡量了模型输出的概率分布与实际标签分布之间的差异,目标是最小化该损失以优化模型。

我们通过一个具体的案例来详细说明 CrossEntropyLoss 的计算过程。

假设我们有一个简单的分类任务,共有 3 个类别。我们有 2 个样本的预测和实际标签。

输入

  • 模型的预测(logits,未经过 softmax 激活)

  • 实际标签

import torch
import torch.nn as nn# 模型的预测(logits)
logits = torch.tensor([[2.0, 1.0, 0.1],[0.5, 2.0, 0.3]])# 实际标签
labels = torch.tensor([0, 2])

计算步骤

  • 步骤 1: Softmax 激活

首先,将 logits 通过 softmax 激活函数转换为概率分布。

softmax = nn.Softmax(dim=1)
probabilities = softmax(logits)
print(probabilities)

输出

tensor([[0.6590, 0.2424, 0.0986],[0.1587, 0.7113, 0.1299]])
  • 步骤 2: 计算交叉熵

交叉熵损失的计算公式为:

C r o s s E n t r o p y L o s s = − ∑ i = 1 N log ⁡ ( p i , y i ) CrossEntropyLoss=-\sum_{i=1}^{N}{\log{(}}{{p}_{i,{{y}_{i}}}}) CrossEntropyLoss=i=1Nlog(pi,yi)

其中 N 是样本数量, p i , y i p_{i,y_i} pi,yi是第 i个样本在实际标签  y i y_i yi 位置上的预测概率。

我们手动计算每个样本的交叉熵损失:

  • 对于第一个样本,实际标签为 0,预测概率为 0.6590

l o s s 1 = − log ⁡ ( 0.6590 ) ≈ 0.4171 {{loss}_{1}}=-\log{(}0.6590)\approx 0.4171 loss1=log(0.6590)0.4171

  • 对于第二个样本,实际标签为 2,预测概率为 0.1299

l o s s 2 = − log ⁡ ( 0.1299 ) ≈ 2.0406 {{loss}_{2}}=-\log{(}0.1299)\approx 2.0406 loss2=log(0.1299)2.0406

平均损失为:

m e a n = 0.4171 + 2.0406 2 ≈ 1.2288 mean=\frac{0.4171+2.0406}{2}\approx 1.2288 mean=20.4171+2.04061.2288

  • 步骤 3: 使用 PyTorch 的 CrossEntropyLoss 计算

我们使用 PyTorch 的 CrossEntropyLoss 函数来验证计算结果:

criterion = nn.CrossEntropyLoss()
loss = criterion(logits, labels)
print(loss.item())

输出

1.2288230657577515
  • 步骤4:依据公式使用 PyTorch 计算

依据前面的公式使用 PyTorch 计算来验算结果

neg_log_p = -torch.log(probabilities)
loss_cal = neg_log_p[torch.arange(neg_log_p.shape[0]), labels].mean()
print(loss_cal.item())

输出

1.228823184967041

结果基本一致。

总结

  1. CrossEntropyLoss 接受未经过 softmax 的 logits 作为输入。

  2. 内部首先对 logits 应用 softmax,将其转换为概率分布。

  3. 然后根据实际标签计算交叉熵损失。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/849298.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Unity编辑器扩展-番外篇-Gizmos基础-物体如何在球面上移动

目录 一、本节目标效果展示 二、先画出素材 1.先新建一个普通的代码 2.画素材(一个头,两个耳朵,一个鼻子) a.关于贴心的Unity b.开始画素材 三、了解移动的原理 四、辅助物体的建立 五、画左耳朵 六、全部代码 七、作者的…

Ceph入门到精通-Ceph OSD 磁盘在系统重启后无法识别处理步骤

如果Ceph OSD磁盘在系统重启后无法识别,你可以按照以下步骤进行检查和解决: 1. 检查硬件状态 物理检查:首先进行物理检查,确保磁盘没有物理损坏,数据线和电源线连接正常。S.M.A.R.T状态:使用smartctl命令检查磁盘的S.M.A.R.T状态,以确定是否有硬件问题。2. 确认磁盘识别…

OpenCv之简单的人脸识别项目(特征标注页面)

人脸识别 准备八、特征标注页面1.导入所需的包2.设置窗口2.1定义窗口外观和大小2.2设置窗口背景2.2.1设置背景图片2.2.2创建label控件 3.定义两个全局变量4.定义选择图片的函数4.1函数定义和全局变量声明4.2打开文件对话框并获取文件路径4.3处理图片并创建标签4.4显示图像 5.定…

MK米客方德 SD NAND与文件系统:技术解析与应用指南

随着数字存储技术的飞速发展,SD NAND(贴片式T卡)已成为我们日常生活中不可或缺的存储工具。我们将深入探讨SD NAND的文件系统,特别是SD 3.0协议支持的文件系统类型,以及它们在实际应用中的作用和用户可能遇到的问题。 MK米客方德的…

kafka-集群-主题创建

文章目录 1、集群主题创建1.1、查看 efak1.2、创建 主题 my_topic1 并建立6个分区并给每个分区建立3个副本1.2.1、查看 my_topic1 的详细信息 1.3、停止 kafka-01实例,端口号为 9095 1、集群主题创建 1.1、查看 efak 已经有三个kafka实例 1.2、创建 主题 my_topic1…

【面试干货】索引的作用

【面试干货】索引的作用 1、索引的作用 💖The Begin💖点点关注,收藏不迷路💖 1、索引的作用 索引 可以协助 快速查询、更新数据库表中数据。 通过使用索引,数据库系统能够快速定位到符合查询条件的数据,提…

为什么需要对政府工作绩效进行第三方评估?

政府工作绩效的第三方评估具有重要意义,能够在多个方面对政府运作和公共管理产生积极影响。以下是第三方评估的主要意义: 一、提升政府透明度和公信力 通过第三方独立评估,政府的工作绩效和决策过程变得更加公开透明,有助于增强…

人工智能--Foxmail邮箱使用方法

目录 🍉Foxmail全面指南 🍉下载与安装 🍈下载软件 🍈安装软件 🍉配置邮箱 🍈启动 Foxmail 🍈添加邮箱账户 🍈手动配置邮箱 🍍接收邮件服务器 (IMAP/POP3) &…

Elastic Platform 8.14:ES|QL 正式发布、静态加密和向量搜索优化

作者:来自 Elastic Gilad Gal, Tyler Perkins, Alex Chalkias, Trevor Blackford, Ninoslav Miskovic, Fabio Busatto, Aris Papadopoulos Elastic Platform 8.14 提供了 Elasticsearch 查询语言 (ES|QL) 的正式发行版 (GA) — Elastic 中数据探索和操作的未来。它还…

# ROS 获取激光雷达数据 (Python实现)

ROS 获取激光雷达数据 (Python实现) 实现思路 构建一个新的软件包,包名叫做lidar_pkg在软件包中新建一个节点,节点名叫做lidar_node.py在节点中,向ROS大管家rospy申请订阅话题/scan,并设置回调函数为Lidarcallback()构建回调函数…

java:使用shardingSphere访问mysql的分库分表数据

# 创建分库与分表 创建两个数据库【order_db_1、order_db_2】。 然后在两个数据库下分别创建三个表【orders_1、orders_2、orders_3】。 建表sql请参考: CREATE TABLE orders_1 (id bigint NOT NULL,order_type varchar(255) NULL DEFAULT NULL,customer_id bigi…

Docker:技术架构演进

文章目录 基本概念架构演进单机架构应用数据分离架构应用服务集群架构读写分离/主从分离架构冷热分离架构垂直分库微服务容器编排架构 本篇开始进行对于Docker的学习,Docker是一个陌生的词汇,那么本篇开始就先从技术架构的角度出发,先对于技术…

【51单片机】智能百叶窗项目

文章目录 功能演示:前置要求:主要功能:主要模块:主函数代码: 具体的仿真程序和代码程序已经免费放置在资源中,如有需要,可以下载进行操作。 功能演示: 前置要求: 编译软…

【NoSQL】Redis练习

1、redis的编译安装 systemctl stop firewalld systemctl disable firewalld setenforce 0 yum install -y gcc gcc-c make wget cd /opt wget https://download.redis.io/releases/redis-5.0.7.tar.gz tar zxvf redis-5.0.7.tar.gz -C /opt/cd /opt/redis-5.0.7/ # 编译 make…

【全开源】CMS内容管理系统(ThinkPHP+FastAdmin)

基于ThinkPHPFastAdmin的CMS内容管理系统,自定义内容模型、自定义单页、自定义表单、专题、统计报表、会员发布等 提供全部前后台无加密源代码和数据库私有化部署,UniAPP版本提供全部无加密UniAPP源码​ 🔍 解锁内容管理新境界:C…

Typesense-开源的轻量级搜索引擎

Typesense-开源的轻量级搜索引擎 Typesense是一个快速、允许输入错误的搜索引擎,用于构建愉快的搜索体验。 开源的Algolia替代方案& 易于使用的弹性搜索替代方案 官网: https://typesense.org/ github: https://github.com/typesense/typesense 目前已有18.4k…

阅读笔记:Multi-threaded Rasterization in the Chromium Compositor

Multi-threaded Rasterization in the Chromium Compositor PPT 原始链接: https://docs.google.com/presentation/d/1nPEC4YRz-V1m_TsGB0pK3mZMRMVvHD1JXsHGr8I3Hvc/edit?uspsharing PPT主要介绍了Chromium浏览器中使用多线程光栅化(Impl-side painting)的机制&a…

Python自动化发送邮件如何实现?怎么配置?

Python自动化发送邮件需要注意什么?邮件群发的技巧? 无论是个人使用还是企业需求,电子邮件的发送都是必不可少的。而Python作为一门功能强大的编程语言,可以通过自动化脚本实现批量发送邮件,从而提高工作效率。AokSen…

【外汇天眼】选择外汇EA的关键:策略适配、风险控制与稳定性评估

外汇EA(Expert Advisor)是外汇交易市场中广泛使用的自动化交易系统。它们通过预定义的规则和算法自动执行交易,旨在为交易者提供便捷的交易体验,同时提高交易效率和准确性。本文将从策略选择、风险控制和稳定性评估三个方面&#…

更改晶振后如何修改配置

GD32官方提供的固件库中使用的晶振配置一般为8M或25M,如果读者使用其他频率的晶振如何修改配置呢?本文为大家讲解如何修改。 以GD32F303为例,官方固件库中的晶振及时钟配置代码如下,改配置代码为使用外部8M晶振倍频到120M时钟。 …