损失函数详细复现(pytorch版本)

什么是损失函数

损失函数(Loss Function)是在机器学习和深度学习中用于评估模型预测结果与实际标签之间差异的函数。它衡量了模型的性能,即模型对训练样本的预测与实际标签的偏差程度。目标是通过调整模型参数,使损失函数的值最小化,从而提高模型的准确性和泛化能力。

常见的损失函数

这里的复现主要是与官方的实现进行对比实验。

L1Loss

它叫做平均绝对误差,定义如下所示:

L_{1} = \frac{1}{N}\sum_{N}^{i=1}\left | y_{i}-\hat{y_{i}} \right |

其中,y_{i}表示样本i的真实标签,\hat{y_{i}}表示模型对于样本i的预测标签。将每个样本的绝对误差取平均值,得到L1 Loss。

class L1Loss(nn.Module):def __init__(self):super(L1Loss, self).__init__()def forward(self, input, target):loss = torch.mean(torch.abs(input - target))return loss

测试代码为以下所示:

if __name__=="__main__":criterion1 = nn.L1Loss()criterion2 = L1Loss()input_data=torch.Tensor([2, 3, 4, 5])target_data=torch.Tensor([4, 5, 6, 7])loss1 = criterion1(input_data, target_data)print(loss1)loss2 = criterion2(input_data, target_data)print(loss2)

测试输出均为 tensor(2.)

L2Loss

它叫做均方误差,定义如下所示:

L_{2} = \frac{1}{N}\sum_{N}^{i=1}(y_{i}--\hat{y_{i}})^{2}

其中,y_{i}表示样本i的真实标签,\hat{y_{i}}表示模型对于样本i的预测标签。测量预测输出中的每个元素与目标或地面实况中的相应元素之间的平均平方差。

class L2Loss(nn.Module):def __init__(self):super(L2Loss, self).__init__()def forward(self, input, target):loss = torch.mean(torch.pow(input - target, 2))return loss

测试代码为以下所示: 

if __name__=="__main__":criterion1 = nn.MSELoss()criterion2 = L2Loss()input_data=torch.Tensor([2, 3, 4, 5])target_data=torch.Tensor([4, 5, 6, 7])loss1 = criterion1(input_data, target_data)print(loss1)loss2 = criterion2(input_data, target_data)print(loss2)

测试输出均为 tensor(4.)

BCELoss

二元交叉熵损失(Binary Cross Entropy Loss),也称为对数损失。

\text{BCELoss} = -\frac{1}{N} \sum_{i=1}^{N} \left( y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) \right)

其中,y_{i}表示样本i的真实标签,\hat{y_{i}}表示模型对于样本i的预测标签。用于测量预测输出中的每个元素与目标或地面实况中的相应元素之间的对数概率差异。

class BCELoss(nn.Module):def __init__(self):super(BCELoss, self).__init__()def forward(self, input, target):input = torch.sigmoid(input)loss = - (target * torch.log(input) + (1 - target) * torch.log(1 - input))return loss.mean()

测试代码为以下所示: 

if __name__=="__main__":criterion1 = nn.BCELoss()criterion2 = BCELoss()input_data = torch.randn((5,))print(input_data)target_data = torch.randint(0, 2, (5,), dtype=torch.float32)print(target_data)loss1 = criterion1(torch.sigmoid(input_data), target_data)loss2 = criterion2(input_data, target_data)print("PyTorch BCELoss:", loss1.item())print("MY BCELoss:", loss2.item())

tensor([-2.0343, -1.5186,  1.6389,  0.4658,  0.6823])
tensor([1., 0., 1., 1., 1.])

测试输出均为 0.6857892274856567

当实际标签为1时(y_{i}=1),我们希望模型的预测概率越接近1,因为实际上这个样本是正类别。因此,我们希望\hat{y_{i}}越大,这样log(\hat{y_{i}})的值越小。因此,我们的损失项是-log(\hat{y_{i}})

当实际标签为0时(y_{i}=0),我们希望模型的预测概率越接近0,因为实际上这个样本是负类别。因此,我们希望1-\hat{y_{i}}越大,这样log(1-\hat{y_{i}})的值就越小。因此,我们的损失项是-(1-\hat{y_{i}})log(1-\hat{y_{i}})

我们将上述两种情况的损失项相加,并取平均。最终的BCELoss公式是上述两项的求和。

CrossEntropyLoss

交叉熵损失(CrossEntropyLoss)是在深度学习中常用于多分类问题的一种损失函数。它衡量了模型输出的概率分布与真实标签之间的差异。

\text{CrossEntropyLoss}(x, y) = -\frac{1}{N} \sum_{i=1}^{N} \log\left(\frac{\exp(x_{i, y_i})}{\sum_{j=1}^{C} \exp(x_{i, j})}\right)

其中,y_{i}表示样本i的真实标签,\hat{y_{i}}表示模型对于样本i的预测标签。 

 class CrossEntropyLoss(nn.Module):def __init__(self):super(CrossEntropyLoss, self).__init__()def forward(self, input, target):return nn.NLLLoss()(F.log_softmax(input, dim=1), target)

测试代码为以下所示: 

if __name__ == "__main__":criterion1 = nn.CrossEntropyLoss()criterion2 = CrossEntropyLoss()input_data = torch.randn((3, 5))target_data = torch.randint(0, 5, (3,))loss1 = criterion1(input_data, target_data)loss2 = criterion2(input_data, target_data)print("PyTorch CrossEntropyLoss:", loss1.item())print("Custom CrossEntropyLoss:", loss2.item())

测试输出均为 2.0007288455963135

分子部分exp(x_{i,j})这是模型对第 i 个样本正确类别的原始输出的指数形式。这一部分希望越大越好,因为我们希望模型对正确类别有更高的置信度。

分母部分\sum_{j=1}^{C} \exp(x_{i,j})这是模型对第i个样本所有类别原始输出的指数形式的和。这一部分用于归一化,将原始输出转化为概率分布。通过除以这个和,我们得到每个类别的概率,表示模型对每个类别的相对置信度。

对上述概率分布取对数。这个操作将概率空间映射到实数空间,使得我们可以用数值优化的方法来优化模型。这一部分希望越小越好,因为我们希望模型对真实标签的估计概率越接近于1。

参考文章

L1 loss 是什么_l1loss-CSDN博客

【损失函数】(三) NLLLoss原理 & pytorch代码解析_pytorch nll_loss-CSDN博客

损失函数(lossfunction)的全面介绍(简单易懂版)-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/647118.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

有向图的拓扑序列——拓扑排序

问题描述 什么是拓扑序列 若一个由图中所有点构成的序列 A 满足:对于图中的每条边 (x,y),x 在 A 中都出现在 y 之前,则称 A 是该图的一个拓扑序列。图中不能有环图中至少存在一个点的入度为0 如何求拓扑序列? 计算出每个节点的…

06 BGP 基础报文状态

06 BGP 基础 报文状态 05 BGP 大纲-CSDN博客 1 BGP 的基础 1.1 为什么要使用 BGP 我们要在不同AS之间实现网络通信,需要使用EGP-BGP协议,当然我们还看重BGP的一些优势 1)非常稳定 2)可以传输大量的路由,支持大规模网络 3)具有非常丰富的路由控制策略,可以实现灵活…

常用通信总线学习——RS232与RS485

RS232概述 RS-232标准接口(又称EIA RS-232)是常用的串行通信接口标准之一,它是由美国电子工业协会(Electronic Industry Association,EIA)联合贝尔系统公司、调制解调器厂家及计算机终端生产厂家于1970年共同制定,其全…

缓存和CDN完整指南

1*JfOWR6ECe92QhH_UTwulrg.png 假设一家公司将其网站托管在芬兰的Google Cloud数据中心的服务器上。对于欧洲的用户,加载可能需要大约100毫秒,但对于墨西哥的用户,可能需要3-5秒。幸运的是,有策略可以最小化远程用户的请求延迟。 …

破解不了WIFI?也许你应该试试社工...

以下案例为虚拟环境,请勿模仿 做什么? 由于工作出差在该某某企业出差,手机和电脑都没办法用流量…流量包1G1块…太贵了…我勒个豆啊…发现WIFI密码难以破解(小kali上过了)。 出去逛逛吧…发现楼道有海康威视摄像头,学过交换机的一般都看得出来这个摄像…

(超全七大错误)Invalid bound statement (not found): com.xxx.dao.xxxDao.add

1.确保你把dao和mapper都在applicationContext.xml中都扫描了 xml文件 <bean id"sqlSessionFactory" class"org.mybatis.spring.SqlSessionFactoryBean"><property name"dataSource" ref"dataSource"/><property nam…

web安全学习笔记【08】——算法1

思维导图在最后 #知识点&#xff1a; 1、Web常规-系统&中间件&数据库&源码等 2、Web其他-前后端&软件&Docker&分配站等 3、Web拓展-CDN&WAF&OSS&反向&负载均衡等 ----------------------------------- 1、APP架构-封装&原生态&…

FastDeploy项目简介,使用其进行(图像分类、目标检测、语义分割、文本检测|orc部署)

FastDeploy是一款全场景、易用灵活、极致高效的AI推理部署工具&#xff0c; 支持云边端部署。提供超过 &#x1f525;160 Text&#xff0c;Vision&#xff0c; Speech和跨模态模型&#x1f4e6;开箱即用的部署体验&#xff0c;并实现&#x1f51a;端到端的推理性能优化。包括 物…

自动 CAPTCHA 解决方案,最佳 CAPTCHA 解决方案扩展 2024?

自动 CAPTCHA 解决方案&#xff0c;最佳 CAPTCHA 解决方案扩展 2024&#xff1f; 在迅速发展的数字领域中&#xff0c;高效的 CAPTCHA&#xff08;Completely Automated Public Turing tests to tell Computers and Humans Apart&#xff0c;完全自动化的全球公共图灵测试&…

JavaScript 执行上下文与作用域

执行上下文与作用域 ​ 执行上下文的概念在 JavaScript 中是颇为重要的。变量或函数的上下文决定了它们可以访问哪些数据&#xff0c;以及它们的行为。每个上下文都有一个关联的变量对象&#xff08;variable object&#xff09;&#xff0c; 而这个上下文中定义的所有变量和函…

C++:使用tinyXML生成矢量图svg

先说一下tinyXML库的配置&#xff1a; 很简单&#xff0c;去下面官网下载 TinyXML download | SourceForge.net 解压后是这样 直接将红框中的几个文件放到项目中即可使用 关于svg文件&#xff0c;SVG是基于XML的可扩展矢量图形&#xff0c;svg是xml文件&#xff0c;但是xml…

软件安装SQLyog

SQLyog 安装配置使用 首先下载SQLyog 软件&#xff0c;并解压 选择自己操作系统的版本 双击点击 .exe 文件&#xff0c;进行安装 选择安装语言&#xff0c;默认中文&#xff0c;直接点击【OK】即可 点击【下一步】 先【勾选】同意协议&#xff0c;再点击【下一步】 …

详解SpringCloud微服务技术栈:ElasticSearch实践1——RestClient操作索引库与文档

&#x1f468;‍&#x1f393;作者简介&#xff1a;一位大四、研0学生&#xff0c;正在努力准备大四暑假的实习 &#x1f30c;上期文章&#xff1a;详解SpringCloud微服务技术栈&#xff1a;ElasticSearch原理精讲、安装、实践 &#x1f4da;订阅专栏&#xff1a;微服务技术全家…

【Tailwind】各种样式的进度条

基本样式进度条&#xff1a; <div class"mb-5 h-2 rounded-full bg-gray-200"><div class"h-2 rounded-full bg-orange-500" style"width: 50%"></div> </div>带文字的进度条&#xff1a; <div class"relativ…

npm install报错certificate has expired

报错&#xff1a; reason: certificate has expired 解决&#xff1a;更换npm镜像源 登录到服务器上&#xff0c;更换npm镜像源(或者在jenkins上配置) npm config set registry http://registry.cnpmjs.org npm config set registry http://registry.npm.taobao.org #如果上面…

人工智能时代:让AIGC成为你的外部智慧源(文末送书)

&#x1f308;个人主页&#xff1a;聆风吟 &#x1f525;系列专栏&#xff1a;数据结构、网络奇遇记 &#x1f516;少年有梦不应止于心动&#xff0c;更要付诸行动。 文章目录 &#x1f4cb;前言一. 什么是AIGC?二. AIGC如何运作&#xff1f;2.1 步骤一&#xff1a;收集数据2.…

postgresql12表膨胀解决(不锁表)

查看所有数据库占用磁盘空间 SELECTpg_database.datname AS "数据库名称",pg_size_pretty(pg_database_size(pg_database.datname)) AS "磁盘占用空间" FROMpg_database;发现有个数据库占用空间过大 查询库中所有表占用空间 SELECTtable_name,pg_size_…

Lucene 源码分析——BKD-Tree

Lucene 源码分析——BKD-Tree - AIQ Bkd-Tree Bkd-Tree作为一种基于K-D-B-tree的索引结构&#xff0c;用来对多维度的点数据(multi-dimensional point data)集进行索引。Bkd-Tree跟K-D-B-tree的理论部分在本篇文章中不详细介绍&#xff0c;对应的两篇论文在附件中&#xff0c…

【LangChain学习之旅】—(9) 用SequencialChain链接不同的组件

【LangChain学习之旅】—&#xff08;9&#xff09;用SequencialChain链接不同的组件 什么是 ChainLLMChain&#xff1a;最简单的链链的调用方式直接调用通过 run 方法通过 predict 方法通过 apply 方法通过 generate 方法 Sequential Chain&#xff1a;顺序链首先&#xff0c;…

Oracle篇—分区表的管理(第二篇,总共五篇)

☘️博主介绍☘️&#xff1a; ✨又是一天没白过&#xff0c;我是奈斯&#xff0c;DBA一名✨ ✌✌️擅长Oracle、MySQL、SQLserver、Linux&#xff0c;也在积极的扩展IT方向的其他知识面✌✌️ ❣️❣️❣️大佬们都喜欢静静的看文章&#xff0c;并且也会默默的点赞收藏加关注❣…