【pytorch学习】交叉熵损失函数 nn.CrossEntropyLoss() = nn.LogSoftmax(dim=1) + nn.NLLLoss()

结论: Pytorch中CrossEntropyLoss()函数的主要是将softmax-log-NLLLoss合并到一块得到的结果。

from: https://zhuanlan.zhihu.com/p/98785902

import torch
import torch.nn as nn
x_input=torch.randn(3,3)#随机生成输入 
print('x_input:\n',x_input) 
y_target=torch.tensor([1,2,0])#设置输出具体值 print('y_target\n',y_target)#计算输入softmax,此时可以看到每一行加到一起结果都是1
softmax_func=nn.Softmax(dim=1)
soft_output=softmax_func(x_input)
print('soft_output:\n',soft_output)#在softmax的基础上取log
log_output=torch.log(soft_output)
print('log_output:\n',log_output)#对比softmax与log的结合与nn.LogSoftmaxloss(负对数似然损失)的输出结果,发现两者是一致的。
logsoftmax_func=nn.LogSoftmax(dim=1)
logsoftmax_output=logsoftmax_func(x_input)
print('logsoftmax_output:\n',logsoftmax_output)#pytorch中关于NLLLoss的默认参数配置为:reducetion=True、size_average=True
nllloss_func=nn.NLLLoss()
nlloss_output=nllloss_func(logsoftmax_output,y_target)
print('nlloss_output:\n',nlloss_output)#直接使用pytorch中的loss_func=nn.CrossEntropyLoss()看与经过NLLLoss的计算是不是一样
crossentropyloss=nn.CrossEntropyLoss()
crossentropyloss_output=crossentropyloss(x_input,y_target)
print('crossentropyloss_output:\n',crossentropyloss_output)

最后计算得到的结果为:

x_input:tensor([[ 2.8883,  0.1760,  1.0774],[ 1.1216, -0.0562,  0.0660],[-1.3939, -0.0967,  0.5853]])
y_targettensor([1, 2, 0])
soft_output:tensor([[0.8131, 0.0540, 0.1329],[0.6039, 0.1860, 0.2102],[0.0841, 0.3076, 0.6083]])
log_output:tensor([[-0.2069, -2.9192, -2.0178],[-0.5044, -1.6822, -1.5599],[-2.4762, -1.1790, -0.4970]])
logsoftmax_output:tensor([[-0.2069, -2.9192, -2.0178],[-0.5044, -1.6822, -1.5599],[-2.4762, -1.1790, -0.4970]])
nlloss_output:tensor(2.3185)
crossentropyloss_output:tensor(2.3185)

通过上面的结果可以看出,直接使用pytorch中的loss_func=nn.CrossEntropyLoss()计算得到的结果与softmax-log-NLLLoss计算得到的结果是一致的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/733282.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

YUNBEE云贝:3月9日-PostgreSQL中级工程师PGCE认证培训

课程介绍 根据学员建议和市场需求,规划和设计了《PostgreSQL CE 认证课程》,本课程以内部原理、实践实战为主,理论与实践相结合。课程包含PG 简介、安装使用、服务管理、体系结构等基础知识。同时结合一线实战案例, 面向 PG 数据库的日常维护管理、服务和…

Vue | 基于 vue-admin-template 项目的跨域问题解决方法

目录 一、现存问题 二、解决方法 2.1 修改的第一个地方 2.2 修改的第二个地方 2.3 修改的第三个地方 自存 一、现存问题 报错截图如下: 二、解决方法 2.1 修改的第一个地方 在 .env.development 文件中: # base api # VUE_APP_BASE_API /d…

springboot整合shiro的实战教程(一)

文章目录 1.权限的管理1.1 什么是权限管理1.2 什么是身份认证1.3 什么是授权 2.什么是shiro3.shiro的核心架构3.1 Subject3.2 SecurityManager3.3 Authenticator3.4 Authorizer3.5 Realm3.6 SessionManager3.7 SessionDAO3.8 CacheManager3.9 Cryptography 4. shiro中的认证4.1…

我的 4096 创作纪念日

作者:明明如月学长, CSDN 博客专家,大厂高级 Java 工程师,《性能优化方法论》作者、《解锁大厂思维:剖析《阿里巴巴Java开发手册》》、《再学经典:《Effective Java》独家解析》专栏作者。 热门文章推荐&am…

header组件编写和Vuex store创建

src\components\Header\index.vue <template><header class"header"><h1><slot></slot></h1></header> </template><script> export default {name: Header } </script> src\main.js 引入全局样式 imp…

YOLOv8+DeepSort/ByteTrack-PyQt-GUI / yolov5 deepsort 行人/车辆(检测 +计数+跟踪+测距+测速)

YoloV8结合可视化界面和GUI&#xff0c;实现了交互式目标检测与跟踪&#xff0c;为用户提供了一体化的视觉分析解决方案。通过YoloV8算法&#xff0c;该系统能够高效准确地检测各类目标&#xff0c;并实时跟踪它们的运动轨迹。 用户可以通过直观的可视化界面进行操作&#xff…

Unity性能优化篇(七) UI优化注意事项以及使用Sprite Atlas打包精灵图集

UI优化注意事项 1.尽量避免使用IMGUI(OnGUI)来做游戏时的UI&#xff0c;因为IMGUI的开销比较大。 2.如果一个UGUI的控件不需要进行射线检测&#xff0c;则可以取消勾选Raycast Target 3.尽量避免使用完全透明的图片和UI控件。因为即使完全透明&#xff0c;我们看不见它&#xf…

常见BUG如何在测试过程中分析定位

前言 在测试的日常工作中&#xff0c;相信经常有测试的小伙伴遇到类似的情况&#xff1a;在项目上线时&#xff0c;只要出现问题&#xff08;bug&#xff09;&#xff0c;就很容易成为“背锅侠”。 软件测试人员在工作中是无法避免的要和开发人员和产品经理打交道的&#xff…

c++的chrono总结用法

C11引入了<chrono>头文件&#xff0c;提供了处理时间的功能。以下是<chrono>头文件中一些常用的类和函数的总结用法&#xff1a; std::chrono::duration&#xff1a; 用于表示时间段的类模板。可以指定不同的时间单位&#xff08;如秒、毫秒、微秒等&#xff09;。…

117.龙芯2k1000-pmon(16)- linux下升级pmon

pmon的升级总是有些不方便&#xff0c;至少是要借助串口和串口工具 如果现场不方便连接串口&#xff0c;是不是可以使用网线升级pmon呢&#xff1f; 答案当然是可行的。 环境&#xff1a;2k1000linux3.10麒麟的文件系统 如今我已经把这个工具开发出来了。 GitHub - zhaozhi…

网络工程师笔记10 ( RIP / OSPF协议 )

RIP 学习路由信息的时候需要配认证 RIP规定超过15跳认定网络不可达 链路状态路由协议-OSPF 1. 产生lsa 2. 生成LSDB数据库 3. 进行spf算法&#xff0c;生成最有最短路径 4. 得出路由表

【探索C++容器:set和map的使用】

[本节目标] 1. 关联式容器 2. 键值对 3. 树形结构的关联式容器 1. 关联式容器 在初阶阶段&#xff0c;我们已经接触过STL中的部分容器&#xff0c;比如&#xff1a;vector、list、deque、forward_list(C11)等&#xff0c;这些容器统称为序列式容器&#xff0c;因为其底层为…

Toyota Programming Contest 2024#3(AtCoder Beginner Contest 344)(A~C)

A - Spoiler 竖线里面的不要输出&#xff0c;竖线只有一对&#xff0c;且出现一次。 #include <bits/stdc.h> //#define int long long #define per(i,j,k) for(int (i)(j);(i)<(k);(i)) #define rep(i,j,k) for(int (i)(j);(i)>(k);--(i)) #define debug(a) cou…

设计模式 工厂模式

工厂模式&#xff0c;最重要的是反射。 反射&#xff1a;Class类 java的注释是这样写的&#xff1a; Class没有公共构造函数。相反&#xff0c;Class对象是在类加载时由Java虚拟机自动构造的&#xff0c;并通过调用类加载器中的defineClass方法来实现。

链表|面试题 02.07.链表相交

ListNode *getIntersectionNode(ListNode *headA, ListNode *headB) {ListNode *l NULL, *s NULL;int lenA 0, lenB 0, gap 0;// 求出两个链表的长度s headA;while (s) {lenA ;s s->next;}s headB;while (s) {lenB ;s s->next;}// 求出两个链表长度差if (lenA &…

LeetCode刷题笔记之动态规划(一)

一、动态规划的基础知识 动态规划&#xff08;Dynamic Programming&#xff0c;简称DP&#xff09;&#xff0c;动态规划问题的一般形式就是求最值&#xff0c;求解动态规划的核心问题是穷举&#xff0c;动态规划中每一个状态一定是由上一个状态推导出来的。解题步骤&#xff…

stm32学习笔记:SPI通信协议原理(未完)

一、SPI简介(serial Peripheral Interface&#xff08;串行 外设 接口&#xff09;) 1、电路模式&#xff08;采用一主多从的模式&#xff09;、同步&#xff0c;全双工 1 所有SPI设备的SCK、MOSI、MISO分别连在一起 2 主机另外引出多条SS控制线&#xff0c;分别接到各从机的S…

DetNet论文速读

paper&#xff1a;DetNet: A Backbone network for Object Detection 存在的问题 最近的目标检测模型通常依赖于在ImageNet分类数据集上预训练的骨干网络。由于ImageNet的分类任务不同于目标检测&#xff0c;后者不仅需要识别对象的类别&#xff0c;而且需要对边界框进行空间…

音视频开发_音频基础知识

如何采集声音——模数转换原理 声音模数转换是将声音信号从模拟形式转换为数字形式的过程。它是数字声音处理的基础&#xff0c;常用于语音识别、音频编码等应用中。 音视频通信流程 音视频采集&#xff1a;首先是从麦克风、摄像头等设备中采集音频和视频数据&#xff0c;将现…

编写Linux的SHELL脚本设置环境变量遇到的那些坑

1.背景 最近进行一个rocketmq的单机版安装&#xff0c;发现安装步骤很繁琐&#xff0c;想着写一个shell脚本&#xff0c;一键执行安装。 本文仅在于说明关于JDK的环境变量设置为示例 2.SHELL脚本设置环境变量 #!/bin/sh #定义变量 profile_file~/.bash_profile #1.安装JDK t…