【深度学习】Pytorch中实现交叉熵损失计算的方式总结

在PyTorch中,计算交叉熵损失主要有以下几种方式,它们针对不同的场景和需求有不同的实现方式和适用范围:

1. nn.CrossEntropyLoss

这是最常用且方便的方法,特别适用于多分类任务。nn.CrossEntropyLoss 实际上是同时完成了 softmax 函数和交叉熵损失的计算。它假设最后一层的输出没有经过归一化处理(不是概率形式),而是直接给出了各个类别的得分。该函数会自动计算每一样本对各类别的得分,应用softmax函数,然后计算交叉熵损失。

import torch
import torch.nn as nn# 假设 outputs 是模型的最后一层输出,shape 为 (batch_size, num_classes),targets 是 ground truth labels
outputs = torch.randn(100, 10)  # 对于10分类问题的100个样本的不归一化的预测值
targets = torch.randint(0, 10, (100,))  # 对应的真实类别loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(outputs, targets)
print(loss.item())

2. F.cross_entropy 函数

torch.nn.functional.cross_entropy 函数也是为了多分类问题设计的,但它接受的是 logits 或者已经经过 softmax 的概率。如果你的输出已经是经过 softmax 的概率,可以直接使用;否则,它会默认内部先执行 log_softmax

import torch.nn.functional as F# 假设 outputs 是未经 softmax 的 logits
outputs = torch.randn(100, 10)# 使用 F.cross_entropy 直接计算损失,无需单独进行 softmax
loss = F.cross_entropy(outputs, targets)
print(loss.item())

3. nn.BCEWithLogitsLoss 类(二分类问题)

对于二分类问题,尤其是sigmoid激活函数之后的结果,可以使用带Sigmoid的二元交叉熵损失函数,它同时完成 sigmoid 和 二元交叉熵损失的计算。

# 二分类问题,输出维度为 (batch_size, 1)
outputs = torch.randn(100, 1) # targets 是介于 [0, 1] 或 {-1, 1} 的值,表示正负样本
targets = torch.rand(100, 1) > 0.5  # 或者其他的二进制标签bce_loss = nn.BCEWithLogitsLoss()
loss = bce_loss(outputs, targets.float())
print(loss.item())

4. 手动计算交叉熵损失

当然,也可以手动组合 log_softmaxnll_loss 函数来计算交叉熵损失,这在特殊情况下可能会有用,比如需要对损失函数进行修改或者自定义的时候:

# 多分类问题,手动组合 log_softmax 和 nll_loss
output_logits = torch.randn(100, 10)
softmax_outputs = F.log_softmax(output_logits, dim=1)  # 计算 log_softmax
loss_manual = -torch.mean(torch.gather(softmax_outputs, 1, targets.unsqueeze(1)).squeeze())  # 使用 gather 和 mean 计算 NLL
assert torch.allclose(loss_manual, F.nll_loss(softmax_outputs, targets, reduction='mean'))  # 应该与 nll_loss 结果一致

在上述代码中,gather 函数用于从预测概率矩阵中按照目标标签索引出相应的对数概率,然后求平均得到最终的交叉熵损失。在多分类任务中,直接使用 F.nll_loss(log_softmax_outputs, targets) 是更加简洁的做法,等价于手动计算。而在二分类问题中,对应的手动计算方式则会涉及 sigmoidbinary_cross_entropy_with_logits 函数。

5. 补充说明

在交叉熵损失计算函数中:
L = − ∑ i = 1 n y i l o g ( S ( f θ ( x i ) ) ) L = -\sum_{i=1}^{n}{y_i}log(S(f_\theta(x_i))) L=i=1nyilog(S(fθ(xi)))
真实值 y i y_i yi可以是热编码后的结果,也可以不进行热编码。
虽然在Pytorch架构中,神经网络内流动的数据类型必须是float类型,但是Pytorch也提供了自动处理整数(int类型)标签的交叉熵损失函数(这里的“整数标签”指的是每个样本所属的真实类别,通常是一个从0开始的整数索引,对应着类别数量中的一个),这些函数会自动将整数标签转换为内部使用的one-hot编码格式,并计算交叉熵损失。
nn.CrossEntropyLoss为例,当输入给定的output是未经归一化的类别得分(logits),而target是整数标签时,这个损失函数会自动将整数标签转换为one-hot格式,然后再进行交叉熵损失的计算。这意味着用户不需要预先将目标标签转换为one-hot编码,损失函数内部会处理这样的转换过程。

import torch
import torch.nn as nn# 假设我们有一个批次的输出和对应的类别标签
outputs = torch.randn(64, 10)  # 这是一个批次的输出,共64个样本,10个类别
labels = torch.tensor([2, 7, 0, ..., 4], dtype=torch.long)  # 这是对应的整数类别标签loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(outputs, labels)print(f'Cross-entropy loss: {loss.item()}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/770248.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端理论总结(css3)——页面布局方法

瀑布流 优点:节省空间,外表美观,更有艺术性 对于触屏设备非常友好,通过向上滑动浏览 用户浏览时的观赏和思维不容易被打断,留存更容易 缺点:用户…

feign demo

直接上代码 AscendKing/springcloud-feign

Saltstack 最大打开文件数问题之奇怪的 8192

哈喽大家好,我是咸鱼。 今天分享一个在压测过程中遇到的问题,当时排查这个问题费了我们好大的劲,所以我觉得有必要写一篇文章来记录一下。 问题出现 周末在进行压测的时候,测试和开发的同事反映压测有问题,请求打到…

一键实现数据采集和存储:Python爬虫、Pandas和Excel的应用技巧

作为一名互联网技术爱好者,我对数据的探索充满热情。在本文中,我将以豆瓣读书为案例,详细介绍如何利用Python爬虫、Pandas和Excel这三大工具,一键化地实现数据采集和存储。豆瓣读书作为一个备受推崇的图书评价平台,拥有…

亮剑AIGC,紫光云能否胜人一筹?

【全球云观察 | 科技热点关注】 扎实创新每一步, 先人一步快人一步。 2023年全球科技行业最火的莫过于生成式AI,即Artificial Intelligence Generated Content。在迈向生成式AI的道路上,虽然说不上千军万马,但是国内…

一文看懂算法交易(一)

国内T0算法哪家强?算法交易费用是多少?算法交易哪些平台好? 最近算力、AI一直都比较火,现在KiMi和语料都已经出世,随着中国金融市场专业化程度和复杂度的提高,交易环节的每一步都对应巨大的增量空间。智能算…

装饰器 篇

文章目录 装饰器classmethod()property()staticmethod() 装饰器 在Python中,装饰器(Decorator)是一个高级功能,它允许你在不修改函数或类本身代码的情况下,给函数或类动态地添加额外的功能。装饰器本质上是一个接受函…

Python学习笔记------文件操作

编码 编码就是一种规则集合,记录了内容和二进制间进行相互转换的逻辑。 编码有许多中,我们最常用的是UTF-8编码 计算机只认识0和1,所以需要将内容翻译成0和1才能保存在计算机中。同时也需要编码,将计算机保存的0和1&#xff0c…

2.4 如何运行Python程序

如何运行Python程序? Python是一种解释型的脚本编程语言,这样的编程语言一般支持两种代码运行方式: 1) 交互式编程 在命令行窗口中直接输入代码,按下回车键就可以运行代码,并立即看到输出结果;执行完一行…

ReentrantLock加锁分析

1、ReentrantLock中其实是有一个AQS的子类实例的成员变量sync; 2、实际是调用的Sync中的lock;Sync是AQS的子类;Sync有两个子类,公平与非公平;默认为非公平;如下是非公平加锁分析; public Reentr…

Visual Basic6.0零基础教学(4)—编码基础,数据类型与变量

编码基础,数据类型与变量 文章目录 编码基础,数据类型与变量前言一、VB中的编程基础二、VB的基本字符集和词汇集1、字符集2、词汇集 VB中的数据类型VB中的变量与常量一.变量和常量的命名规则二.变量声明1.用Dim语句显式声明变量三. 常量 运算符和表达式一. 运算符 1. 算术运算符…

获取Book里所有sheet的名字,且带上超链接

应用背景: 当一个excel有很多sheet的时候,来回切换sheet会比较复杂,所以我希望excel的第一页有目录,可以随着sheet的增加,减少,改名而随时可以去更新,还希望有超链接可以直接跳到该sheet。 可以…

06-验证浮点数输入

鉴于shell脚本的限制和本事,浮点数(或“实数”)的验证过程乍一看似乎让人望而生畏,不过考虑到浮点数只不过是由小数点分隔的两个整数,再配合能够在脚本中引用其他脚本的能力(validint)&#xff…

13、Spring CLI中的特殊命令

特殊命令(Special Commands) 特殊命令是一个名为 . 的命令组的一部分。 操作系统 Shell 命令(OS Shell command) .! 命令在你启动 shell 的目录中运行一个操作系统命令。这个命令只在交互模式下工作。 如果在运行命令时遇到困难,你可能想尝试用双引号将其包围。然而,…

【爬取网易财经文章】

引言 在信息爆炸的时代,获取实时的财经资讯对于投资者和金融从业者来说至关重要。然而,手动浏览网页收集财经文章耗时费力,为了解决这一问题,本文将介绍如何使用Python编写一个爬虫程序来自动爬取网易财经下关于财经的文章 1. 爬…

前端基础 Vue -组件化基础

1.全局组件 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</title><script src&…

分布式任务队列:cppq

文章目录 简介FeaturesQuickstartExampleWeb UI命令行 参考 简介 cppq 是一个简单、可靠、高效的 C17 分布式任务队列。 cppq 是一个 C 库&#xff0c;用于对任务进行排队并与工作线程异步处理它们。它由 Redis 支持&#xff0c;旨在可扩展且易于入门。 cppq 工作原理&#x…

【QA】MySQL导出某数据库的所有数据为sql文件,包含建库命令、建表命令。

文章目录 前言Windows系统下 | mysqldump导出数据库数据Docker中导入初始化数据【补充】通过命令行&#xff0c;执行sql文件&#xff0c;将数据导入到数据库在MySQL外面执行在MySQL中执行 前言 我们在用docker部署mysql项目的时候&#xff0c;往往需要对数据库进行数据初始化。…

ARM 和 龙芯上 Arch Linux 安装手记

背景 今天尝试安装龙芯版 Linux,本来希望能安装 Debian 版,但只找到一些文档,没找到可安装版的 ISO。 后来顺着这篇文章找到了Arch Linux,就尝试安装了一下。 安装后发现竟然不会配置网络 😂。而且龙芯版由于是在 QEMU 虚拟机里,运行速度也较慢。所以,我想我需要先学…

Java-SSM电影在线播放系统

Java-SSM电影在线播放系统 1.服务承诺&#xff1a; 包安装运行&#xff0c;如有需要欢迎联系&#xff08;VX:yuanchengruanjian&#xff09;。 2.项目所用框架: 前端:JSP、layui等 后端:SSM,即Spring、SpringMvc、Mybatis等。 3.项目功能点: 3-1.后端功能: - 所有后台管理展…