还看不懂 DETR 的匈牙利损失函数?4个公式教你理解

看到 DETR 的损失函数的时候,你是否有下面的疑问:

  • 公式中的 σ ∈ S N \sigma \in \mathfrak{S}_N σSN 是什么意思?
  • 公式中的 y ^ σ ( i ) \hat{y}_{\sigma(i)} y^σ(i) 的下标 σ ( i ) \sigma(i) σ(i) 又有什么含义?
  • DETR 的损失函数计算的完整流程又是怎么样的?
  • 为什么计算 box 损失的时候为什么要加上 GIOU 损失

等等问题,都可以在下面的文章中得到解答。

概述

在 DETR 中,进行梯度更新可以分成 2 步:

  1. 使用匈牙利匹配算法,根据优化函数求解集合 y y y y ^ \hat{y} y^ 的最佳匹配:集合 y ^ \hat{y} y^ 的排列 σ ^ \hat{\sigma} σ^ σ ^ = arg min ⁡ σ ∈ S N ∑ i N L match ( y i , y ^ σ ( i ) ) \hat{\sigma}=\argmin_{\sigma \in \mathfrak{S}_N }\sum_i^{N}\mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)}) σ^=σSNargminiNLmatch(yi,y^σ(i)) L match ( y i , y ^ σ ( i ) ) = − 1 { c i ≠ ∅ } p ^ σ ( i ) ( c i ) + 1 { c i ≠ ∅ } L box ( b i , b ^ σ ( i ) ) \mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)}) = -\mathbb{1}_{\{c_i\ne \varnothing\}}\hat{p}_{\sigma(i)}(c_i) + \mathbb{1}_{\{c_i\ne \varnothing\}}\mathcal{L}_{\text{box}}(b_i, \hat{b}_{\sigma(i)} ) Lmatch(yi,y^σ(i))=1{ci=}p^σ(i)(ci)+1{ci=}Lbox(bi,b^σ(i))
  2. 根据集合 y ^ \hat{y} y^ 最佳排列 σ ^ \hat{\sigma} σ^ 带入损失函数中求解损失,并进行梯度更新。 L Hungarian ( y , y ^ ) = ∑ i = 1 N [ − log ⁡ p ^ σ ^ ( i ) ( c i ) + 1 { c i ≠ ∅ } L box ( b i , b ^ σ ^ ( i ) ) ] \mathcal{L}_{\text{Hungarian}}(y, \hat{y}) = \sum_{i=1}^N \left[ -\log\hat{p}_{\hat{\sigma}(i)}(c_i) + \mathbb{1}_{\{c_i \ne \varnothing\}}\mathcal{L}_{\text{box}}(b_i, \hat{b}_{\hat{\sigma}(i)})\right] LHungarian(y,y^)=i=1N[logp^σ^(i)(ci)+1{ci=}Lbox(bi,b^σ^(i))] L box ( b i , b ^ σ ^ ( i ) ) = λ giou L giou ( b i , b ^ σ ^ ( i ) ) + λ L1 ∣ ∣ b i − b ^ σ ^ ( i ) ∣ ∣ 1 \mathcal{L}_{\text{box}}(b_i,\hat{b}_{\hat{\sigma}(i)}) =\lambda_{\text{giou}}\mathcal{L}_{\text{giou}}(b_i,\hat{b}_{\hat{\sigma}(i)}) + \lambda_{\text{L1}}||b_i - \hat{b}_{\hat{\sigma}(i)}||_1 Lbox(bi,b^σ^(i))=λgiouLgiou(bi,b^σ^(i))+λL1∣∣bib^σ^(i)1

可以看出来,其实想要理解 DETR 的损失函数是怎么计算的,只要理解上面的 4 个公式就行了。

第一步:求最佳 σ ^ \hat{\sigma} σ^

σ ^ = arg min ⁡ σ ∈ S N ∑ i N L match ( y i , y ^ σ ( i ) ) \hat{\sigma}=\argmin_{\sigma \in \mathfrak{S}_N }\sum_i^{N}\mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)}) σ^=σSNargminiNLmatch(yi,y^σ(i))(还不算是损失函数,只是通过匈牙利匹配算法求解最优排列的一个优化目标函数)

  • y = { y i } i = 1 N y=\{y_i\}_{i=1}^N y={yi}i=1N:表示 N N N个 ground truth 的集合,其中 y i y_i yi是第 i i i 个 ground truth,当然实际中,集合 y y y 中的 ground truth 数量是远小于 N N N的,为了让 y y y y ^ \hat{y} y^ 两个集合大小一致,在集合 y y y 中会使用 ∅ \varnothing (no object)来对集合进行填充。
  • y ^ = { y ^ i } i = 1 N \hat{y}=\{\hat{y}_i\}_{i=1}^N y^={y^i}i=1N:表示 N N N个 预测的集合,其中 y ^ i \hat{y}_i y^i是第 i i i 个预测。
  • σ \sigma σ:是一种预测值 y ^ \hat{y} y^ 的排列方式,我们知道集合 y y y 与集合 y ^ \hat{y} y^ 要一一匹配,然后进行排列,我们把 y y y 的排列顺序固定,就只需要调整 y ^ \hat{y} y^ 的排列顺序就可以了,而 σ \sigma σ就是表示的集合 y ^ \hat{y} y^ 的某种排列方式, y ^ σ ( i ) \hat{y}_{\sigma(i)} y^σ(i) 也只是表示,在 σ \sigma σ这种排列中,第 i i i 个预测值。
  • S N \mathfrak{S}_N SN:是排列 σ \sigma σ 的集合,也是一种对称群? arg min ⁡ σ ∈ S N \argmin_{\sigma \in \mathfrak{S}_N } argminσSN表示在 S N \mathfrak{S}_N SN内存在一种集合 y ^ \hat{y} y^ 的排列 σ \sigma σ,可以使得匈牙利匹配的 cost 最低。
  • L match ( y i , y ^ σ ( i ) ) \mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)}) Lmatch(yi,y^σ(i)):是 pair-wise matching cost ,一般是使用匈牙利算法进行计算的。

L match ( y i , y ^ σ ( i ) ) = − 1 { c i ≠ ∅ } p ^ σ ( i ) ( c i ) + 1 { c i ≠ ∅ } L box ( b i , b ^ σ ( i ) ) \mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)}) = -\mathbb{1}_{\{c_i\ne \varnothing\}}\hat{p}_{\sigma(i)}(c_i) + \mathbb{1}_{\{c_i\ne \varnothing\}}\mathcal{L}_{\text{box}}(b_i, \hat{b}_{\sigma(i)} ) Lmatch(yi,y^σ(i))=1{ci=}p^σ(i)(ci)+1{ci=}Lbox(bi,b^σ(i))每个 ground truth y i y_i yi都是由两部分信息组成的,类别 + 位置,也可以写成 y i = ( c i , b i ) y_i=(c_i, b_i) yi=(ci,bi),其中:

  • c i c_i ci:表示类别信息(也有可能是空 ∅ \varnothing
  • b i b_i bi:表示位置信息,是一个归一化(值都小于 1 )的向量,有 4 个值,分别表示 box 中心点的坐标和宽高。

对于预测值 y ^ σ ( i ) \hat{y}_{\sigma(i)} y^σ(i),我们将类别和位置信息定义为 y ^ σ ( i ) = ( p ^ σ ( i ) ( c i ) , b ^ σ ( i ) ) \hat{y}_{\sigma(i)}=(\hat{p}_{\sigma(i)}(c_i), \hat{b}_{\sigma(i)}) y^σ(i)=(p^σ(i)(ci),b^σ(i))

  • p ^ σ ( i ) ( c i ) \hat{p}_{\sigma(i)}(c_i) p^σ(i)(ci):我们已经知道了 ground truch 的类别信息 c i c_i ci,这个概率值是通过模型的分类器计算得出的,反映了模型对于该预测值属于类别 c i c_i ci 的确信度。

第二步:求损失

L Hungarian ( y , y ^ ) = ∑ i = 1 N [ − log ⁡ p ^ σ ^ ( i ) ( c i ) + 1 { c i ≠ ∅ } L box ( b i , b ^ σ ^ ( i ) ) ] \mathcal{L}_{\text{Hungarian}}(y, \hat{y}) = \sum_{i=1}^N \left[ -\log\hat{p}_{\hat{\sigma}(i)}(c_i) + \mathbb{1}_{\{c_i \ne \varnothing\}}\mathcal{L}_{\text{box}}(b_i, \hat{b}_{\hat{\sigma}(i)})\right] LHungarian(y,y^)=i=1N[logp^σ^(i)(ci)+1{ci=}Lbox(bi,b^σ^(i))]论文这里还将 b ^ σ ^ ( i ) \hat{b}_{\hat{\sigma}(i)} b^σ^(i)打错成了 b ^ σ ^ ( i ) \hat{b}_{\hat{\sigma}}(i) b^σ^(i)

  • σ ^ \hat{\sigma} σ^是最优的排列,也就是使得整体 cost 最小的 y ^ \hat{y} y^ 排列。
  • log ⁡ \log log:这里存在一个问题,为什么上面的 − 1 { c i ≠ ∅ } p ^ σ ( i ) ( c i ) -\mathbb{1}_{\{c_i\ne \varnothing\}}\hat{p}_{\sigma(i)}(c_i) 1{ci=}p^σ(i)(ci) 在这里就变成了 − log ⁡ p ^ σ ^ ( i ) ( c i ) -\log\hat{p}_{\hat{\sigma}(i)}(c_i) logp^σ^(i)(ci)。一个 no object ∅ \varnothing y y y) 与预测值( y ^ \hat{y} y^)的 L match ( y i , y ^ σ ( i ) ) \mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)}) Lmatch(yi,y^σ(i))匹配代价实际上并不取决于预测值,因为 c i = ∅ c_i = \varnothing ci=的时候,因为指示函数的关系, L match ( y i , y ^ σ ( i ) ) = 0 \mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)}) = 0 Lmatch(yi,y^σ(i))=0,也就是一个常数。在计算匹配代价(cost)的类别代价(cost)的时候,我们使用概率而不是对数概率,因为实际效果更好。
  • 上面的公式是为了求解最优的排列 σ ^ \hat{\sigma} σ^,而这里根据最优的排列 σ ^ \hat{\sigma} σ^ 来求解损失。一般来说,为了解决类间不平衡问题,会在 c i = ∅ c_i = \varnothing ci= 对数概率项前乘以 1 / 10 1/10 1/10来降低权重。

L box ( b i , b ^ σ ^ ( i ) ) = λ giou L giou ( b i , b ^ σ ^ ( i ) ) + λ L1 ∣ ∣ b i − b ^ σ ^ ( i ) ∣ ∣ 1 \mathcal{L}_{\text{box}}(b_i,\hat{b}_{\hat{\sigma}(i)}) =\lambda_{\text{giou}}\mathcal{L}_{\text{giou}}(b_i,\hat{b}_{\hat{\sigma}(i)}) + \lambda_{\text{L1}}||b_i - \hat{b}_{\hat{\sigma}(i)}||_1 Lbox(bi,b^σ^(i))=λgiouLgiou(bi,b^σ^(i))+λL1∣∣bib^σ^(i)1

  • 直接使用 L1 损失:因为 L1 损失是计算绝对值,但是在目标检测中,大目标和小目标,即便是有着相同的相对误差(relative error),其绝对误差,也就是 L1 损失值都会有很大差异,对尺度的支持比较差。
  • 所以为了缓解 L1 损失的尺度不变性比较差的问题,我们引入了 GIOU(关于 GIOU 的部分,我后续有时间再进行补充吧)

后言

还有什么疑问或者问题可以在评论区评论,我会尽可能的解答并更新文章。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/33565.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

arduino按钮

Arduino - Button Arduino - 按钮 参考: ezButton-按钮库从按钮开关看上拉pull-up电阻、下拉电阻按键的防抖动处理 The button is also called pushbutton, tactile button or momentary switch. It is a basic component and widely used in many Arduino projec…

【gif制作】Win下视频生成GIF;工具GifCam单色保存,灰度保存,调速,编辑删除帧添加文本

下载地址 https://blog.bahraniapps.com/gifcam/#download https://gifcam.en.softonic.com/ 界面功能 GifCam 简洁、小巧的 gif 录制软件。GifCam就像照相机一样位于所有窗口的顶部,可以移动它并调整其大小录屏所需的区域。 如图:空闲状态下窗口内…

【java】JUC

5. 阻塞队列 5.1 生产者消费者概念 生产者消费者是设计模式的一种。让生产者和消费者基于一个容器来解决强耦合问题。 生产者 消费者彼此之间不会直接通讯的,而是通过一个容器(队列)进行通讯。 所以生产者生产完数据后扔到容器中&#xff0c…

给XPTABLE添加右键菜单(XPTable控件使用说明十一)

用户右键点击TABLE控件,弹出一个菜单,选择菜单对应到相关的操作 1、增加一个contextMenuStrip6控件,在里面增加2个ITEM,名称用中文命名 2、给两个ITEM添加点击后的事件 3、在XPTABLE上增加点击事件: 4、当用户右键点击…

vantUI upload 上传组件v-model绑定问题

直接绑定一个数组会有问题,删除失效/上传不了等等 解决在v-model绑定的数组外包一个对象即可

零基础MySQL完整学习笔记

零基础MySQL完整学习笔记 1. 基础操作(必须会!)1.1 修改密码(4种方法)1.2 创建新用户1.3 导入数据库 2. SQL四种语言介绍2.1 DDL(数据库定义语言)2.2 DML(数据操纵语言)2.3 DCL(数据库控制语言)2.4 TCL(事务控制语言) 3. 数据库操作3.1 创建数据库3.2 查询数据库3.3 删除数据库…

聊聊 oracle varchar2 字段的gbk/utf8编码格式和字段长度问题

聊聊 oracle varchar2 字段的gbk/utf8编码格式和字段长度问题 1 问题现象 最近在排查某客户现场的数据同步作业报错问题时,发现了部分 ORACLE 表的 varchar2 字段,因为上游 ORACLE数据库采用 GBK 编码格式,而下游 ORACLE 数据库采用UTF8 编…

封装了一个优雅的iOS转场动画

效果图 代码 // // LBTransition.m // LBWaterFallLayout_Example // // Created by mac on 2024/6/16. // Copyright © 2024 liuboliu. All rights reserved. //#import "LBTransition.h"interface LBPushAnimation:NSObject<UIViewControllerAnimated…

【服务器02】之【阿里云平台】

百度一下阿里云官网 点击注册直接使用支付宝注册可以跳过认证 成功登录后&#xff0c;点击产品 点击免费试用 点击勾选 选一个距离最近的 点满GB 注意&#xff1a;一般试用的时用的是【阿里云】&#xff0c;真正做项目时用的是【腾讯云】 现在开始学习使用&#xff1a; 首先…

串口接收不定长数据实现思路

目录 帧头帧尾标志法&#xff1a; 长度字段法&#xff1a; 超时等待法&#xff1a; 基于STM32串口中断的方法&#xff1a; 基于回调函数的方法&#xff1a; 基于定长数据的方法&#xff08;如果数据包长度固定且已知&#xff09;&#xff1a; 串口实现不定长数据接收通常…

2024年综合艺术与媒体传播国际会议(ICIAMC 2024)

2024年综合艺术与媒体传播国际会议(ICIAMC 2024) 2024 International Conference on Integrated Arts and Media Communication (ICIAMC 2024) 会议地点&#xff1a;贵阳&#xff0c;中国 网址&#xff1a;www.iciamc.com 邮箱: iciamcsub-conf.com 投稿主题请注明:ICIAMC…

Java中如何处理ArithmeticException异常?

Java中如何处理ArithmeticException异常&#xff1f; 大家好&#xff0c;我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编&#xff0c;也是冬天不穿秋裤&#xff0c;天冷也要风度的程序猿&#xff01; 在Java编程中&#xff0c;ArithmeticException异常是开发…

【Python机器学习】DBSCAN(具有噪声的基于密度的空间聚类应用)

DBSCAN&#xff08;具有噪声的基于密度的空间聚类应用&#xff09;是一种非常有用的聚类算法&#xff0c;它的主要优点是不需要用户先验地设置簇的个数&#xff0c;可以划分具有复杂形状的簇&#xff0c;还可以找出不属于任何簇的点。DBSCAN比凝聚聚类和k均值稍慢&#xff0c;但…

常见加密方式:MD5、DES/AES、RSA、Base64

16/32位的数据&#xff0c;最有可能就是使用md5加密的 使用对称加密的时候&#xff0c;双方使用相同的私钥 私钥&#xff1a;单独请求/隐藏在前端的隐藏标签当中 二、RSA非对称密钥加密 公钥加密&#xff0c;私钥解密 私钥是通过公钥计算生成的 加密解密算法都在js源文件当…

简单了解java中的File类

1、File类 1.1、概述 File对象就表示一个路径&#xff0c;可以是文件路径也可以是文件夹路径&#xff0c;这个路径可以 是存在的&#xff0c;也可以是不存在的。 1.2、常见的构造方法 方法名称说明public File&#xff08;String pathname&#xff09;根据文件路径创建文件…

0620# C++八股记录

如何防止头文件被重复包含 1. 使用宏定义&#xff08;Include Guards&#xff09; #ifndef HEADER_FILE_NAME_H #define HEADER_FILE_NAME_H// 头文件的内容#endif // HEADER_FILE_NAME_H例如&#xff0c;假设有一个头文件名为example.h&#xff0c;可以这样编写&#xff1a;…

U盘数据恢复全攻略:从原理到实践

一、引言&#xff1a;为何U盘数据恢复至关重要 在信息化时代&#xff0c;U盘作为便携存储设备&#xff0c;广泛应用于各个领域。然而&#xff0c;U盘数据的丢失往往给个人和企业带来极大的困扰。数据丢失的原因多种多样&#xff0c;可能是误删除、格式化、文件系统损坏&#x…

session 共享、Nginx session 共享、Token、Json web Token 【JWT】等认证

.NET JWT JWT 》》Json Web Token header . payload . Signature 三部分组成 JWT 在线生成 》》 https://jwt.io/ 》》https://tooltt.com/jwt-encode/ 》》解码工具 https://tool.box3.cn/jwt.html JWT 特点 无状态 JWT不需要在服务端存储任何状态&#xff0c;客户端可以携…

【FFMPEG+Mediamtx】 本地RTSP测试推流记录

利用本地FFMPEGMediamtx 搭建本地RTSP测试推流电脑摄像头 起因 本来要用qt的qml的Video做摄像头测试。 &#x1f614;但是&#xff0c;不在现场&#xff0c;本地测试&#xff0c;又要测试rtsp流&#xff0c;又因为搜了一圈找不到一个比较好的在线测试rtsp推流网址&#x1f6…

自从用了这个 69k star 的项目,前端小姐姐再也不催我了

一般在开发前后端分离的项目时&#xff0c;双方会定义好前后端交互的 http 接口&#xff0c;根据接口文档各自进行开发。这样并行开发互不耽误&#xff0c;开发好后做个联调就可以提测了。 不过最近也不知道怎么回事&#xff0c;公司新来的前端小姐姐总是在刚开始开发的时候就…