Pytorch实用教程:nn.CrossEntropyLoss()的用法

在 PyTorch 中,nn.CrossEntropyLoss() 是一个非常常用且功能强大的损失函数,特别适合用于多类分类问题。这个损失函数结合了 nn.LogSoftmax()nn.NLLLoss() (Negative Log Likelihood Loss) 两个操作,从而在一个模块中提供完整的交叉熵损失计算功能。这不仅方便使用,也提高了数值稳定性。

功能说明

nn.CrossEntropyLoss() 计算模型输出实际标签之间的交叉熵损失。它自动完成softmax 概率分布的计算和对数似然损失的计算,这意味着你应该直接将网络的原始输出(logits,即未经 softmax 层处理的输出)作为 CrossEntropyLoss 的输入。

上面这句话非常重要,这就是为什么在用交叉熵损失函数的时候,在模型的输出部分见不到softmax的原因。

参数详解

nn.CrossEntropyLoss 主要有以下几个参数:

  • weight (Tensor, optional): 一个手动指定的权重,用于平衡类别间的损失贡献。这在类别不平衡的情况下非常有用。
  • size_average (bool, deprecated): 这个参数已经被弃用,用 reduction 参数代替。
  • ignore_index (int, optional): 指定一个类别索引,对于这个类别的目标(target),损失将不会被计算。这常用于忽略特定的类别。
  • reduce (bool, deprecated): 这个参数也已经被弃用,用 reduction 参数代替。
  • reduction (str, optional): 指定损失的计算模式。可以是 ‘none’(无操作),‘mean’(计算损失的均值,是默认设置)或 ‘sum’(计算损失的总和)。

使用示例

下面是一个使用 nn.CrossEntropyLoss 的简单例子。假设我们有一个分类问题,目标是将输入分类到三个类别中的一个:

import torch
import torch.nn as nn# 假设我们有3个类别,batch_size为4
data = torch.randn(4, 3)  # 输入,来自某个神经网络的原始输出,形状为(batch_size, num_classes)
targets = torch.tensor([0, 2, 1, 0])  # 实际的标签,形状为(batch_size,)# 创建交叉熵损失函数实例
criterion = nn.CrossEntropyLoss()# 计算损失
loss = criterion(data, targets)
print(loss) # 输出:tensor(1.6401)

数学原理

对于每个样本 (i),假设 (C) 是类别总数,交叉熵损失定义为:

在这里插入图片描述

这里 (x[class_i]) 是模型输出的第 (i) 个样本对应其真实类别 (class_i) 的 logit。交叉熵损失将这些 logits 转换为正规化的概率分布,然后计算其对数似然。

应用场景

这个损失函数是处理多类分类问题的标准选择之一,特别是当你有一个多类的标签目标时。由于其数学上的稳定性,它在训练深度学习模型时非常受欢迎。使用它可以直接处理 logits,无需单独计算 softmax,从而在实际应用中减少计算量和增加数值稳定性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/547.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[qiankun]: Target container with #container not existed while childOne loading!

主应用container容器不存在导致无法挂载子应用 解决&#xff1a;不要将<div id"container"></div>放在Router标签内&#xff0c;跟Router同级即可

通过一系列vue-demo入门vue2

一、创建简单vue实例 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><meta http-equiv"X-UA-Compatible&…

2024年核科学与地球化学国际会议 (ICNSG 2024)

2024年核科学与地球化学国际会议 (ICNSG 2024) 2024 International Conference on Nuclear Science and Geochemistry 【会议简介】 2024年核科学与地球化学国际会议即将在北京召开。本次会议旨在汇聚全球核科学与地球化学领域的专家学者&#xff0c;共同探讨核科学的最新进展…

不是我说,这玩意也叫高可用?

背景&#xff1a;有人求助说数据库起不来了。原因是某个文件有问题&#xff08;可以理解为无法访问或者读写&#xff09;。我问有从库吗&#xff1f;说没有。这里的高可用架构是通过存储复制做的。然后高可用那端的文件也一样。听到这里随着这个系统不是我的&#xff0c;我都忍…

Django模型的字段类型

Django模型中最重要并且也是唯一必须执行的就是字段定义。字段在类中进行定义&#xff0c;对应于实体数据库的字段。另外&#xff0c;定义模型字段名时为了避免冲突&#xff0c;不建议使用模型API中已经定义的关键字。 字段类型用以指定数据库的数据类型&#xff0c;例如Integ…

美团外卖10元无门槛通用券怎么领取10元外卖通用红包?

词令公众号美团外卖红包天天领入口&#xff0c;首次使用的外卖新客可领取10元无门槛通用券&#xff0c;点餐使用即可享受优惠&#xff1b; 美团外卖10元无门槛通用券怎么领取&#xff1f; 1、关注「词令」公众号&#xff0c;回复「外卖红包」&#xff1b; 2、打开后立即领取外…

R语言 并行计算makeCluster报错

问题&#xff1a;使用parallel包进行并行计算&#xff0c; cl <- makeCluster(detectCores()) 出现以下问题&#xff1a; 解决方式&#xff1a;用makeClusterPSOCK命令代替即可 library("future") cl <- makeClusterPSOCK(124, revtunnel TRUE, outfile &…

华为OD-C卷-查找接口成功率最优时间段[100分]Python3-100%

题目描述 服务之间交换的接口成功率作为服务调用关键质量特性,某个时间段内的接口失败率使用一个数组表示, 数组中每个元素都是单位时间内失败率数值,数组中的数值为0~100的整数, 给定一个数值(minAverageLost)表示某个时间段内平均失败率容忍值,即平均失败率小于等于m…

日志记录不再烦恼!Python开发利器Logbook模块带你飞!

在Python开发中&#xff0c;日志记录是一项至关重要的功能。通过记录应用程序的运行状态、错误信息和调试信息&#xff0c;可以帮助开发人员更好地理解程序的运行情况&#xff0c;快速定位问题并进行调试。 Python标准库中的logging模块提供了基本的日志记录功能&#xff0c;但…

基于粒子群算法改进三隐含层BP神经网络的回归预测,基于粒子群算法改进的多输入多输出BP神经网络回归分析

目录 摘要 BP神经网络的原理 BP神经网络的定义 BP神经网络的基本结构 BP神经网络的神经元 BP神经网络的激活函数, BP神经网络的传递函数 粒子群算法的原理及步骤 粒子群算法优化三隐含层BP神经网络回归分析,粒子群优化多输入多输出BP神经网络 matlab代码下载链接:粒子群算法…

华为ensp中Hybrid接口原理和配置命令

作者主页&#xff1a;点击&#xff01; ENSP专栏&#xff1a;点击&#xff01; 创作时间&#xff1a;2024年4月19日14点03分 Hybrid接口是ENSP虚拟化中的一种重要技术&#xff0c;它既可以连接普通终端的接入链路&#xff0c;又可以连接交换机间的干道链路。Hybrid接口允许多…

德鲁伊参数踩坑之路

上文说到 Druid德鲁伊参数调优实战&#xff0c;也正因此次优化&#xff0c;为后续问题埋下了伏笔 背景 2024/04/16日&#xff0c;业务反馈某个定时统计的数据未出来&#xff0c;大清早排查定位是其统计任务跑批失败&#xff0c;下面给一段伪代码 // 无事务执行 public void …

Linux 基于 UDP 协议的简单服务器-客户端应用

目录 一、socket编程接口 1、socket 常见API socket()&#xff1a;创建套接字 bind()&#xff1a;将用户设置的ip和port在内核中和我们的当前进程关联 listen() accept() 2、sockaddr结构 3、inet系列函数 二、UDP网络程序—发送消息 1、服务器udp_server.hpp initS…

git rebase回退到根

项目初始有2个commit&#xff0c;git rebase -i 合并提交记录只能看到一个最新的&#xff0c; 需要git rebase -i --root才能看到第一个提交 git rebase -i -root以后&#xff0c;编辑提交信息&#xff0c;然后就可以了。 之前本地调试的时候经过多次实验性操作&#xff0c;导致…

探索“人工智能+”战略下的企业切入点

在“人工智能”的大战略框架下&#xff0c;企业正面临着巨大的发展机遇与挑战。本文将深入探讨在这一战略框架下&#xff0c;企业可以采取的具体切入点&#xff0c;以实现技术创新、提升竞争力和实现可持续发展。 --- 随着人工智能技术的不断发展和应用&#xff0c;以“人工智能…

java spring 05 图灵 启动性能优化

一.doscan方法的补充&#xff1a; 01.在findCandidateComponents(basePackage)方法中&#xff1a;优化&#xff0c;因为扫描package 如果存在有索引的文件&#xff0c;使用索引文件来加载bean public Set<BeanDefinition> findCandidateComponents(String basePackage)…

docker-004-搭建本地镜像库

背景 1 官方Docker Hub地址:https:/hub.docker.com,中国大陆访问太慢了且有被阿里云取代的趋势,不太主流 2 Dockerhub、阿里云这样的公共镜像仓库可能不太方便,涉及机密的公司不可能提供镜像给公网,所以需要创建一个本地私人仓库供给团队使用,基于公司内部项目构建镜像。…

python教学入门:字典和集合

字典&#xff08;Dictionary&#xff09;&#xff1a; 定义&#xff1a; 字典是 Python 中的一种数据结构&#xff0c;用于存储键值对&#xff08;key-value pairs&#xff09;。字典使用花括号 {} 定义&#xff0c;键值对之间用冒号 : 分隔&#xff0c;每对键值对之间用逗号 …

动态规划——记忆化搜索

数字三角形 找一条最大路径。发现从上面往下一步步走很麻烦&#xff0c;直接搜索肯定超时&#xff0c;我们可以逆向求解。从下往上看。从倒数第二行开始看&#xff0c;2可以选4和5&#xff0c;因为找最大&#xff0c;所以我们选5&#xff0c;把2加上5更新为7&#xff0c;以此类…

vs2022断点空心加感叹号 解决方案

有时会出现设置的调试时&#xff0c;断点红色断点出现黄色的感叹号&#xff0c;并提示与原版本不同&#xff0c;现两种解决办法。 1、“工具”&#xff0c;“选项”&#xff0c;“调试”&#xff0c;“要求源文件与原始版本完成匹配”去掉勾。