深度学习 - softmax交叉熵损失计算

示例代码

import torch
from torch import nn# 多分类交叉熵损失,使用nn.CrossEntropyLoss()实现。nn.CrossEntropyLoss()=softmax + 损失计算
def test1():# 设置真实值: 可以是热编码后的结果也可以不进行热编码# y_true = torch.tensor([[0, 1, 0], [0, 0, 1]], dtype=torch.float32)# 注意的类型必须是64位整型数据y_true = torch.tensor([1, 2], dtype=torch.int64)y_pred = torch.tensor([[0.2, 0.6, 0.2], [0.1, 0.8, 0.1]], dtype=torch.float32)# 实例化交叉熵损失loss = nn.CrossEntropyLoss()# 计算损失结果my_loss = loss(y_pred, y_true).numpy()print('loss:', my_loss)

输入数据

y_true = torch.tensor([1, 2], dtype=torch.int64)
y_pred = torch.tensor([[0.2, 0.6, 0.2], [0.1, 0.8, 0.1]], dtype=torch.float32)
  • y_true:真实标签,包含两个样本,分别属于类别 1 和类别 2。
  • y_pred:预测的概率分布,包含两个样本,每个样本有三个类别的预测值。

Step 1: Softmax 变换

Softmax 函数将原始的预测值转换为概率分布。Softmax 的公式如下:

Softmax ( x i ) = e x i ∑ j e x j \text{Softmax}(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}} Softmax(xi)=jexjexi

对于第一个样本 y_pred = [0.2, 0.6, 0.2]

  1. 计算指数:

e 0.2 ≈ 1.221 , e 0.6 ≈ 1.822 , e 0.2 ≈ 1.221 e^{0.2} \approx 1.221, \quad e^{0.6} \approx 1.822, \quad e^{0.2} \approx 1.221 e0.21.221,e0.61.822,e0.21.221

  1. 计算 Softmax 分母:

∑ j e x j = 1.221 + 1.822 + 1.221 = 4.264 \sum_{j} e^{x_j} = 1.221 + 1.822 + 1.221 = 4.264 jexj=1.221+1.822+1.221=4.264

  1. 计算 Softmax 分子并得到结果:

Softmax ( 0.2 ) = 1.221 4.264 ≈ 0.286 \text{Softmax}(0.2) = \frac{1.221}{4.264} \approx 0.286 Softmax(0.2)=4.2641.2210.286

Softmax ( 0.6 ) = 1.822 4.264 ≈ 0.427 \text{Softmax}(0.6) = \frac{1.822}{4.264} \approx 0.427 Softmax(0.6)=4.2641.8220.427

Softmax ( 0.2 ) = 1.221 4.264 ≈ 0.286 \text{Softmax}(0.2) = \frac{1.221}{4.264} \approx 0.286 Softmax(0.2)=4.2641.2210.286

Softmax 结果为 [[0.286, 0.427, 0.286]]

对于第二个样本 y_pred = [0.1, 0.8, 0.1]

  1. 计算指数:

e 0.1 ≈ 1.105 , e 0.8 ≈ 2.225 , e 0.1 ≈ 1.105 e^{0.1} \approx 1.105, \quad e^{0.8} \approx 2.225, \quad e^{0.1} \approx 1.105 e0.11.105,e0.82.225,e0.11.105

  1. 计算 Softmax 分母:

∑ j e x j = 1.105 + 2.225 + 1.105 = 4.435 \sum_{j} e^{x_j} = 1.105 + 2.225 + 1.105 = 4.435 jexj=1.105+2.225+1.105=4.435

  1. 计算 Softmax 分子并得到结果:

Softmax ( 0.1 ) = 1.105 4.435 ≈ 0.249 \text{Softmax}(0.1) = \frac{1.105}{4.435} \approx 0.249 Softmax(0.1)=4.4351.1050.249

Softmax ( 0.8 ) = 2.225 4.435 ≈ 0.502 \text{Softmax}(0.8) = \frac{2.225}{4.435} \approx 0.502 Softmax(0.8)=4.4352.2250.502

Softmax ( 0.1 ) = 1.105 4.435 ≈ 0.249 \text{Softmax}(0.1) = \frac{1.105}{4.435} \approx 0.249 Softmax(0.1)=4.4351.1050.249

Softmax 结果为 [[0.249, 0.502, 0.249]]

Step 2: 计算交叉熵损失

交叉熵损失的公式为:

CrossEntropyLoss ( p , y ) = − ∑ i = 1 N y i log ⁡ ( p i ) \text{CrossEntropyLoss}(p, y) = -\sum_{i=1}^{N} y_i \log(p_i) CrossEntropyLoss(p,y)=i=1Nyilog(pi)

对于第一个样本,真实标签为 1(y_true = 1),Softmax 后的预测概率分布为 [0.286, 0.427, 0.286]

CrossEntropyLoss = − [ 0 ⋅ log ⁡ ( 0.286 ) + 1 ⋅ log ⁡ ( 0.427 ) + 0 ⋅ log ⁡ ( 0.286 ) ] \text{CrossEntropyLoss} = - [0 \cdot \log(0.286) + 1 \cdot \log(0.427) + 0 \cdot \log(0.286)] CrossEntropyLoss=[0log(0.286)+1log(0.427)+0log(0.286)]

由于 (0 \cdot \log(0.286) = 0),忽略后我们得到:

log ⁡ ( 0.427 ) ≈ 0.851 \log(0.427) \approx 0.851 log(0.427)0.851

对于第二个样本,真实标签为 2(y_true = 2),Softmax 后的预测概率分布为 [0.249, 0.502, 0.249]

CrossEntropyLoss = − [ 0 ⋅ log ⁡ ( 0.249 ) + 0 ⋅ log ⁡ ( 0.502 ) + 1 ⋅ log ⁡ ( 0.249 ) ] \text{CrossEntropyLoss} = - [0 \cdot \log(0.249) + 0 \cdot \log(0.502) + 1 \cdot \log(0.249)] CrossEntropyLoss=[0log(0.249)+0log(0.502)+1log(0.249)]

由于 (0 \cdot \log(0.249) = 0) 和 (0 \cdot \log(0.502) = 0),忽略后我们得到:

log ⁡ ( 0.249 ) ≈ 1.390 \log(0.249) \approx 1.390 log(0.249)1.390

Step 3: 平均损失

计算平均损失:

平均损失 = 0.851 + 1.390 2 ≈ 2.241 2 ≈ 1.1205 \text{平均损失} = \frac{0.851 + 1.390}{2} \approx \frac{2.241}{2} \approx 1.1205 平均损失=20.851+1.39022.2411.1205

因此,最终的交叉熵损失 my_loss 约为 1.1205。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/24988.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Mysql使用中的性能优化——批量插入的规模对比

在《Mysql使用中的性能优化——单次插入和批量插入的性能差异》中,我们观察到单次批量插入的数量和耗时呈指数型关系。 这个说明,不是单次批量插入的数量越多越好。本文我们将通过实验测试出本测试案例中最佳的单次批量插入数量。 结论 本案例中约每次…

【云岚到家】-day00-开发环境配置

文章目录 1 开发工具版本2 IDEA环境配置2.1 编码配置2.2 自动导包设置2.3 提示忽略大小写2.4 设置 Java 编译级别 3 Maven环境3.1 安装Maven3.2 配置仓库3.3 IDEA中配置maven 4 配置虚拟机4.1 导入虚拟机4.2 问题 5 配置数据库环境5.1 启动mysql容器5.2 使用MySQL客户端连接数据…

【YOLO系列】YOLOv1学习(PyTorch)原理加代码

论文网址:https://arxiv.org/pdf/1506.02640 训练集博客链接:目标检测实战篇1——数据集介绍(PASCAL VOC,MS COCO)-CSDN博客 代码文件:在我资源里,但是好像还在审核,大家可以先可以,如果没有的…

从 Android 恢复已删除的备份录

本文介绍了几种在 Android 上恢复丢失和删除的短信的方法。这些方法都不能保证一定成功,但您可能能够恢复一些短信或其中存储的文件。 首先要尝试什么 首先,尝试保留数据。如果你刚刚删除了信息,请立即将手机置于飞行模式,方法是…

【Linux】信号(二)

上一章节我们进行了信号产生的讲解。 本节将围绕信号保存展开,并会将处理部分开一个头。 目录 信号保存:信号的一些概念:关于信号保存的接口:sigset_t的解释:对应的操作接口:sigprocmask:sigp…

SwiftUI中Preference的理解与使用(ScrollView偏移量示例)

在 SwiftUI 中,Preference用于从视图层次结构的较深层次向上传递信息到较浅层次。这通常用于在父视图中获取子视图的属性或状态,而不需要使用状态管理工具如State或 ObservableObject。Preference特别用于自定义布局或组件,其中子视图需要向父…

Git 和 Github 的使用

补充内容:EasyHPC - Git入门教程【笔记】 文章目录 常用命令配置信息分支管理管理仓库 概念理解SSH 密钥HTTPS 和 SSH 的区别在本地生成 SSH key在 Github 上添加 SSH key 使用的例子同步本地仓库的修改到远程仓库拉取远程仓库的修改到本地仓库拉取远程仓库的分支并…

千益畅行:合法合规的旅游卡服务,真实可靠的旅游体验

近期,关于千益畅行旅游卡服务的讨论引起了广泛关注。然而,网络上出现了一些对其误解和质疑的声音。为了澄清事实,我们深入了解了千益畅行的运营模式和业务特点,发现它是一家合法合规的旅游卡服务提供商,为消费者提供真…

Eslint配置指南

1. Eslint配置指南 1.1. 安装 ESLint1.2. 生成配置文件1.3. 修改配置文件1.4. 创建 .eslintignore 文件1.5. 运行 ESLint1.6. 整合到编辑器/IDE1.7. 自动修复 2. 配置prettier 2.1. 安装依赖包2.2. .prettierrc.json添加规则2.3. .prettierignore忽略文件2.4. 保存自动格式化 3…

深度学习中2D分割

深度学习中的2D图像分割 2D图像分割是深度学习中的一个重要任务,旨在将图像划分为不同的区域,每个区域对应于图像中的不同对象或背景。该任务广泛应用于医学影像分析、自动驾驶、卫星图像分析等领域。以下是对深度学习中2D图像分割的详细介绍&#xff0…

实战 | 通过微调SegFormer改进车道检测效果(数据集 + 源码)

背景介绍 SegFormer:实例分割在自动驾驶汽车技术的快速发展中发挥了关键作用。对于任何在道路上行驶的车辆来说,车道检测都是必不可少的。车道是道路上的标记,有助于区分道路上可行驶区域和不可行驶区域。车道检测算法有很多种,每…

vue2实现将el-table表格数据导出为长图片

方法一、 el-table数据导出为长图片 将el-table数据导出为图片不是一个直接的功能,但可以通过以下步骤实现: 使用html2canvas库将表格区域转换为画布(canvas)。 使用canvas的toDataURL方法将画布导出为图片格式(例如PNG)。 创建…

数据结构--实验

话不多说,直接启动!👌🤣 目录 一、线性表😎 1、建立链表 2、插入元素 3、删除特定位置的元素 4、输出特定元素值的位置 5、输出特定位置的元素值 6、输出整个链表 实现 二、栈和队列😘 栈 顺序栈 …

Mybatis缓存的生命周期、使用的特殊情况

以下场景均在Spring Boot程序中,并非手动创建SqlSession使用。 在回答这个问题之前,我们先来回顾一下,Mybatis的一级二级缓存是啥。 一级二级缓存 是什么 一级缓存(本地缓存):一级缓存是SqlSession级别的…

将web项目打包成electron桌面端教程(一)vue3+vite+js

说明:后续项目需要web端和桌面端,为了提高开发效率,准备直接将web端的代码打包成桌面端,在此提前记录一下demo打包的过程,需要注意的是vue2或者vue3的打包方式各不同,如果你的项目不是vue3vitejs&#xff0…

数字孪生技术体系和核心能力整理

最近对数字孪生技术进行了跟踪调研学习,整理形成了调研成果,供大家参考。通过学习,发现数字孪生技术的构建过程其实就是数字孪生体的构建与应用过程,数字孪生体的构建是一个体系化的系统工程,数字化转型的最终形态应该就是数实融合互动互联的终极状态。数实融合是每个行业…

MSA算法满足条件后续之Blum定理

文章目录 引言Blum定理关于MSA算法的讨论 此文章属于文献研读内容,文章内容来源于以下文献 Warren B. Powell, Yosef Sheffi , (1982) The Convergence of Equilibrium Algorithms with Predetermined Step Sizes. Transportation Science ,16(1):45-55. http://dx.…

赶紧收藏!2024 年最常见 20道分布式、微服务面试题(三)

上一篇地址:赶紧收藏!2024 年最常见 20道分布式、微服务面试题(二)-CSDN博客 五、微服务架构有哪些优点和缺点? 微服务架构是一种设计方法,它将应用程序分解为一组小型、独立、松散耦合的服务。每个服务都…

[每周一更]-(第100期):介绍 goctl自动生成代码

​ 在自己组件库中,由于部分设计会存在重复引用各个模板的文件,并且基础架构中需要基础模块内容,就想到自动生成代码模板,刚好之前有使用过goctl,以下就简单描述下gozero中goctl场景和逻辑,后续自己借鉴将自…

英语学习笔记32——What‘s he/she/it doing?

What’s he/she/it doing? 他/她/它 正在做什么? 词汇 Vocabulary type /taɪp/ v. 打字 n. 类型,签字 ing形式:typeing 用法:this type of …    这种类型的…… 例句:我喜欢这种苹果。    I like this type…