深度学习快速入门系列---损失函数

在深度学习中,损失函数的作用是量化预测值和真实值之间的差异,使得网络模型可以朝着真实值的方向预测,损失函数通过衡量模型预测结果与真实标签之间的差异,反映模型的性能。同时损失函数作为一个可优化的目标函数,通过最小化损失函数来优化模型参数。在本篇文章中,我们介绍一下,深度学习中最常用的几种损失函数:

目录

一、适用于回归问题的损失函数

1、L1 LOSS

2、L2 LOSS

 3. smooth L1 loss

 二、适用于分类问题的损失函数

1、交叉熵损失函数

 2、Binary Cross-Entropy 交叉熵损失

3、Focal Loss


一、适用于回归问题的损失函数

1、L1 LOSS

L1损失函数也叫作平均绝对误差(MAE),它是一种常用的回归损失函数,是目标值与预测值之差绝对值的和,表示了预测值的平均误差幅度,而不需要考虑误差的方向。总的来说,它把目标值与估计值的绝对差值的总和最小化。L1 LOSS的数学公式为:

 下面演示,在pytorch中使用该函数:

import torch
import torch.nn as nnpredict=torch.randn(2,3);
target=torch.randn(2,3)print("predict:{}".format(predict))
print("target:{}".format(target))#方式一:使用Pytorch内置函数
loss1_fn=nn.L1Loss()
loss1=loss1_fn(predict,target)print("loss1:{}".format(loss1))#方式二:自己按照L1公式实现函数
loss2_fn=torch.abs(target-predict)
loss2=torch.mean(loss2_fn)print("loss2:{}".format(loss2))

从控制台可以看到,使用pytorch内置的L1函数与我们自己实现的L1函数结果相同。

2、L2 LOSS

也被称为均方误差(MSE, mean squared error),它把目标值与估计值的差值的平方和最小化。其数学公式为:

 下面在pytorch中实现MSE损失函数:

import torch
import torch.nn as nn
predict=torch.rand(2,3)
target=torch.rand(2,3)
#使用pytroch内置
loss1_fn=nn.MSELoss()
loss1=loss1_fn(predict,target)
print(loss1)
#自己按照公式实现
loss2_var=predict-target
loss2_var=loss2_var**2
loss2=torch.mean(loss2_var)
print(loss2)

 3. smooth L1 loss

在提到smooth L1 loss之前,很有必要提一下L1、L2损失函数的优缺点:

L1损失函数的导数公式如下:

在这里插入图片描述

 L2损失函数的导数公式如下:

在这里插入图片描述

 smooth L1的公式(其中公式中的x表示,预测值与真实值的差值的绝对值)

  smooth L1的导数:

  从上图中可以看出L1损失函数具有如下优缺点: 

  • 优点:无论对于什么样的输入值,都有着稳定的梯度,不会导致梯度爆炸问题,具有较为稳健性的解
  • 缺点:在中心点是折点,不能求导,梯度下降时要是恰好学习到w=0就没法接着进行了

L2损失函数: 

  • 优点:各点都连续光滑,方便求导,具有较为稳定的解
  • 缺点:不是特别的稳健,因为当函数的输入值距离真实值较远的时候,对应loss值很大在两侧,则使用梯度下降法求解的时候梯度很大,可能导致梯度爆炸

尽管L1收敛速度比L2损失函数要快,并且能提供更大且稳定的梯度,但是L1有致命的缺陷:导数不连续,导致求解困难,在训练后期损失函数将在稳定值附近波动,难以继续收敛达到更高精度。这也导致L1损失函数极其不受欢迎。使用MAE损失(特别是对于神经网络来说)的一个大问题就是,其梯度始终一样:

  1. 这意味着梯度即便是对于很小的损失值来说,也还会非常大,会出现难以收敛的问题;而 MSE 当损失变小的时候,梯度也会变小,从而更容易收敛

  2. 为了修正这一点,我们可以使用动态学习率,它会随着我们越来越接近最小值而逐渐变小。

  3. 在这种情况下,MSE会表现的很好,即便学习率固定,也会收敛。MSE损失的梯度对于更大的损失值来说非常高,当损失值趋向于0时会逐渐降低,从而让它在模型训练收尾时更加准确

而Smooth L1 Loss 是在 MAE 和 MSE 的基础上进行改进得到的;在 Faster R-CNN 以及 SSD 中对边框的回归使用的损失函数都是Smooth L1 作为损失函数。

仔细看上面的图像,smooth L1各点连续,在x较小时,对x的梯度也会变小,x很大时,对x的梯度的绝对值达到上限1,也不会太大导致训练不稳定。smooth L1避开L1和L2损失的缺陷

公式以及图像中可以看出,Smooth L1 Loss 从两个方面限制梯度:

  1. 当预测框与 ground truth 差别过大时,梯度值不至于过大,防止梯度爆炸

  2. 当预测框与 ground truth 差别很小时,梯度值足够小,有利于收敛;

Smooth L1 的优点是结合了 L1 和 L2 Loss:

  1. 相比于L1损失函数,可以收敛得更快;

  2. 相比于L2损失函数,对离群点、异常值不敏感,梯度变化相对更小,训练时不容易跑飞

下面在pytorch中实现smooth L1损失函数:

import torch
import torch.nn as nnpredict=torch.randn(2,3)
target=torch.randn(2,3)loss1_fn=nn.SmoothL1Loss()
loss1=loss1_fn(predict,target)
print(loss1)def smooth_l1_loss(x,y,beta=1):diff=torch.abs(x-y)loss2=torch.where(diff<beta,0.5*diff**2/beta,diff-0.5*beta)return loss2.mean()
loss2=smooth_l1_loss(predict,target)
print(loss2)

 二、适用于分类问题的损失函数

1、交叉熵损失函数

交叉熵损失函数Cross-Entropy Loss Function)一般用于分类问题。假设样本的标签y ∈ {1, · · · C}为离散的类别,模型f(x, θ) [0, 1] 的输出为类别标签的条件概率分布,即

 并满足

我们可以用一个C 维的one-hot向量y来表示样本标签。假设样本的标签为k,那么标签向量y只有第k 维的值为1,其余元素的值都为0。标签向量y可以看作是样本标签的真实概率分布,即第c维(记为yc1 c C)是类别为c的真实概率。假设样本的类别为k,那么它属于第k 类的概率为1,其它类的概率为0。 对于两个概率分布,一般可以用交叉熵来衡量它们的差异。标签的真实分布y和模型预测分布f(x, θ)之间的交叉熵为:

 比如对于三类分类问题,一个样本的标签向量为y = [0, 0, 1]T,模型预测的标签分布为f(x, θ) = [0.3, 0.3, 0.4]T,则它们的交叉熵为:

因为 y one-hot 向量,因此交叉熵损失函数公式也可以写为:

 其中fy(x, θ)可以看作真实类别y 的似然函数。因此,交叉熵损失函数也就是负对数似然损失函数(Negative Log-Likelihood Function)。

import randomimport torch
import torch.nn as nn
predict=torch.randn(2,3)
#随机生成标签
target=torch.tensor([random.randint(0,2) for _ in range(2)],dtype=torch.long)
# print(target)
#方式一:使用torch中定义好的函数
loss_fn=nn.CrossEntropyLoss()
loss=loss_fn(predict,target)
print(loss)
#方式二:自己按照公式实现
def cross_entropy_loss(predict,label):prob=nn.functional.softmax(predict,dim=1)log_prob=torch.log(prob)label_view=label.view(-1,1)loss=-log_prob.gather(1,label_view)loss_mean=loss.mean()return loss_mean
loss2=cross_entropy_loss(predict,target)
print(loss2)

 2、Binary Cross-Entropy 交叉熵损失

Binary Cross-Entropy 交叉熵损失用于二分类问题,它其实就是交叉熵损失函数Cross-Entropy Loss Function)的特例,也就是将多分类任务的特例化,变成二分类任务。这里不在做赘述。

3、Focal Loss

该损失函数由《Focal Loss for Dense Object Detection》论文首次提出,当时提出的背景是为了解决目标检测领域的突出问题:

Two-stage 的目标检测算法准确率高,但是速度比较慢;One-stage 的目标检测算法速度虽然快很多,但是准确率比较低

想要提高 One-stage 方法的准确率,就要找到其原因,作者提出的一个原因就是正负样本不均衡

  1. Focal loss 损失函数是为了解决 one-stage 目标检测中正负样本极度不平衡的问题;

    • 目标检测算法为了定位目标会生成大量的anchor box

    • 而一幅图中目标(正样本)个数很少,大量的anchor box处于背景区域(负样本),这就导致了正负样本极不平衡

  2. two-stage 的目标检测算法这种正负样本不平衡的问题并不突出,原因:

    • two-stage方法在第一阶段生成候选框,RPN只是对anchor box进行简单背景和前景的区分,并不对类别进行区分

    • 经过这一轮处理,过滤掉了大部分属于背景的anchor box,较大程度降低了anchor box正负样本的不平衡性

    • 同时在第二阶段采用启发式采样(如:正负样本比1:3)或者OHEM进一步减轻正负样本不平衡的问题

  3. One-Stage 目标检测算法,不能使用采样操作

    • Focal loss 就是专门为 one-stage 检测算法设计的,将易区分负例的 loss 权重降低

    • 使得网络不会被大量的负例带偏

focal loss,是在标准交叉熵损失基础上修改得到的。这个函数可以通过减少易分类样本的权重,使得模型在训练时更专注于难分类的样本。为了证明focal loss的有效性,作者设计了一个dense detector:RetinaNet,并且在训练时采用focal loss训练。实验证明RetinaNet不仅可以达到one-stage detector的速度,也能有two-stage detector的准确率。在了解该损失函数的公式之前先了解一下在分类问题中常用到的交叉熵函数,交叉熵公式可以表示为:

 为了方便,将交叉熵公式写为:

为了解决正负样本不平衡的问题,我们通常会在交叉熵损失的前面加上一个参数 \alpha

既然在 One-stage 方法中,正负样本不均衡是存在的问题,那么一个比较常见的算法就是给正负样本加上权重:增大正样本的权重,减小负样本的权重

通过设定 \alpha 的值来控制正负样本对总的 loss 的共享权重;上面的方法虽然可以控制正负样本的权重,但是无法控制容易分类和难分类样本的权重。因此就设计了 Focal Loss。其公式为:

 其中

\gamma为常数,称之为 focusing parameter (\gamma ≥ 0),当 \gamma=0 时,Focal Loss 就与一般的交叉熵损失函数一致; (1-p_{t})^{\gamma }称之为调制系数,目的是通过减少易分类样本的权重,从而使得模型在训练时更专注于难分类的样本。当 \gamma 取不同的值,Focal Loss 曲线如下图所示,其中横坐标是 p_{t} 纵坐标是 loss

通过\gamma参数,解决了难易样本分类的难题,但是我们通常还会在Focal Loss 的公式前面再加上一个参数 \alpha用于解决正负样本不平衡的问题:

 实验表明\gamma 取2, \alpha取0.25的时候效果最佳。

Focal Loss实现:

def py_sigmoid_focal_loss(pred,target,weight=None,gamma=2.0,alpha=0.25,reduction='mean',avg_factor=None):pred_sigmoid = pred.sigmoid()target = target.type_as(pred)pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)focal_weight = (alpha * target + (1 - alpha) *(1 - target)) * pt.pow(gamma)loss = F.binary_cross_entropy_with_logits(pred, target, reduction='none') * focal_weightloss = weight_reduce_loss(loss, weight, reduction, avg_factor)return loss

这个代码很容易理解,先定义一个pt:

 然后计算:

focal_weight = (alpha * target + (1 - alpha) *(1 - target)) * pt.pow(gamma)

也就是这个公式:

 然后再把BCE损失*focal_weight

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/36011.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

10个微服务设计模式

微服务设计模式是一种指导微服务架构设计和开发的一系列原则和实践。微服务设计模式的目的是为了解决微服务架构中遇到的一些常见的问题和挑战&#xff0c;比如服务划分、服务通信、服务治理、服务测试等。微服务设计模式可以帮助我们构建出高效、可靠、可扩展、可维护的微服务…

使用AT命令操作Modem 3G/4G模块

1. 引言 AT命令是一种通信协议&#xff0c;用于控制和配置各种设备&#xff0c;尤其在通信领域中具有重要性。它的名称来源于"ATtention"&#xff08;注意&#xff09;&#xff0c;因为命令通常以"AT"开头。AT命令最早被用于调制解调器&#xff0c;用于与…

springboot整合rabbitmq

rabbitmq的七种模式 Hello word 客户端引入依赖 <!--rabbitmq 依赖客户端--><dependency><groupId>com.rabbitmq</groupId><artifactId>amqp-client</artifactId><version>5.8.0</version></dependency> 生产者 imp…

STM32 LoRa源码解读

目录结构&#xff1a; SX1278 |-- include | |-- fifo.h | |-- lora.h | |-- platform.h | |-- radio.h | |-- spi.h | |-- sx1276.h | |-- sx1276Fsk.h | |-- sx1276FskMisc.h | |-- sx1276Hal.h | |-- sx1276LoRa.h | -- sx1276LoRaMisc.h – src |-- fifo.c |-- lora.c |-- …

【解析postman工具的使用---基础篇】

postman前端请求详解 主界面1.常见类型的接口请求1.1 查询参数的接口请求1.1.1 什么是查询参数?1.1.2 postman如何请求 1.2 ❤表单类型的接口请求1.2.1 复习下http请求1.2.2❤ 什么是表单 1.3 上传文件的表单请求1.4❤ json类型的接口请求 2. 响应接口数据分析2.1 postman的响…

什么是DNS欺骗及如何进行DNS欺骗

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、什么是 DNS 欺骗&#xff1f;二、开始1.配置2.Ettercap启动3.操作 总结 前言 我已经离开了一段时间&#xff0c;我现在回来了&#xff0c;我终于在做一个教…

【AI】p54-p58导航网络、蓝图和AI树实现AI随机移动和跟随移动、靠近玩家挥拳、AI跟随样条线移动思路

p54-p58导航网络、蓝图和AI树实现AI随机移动和跟随移动、靠近玩家挥拳、AI跟随样条线移动思路 p54导航网格p55蓝图实现AI随机移动和跟随移动AI Move To&#xff08;AI进行移动&#xff09;Get Random Pointln Navigable Radius&#xff08;获取可导航半径内的随机点&#xff09…

时序预测 | MATLAB实现基于LSTM长短期记忆神经网络的时间序列预测-递归预测未来(多指标评价)

时序预测 | MATLAB实现基于LSTM长短期记忆神经网络的时间序列预测-递归预测未来(多指标评价) 目录 时序预测 | MATLAB实现基于LSTM长短期记忆神经网络的时间序列预测-递归预测未来(多指标评价)预测结果基本介绍程序设计参考资料 预测结果 基本介绍 Matlab实现LSTM长短期记忆神经…

识别和应对内存抖动

关于作者&#xff1a;CSDN内容合伙人、技术专家&#xff0c; 从零开始做日活千万级APP。 专注于分享各领域原创系列文章 &#xff0c;擅长java后端、移动开发、人工智能等&#xff0c;希望大家多多支持。 目录 一、导读二、概览三、案例分析3.1 使用memory-profiler3.2 使用 cp…

磁粉制动器离合器收放卷应用介绍

张力控制系统的开环闭环应用介绍,请查看下面文章链接: PLC张力控制(开环闭环算法分析)_张力控制plc程序实例_RXXW_Dor的博客-CSDN博客里工业控制张力控制无处不在,也衍生出很多张力控制专用控制器,磁粉制动器等,本篇博客主要讨论PLC的张力控制相关应用和算法,关于绕线…

APP外包开发的iOS开发语言

学习iOS开发需要掌握Swift编程语言和相关的开发工具、框架和技术。而学习iOS开发需要时间和耐心&#xff0c;尤其是对于初学者。通过坚持不懈的努力&#xff0c;您可以逐步掌握iOS开发技能&#xff0c;构建出功能丰富、优质的移动应用。今天和大家分享学习iOS开发的一些建议方法…

【数据结构系列】链表

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kuan 的首页,持续学…

解决hbase节点已下线,但在status中显示为dead问题

工作中需要下线4台hbase小节点&#xff0c;下线完成后使用status 命令查看,有一台为dead状态: 使用status detailed 查看&#xff0c;发现“hd-03"这台节点是dead。 检查各节点配置文件无误&#xff0c;并使用 /opt/hbase/bin/hbase-daemon.sh restart master 重启两个…

less基本使用

1 less中的变量 //对值进行声明 link-color: #ccc//定义变量名称 .{sleName} {}bg: background-color; //定义属性名称 .container {{bg}: red; }2 继承&#xff08;复用重复样式&#xff09; //继承必须位于选择器最后 //继承选择器名不能为变量 .a:hover:extend(.b) {}.a {…

浅谈人工智能技术与物联网结合带来的好处

物联网是指通过互联网和各种技术将设备进行连接&#xff0c;实时采集数据、交互信息的网络&#xff0c;对设备实现智能化自动化感知、识别和控制&#xff0c;给人们带来便利。 人工智能是计算机科学的一个分支&#xff0c;旨在研究和开发能够模拟人类智能的技术和方法。人工智能…

后院失火、持续亏损!Mobileye半年报「不回避」竞争压力

"客户在2023年上半年非常谨慎&#xff0c;导致增长率低于正常水平&#xff0c;但我们已经看到下半年回暖趋势&#xff0c;预计下半年交付将比去年同期增长16%&#xff0c;远高于上半年。"这是Mobileye在近日公司半年报发布会上的预判。 公开数据显示&#xff0c;今年…

2023网络安全常用工具汇总(附学习资料+工具安装包)

几十年来&#xff0c;攻击方、白帽和安全从业者的工具不断演进&#xff0c;成为网络安全长河中最具技术特色的灯塔&#xff0c;并在一定程度上左右着网络安全产业发展和演进的方向&#xff0c;成为不可或缺的关键要素之一。 话不多说&#xff0c;网络安全10款常用工具如下 1、…

Opencv4基于C++基础入门笔记:图像 颜色 事件响应 图形 视频 直方图

效果图◕‿◕✌✌✌&#xff1a;opencv人脸识别效果图(请叫我真爱粉) 先看一下效果图勾起你的兴趣&#xff01; 文章目录&#xff1a; 一&#xff1a;环境配置搭建 二&#xff1a;图像 1.图像读取与显示 main.cpp 运行结果 2.图像色彩空间转换 2.1 换色彩 test.h …

感受RFID服装门店系统的魅力

嘿&#xff0c;亲爱的时尚追随者们&#xff01;今天小编要给你们带来一股时尚新风潮&#xff0c;让你们感受一下什么叫做“RFID服装门店系统”&#xff0c;这个超酷的东西&#xff01; 别着急&#xff0c;先别翻白眼&#xff0c;小编来解释一下RFID是什么玩意儿。它是射频识别…

云计算——存储虚拟化功能

作者简介&#xff1a;一名云计算网络运维人员、每天分享网络与运维的技术与干货。 座右铭&#xff1a;低头赶路&#xff0c;敬事如仪 个人主页&#xff1a;网络豆的主页​​​​​ 目录 前期回顾 前言 一.存储虚拟化功能 1.精简磁盘和空间回收 2.快照 &#xff08;1&a…