深度学习-损失函数

目录

1. 线性回归损失函数

1.1 MAE损失

1.2 MSE损失

2. CrossEntropyLoss

2.1 信息量

2.2 信息熵

2.3 KL散度

2.4 交叉熵

3. BCELoss

4. 总结


1. 线性回归损失函数

1.1 MAE损失

MAE(Mean Absolute Error,平均绝对误差)通常也被称为 L1-Loss,通过对预测值和真实值之间的绝对差取平均值来衡量他们之间的差异。。

MAE的公式如下:


\text{MAE} = \frac{1}{n} \sum_{i=1}^{n} \left| y_i - \hat{y}_i \right|

其中:

  • n 是样本的总数。

  • y_i 是第 i 个样本的真实值。

  • \hat{y}_i是第 i 个样本的预测值。

  • \left| y_i - \hat{y}_i \right|是真实值和预测值之间的绝对误差。

特点

  1. 鲁棒性:与均方误差(MSE)相比,MAE对异常值(outliers)更为鲁棒,因为它不会像MSE那样对较大误差平方敏感。

  2. 物理意义直观:MAE以与原始数据相同的单位度量误差,使其易于解释。

应用场景: MAE通常用于需要对误差进行线性度量的情况,尤其是当数据中可能存在异常值时,MAE可以避免对异常值的过度惩罚。

使用torch.nn.L1Loss即可计算MAE:

import torch
import torch.nn as nn
​
# 初始化MAE损失函数
mae_loss = nn.L1Loss()
​
# 假设 y_true 是真实值, y_pred 是预测值
y_true = torch.tensor([3.0, 5.0, 2.5])
y_pred = torch.tensor([2.5, 5.0, 3.0])
​
# 计算MAE
loss = mae_loss(y_pred, y_true)
print(f'MAE Loss: {loss.item()}')

1.2 MSE损失

均方差损失,也叫L2Loss。

MSE(Mean Squared Error,均方误差)通过对预测值和真实值之间的误差平方取平均值,来衡量预测值与真实值之间的差异。

MSE的公式如下:


\text{MSE} = \frac{1}{n} \sum_{i=1}^{n} \left( y_i - \hat{y}_i \right)^2

其中:

  • n 是样本的总数。

  • y_i 是第 i 个样本的真实值。

  • \hat{y}_i 是第 i 个样本的预测值。

  • \left( y_i - \hat{y}_i \right)^2 是真实值和预测值之间的误差平方。

特点

  1. 平方惩罚:因为误差平方,MSE 对较大误差施加更大惩罚,所以 MSE 对异常值更为敏感。

  2. 凸性:MSE 是一个凸函数(国际的叫法,国内叫凹函数),这意味着它具有一个唯一的全局最小值,有助于优化问题的求解。

应用场景

MSE被广泛应用在神经网络中。

使用 torch.nn.MSELoss 可以实现:

import torch
import torch.nn as nn
​
# 初始化MSE损失函数
mse_loss = nn.MSELoss()
​
# 假设 y_true 是真实值, y_pred 是预测值
y_true = torch.tensor([3.0, 5.0, 2.5])
y_pred = torch.tensor([2.5, 5.0, 3.0])
​
# 计算MSE
loss = mse_loss(y_pred, y_true)
print(f'MSE Loss: {loss.item()}')

2. CrossEntropyLoss

2.1 信息量

信息量用于衡量一个事件所包含的信息的多少。信息量的定义基于事件发生的概率:事件发生的概率越低,其信息量越大。其量化公式:

对于一个事件x,其发生的概率为 P(x),信息量I(x) 定义为:

性质

  1. 非负性:I(x)≥0。

  2. 单调性:P(x)越小,I(x)越大。

2.2 信息熵

信息熵是信息量的期望值。熵越高,表示随机变量的不确定性越大;熵越低,表示随机变量的不确定性越小。

公式由数学中的期望推导而来:

其中:

-logP(x_i)是信息量,P(x_i)是信息量对应的概率

2.3 KL散度

KL散度用于衡量两个概率分布之间的差异。它描述的是用一个分布 Q来近似另一个分布 P时,所损失的信息量。KL散度越小,表示两个分布越接近。

对于两个离散概率分布 P和 Q,KL散度定义为:

其中:P 是真实分布,Q是近似分布。

2.4 交叉熵

对KL散度公式展开:

由上述公式可知,P是真实分布,H(P)是常数,所以KL散度可以用H(P,Q)来表示;H(P,Q)叫做交叉熵。

如果将P换成y,Q换成\hat{y},则交叉熵公式为:

其中:

  • C 是类别的总数。

  • y 是真实标签的one-hot编码向量,表示真实类别。

  • \hat{y} 是模型的输出(经过 softmax 后的概率分布)。

  • y_i 是真实类别的第 i 个元素(0 或 1)。

  • \hat{y}_i 是预测的类别概率分布中对应类别 i 的概率。

函数曲线图:

特点:

  1. 概率输出:CrossEntropyLoss 通常与 softmax 函数一起使用,使得模型的输出表示为一个概率分布(即所有类别的概率和为 1)。PyTorch 的 nn.CrossEntropyLoss 已经内置了 Softmax 操作。如果我们在输出层显式地添加 Softmax,会导致重复应用 Softmax,从而影响模型的训练效果。

  2. 惩罚错误分类:该损失函数在真实类别的预测概率较低时,会施加较大的惩罚,这样模型在训练时更注重提升正确类别的预测概率。

  3. 多分类问题中的标准选择:在大多数多分类问题中,CrossEntropyLoss 是首选的损失函数。

应用场景:

CrossEntropyLoss 广泛应用于各种分类任务,包括图像分类、文本分类等,尤其是在神经网络模型中。

nn.CrossEntropyLoss基本原理:

由交叉熵公式可知:


\text{Loss}(y, \hat{y}) = - \sum_{i=1}^{C} y_i \log(\hat{y}_i)

因为y_i是one-hot编码,其值不是1便是0,又是乘法,所以只要知道1对应的index就可以了,展开后:


\text{Loss}(y, \hat{y}) = - \log(\hat{y}_m)

其中,m表示真实类别。

因为神经网络最后一层分类总是接softmax,所以可以把\hat{y}_m直接看为是softmax后的结果。


\text{Loss}(i) = - \log(softmax(x_i))
 

所以,CrossEntropyLoss 实质上是两步的组合:Cross Entropy = Log-Softmax + NLLLoss

  • Log-Softmax:对输入 logits 先计算对数 softmax:log(softmax(x))

  • NLLLoss(Negative Log-Likelihood):对 log-softmax 的结果计算负对数似然损失。简单理解就是求负数。原因是概率值通常在 0 到 1 之间,取对数后会变成负数。为了使损失值为正数,需要取负数。

对于softmax(x_i),在softmax介绍中了解到,需要减去最大值以确保数值稳定。


\mathrm{Softmax}(x_i)=\frac{e^{x_i-\max(x)}}{\sum_{j=1}^ne^{x_j-\max(x)}}

则:


LogSoftmax(x_i) =log(\frac{e^{x_i-\max(x)}}{\sum_{j=1}^ne^{x_j-\max(x)}})\\ =x_i-\max(x)-log(\sum_{j=1}^ne^{x_j-\max(x)})

所以:


\text{Loss}(i) = - (x_i-\max(x)-log(\sum_{j=1}^ne^{x_j-\max(x)}))
 

总的交叉熵损失函数是所有样本的平均值:


\ell(x, y) = \begin{cases} \frac{\sum_{n=1}^N l_n}{N}, & \text{if reduction} = \text{`mean';}\\ \sum_{n=1}^N l_n, & \text{if reduction} = \text{`sum'.} \end{cases}

示例代码如下:

import torch
import torch.nn as nn
​
# 假设有三个类别,模型输出是未经softmax的logits
logits = torch.tensor([[1.5, 2.0, 0.5], [0.5, 1.0, 1.5]])
​
# 真实的标签
labels = torch.tensor([1, 2])  # 第一个样本的真实类别为1,第二个样本的真实类别为2
​
# 初始化CrossEntropyLoss
# 参数:reduction:mean-平均值,sum-总和
criterion = nn.CrossEntropyLoss()
​
# 计算损失
loss = criterion(logits, labels)
print(f'Cross Entropy Loss: {loss.item()}')

在这个例子中,CrossEntropyLoss 直接作用于未经 softmax 处理的 logits 输出和真实标签,PyTorch 内部会自动应用 softmax 激活函数,并计算交叉熵损失。

分析示例中的代码:

logits = torch.tensor([[1.5, 2.0, 0.5], [0.5, 1.0, 1.5]])

第一个样本的得分是 [1.5, 2.0, 0.5],分别对应类别 0、1 和 2 的得分。

第二个样本的得分是 [0.5, 1.0, 1.5],分别对应类别 0、1 和 2 的得分

labels = torch.tensor([1, 2])

第一个样本的真实类别是 1。

第二个样本的真实类别是 2。

CrossEntropyLoss 的计算过程可以分为以下几个步骤:

(1) LogSoftmax 操作

首先,对每个样本的 logits 应用 LogSoftmax 函数,将 logits 转换为概率分布。LogSoftmax 函数的公式是: LogSoftmax(x_i) =x_i-\max(x)-log(\sum_{j=1}^ne^{x_j-\max(x)})

对于第一个样本 [1.5, 2.0, 0.5]

减去最大值:

x_i-\max(x)=[1.5-2.0,2.0-2.0,0.5-2.0]=[-0.5,0,-1.5]

计算e^{x_i-\max(x)}

求和并取对数:

计算 log_softmax

对于第二个样本 [0.5, 1.0, 1.5]

减去最大值:

x_i-\max(x)=[0.5-1.5,1.0-1.5,1.5-1.5]=[-1.0,-0.5,0]

计算e^{x_i-\max(x)}

求和并取对数:

计算 log_softmax

(2) 计算每个样本的损失

接下来,根据真实标签 z_t 计算每个样本的交叉熵损失。交叉熵损失的公式是:

对于第一个样本:

  • 真实类别是 1,对应的 softmax 值是 -0.6041。

对于第二个样本:

  • 真实类别是 2,对应的 softmax 值是 -0.6803。

(3) 计算平均损失

最后,计算所有样本的平均损失:

3. BCELoss

二分类交叉熵损失函数,使用在输出层使用sigmoid激活函数进行二分类时。

由交叉熵公式:

\text{CELoss}(y, \hat{y}) = - \sum_{i=1}^{C} y_i \log(\hat{y}_i)

对于二分类问题,真实标签 y的值为(0 或 1),假设模型预测为正类的概率为 \hat{y},则:

所以:

示例:

import torch
import torch.nn as nn
​
# y 是模型的输出,已经被sigmoid处理过,确保其值域在(0,1)
y = torch.tensor([[0.7], [0.2], [0.9], [0.7]])
# targets 是真实的标签,0或1
t = torch.tensor([[1], [0], [1], [0]], dtype=torch.float)
​
# 计算损失方式一:
bceLoss = nn.BCELoss()
loss1 = bceLoss(y, t)
​
#计算损失方式二: 两种方式结果相同
loss2 = nn.functional.binary_cross_entropy(y, t)
​
print(loss1, loss2)

逐样本计算

样本y_it_i计算项 t_i * log(y_i) + (1-t_i) * log(1-y_i)
10.711*log(0.7) + 0*log(0.3) ≈ -0.3567
20.200*log(0.2) + 1*log(0.8) ≈ -0.2231
30.911*log(0.9) + 0*log(0.1) ≈ -0.1054
40.700*log(0.7) + 1*log(0.3) ≈ -1.2040

计算最终损失

4. 总结

  • 当输出层使用softmax多分类时,使用交叉熵损失函数;

  • 当输出层使用sigmoid二分类时,使用二分类交叉熵损失函数, 比如在逻辑回归中使用;

  • 当功能为线性回归时,使用均方差损失-L2 loss;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/77988.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第六篇:linux之解压缩、软件管理

第六篇:linux之解压缩、软件管理 文章目录 第六篇:linux之解压缩、软件管理一、解压和压缩1、window压缩包与linux压缩包能否互通?2、linux下压缩包的类型3、打包与压缩 二、软件管理1、rpm1、什么是rpm?2、rpm包名组成部分3、如何…

Redis 键管理

Redis 键管理 以下从键重命名、随机返回键、键过期机制和键迁移四个维度展开详细说明,结合 Redis 核心命令与底层逻辑进行深入分析: 一、键重命名 1. ​RENAME​​ 与 ​RENAMENX​​ **RENAME key newkey​**: 功能:强制重命名…

OpenCV 模板匹配方法详解

文章目录 1. 什么是模板匹配?2. 模板匹配的原理2.1数学表达 3. OpenCV 实现模板匹配3.1基本步骤 4. 模板匹配的局限性5. 总结 1. 什么是模板匹配? 模板匹配(Template Matching)是计算机视觉中的一种基础技术,用于在目…

TextCNN 模型文本分类实战:深度学习在自然语言处理中的应用

在自然语言处理(NLP)领域,文本分类是研究最多且应用最广泛的任务之一。从情感分析到主题识别,文本分类技术在众多场景中都发挥着重要作用。最近,我参与了一次基于 TextCNN 模型的文本分类实验,从数据准备到…

Qt-创建模块化.pri文件

文章目录 一、.pri文件的作用与基本结构作用基本结构 二、创建.pri文件如何添加模块代码? 一、.pri文件的作用与基本结构 作用 在Qt开发中,.pri文件(Project Include File)是一种配置包含文件,用于模块化管理和复用项…

SpringCloud组件——Eureka

一.背景 1.问题提出 我们在一个父项目下写了两个子项目,需要两个子项目之间相互调用。我们可以发送HTTP请求来获取我们想要的资源,具体实现的方法有很多,可以用HttpURLConnection、HttpClient、Okhttp、 RestTemplate等。 举个例子&#x…

EAL4+与等保2.0:解读中国网络安全双标准

EAL4与等保2.0:解读中国网络安全双标准 在当今数字化时代,网络安全已成为各个行业不可忽视的重要议题。特别是在金融、政府、医疗等领域,保护信息的安全性和隐私性显得尤为关键。在中国,EAL4和等级保护2.0(简称“等保…

FFmpeg+Nginx+VLC打造M3U8直播

一、视频直播的技术原理和架构方案 直播模型一般包括三个模块:主播方、服务器端和播放端 主播放创造视频,加美颜、水印、特效、采集后推送给直播服务器 播放端: 直播服务器端:收集主播端的视频推流,将其放大后推送给…

【Redis】缓存三剑客问题实践(上)

本篇对缓存三剑客问题进行介绍和解决方案说明,下篇将进行实践,有需要的同学可以跳转下篇查看实践篇:(待发布) 缓存三剑客是什么? 缓存三剑客指的是在分布式系统下使用缓存技术最常见的三类典型问题。它们分…

Flink 2.0 编译

文章目录 Flink 2.0 编译第一个问题 java 版本太低maven 版本太低maven 版本太高开始编译扩展多版本jdk 配置 Flink 2.0 编译 看到Flink2.0 出来了,想去玩玩,看看怎么样,当然第一件事,就是编译代码,但是没想到这么多问…

获取印度股票市场列表、查询IPO信息以及通过WebSocket实时接收数据

为了对接印度股票市场,获取市场列表、查询IPO信息、查看涨跌排行榜以及通过WebSocket实时接收数据等步骤。 1. 获取市场列表 首先,您需要获取支持的市场列表,这有助于了解哪些市场可以交易或监控。 请求方法:GETURL&#xff1a…

云原生--CNCF-1-云原生计算基金会介绍(云原生生态的发展目标和未来)

1、CNCF定义与背景 云原生计算基金会(Cloud Native Computing Foundation,CNCF)是由Linux基金会于2015年12月发起成立的非营利组织,旨在推动云原生技术的标准化、开源生态建设和行业协作。其核心目标是通过开源项目和社区协作&am…

【Rust 精进之路之第5篇-数据基石·下】复合类型:元组 (Tuple) 与数组 (Array) 的定长世界

系列: Rust 精进之路:构建可靠、高效软件的底层逻辑 作者: 码觉客 发布日期: 2025-04-20 引言:从原子到分子——组合的力量 在上一篇【数据基石上】中,我们仔细研究了 Rust 的四种基本标量类型&#xff1…

MongoDB 集合名称映射问题

项目场景 在使用 Spring Data MongoDB 进行开发时,定义了一个名为 CompetitionSignUpLog 的实体类,并创建了对应的 Repository 接口。需要明确该实体类在 MongoDB 中实际对应的集合名称是 CompetitionSignUpLog 还是 competitionSignUpLog。 问题描述 …

物联网 (IoT) 安全简介

什么是物联网安全? 物联网安全是网络安全的一个分支领域,专注于保护、监控和修复与物联网(IoT)相关的威胁。物联网是指由配备传感器、软件或其他技术的互联设备组成的网络,这些设备能够通过互联网收集、存储和共享数据…

PCB原理图解析(炸鸡派为例)

晶振 这是外部晶振的原理图。 32.768kHz 的晶振,常用于实时时钟(RTC)电路,因为它的频率恰好是一天的分数(32768 秒),便于实现秒计数。 C25 和 C24:两个 12pF 的电容,用于…

Jupyter Notebook 中切换/使用 conda 虚拟环境的方式(解决jupyter notebook 环境默认在base下面的问题)

使用 nb_conda_kernels 添加所有环境 一键添加所有 conda 环境 conda activate my-conda-env # this is the environment for your project and code conda install ipykernel conda deactivateconda activate base # could be also some other environment conda in…

【JAVA】十三、基础知识“接口”精细讲解!(二)(新手友好版~)

哈喽大家好呀qvq,这里是乎里陈,接口这一知识点博主分为三篇博客为大家进行讲解,今天为大家讲解第二篇java中实现多个接口,接口间的继承,抽象类和接口的区别知识点,更适合新手宝宝们阅读~更多内容持续更新中…

基于MuJoCo物理引擎的机器人学习仿真框架robosuite

Robosuite 基于 MuJoCo 物理引擎,能支持多种机器人模型,提供丰富多样的任务场景,像基础的抓取、推物,精细的开门、拧瓶盖等操作。它可灵活配置多种传感器,提供本体、视觉、力 / 触觉等感知数据。因其对强化学习友好&am…

企业微信自建应用开发回调事件实现方案

目录 1. 前言 2. 正文 2.1 技术方案 2.2 策略上下文 2.2 添加客户策略实现类 2.3 修改客户信息策略实现类 2.4 默认策略实现类 2.5 接收事件的实体类(可以根据事件格式的参数做修改) 2.6 实际接收回调结果的接口 近日在开发企业微信的自建应用时…