深入理解交叉熵损失 CrossEntropyLoss - CrossEntropyLoss

深入理解交叉熵损失 CrossEntropyLoss - CrossEntropyLoss

flyfish

本系列的主要内容是在2017年所写,GPT使用了交叉熵损失函数,所以就温故而知新,文中代码又用新版的PyTorch写了一遍,在看交叉熵损失函数遇到问题时,可先看链接提供的基础知识,可以有更深的理解。

深入理解交叉熵损失 CrossEntropyLoss - one-hot 编码
深入理解交叉熵损失 CrossEntropyLoss - 对数
深入理解交叉熵损失 CrossEntropyLoss - 概率基础
深入理解交叉熵损失 CrossEntropyLoss - 概率分布
深入理解交叉熵损失 CrossEntropyLoss - 损失函数
深入理解交叉熵损失 CrossEntropyLoss - 归一化
深入理解交叉熵损失 CrossEntropyLoss - 信息论(交叉熵)
深入理解交叉熵损失 CrossEntropyLoss - Softmax
深入理解交叉熵损失 CrossEntropyLoss - nn.LogSoftmax

深入理解交叉熵损失 CrossEntropyLoss - 似然
深入理解交叉熵损失CrossEntropyLoss - 乘积符号在似然函数中的应用

深入理解交叉熵损失 CrossEntropyLoss - nn.NLLLoss
深入理解交叉熵损失 CrossEntropyLoss - nn.CrossEntropyLoss

深入理解交叉熵损失CrossEntropyLoss

  • 深入理解交叉熵损失 CrossEntropyLoss - CrossEntropyLoss
    • LogSoftmax和 NLLLoss两者的结合,对比立使用CrossEntropyLoss
      • 解释
      • 直观解释 Softmax和负对数似然
    • 二分类问题
      • 手动计算步骤
      • 代码实现
    • 多分类问题
      • 手动计算步骤
      • 代码验证

在 PyTorch 中, torch.nn.CrossEntropyLoss 是一个常用的 损失函数,主要用于多分类任务。它结合了 nn.LogSoftmaxnn.NLLLoss,并且内部进行了优化以避免 数值稳定性问题。

具体来说,torch.nn.CrossEntropyLoss 计算的是预测值与目标值之间的交叉熵损失。对于多分类问题,交叉熵损失是最常用的损失函数,因为它直接衡量了两个概率分布(预测概率分布和实际分布)之间的差异。

LogSoftmax和 NLLLoss两者的结合,对比立使用CrossEntropyLoss

nn.CrossEntropyLoss 在内部已经包含了 LogSoftmax 和 NLLLoss 的操作。
编写代码验证,分别是 LogSoftmax和 NLLLoss两者的结合,对比立使用CrossEntropyLoss。

import torch
import torch.nn as nn# 输入张量 (batch_size=2, num_classes=3)
input_tensor = torch.tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]])
# 目标张量 (batch_size=2)
target_tensor = torch.tensor([2, 0])# 使用 nn.LogSoftmax 和 nn.NLLLoss
log_softmax = nn.LogSoftmax(dim=1)
log_probs = log_softmax(input_tensor)
nll_loss = nn.NLLLoss()
loss = nll_loss(log_probs, target_tensor)
print(f'Loss using LogSoftmax and NLLLoss: {loss.item()}')# 使用 nn.CrossEntropyLoss
cross_entropy_loss = nn.CrossEntropyLoss()
loss_ce = cross_entropy_loss(input_tensor, target_tensor)
print(f'Loss using CrossEntropyLoss: {loss_ce.item()}')

输出结果
Loss using LogSoftmax and NLLLoss: 1.4076058864593506
Loss using CrossEntropyLoss: 1.4076058864593506

解释

对于单个样本,交叉熵损失的定义如下:

CrossEntropyLoss = − ∑ i = 1 C y i log ⁡ ( y ^ i ) \text{CrossEntropyLoss} = -\sum_{i=1}^{C} y_i \log(\hat{y}_i) CrossEntropyLoss=i=1Cyilog(y^i)

其中:

  • C C C 是类别的数量。
  • y i y_i yi 是真实标签的一个one-hot编码(若样本属于类别 i i i,则 y i = 1 y_i = 1 yi=1,否则 y i = 0 y_i = 0 yi=0)。
  • y ^ i \hat{y}_i y^i 是模型预测的第 i i i 类的概率。

直观解释 Softmax和负对数似然

交叉熵损失结合了两个概念:

  1. Softmax
    首先将模型输出的原始分数(logits)通过 softmax 函数转换成概率分布,Softmax 函数将 logits 转换为概率分布。对于一个有 C C C 个类别的分类问题,Softmax 公式如下:

y ^ i = exp ⁡ ( z i ) ∑ j = 1 C exp ⁡ ( z j ) \hat{y}_i = \frac{\exp(z_i)}{\sum_{j=1}^{C} \exp(z_j)} y^i=j=1Cexp(zj)exp(zi)

其中 z i z_i zi 是第 i i i 类的 logit。

  1. 负对数似然
    计算这些概率分布与真实标签之间的负对数似然。在获得概率分布后,交叉熵损失计算真实标签的负对数概率。如果真实标签对应的类别概率很高,损失就小;如果概率很低,损失就大。这驱动模型在训练过程中提高真实标签类别的预测概率。

以下是一个简单的示例,展示如何计算交叉熵损失:

import torch
import torch.nn as nn# 假设我们有两个样本,每个样本属于3个类别中的一个
logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.0, 0.3]])
# 真实标签
labels = torch.tensor([0, 1])# 使用 nn.CrossEntropyLoss 计算损失
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, labels)
print(f'Cross Entropy Loss: {loss.item()}')

Cross Entropy Loss: 0.37882310152053833
在这个示例中:

  • logits 是模型输出的原始分数。
  • labels 是真实的类别标签。
  • nn.CrossEntropyLoss 会先将 logits 转换为概率分布,然后计算真实标签的负对数似然损失。

二分类问题

二分类交叉熵损失的公式为:

CrossEntropyLoss = − ( y log ⁡ ( y ^ ) + ( 1 − y ) log ⁡ ( 1 − y ^ ) ) \text{CrossEntropyLoss} = - (y \log(\hat{y}) + (1 - y) \log(1 - \hat{y})) CrossEntropyLoss=(ylog(y^)+(1y)log(1y^))

手动计算步骤

  1. 计算 Sigmoid 激活值

假设:

  • 真实标签 y = 1 y = 1 y=1
  • 模型输出的logits为 z = 1.5 z = 1.5 z=1.5
    计算过程:
    σ ( z ) = 1 1 + exp ⁡ ( − 1.5 ) \sigma(z) = \frac{1}{1 + \exp(-1.5)} σ(z)=1+exp(1.5)1

我们使用更高精度来计算:
exp ⁡ ( − 1.5 ) ≈ 0.22313016014842982 \exp(-1.5) \approx 0.22313016014842982 exp(1.5)0.22313016014842982
σ ( z ) = 1 1 + 0.22313016014842982 ≈ 1 1.22313016014842982 ≈ 0.8175744761936437 \sigma(z) = \frac{1}{1 + 0.22313016014842982} \approx \frac{1}{1.22313016014842982} \approx 0.8175744761936437 σ(z)=1+0.2231301601484298211.2231301601484298210.8175744761936437

  1. 计算交叉熵损失

CrossEntropyLoss = − ( y log ⁡ ( σ ( z ) ) + ( 1 − y ) log ⁡ ( 1 − σ ( z ) ) ) \text{CrossEntropyLoss} = - (y \log(\sigma(z)) + (1 - y) \log(1 - \sigma(z))) CrossEntropyLoss=(ylog(σ(z))+(1y)log(1σ(z)))
CrossEntropyLoss = − log ⁡ ( 0.8175744761936437 ) \text{CrossEntropyLoss} = - \log(0.8175744761936437) CrossEntropyLoss=log(0.8175744761936437)
log ⁡ ( 0.8175744761936437 ) ≈ − 0.2014132779827524 \log(0.8175744761936437) \approx -0.2014132779827524 log(0.8175744761936437)0.2014132779827524
CrossEntropyLoss ≈ 0.2014132779827524 \text{CrossEntropyLoss} \approx 0.2014132779827524 CrossEntropyLoss0.2014132779827524

代码实现

import torch
import torch.nn as nn
import math# 真实标签和 logits
labels = torch.tensor([1.0])
logits = torch.tensor([1.5])# 使用 BCEWithLogitsLoss
criterion = nn.BCEWithLogitsLoss()
loss = criterion(logits, labels)
print(f'Binary Classification Cross Entropy Loss: {loss.item()}')# 手动计算 sigmoid 和交叉熵损失
sigmoid = 1 / (1 + math.exp(-1.5))
manual_loss = - (1 * math.log(sigmoid) + (1 - 1) * math.log(1 - sigmoid))
print(f'Manually Computed Cross Entropy Loss: {manual_loss}')

输出结果

Binary Classification Cross Entropy Loss: 0.20141397416591644
Manually Computed Cross Entropy Loss: 0.2014132779827524

多分类问题

假设有3个类别:

  • 真实标签为第3类,所以one-hot编码 y = [ 0 , 0 , 1 ] y = [0, 0, 1] y=[0,0,1]
  • 模型预测的logits为 logits = [ 0.1 , 0.2 , 0.7 ] \text{logits} = [0.1, 0.2, 0.7] logits=[0.1,0.2,0.7]

手动计算步骤

  1. 计算Softmax
    y ^ i = exp ⁡ ( z i ) ∑ k = 1 C exp ⁡ ( z k ) \hat{y}_i = \frac{\exp(z_i)}{\sum_{k=1}^{C} \exp(z_k)} y^i=k=1Cexp(zk)exp(zi)

具体计算:

y ^ 1 = exp ⁡ ( 0.1 ) exp ⁡ ( 0.1 ) + exp ⁡ ( 0.2 ) + exp ⁡ ( 0.7 ) \hat{y}_1 = \frac{\exp(0.1)}{\exp(0.1) + \exp(0.2) + \exp(0.7)} y^1=exp(0.1)+exp(0.2)+exp(0.7)exp(0.1)
y ^ 2 = exp ⁡ ( 0.2 ) exp ⁡ ( 0.1 ) + exp ⁡ ( 0.2 ) + exp ⁡ ( 0.7 ) \hat{y}_2 = \frac{\exp(0.2)}{\exp(0.1) + \exp(0.2) + \exp(0.7)} y^2=exp(0.1)+exp(0.2)+exp(0.7)exp(0.2)
y ^ 3 = exp ⁡ ( 0.7 ) exp ⁡ ( 0.1 ) + exp ⁡ ( 0.2 ) + exp ⁡ ( 0.7 ) \hat{y}_3 = \frac{\exp(0.7)}{\exp(0.1) + \exp(0.2) + \exp(0.7)} y^3=exp(0.1)+exp(0.2)+exp(0.7)exp(0.7)

计算得到:

exp ⁡ ( 0.1 ) ≈ 1.1052 \exp(0.1) \approx 1.1052 exp(0.1)1.1052
exp ⁡ ( 0.2 ) ≈ 1.2214 \exp(0.2) \approx 1.2214 exp(0.2)1.2214
exp ⁡ ( 0.7 ) ≈ 2.0138 \exp(0.7) \approx 2.0138 exp(0.7)2.0138

总和:

exp ⁡ ( 0.1 ) + exp ⁡ ( 0.2 ) + exp ⁡ ( 0.7 ) ≈ 1.1052 + 1.2214 + 2.0138 = 4.3404 \exp(0.1) + \exp(0.2) + \exp(0.7) \approx 1.1052 + 1.2214 + 2.0138 = 4.3404 exp(0.1)+exp(0.2)+exp(0.7)1.1052+1.2214+2.0138=4.3404

各个概率:

y ^ 1 = 1.1052 4.3404 ≈ 0.2546 \hat{y}_1 = \frac{1.1052}{4.3404} \approx 0.2546 y^1=4.34041.10520.2546
y ^ 2 = 1.2214 4.3404 ≈ 0.2814 \hat{y}_2 = \frac{1.2214}{4.3404} \approx 0.2814 y^2=4.34041.22140.2814
y ^ 3 = 2.0138 4.3404 ≈ 0.4639 \hat{y}_3 = \frac{2.0138}{4.3404} \approx 0.4639 y^3=4.34042.01380.4639

  1. 计算交叉熵损失
    CrossEntropyLoss = − ( 0 ⋅ log ⁡ ( 0.2546 ) + 0 ⋅ log ⁡ ( 0.2814 ) + 1 ⋅ log ⁡ ( 0.4639 ) ) \text{CrossEntropyLoss} = - (0 \cdot \log(0.2546) + 0 \cdot \log(0.2814) + 1 \cdot \log(0.4639)) CrossEntropyLoss=(0log(0.2546)+0log(0.2814)+1log(0.4639))
    CrossEntropyLoss = − log ⁡ ( 0.4639 ) ≈ 0.769 \text{CrossEntropyLoss} = - \log(0.4639) \approx 0.769 CrossEntropyLoss=log(0.4639)0.769

代码验证

import torch
import torch.nn as nn
import torch.nn.functional as F# 模拟输入的 logits 和真实标签
logits = torch.tensor([[0.1, 0.2, 0.7]], requires_grad=True)
labels = torch.tensor([2])# 使用 CrossEntropyLoss
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, labels)
print(f'Computed Cross Entropy Loss (using nn.CrossEntropyLoss): {loss.item()}')# 手动计算 softmax 和交叉熵损失
softmax_probs = F.softmax(logits, dim=1)
manual_loss = -torch.log(softmax_probs[0, labels])
print(f'Manually Computed Cross Entropy Loss: {manual_loss.item()}')

输出结果

Computed Cross Entropy Loss (using nn.CrossEntropyLoss): 0.7679495811462402
Manually Computed Cross Entropy Loss: 0.7679495811462402

注意在多分类问题的代码中,我们提供了logits而不是softmax后的概率,因为nn.CrossEntropyLoss会在内部应用softmax。

在二分类问题中,我们可以使用 nn.BCEWithLogitsLoss,它会在内部应用 Sigmoid 激活函数,并计算二分类的交叉熵损失。
在多分类问题中,我们可以使用 nn.CrossEntropyLoss,它会在内部应用 Softmax 激活函数,并计算多分类的交叉熵损失

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/25535.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

小冬瓜AIGC 手撕LLM 拼课

小冬瓜aigc手撕LLM学习 官方认证 手撕LLMRLHF速成班-(附赠LLM加速分布式训练超长文档) 帮助多名同学上岸LLM方向,包括高校副教授,北美PhD,大厂等 课程名称【手撕LLMRLHF】 授课形式:在线会议直播讲解课后录播 时间&…

oppo手机精简包名列表

oppo广告机,coloros为13.0,测试机为oppo a1x 5g。 手机第一次开机后就全屏广告,被恶心了好几个月。现使用universal Android debolater进行卸载测试,其中: 不可卸载的: 开机广告:com.coloros.…

RK3568笔记三十一:ekho 6.3 文本转语音移植

若该文为原创文章,转载请注明原文出处。 移植的目的是在在OCR识别基础上增加语音播放,把识别到的文字直接转TTS播报出来,形成类似点读机的功能。 1、下载文件 libsndfile-1.0.28.tar.gz ekho-6.3.tar.xz 2、解压 tar zxvf libsndfile-1.0…

7-6 sdut-C语言实验-爬楼梯

7-6 sdut-C语言实验-爬楼梯 分数 20 全屏浏览 切换布局 作者 马新娟 单位 山东理工大学 小明是个非常无聊的人,他每天都会思考一些奇怪的问题,比如爬楼梯的时候,他就会想,如果每次可以上一级台阶或者两级台阶,那么…

LangChain基础知识入门

LangChain的介绍和入门 1 什么是LangChain LangChain由 Harrison Chase 创建于2022年10月,它是围绕LLMs(大语言模型)建立的一个框架,LLMs使用机器学习算法和海量数据来分析和理解自然语言,GPT3.5、GPT4是LLMs最先进的代…

【源码】Spring Data JPA原理解析之事务注册原理

Spring Data JPA系列 1、SpringBoot集成JPA及基本使用 2、Spring Data JPA Criteria查询、部分字段查询 3、Spring Data JPA数据批量插入、批量更新真的用对了吗 4、Spring Data JPA的一对一、LazyInitializationException异常、一对多、多对多操作 5、Spring Data JPA自定…

Docker 基础使用 (4) 网络管理

文章目录 Docker 网络管理需求Docker 网络架构认识Docker 常见网络类型1. bridge 网络2. host 网络3. container 网络4. none 网络5. overlay 网络 Docker 网路基础指令Docker 网络管理实操 其他相关链接 Docker 基础使用(0)基础认识 Docker 基础使用(1)…

文件操作(Python和C++版)

一、C版 程序运行时产生的数据都属于临时数据&#xff0c;程序—旦运行结束都会被释放通过文件可以将数据持久化 C中对文件操作需要包含头文件< fstream > 文件类型分为两种: 1. 文本文件 - 文件以文本的ASCII码形式存储在计算机中 2. 二进制文件- 文件以文本的二进…

Spring运维之boo项目表现层测试匹配响应执行状态响应体JSON和响应头

匹配响应执行状态 我们创建了测试环境 而且发送了虚拟的请求 我们接下来要进行验证 验证请求和预期值是否匹配 MVC结果匹配器 匹配上了 匹配失败 package com.example.demo;import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Auto…

Transformer动画讲解:Softmax函数

暑期实习基本结束了&#xff0c;校招即将开启。 不同以往的是&#xff0c;当前职场环境已不再是那个双向奔赴时代了。求职者在变多&#xff0c;HC 在变少&#xff0c;岗位要求还更高了。提前准备才是完全之策。 最近&#xff0c;我们又陆续整理了很多大厂的面试题&#xff0c…

读AI未来进行式笔记07量子计算

1. AI审讯技术 1.1. 发明者最初的目的是发明一种能够替代精神药物&#xff0c;为人类带来终极快乐的技术 1.1.1. 遗憾的是&#xff0c;他找到的只是通往反方向的大门 1.2. 通过非侵入式的神经电磁干扰大脑边缘系统&#xff0c;诱发受审者最…

VRRP基础配置(华为)

#交换设备 VRRP基础配置 VRRP (Virtual Router Redundancy Protocol) 全称是虚拟路由规元余协议&#xff0c;它是一种容错协议。该协议通过把几台路由设备联合组成一台虚拟的路由设备&#xff0c;该虚拟路由器在本地局域网拥有唯一的一个虚拟 ID 和虚拟 IP 地址。实际上&…

UV胶的均匀性对产品质量有什么影响吗?

UV胶的均匀性对产品质量有什么影响吗? UV胶的均匀性对产品质量具有显著的影响&#xff0c;主要体现在以下几个方面&#xff1a; 粘合强度&#xff1a;UV胶的均匀性直接影响其粘合强度。如果UV胶分布不均匀&#xff0c;可能导致部分区域粘接力不足&#xff0c;从而影响产品的…

报错:CMake Error OpenCVConfig.cmake opencv-config.cmake

1、编译过程中&#xff0c;出现OpenCV 报错问题 报错&#xff1a;CMake Error OpenCVConfig.cmake opencv-config.cmake 解决思路&#xff1a;参考此链接

Python 标准库中常用的模块

Python 标准库中包含了很多常用的模块&#xff0c;以下是一些常用的模块&#xff1a; math&#xff1a;提供了数学运算函数&#xff0c;如三角函数、对数函数、指数函数等。random&#xff1a;提供了生成随机数的函数。datetime&#xff1a;提供了处理日期和时间的函数&#x…

LangChain + ChatGLM 实现本地知识库问答

基于LangChain ChatGLM 搭建融合本地知识的问答机器人 1 背景介绍 近半年以来&#xff0c;随着ChatGPT的火爆&#xff0c;使得LLM成为研究和应用的热点&#xff0c;但是市面上大部分LLM都存在一个共同的问题&#xff1a;模型都是基于过去的经验数据进行训练完成&#xff0c;无…

函数知识点

基本概念 函数&#xff08;方法&#xff09; 本质是一块具有名称的代码块。 可以使用函数&#xff08;方法&#xff09;的名称俩执行该代码块。 函数&#xff08;方法&#xff09;是封装代码进行重复谁用的一种机制。 函数&#xff08;方法&#xff09;的主要作用&#xf…

Python进阶-部署Flask项目(以TensorFlow图像识别项目WSGI方式启动为例)

本文详细介绍了如何通过WSGI方式部署一个基于TensorFlow图像识别的Flask项目。首先简要介绍了Flask框架的基本概念及其特点&#xff0c;其次详细阐述了Flask项目的部署流程&#xff0c;涵盖了服务器环境配置、Flask应用的创建与测试、WSGI服务器的安装与配置等内容。本文旨在帮…

linux-du指令

目录 du 命令 du 是 Linux 系统中用于估算文件和目录磁盘使用空间的命令。以下是 du 命令的完整使用说明文档&#xff1a; du 命令 名称&#xff1a;du 简介&#xff1a;du&#xff08;disk usage&#xff09;命令用于估算文件和目录的磁盘使用空间。 语法&#xff1a; du…

JAVA-LeetCode 热题 100 第56.合并区间

思路&#xff1a; class Solution {public int[][] merge(int[][] intervals) {if(intervals.length < 1) return intervals;List<int[]> res new ArrayList<>();Arrays.sort(intervals, (o1,o2) -> o1[0] - o2[0]);for(int[] interval : intervals){if(res…