【交叉熵损失torch.nn.CrossEntropyLoss详解-附代码实现】

CrossEntropyLoss

  • 什么是交叉熵
  • softmax
  • 损失计算
  • 验证
    • CrossEntropyLoss 输入输出介绍
    • 验证代码

什么是交叉熵

交叉熵有很多文章介绍,此处不赘述。只需要知道它是可以衡量真实值和预测值之间的差距的,因而用交叉熵来计算损失的时候,损失是越小越好,它用数学公式表示是:

-P(x) log Q(x)

其中P(x)是真实值,Q(x)是预测值
当p(x)和Q(x)是矩阵的时候,就分别对其计算,然后求和即可

在pytorch中的交叉熵损失CrossEntropyLoss 包含了 两部分,softmax和交叉熵计算,下面分别介绍这两部分

softmax

一句话理解,是将预测值转成概率。通常经过神经网络计算出来的预测数据不是一个,举个例子:
比如一个二分类问题,一个输入计算出来的结果总是两个值(a, b)其中 a 表示1分类的得分,b 表示2分类的得分,多分类同样
比如一个翻译模型,每个时间步的输出是词表大小(a, b,…) 其中每个值表示词表中每个词的得分

而我们需要的是概率,不是分数,因此需要一个转换,要保证所有分类的概率和为1 softmax的做法:
在这里插入图片描述

即:exp(某分数)/所有分类的exp后的分数

损失计算

计算完softmax,就可以用文中刚开始的 -P(x) log Q(x) 计算损失了,通常情况下,我们的真实值 p(x),也就是target 通常是one-hot编码的,举个例子:
比如二分类类的时候,target通常是(0,1)(1,0)
比如翻译模型,target通常是(0,…1…0)等
我们计算的时候不难发现target中为0经过乘法都是0了,因此最后只剩下正确类型的这个损失差距 最后公式可以演变成 - log Q(x)

一句话来说,交叉熵的损失值只关注了正确分类的差距

验证

自己实现了一下softmax和cross_loss,验证下上述理论的正确性,那就要介绍下torch.nn.CrossEntropyLoss

CrossEntropyLoss 输入输出介绍

可以翻看官网介绍

CLASStorch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction=‘mean’, label_smoothing=0.0)

reduction是指损失计算方式,默认取平均mean,同时支持none,sum ,分别表示每一个损失不做其他操作、所有损失求求和

计算是target 的shape支持直接输入具体值,或者是索引形式,举个例子:
预测值: [0.8, 0.5, 0.2, 0.5]
target可以是 [1, 0, 0, 0] 或者索引形式 0

多样本也同样:
预测值:
[[0.8, 0.5, 0.2, 0.5],
[0.2, 0.9, 0.3, 0.2],
[0.4, 0.3, 0.7, 0.1],
[0.1, 0.2, 0.4, 0.8]]
target 可以是:

  • 列表形式 torch.tensor([[1, 0, 0, 0],
    [0, 1, 0, 0],
    [0, 0, 1, 0],
    [0, 0, 0, 1]], dtype=torch.float)
  • 索引形式: torch.tensor([0,1, 1, 3], dtype=torch.long)

验证代码

def soft_max(x):x_exp = torch.exp(x)partition = x_exp.sum(1, keepdim=True)# 广播partitionreturn x_exp / partition
def cross_entropy(y, y_hat):x = y_hat[range(len(y_hat)), y]print("取出对应元素:", x, '真实label:', y)return -torch.log(x)y = torch.tensor([0, 2])y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])y_hat_softmax = soft_max(y_hat)print(y_hat_softmax)out = cross_entropy(y, y_hat_softmax)print('手动计算的损失', out)cr_loss = torch.nn.CrossEntropyLoss(reduction="none")out = cr_loss(y_hat, y)print('公式计算的损失', out)

输出如下:

手动计算的损失 tensor([1.3533, 0.9398])
公式计算的损失 tensor([1.3533, 0.9398])

结果一致,可以验证无问题

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/74484.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入了解HTTP代理的工作原理

HTTP代理是一种常见的网络代理方式,它可以帮助用户隐藏自己的IP地址,保护个人隐私和安全。了解HTTP代理的工作原理对于使用HTTP代理的用户来说非常重要。本文将深入介绍HTTP代理的工作原理。 代理服务器的作用 HTTP代理的工作原理基于代理服务器的作用。…

Android常用的工具“小插件”——Widget机制

Widget俗称“小插件”,是Android系统中一个很常用的工具。比如我们可以在Launcher中添加一个音乐播放器的Widget。 在Launcher上可以添加插件,那么是不是说只有Launcher才具备这个功能呢? Android系统并没有具体规定谁才能充当“Widget容器…

2023年“羊城杯”网络安全大赛 Web方向题解wp 全

团队名称:ZhangSan 序号:11 不得不说今年本科组打的是真激烈,初出茅庐的小后生没见过这场面QAQ~ D0n’t pl4y g4m3!!! 简单记录一下,实际做题踩坑很多,尝试很多。 先扫了个目录,扫出start.sh 内容如下…

Linux CentOS7 系统中添加用户

在linux centOS7系统中,添加用户是管理员的基本操作。作为学习linux系统的基本操作,对添加用户应该多方面了解。 添加用户的命令useradd,跟上用户名,就可以快速创建一个用户。添加一些选项,可以设置更人性化的用户信息…

【论文阅读】Pay Attention to MLPs

作者:Google Research, Brain Team 泛读:只关注其中cv的论述 提出了一个简单的网络架构,gMLP,基于门控的MLPs,并表明它可以像Transformers一样在关键语言和视觉应用中发挥作用 提出了一个基于MLP的没有self-attentio…

docker 笔记11: Docker容器监控之CAdvisor+InfluxDB+Granfana

1.原生命令 docker stats命令的结果 是什么 2.是什么 容器监控3剑客 CAdvisor监控收集InfluxDB存储数据Granfana展示图表 3.CAdvisor 4.InfluxDB 5.Granfana 6.总结 7.compose容器编排,一套带走 新建目录 7.1新建3件套组合的 docker-compose.yml version: 3.1vo…

网络原理

网络原理 传输层 UDP 特点 特点:无连接,不可靠,面向数据报,全双工 格式 怎么进行校验呢? 把UDP数据报中的源端口,目的端口,UDP报文长度的每个字节,都依次进行累加 把累加结果&a…

人脸识别技术,如何解决学校门禁安全?

在当今社会,学校安全已经成为一个备受关注的议题,而门禁监控系统已经成为学校管理和保障学生安全的重要工具之一。随着社会的不断发展和技术的不断进步,学校不再只是知识传授的场所,它们也成为了数百、数千甚至数万学生和教职员工…

Elasticsearch——Docker单机部署安装

文章目录 1 简介2 Docker安装与配置2.1 安装Docker2.2 配置Docker镜像加速器2.3 调整Docker资源限制 3 准备Elasticsearch Docker镜像3.1 下载Elasticsearch镜像3.2 自定义镜像配置3.3执行Docker Compose 4 运行Elasticsearch容器4.1 创建Elasticsearch容器4.2 修改配置文件4.3…

入门人工智能 —— 使用 Python 进行文件读写,并完成日志记录功能(4)

入门人工智能 —— 使用 Python 进行文件读写(4) 入门人工智能 —— 使用 Python 进行文件读写打开文件读取文件内容读取整个文件逐行读取文件内容读取所有行并存储为列表 写入文件内容关闭文件 日志记录功能核心代码:完整代码:运…

UE5、CesiumForUnreal实现瓦片坐标信息图层效果

文章目录 1.实现目标2.实现过程2.1 原理简介2.2 cesium-native改造2.3 CesiumForUnreal改造2.4 运行测试3.参考资料1.实现目标 参考CesiumJs的TileCoordinatesImageryProvider,在CesiumForUnreal中也实现瓦片坐标信息图层的效果,便于后面在调试地形和影像瓦片的加载调度等过…

超详细最新PyCharm+Python环境安装,多图,逐步骤

PyCharmPython环境安装 前言一、pycharm下载安装1. 安装地址2. 安装详细步骤 二、Python下载安装1. 安装地址2. 安装详细步骤3. 环境变量忘记添加4. python安装成功测试 三. PyCharm上配置Python总结推荐文章 前言 文章会详细介绍PyCharmPython详细安装步骤,接下来…

node.js笔记

首先:浏览器能执行 JS 代码,依靠的是内核中的 V8 引擎(C 程序) 其次:Node.js 是基于 Chrome V8 引擎进行封装(运行环境) 区别:都支持 ECMAScript 标准语法,Node.js 有独立…

网络安全-IP地址信息收集

本文为作者学习文章,按作者习惯写成,如有错误或需要追加内容请留言(不喜勿喷) 本文为追加文章,后期慢慢追加 IP反查域名 http://stool.chinaz.com/same https://tools.ipip.net/ipdomain.php 如果渗透目标为虚拟主机…

FPGA基本算术运算

FPGA基本算术运算 FPGA基本算术运算1 有符号数与无符号数2 浮点数及定点数I、定点数的加减法II、定点数的乘除法 3 仿真验证i、加减法验证ii、乘除法验证 FPGA基本算术运算 FPGA相对于MCU有并行计算、算法效率较高等优势,但同样由于没有成型的FPU等MCU内含的浮点数运…

合宙Air724UG LuatOS-Air LVGL API控件-图片(Gif)

图片(Gif) GIF图片显示,core版本号要>3211 示例代码 方法一 -- 创建GIF图片控件 glvgl.gif_create(lvgl.scr_act()) -- 设置显示的GIF图像 lvgl.gif_set_src(g,"/lua/test.gif") -- gif图片居中 lvgl.obj_align(g, nil, lvgl…

GaussDB技术解读系列:高级压缩之OLTP表压缩

8月16日,第14届中国数据库技术大会(DTCC2023)在北京国际会议中心顺利举行。在GaussDB“五高两易”核心技术,给世界一个更优选择的专场,华为云数据库GaussDB首席架构师冯柯对华为云GaussDB数据库的高级压缩技术进行了详…

Vue3-devtools开发者工具安装方法

因为最近在学习Vue3,但是之前找到的Vue3-Devtools失效了,那就来下载安装下 下载安装 Github下载地址:Vue3-Devtools 这个链接快点:Vue3-Devtools 点击链接后页面如下 点击main选项,下拉列表往下拉,找到你想要的版…

JAVA设计模式第七讲:设计模式在 Spring 源码中的应用

设计模式(design pattern)是对软件设计中普遍存在的各种问题,所提出的解决方案。本文以面试题作为切入点,介绍了设计模式的常见问题。我们需要掌握各种设计模式的原理、实现、设计意图和应用场景,搞清楚能解决什么问题…

渗透测试漏洞原理之---【业务安全】

文章目录 1、业务安全概述1.1业务安全现状1.1.1、业务逻辑漏洞1.1.2、黑客攻击目标 2、业务安全测试2.1、业务安全测试流程2.1.1、测试准备2.1.2、业务调研2.1.3、业务建模2.1.4、业务流程梳理2.1.5、业务风险点识别2.1.6 开展测试2.1.7 撰写报告 3、业务安全经典场景3.1、业务…