深度学习常用损失函数详解

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

  • 前言
  • 一、回归问题
    • 1. 均方误差(MSE)
    • 2. 均方根误差 (RMSE)
    • 3. 平均绝对误差 (MAE)
  • 二、分类问题
    • 1. 相关概念
    • 2. 交叉熵损失函数
    • 3. BCE Loss


前言

机器学习任务大概可以分为两类问题,分别是回归问题和分类问题。回归问题预测的是一个连续的数值,例如房价,气温等。而分类问题是将输入预测为不同的类别,例如猫狗分类等。总的来说,回归问题输出是一个实数范围,可以是任何数值;分类问题输出是离散的类别标签,通常是整数或特定的类别名称。接下来对回归问题和分类问题常用的损失函数进行介绍。

一、回归问题

1. 均方误差(MSE)

 MSE损失(Mean Squared Error)也称为 L 2 L_2 L2 Loss,是回归问题中比较常用的损失函数。其公式为:

L M S E ( y , y ^ ) = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 L_{MSE}(y,\hat{y}) = \frac{1}{n}\sum_{i=1}^n(y_i-\hat{y}_i)^2 LMSE(y,y^)=n1i=1n(yiy^i)2

其中, y y y为模型的预测值, y ^ \hat{y} y^为标签值。
 优点:由于平方操作,MSE对较大的误差给予更大的惩罚,这有助于模型学习减少大的预测偏差。
 缺点:MSE对异常值或离群点非常敏感,这可能会影响模型的泛化能力。
代码实现:

import torch
import torch.nn as nny = torch.Tensor([1,2,3])
label = torch.Tensor([3,5,1])
criterion = nn.MSELoss()
loss = criterion(y,label)
print(loss)

输出为tensor(5.6667)

2. 均方根误差 (RMSE)

 RMSE误差(Root Mean Squared Error)的公式为:

L M S E ( y , y ^ ) = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 L_{MSE}(y,\hat{y}) = \sqrt{\frac{1}{n}\sum_{i=1}^n(y_i-\hat{y}_i)^2} LMSE(y,y^)=n1i=1n(yiy^i)2

其相当于在MSE的基础上加了一个根号。
 优点:与MSE一样,RMSE对较大的误差给予更大的惩罚,同时因为开根号的缘故,它的单位与原始数据的单位相同,这使得它在解释预测误差时更加直观。
 缺点:对异常值或离群点非常敏感,计算量较大,RMSE比较难优化,因为它不是一个严格凸函数。

3. 平均绝对误差 (MAE)

 MAE(Mean Absolute Error)又称为 L 1 L_1 L1 Loss,计算所有样本预测值与实际值之差的绝对值的平均值。公式为:

L M A E ( y , y ^ ) = 1 n ∑ i = 1 n ∣ ( y i − y ^ i ) ∣ L_{MAE}(y,\hat{y}) = \frac{1}{n}\sum_{i=1}^n|(y_i-\hat{y}_i)| LMAE(y,y^)=n1i=1n(yiy^i)

其中, y y y为预测值, y ^ \hat{y} y^为标签值。
 优点:对异常值不敏感,单位与原始数据一致,易于直观理解,计算简单。
 缺点:可能不利于模型学习减少较大的预测偏差。
代码实现:

import torch
import torch.nn as nny = torch.Tensor([1,2,3])
label = torch.Tensor([3,5,1])
criterion = nn.L1Loss()
loss = criterion(y,label)
print(loss)

输出为tensor(2.3333)

二、分类问题

1. 相关概念

 首先我们先介绍一下相关概念。
信息熵:用来衡量信息的不确定性或信息的平均信息量。其公式为:

H ( x ) = − ∑ i = 1 n P ( x i ) log ⁡ P ( x i ) H(x) = -\sum_{i=1}^nP(x_i)\log P(x_i) H(x)=i=1nP(xi)logP(xi)

其中, P ( x i ) P(x_i) P(xi)为随机事件 x i x_i xi发生的概率。若事件发生的不确定性越大,则其熵越大,代表含有更多的信息量。
KL散度:是一种衡量一个概率分布P相对于另一个概率分布Q的非对称性差异,其公式为:

D K L ( p ∣ ∣ q ) = ∑ i = 1 n p ( x i ) log ⁡ ( p ( x i ) q ( x i ) ) D_{KL}(p||q) = \sum_{i=1}^np(x_i)\log (\frac{p(x_i)}{q(x_i)}) DKL(p∣∣q)=i=1np(xi)log(q(xi)p(xi))

其中, p p p q q q是两种不同的分布, p ( x i ) p(x_i) p(xi) q ( x i ) q(x_i) q(xi)分别代表随机事件在 p p p q q q这两个分布中发生的概率。其具有以下特性:

非负性:KL散度总是正的
非对称性: D K L ( p ∣ ∣ q ) D_{KL}(p||q) DKL(p∣∣q)不等于 D K L ( q ∣ ∣ p ) D_{KL}(q||p) DKL(q∣∣p),只有当两个分布相等时, D K L ( p ∣ ∣ q ) = D K L ( q ∣ ∣ p ) D_{KL}(p||q)=D_{KL}(q||p) DKL(p∣∣q)=DKL(q∣∣p)
在信息论中,KL散度可以被解释为在使用概率分布Q来拟合真实分布P时产生的信息损耗。

交叉熵:用于衡量两个概率分布之间的差异,是对称的。

H ( P , Q ) = − ∑ i = 1 n p ( x i ) log ⁡ q ( x i ) = H ( P ) + D K L ( p ∣ ∣ q ) H(P,Q) = -\sum_{i=1}^np(x_i)\log q(x_i) = H(P)+D_{KL}(p||q) H(P,Q)=i=1np(xi)logq(xi)=H(P)+DKL(p∣∣q)
H ( P , Q ) = H ( Q , P ) H(P,Q)=H(Q,P) H(P,Q)=H(Q,P)

其中, p p p q q q是两个分布。交叉熵经常用作分类问题的损失函数,其中,P可以看作是标签的分布,Q可以看作是模型预测的分布。

2. 交叉熵损失函数

 交叉熵损失函数(Cross-Entropy Loss)是分类问题中常用的损失函数,其衡量的是模型预测的概率分布和标签的真实分布之间的差异。前面说过,计算两个分布之间的交叉熵公式为:

H ( P , Q ) = − ∑ i = 1 n p ( x i ) log ⁡ q ( x i ) H(P,Q) = -\sum_{i=1}^np(x_i)\log q(x_i) H(P,Q)=i=1np(xi)logq(xi)

其中, P P P为标签的分布, Q Q Q为模型预测的分布。接下来我们举一个例子来说明。

例如我们要做一个三分类[汽车,猫,飞机]的物体分类,假设标签为1,我们会将标签1转换为one-hot编码[0,1,0],然后将模型的输出经过softmax操作进行归一化将其转换为概率值,然后计算交叉熵。

分类模型预测标签
汽车0.050
0.801
飞机0.150

则公式为

L o s s = − ( 0 ∗ log ⁡ 0.05 + 1 ∗ log ⁡ 0.80 + 0 ∗ log ⁡ 0.15 ) = − log ⁡ 0.80 Loss = -(0*\log0.05 + 1*\log0.80 + 0*\log0.15)=-\log0.80 Loss=(0log0.05+1log0.80+0log0.15)=log0.80

 因为在分类问题中,标签的分布中只有一个1,其他都为0。因此,在分类问题中,交叉熵损失函数可以简化为:

L o s s = − log ⁡ p ( x i ) Loss = -\log p(x_i) Loss=logp(xi)

其中, p ( x i ) p(x_i) p(xi)为模型预测的分类结果的概率值。例如,在上述例子中 p ( x i ) p(x_i) p(xi)的值为0.80。
代码实现:

import torch
import torch.nn as nnlogits = torch.tensor([[0.68,-0.8,0.75]]) #模型的输出  
labels = torch.tensor([2])  #标签值为2 
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, labels)
print(loss.item())

最终的输出结果为0.8329725861549377。

3. BCE Loss

 BCE Loss(Binary Cross-Entropy Loss)是在二分类问题中常用的损失函数,其公式为:

L B C E = − ( y ∗ log ⁡ ( p ) + ( 1 − y ) ∗ log ⁡ ( 1 − p ) ) L_{BCE} = -(y*\log(p)+(1-y)*\log(1-p)) LBCE=(ylog(p)+(1y)log(1p))

其中, y y y为标签值(0或者1), p p p为模型预测为正样本(标签值为1)的概率。在计算Loss之前,首先要将模型的输出经过Sigmoid激活函数将概率值 p p p映射到[0,1]之间。
 注意:与使用交叉熵损失做二分类不同的是,如果使用交叉熵损失,则最终模型的输出是一个向量,这个向量经过softmax操作后变成概率分布,代表是正样本的概率和负样本的概率,这些概率相加等于1;而使用BCE Loss时,模型的输出是一个数,然后经过sigmoid操作映射到[0,1]之间,映射后的值就代表模型预测该物品是正样本的概率。下面举一个例子来说明。

假设使用BCE Loss做一个二分类的问题,模型的输出为2.56,然后经过sigmoid操作映射成0.80(瞎编的),说明模型预测为正样本的概率为0.8。
如果标签是0(负样本),然后使用公式计算:
L o s s = − ( 0 ∗ log ⁡ ( 0.8 ) + 1 ∗ log ⁡ ( 0.2 ) ) = − l o g ( 0.2 ) Loss = -(0*\log(0.8) + 1*\log(0.2)) = -log(0.2) Loss=(0log(0.8)+1log(0.2))=log(0.2)
如果标签是1(正样本),然后使用公式计算:
L o s s = − ( 1 ∗ log ⁡ ( 0.8 ) + 0 ∗ log ⁡ ( 0.2 ) ) = − l o g ( 0.8 ) Loss = -(1*\log(0.8) + 0*\log(0.2)) = -log(0.8) Loss=(1log(0.8)+0log(0.2))=log(0.8)

如果使用的是交叉熵损失函数做二分类问题,标签0代表负样本,标签1代表正样本,假设经过softmax操作后的概率分布为[0.2(负样本),0.8(正样本)]
如果标签是0,转换为one-hot编码[1,0],使用交叉熵损失函数:
L o s s = − log ⁡ ( 0.2 ) Loss =-\log(0.2) Loss=log(0.2)
如果标签是1,转换为one-hot编码[0,1],使用交叉熵损失函数:
L o s s = − log ⁡ ( 0.8 ) Loss =-\log(0.8) Loss=log(0.8)

 通过上述两个例子可以看出,当做二分类任务时,BCE和交叉熵的Loss最终函数形式是一样的,个人认为,这两种损失函数主要的不同在于模型最终的输出形式和最终的概率映射方式不同。

代码实现:

import torch
import torch.nn as nnpredicted_probabilities = torch.tensor([0.8])  # 模型预测的为正样本概率
#1表示正样本,0表示负样本
true_labels = torch.tensor([1])
# 将布尔值转换为浮点数,因为PyTorch的BCELoss期望浮点数标签
true_labels = true_labels.float()
criterion = nn.BCELoss()
loss = criterion(predicted_probabilities, true_labels)
print(loss.item())

代码输出为0.2231435328722。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/877427.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

STM32之SPI读写W25Q128芯片

SPI简介 STM32的SPI是一个串行外设接口。它允许STM32微控制器与其他设备(如传感器、存储器等)进行高速、全双工、同步的串行通信。通常包含SCLK(串行时钟)、MOSI(主设备输出/从设备输入Master Output Slave Input&…

Linux系统编程 --- 多线程

线程:是进程内的一个执行分支,线程的执行粒度,要比进程要细。 一、线程的概念 1、Linux中线程该如何理解 地址空间就是进程的资源窗口。 在一个程序里的一个执行路线就叫做线程(thread)。更准确的定义是&#xff1…

【vim 学习系列文章 15.1 -- vim 只显示高亮字符所在的行】

文章目录 vim 只显示高亮字符所在的行搜索并高亮字符仅显示高亮字符所在的行在快速修复列表中导航使用 :g 命令仅显示匹配的行Summary vim 只显示高亮字符所在的行 在 Vim 中,如果你想只显示包含高亮字符的行,可以使用一些 Vim 内置的命令与功能来实现。…

聊聊场景及场景测试

在我们进行测试过程中,有一种黑盒测试叫场景测试,我们完全是从用户的角度去理解系统,从而可以挖掘用户的隐含需求。 场景是指用户会使用这个系统来完成预定目标的所有情况的集合。 场景本身也代表了用户的需求,所以我们可以认为…

SpringBoot+Vue在线商城(电子商城)系统-附源码与配套论文

摘 要 随着互联网技术的发展和普及,电子商务在全球范围内得到了迅猛的发展,已经成为了一种重要的商业模式和生活方式。电子商城是电子商务的重要组成部分,是一个基于互联网的商业模式和交易平台,通过网络进行产品和服务的销售。…

计算机图形学 | 动画模拟

动画模拟 布料模拟 质点弹簧系统: 红色部分很弱地阻挡对折 Steep connection FEM:有限元方法 粒子系统 粒子系统本质上就是在定义个体和群体的关系。 动画帧率 VR游戏要不晕需要达到90fps Forward Kinematics Inverse Kinematics 只告诉末端p点,中间…

Delphi5实现色板程序——滑块型组件实例

效果图 参考 Delphi程序设计基础:教程、实验、习题 代码 unit Unit1;interfaceusesSysUtils, WinTypes, WinProcs, Messages, Classes, Graphics, Controls,Dialogs, Forms,Form, Formprpt, ExtCtrls, StdCtrls;typeTForm1 class(MForm)Label1: TLabel;Label2: …

公式编辑器 -vue-formula-editor

前言 公式编辑旨在帮助用户使用可视化的前提,能便捷的使用平台,例如低代码平台使用广泛 vue-formula-editor vue-formula-editor是一款开源的Vue公式计算组件,可以帮助开发者快速集成公式编辑 在线体验 demo & 源码 安装 npm i vue-form…

[Python学习日记-9] Python中的运算符

简介 计算机可以进行的运算有很多种,但可不只加减乘除这么简单,运算按种类可分为算数运算、比较运算、逻辑运算、赋值运算、成员运算、身份运算、位运算,而本篇我们暂只介绍算数运算、比较运算、逻辑运算、赋值运算 算数运算 一、运算符描述…

FunHPC算力平台评测

作为内测老用户,已经用DeepLn平台(现改名为FunHPC平台)好久了,一路见证了平台从最初100多人的小群到现在满群的状态,FunHpc平台确实在一步步的走向成熟,一步步的变大。趁着现在活动的时间,发篇文…

XSS反射型和DOM型+DOM破坏

目录 第一关 源码分析 payload 第二关 源码分析 payload 第三关 源码分析 payload 第四关 源码分析 payload 第五关 源码分析 payload 第六关 源码分析 第七关 源码分析 方法一:构造函数 方法二:parseInt 方法三:locat…

项目问题 | CentOS 7停止维护导致yum失效的解决办法

目录 centos停止维护意味着yum相关源伴随失效。 报错: 解决方案:将图中四个文件替换掉/etc/yum.repos.d/目录下同名文件 资源提交在博客头部,博客结尾也提供文件源码内容 CentOS-Base.repo CentOS-SCLo-scl.repo CentOS-SCLo-scl-rh.rep…

HTML5服装电商网上商城模板源码

文章目录 1.设计来源1.1 主界面1.2 购物车界面1.3 电子产品界面1.4 商品详情界面1.5 联系我们界面1.6 各种标签演示界面 2.效果和源码2.1 动态效果2.2 源代码 源码下载万套模板,程序开发,在线开发,在线沟通 【博主推荐】:前些天发…

【系统分析师】-综合知识-系统架构

1、设计模式 1)观察者模式定义了对象间的一种一对多依赖关系,使得每当一个对象改变状态,则所有依赖于它的对象都会得到通知并被自动更新【消息订阅】。在该模式中,发生改变的对象称为观察目标,被通知的对象称为观察者&…

Linux中查看正在监听的IP和端口

最近和其他终端设备联调时,需要去查看正在监听的IP和端口,以下在 Linux 系统中,可以使用以下命令查看正在监听的端口和 IP 地址。这个取决于当前Linux系统内已有的工具,如果不清楚,可以都试一下。 netstat -tuln -t&am…

Python爬虫使用实例

IDE:大部分是在PyCharm上面写的 解释器装的多 → 环境错乱 → error:没有配置,no model 爬虫可以做什么? 下载数据【文本/二进制数据(视频、音频、图片)】、自动化脚本【自动抢票、答题、采数据、评论、点…

vue3 响应式 API:ref() 和 reactive()

在 Vue 3 中,响应式系统是其核心特性之一,它使得数据的变化能够自动触发视图的更新。 官方文档: 响应式 API:核心 要更好地了解响应式 API,推荐阅读官方指南中的章节: 响应式基础 (with the API preference…

SX_初识GitLab_1

1、对GitLab的理解: 目前对GitLab的理解是其本质是一个远程代码托管平台,上面托管多个项目,每个项目都有一个master主分支和若干其他分支,远程代码能下载到本机,本机代码也能上传到远程平台 1.分支的作用&#xff1a…

源/目的检查开启导致虚拟IP背后的LVS无法正常访问

情况描述 近期发现48网段主机无法访问8.83这个VIP(虚拟IP),环境是 8.83 绑定了两个LVS实例,然后LVS实例转发到后端的nginx 静态资源;整个流程是,客户端发起对VIP的请求,LVS将请求转发到后端实例…

Oracle大师Roger Cornejo的推荐:使用ASH诊断Oracle解析故障

这篇文章被Oracle大师Roger Cornejo在X平台上推荐(见下图),英文原文在: Diagnosing Parsing Issue with ASH 解析,尤其是硬解析,是非生产性操作,会消耗大量系统资源,导致库缓存争用…