python自定义交叉熵损失,再和pytorch api对比

背景

我们知道,交叉熵本质上是两个概率分布之间差异的度量,公式如下

 其中概率分布P是基准,我们知道H(P,Q)>=0,那么H(P,Q)越小,说明Q约接近P。

损失函数本质上也是为了度量模型和完美模型的差异,因此可以用交叉熵作为损失函数,公式如下

其中

的部分不过是考虑到每次都是输入一批样本,因此把每个样本的交叉熵求出来以后要再求个平均。

注意,我的代码没有考虑标签是soft embedding的情况,如果遇到标注Y是[[0.1,0.1,0.8],[0.1,0.8,0.1],[0.1,0.1,0.8]],那么你需要把代码再推广一下。

自定义交叉熵损失

from typing import List
import mathdef my_softmax(x:List[List[float]])->List[List[float]]:new_x:List[List[float]] = []for i in range(len(x)):sum:float = 0new_x_i = []for j in range(len(x[0])):sum += math.exp(x[i][j])for j in range(len(x[0])):new_x_i.append(math.exp(x[i][j])/sum)new_x.append(new_x_i)return new_xdef my_cross_entropy(x:List[List[float]],y:List[int])->float:res:float = 0x = my_softmax(x)for i in range(len(x)):res += -math.log(x[i][y[i]]) # 根号外面的1和底数e省去了res /= len(x) # meanreturn res# 假设有一个简单的三分类问题,批量大小为2
# 预测输出(通常是模型的原始输出,没有经过softmax)
logits = [[1.5, 0.5, -0.5], [1.2, 0.2, 3.0]]
# 0 和 2 分别表示第一个和第三个类别是正确的
targets = [0, 2]
print(my_cross_entropy(logits,targets))

Pytorch交叉熵损失

import torch
import torch.nn as nnlogits = torch.tensor([[1.5, 0.5, -0.5],[1.2, 0.2, 3.0]])targets = torch.tensor([0, 2])  criterion = nn.CrossEntropyLoss()loss = criterion(logits, targets)print(loss.item())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/3713.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

网御星云防火墙策略配置

网御星云防火墙配置 1. 初始设定2. 网络配置3. 安全规则和策略4. 监控和维护零基础入门学习路线视频配套资料&国内外网安书籍、文档网络安全面试题 1. 初始设定 接入网络: 在开始配置之前,确保你的网御星云防火墙正确连接到网络。这通常涉及将WAN接…

07 流量回放实现自动化回归测试

在本模块的前四讲里,我向你介绍了可以直接落地的、能够支撑百万并发的读服务的系统架构,包含懒加载缓存、全量缓存,以及数据同步等方案的技术细节。 基于上述方案及细节,你可以直接对你所负责的读服务进行架构升级,将…

【Redis 开发】一人一单,超卖问题(悲观锁,乐观锁,分布式锁)

锁 悲观锁乐观锁第一种:版本号法第二种:CAS法实现乐观锁 悲观锁与乐观锁的比较 一人一单分布式锁Redis实现分布式锁 悲观锁 认为线程问题一定会发生,因此在操作数据库之前先获取锁,确保线程串行执行,例如Synchronized…

51单片机使用两个按钮控制LED灯不同频率的闪烁

#include <reg52.h>sbit button1 P1^1; // 间隔2秒的按钮 sbit button2 P1^5; // 间隔0.6秒的按钮sbit led P1^3;unsigned int cnt1 0; // 设置LED1灯的定时器溢出次数 unsigned int cnt2 0; // 设置LED2灯的定时器溢出次数 unsigned int flg1 0; // 模式1的标识值…

x86 64位的ubuntu环境下汇编(无优化)及函数调用栈的详解

1. 引言 为了深入理解c&#xff0c;决定学习一些简单的汇编语言。使用ubuntu系统下g很容易将一个c的文件编译成汇编语言。本文使用此方法&#xff0c;对一个简单的c文件编译成汇编语言进行理解。 2.示例 文件名&#xff1a;reorder_demo.cpp #include<stdio.h>typede…

强固型车载电脑在智能轨道安全解决方案的应用

智能轨道安全解决方案 信迈提供一系列具有传感、诊断、人工智能和无线功能的车载列车解决方案。它们提供全面的可扩展性和面向未来的车辆、路旁、信号、电力、障碍物检测和数据收集功能。 应用程序: 铁路供电监控车载列车安全保护铁路轨道监控驾驶行为分析 智能车载解决方案…

Django连接数据库

数据库登录命令 mysql -u root -p show databases; Django连接数据库 在settings.py文件中进行配置和修改 DATABASES {default: {ENGINE: django.db.backends.mysql,HOST: 127.0.0.1, # 数据库主机PORT: 3306, # 数据库端口USER: root, # 数据库用户名PASSWORD: 12345…

flutter release 报错 Error: SocketException: Failed host lookup:

flutter 的 debug 模式没有任何问题 &#xff0c;打了release 包后一直报下面的错&#xff0c;查了一下是 因为没有网络权限 Error: SocketException: Failed host lookup: yomi-test-aws-sg.yomigame.games (OS Error: No address associated with hostname, errno 7) 按照下…

win10加入域环境

win10加入域环境 导航 文章目录 win10加入域环境导航一、关闭防火墙二、使客户端的电脑指向于域控服务器三、检验是否加入了域 一、关闭防火墙 在进行加入域服务之前,我们需要先关闭防火墙(为了不必要的麻烦) 按 winr调出运行窗口,输入 control打开控制面板 点击系统和安全点…

python基础之元组、集合和函数的定义与返回值

1.元祖 1.元祖的定义 元组的数据结构跟列表相似 特征&#xff1a;有序、 有序&#xff1a;有&#xff08;索引/下标/index&#xff09; 正序、反序标识符&#xff1a; ( ) 里面的元素是用英文格式的逗号分割开来关键字&#xff1a;tuple 列表和元组有什么区别&#xff1f; 元组…

异常风云:解码 Java 异常机制

哈喽&#xff0c;各位小伙伴们&#xff0c;你们好呀&#xff0c;我是喵手。 今天我要给大家分享一些自己日常学习到的一些知识点&#xff0c;并以文字的形式跟大家一起交流&#xff0c;互相学习&#xff0c;一个人虽可以走的更快&#xff0c;但一群人可以走的更远。 我是一名后…

C语言数据类型的介绍,类型的基本归类,整型在内存中的存储,原码、反码、补码,大小端等介绍

文章目录 前言一、数据类型的介绍类型的意义 1. 类型的基本归类&#xff08;1&#xff09;. 整型家族&#xff08;2&#xff09;. 浮点数家族&#xff08;3&#xff09;. 构造类型&#xff08;4&#xff09;. 指针类型&#xff08;5&#xff09;. 空类型 二、整型在内存中的存储…

[Collection与数据结构] PriorityQueue与堆

1. 优先级队列 1.1 概念 前面介绍过队列&#xff0c;队列是一种先进先出(FIFO)的数据结构&#xff0c;但有些情况下&#xff0c;操作的数据可能带有优先级&#xff0c;一般出队列时&#xff0c;可能需要优先级高的元素先出队列&#xff0c;该中场景下&#xff0c;使用队列显然…

自动化机器学习流水线:基于Spring Boot与AI机器学习技术的融合探索

&#x1f9d1; 作者简介&#xff1a;阿里巴巴嵌入式技术专家&#xff0c;深耕嵌入式人工智能领域&#xff0c;具备多年的嵌入式硬件产品研发管理经验。 &#x1f4d2; 博客介绍&#xff1a;分享嵌入式开发领域的相关知识、经验、思考和感悟&#xff0c;欢迎关注。提供嵌入式方向…

Python常用包介绍

数据处理 1.numpy&#xff08;数据处理和科学计算&#xff09; import numpy as np np.set_printoptions(precision2, suppressTrue) # 设置打印选项&#xff0c;保留两位小数&#xff0c;禁止科学计数法arr np.arange(1, 6) # 使用arange函数创建数组 print(arr)# 输出&…

深度学习下的视觉SLAM综述

作者&#xff1a;黄泽霞&#xff0c;邵春莉 来源&#xff1a;《机器人》 编辑&#xff1a;东岸因为一点人工一点智能 深度学习下的视觉SLAM综述到目前为止&#xff0c;深度学习与SLAM的结合已经在视觉里程计、场景识别与全局优化等各种任务中取得了显著的成果。同时&#xf…

Prompt Engineering,提示工程

什么是提示工程&#xff1f; 提示工程也叫【指令工程】。 Prompt发送给大模型的指令。比如[讲个笑话]、[用Python编个贪吃蛇游戏]、[给男/女朋友写情书]等看起来简单&#xff0c;但上手简单精通难 [Propmpt]是AGI时代的[编程语言][Propmpt]是AGI时代的[软件工程][提示工程]是…

线上申报开放时间!2024年阜阳市大数据企业培育认定申报条件、流程和材料

2024年阜阳市大数据企业培育认定申报条件、流程和材料&#xff0c;线上申报开放时间整理如下 一、2024年阜阳市大数据企业培育认定申报要求 &#xff08;一&#xff09;经营范围 申请认定的企业应当从事以下生产经营活动&#xff1a; 1.从事数据收集、存储、使用、加工、传输、…

使用linux,c++,创作一个简单的五子棋游戏

#include <iostream> #include <vector> #include <unordered_map> using namespace std; // 棋盘大小 const int BOARD_SIZE 15; // 棋子类型 enum ChessType { EMPTY, BLACK, WHITE }; // 棋盘类 class ChessBoard { private: vect…

CHiME-8多通道远场语音识别Baseline介绍

语音领域每年都有很多比赛&#xff0c;每个比赛都有自己的侧重点&#xff0c;其中CHiME系列比赛的侧重点就是多通道远场语音识别&#xff0c;与其他的语音识别比赛有所区别的是&#xff0c;CHiME提供分布式麦克风和麦克风阵列数据&#xff0c;这样可以选择合适的前端算法以降低…