YOLOv8训练流程-原理解析[目标检测理论篇]

          关于YOLOv8的主干网络在YOLOv8网络结构介绍-CSDN博客介绍了,为了更好地学习本章内容,建议先去看预测流程的原理分析YOLOv8原理解析[目标检测理论篇]-CSDN博客,再次把YOLOv8网络结构图放在这里,方便随时查看。

1.前言

          YOLOv8训练流程这一块内容还是比较复杂的,所以先来谈一下训练流程的思路,一共就两步:第一步就是从网络预测的结果中找到正样本,并且确定正样本要预测的对象;第二步就是计算预测结果和标签之间的损失,分别计算预测框的损失以及预测类别的损失。

         如下图所示,假设一张图片中只有3个标签,那么需要从8400个Grid cell中找到这3个标签对应的正样本,然后通过计算正样本的预测值和标签值之间的损失,最后通过损失的反向传播更新模型的权值和偏差。

 

        为了更好地理解YOLOv8或者说是YOLO系列网络,需要对Grid cell建立概念,如下所示:

        首先可以看到通过网络输出的三个特征图的分辨率分别为:80*80,40*40,20*20,本文所说的Grid cell即为图中的红点、蓝点以及黄点,从图中可以得到以下信息:第一,红点代表的Grid cell是80*80分辨率中每个像素的中心点,因为红色Grid cell比较密集并且可以x8将红点映射回原图,所以80*80分辨率的特征图Grid cell是用来训练小目标的,蓝色和黄色Grid cell同理;第二,如果8400个Grid cell全部当成正样本的话是不实际的,所以必须从这8400个Grid cell中选出一些正样本;第三,由于YOLOv8是Anchor Free的模型,所以会将这三个尺度的特征图展开变成长度为8400的一维向量。

2. Task Aligned Assigner

        Task Aligned Assigner中文翻译为任务对齐分配器,是一种正负样本分配策略,也就是找正样本的方法,也就是训练流程中的第一步。

         在正式开始找正样本之前,需要先把网络预测值Box和Cls解码,同时也需要把标签的Box和Cls解码,过程如下图所示:首先是网络预测结果Pred的Box需要解码成4维,用来预测LTRB的(解码过程在预测原理第三章有提到YOLOv8预测流程-原理解析[目标检测理论篇]-CSDN博客),另外还需要转换为XYXY格式且预测的坐标值是相对于网络输入尺寸的(即640*640);Cls只需要使用Sigmoid()解码就行。其次是标签Target的解码,其实只有Box需要解码,为了和Pred的解码格式保持一致,需要将XYWH格式转换为XYXY,并且标签值对应的坐标是相对于网络输入尺寸(即640*640)。

        然后正式开始找正样本了,假设一张图片上只有一个GT Box,使用红色框作为标记。由于已经将三个特征图下的grid cell都转换到640*640坐标系了,结合GT框的位置和大小,找到合适的中心点作为训练的正样本,这就是TaskAlignedAssigner的任务,一共分成三步,即初步筛选,精细筛选,剔除多余三个步骤。

        (1)初步筛选:即select_candidates_in_gts,转换之后的Grid Cell落在GT Box内部,作为初步筛选的正样本,如图中所示红色点为初步筛选的Grid Cell,而落在GT Box外部的点或者落在GT Box角上或者边上的都需要过滤掉,如蓝色点所示;经过初步筛选,图2中9个红色的点作为初筛后留下的正样本点。

 (2)精细筛选:即get_box_metrics,select_topk_candidates,通过公式align_metric=s^α∗u^β(s和u分别表示分类得分和CIoU得分, a和b是权重系数,默认值分别0.5和6.0),计算出每个预测框的得分,然后把得分低的预测框给过滤掉,一般会取得分最高的top10个gird cell。

        其中分类得分,取的在是GT Box内对应的类别的预测值,比如该GT的类别下标为1,那么落在GT box内的点所预测的类别下标为1时的置信度。另外计算IoU使用的是CIoU,计算公式和计算过程如下所示,如何理解CIoU呢,IoU并无法充分表示预测框和标注框之间的关系,需要引入中心点距离,以及最小矩形框斜边距离,通过这两者的比值来表示预测框和标注框的相似度。所以会在IoU的基础上减去该比值,再减去由预测框宽高和标注框宽高组成的式子。

        (3)剔除多余:保证一个Grid Cell只预测一个GT框,如果一个Grid Cell同时匹配到两个GT Box,那么将从这两个GT中,选出与他CIoU值最大的一个作为他要预测的GT Box。如图所示,Grid Cell A、B、C负责预测GT1,包括预测GT1的类别和Box,而Grid Cell D负责预测GT2,也是预测类别和Box。 

3.Loss

         YOLOv8的Loss由三部分组成:Loss_box,Loss_cls,Loss_DFL分别表示回归框损失,类比损失和DFL损失(其实也是回归框的损失),下面会详细介绍这三种损失。

        还是先来简单了解下Loss计算的思路,如下图所示:左边Target表示标签值,右边Pred表示预测值,均需要借助上一章找到的正样本,然后通过对比同一个Grid cell正样本的预测值和标签值,计算对应的Loss。 

        先来看一下get_targets函数做了哪些处理,GT_Box是经过预处理的,(1,3,4)表示XYXY且相对于640*640尺度的坐标。GT_Cls没有经过处理,表示GT_Box1、 GT_Box2、 GT_Box3的类别。

    假设GT1的box是(x0,y0,X0,Y0),cls是0; GT2的box是(x1,y1,X1,Y1),cls是1;根据找到的正样本和负样本来举个例子,其中负样本为E,正样本为A/B/C/D.由图可以看到Target_Score在这一步已经区分了正负样本了,其中负样本使用[0,0]来表示。而Target_Bbox并没有区分正负样本,负样本统统会选择第1个GT的Box作为其Target_Bbox,换句话说,Target_Bbox值为[x0,y0,X0,Y0]的Grid cell可能为正样本也可能为负样本。

        

3.1Loss_cls 

         下面是YOLOv8中计算Loss_cls的代码:

target_scores_sum = max(target_scores.sum(), 1)loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum   # BCE

        而主要的部分是Loss_cls采用了BCELoss损失函数,损失计算公式如下(注意:YOLOv8中的Cls使用的是BCEWithLogitsLoss,传入的预测值是不需要自己进行Sigmoid,损失内部会自动进行sigmoid,但我这里演示使用的是BCELoss):

        假设当前只有两个类别,取出其中三个Grid cell的值,其中(0,0)表示负样本,(0,1)和(1,0)表示正样本,经过Normalize后得到带有权重的真实标签,这里正负样本均计算Loss.

3.2Loss_box

        下面是YOLOv8中计算Loss_box的代码:

iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum

         而主要的部分是Loss_box采用了CIoULoss损失函数,损失计算公式如下:

        由前面可知,经过get_targets后的Target_Bbox并没有区分正负样本,因此下一步将利用fg_mask来区分正负样本,从而得到30个正样本。对Box会求两个损失,所以有Target_Bbox1和Target_Bbox2,都需要还原到各自的特征图的比例进行计算(可能这样数字比较小计算比较方便),并且分别采用XYXY格式和LTRB格式表示。

        另一方面,Pred_Box1需要通过网络预测的结果(1,64,8400)解码成(1,4,8400)并采用XYXY坐标的格式表示,并且找到对应的30个正样本和Target_Bbox1计算CIoU损失;Pred_Box2则是直接把网络预测的结果(1,64,8400)取出来,然后找到对应的30个正样本,和Target_Bbox1计算DFL损失。

         这里再说一下为什么会是30个正样本,因为有3个GT,每个GT取top10个得分最高的grid cell,并且这30个中没有因重复而被过滤掉的Grid cell。

3.3Loss_DFL 

         下面是YOLOv8中计算Loss_DFL的代码:

loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weightloss_dfl = loss_dfl.sum() / target_scores_sum

        而主要的部分是Loss_DFL,损失计算公式如下:

        下面演示一个Grid cell正样本LTRB的计算过程:

        首先,Pred_Box2即Pred_dist,是一个(120,16)的矩阵,可以理解为(30*4,16),即共有30个正样本,每个正样本需要预测LTRB四个数值,并且这四个数又分别通过0~15来表示。其次,Target_Bbox2即Target,是一个(120,1)的向量,分别对应着30个样本中每个样本的LTRB真实值。最后,由于Target一般不会是整数值,所以需要计算相邻的两个真实值对应的损失。损失函数使用Cross_entropy损失.

        前面提到了由于Target一般不会是整数值,所以需要计算相邻的两个真实值对应的损失,那么如何选择呢?这两个损失之间的权重又是怎么样的呢?为了加深理解,又单独举例演示该Grid cell中的Top_loss是怎么计算的:

        该正样本需要对GT对应的LTRB中的T为例,该正样本的中心点距离上边框是7.29像素,因为网络预测只能是0~15的整数,那么只能选择7和8这两个相邻的值作为标签值,即yi=7和yi+1=8。接下来是选择这两个损失的权重,遵循一个原则:离得越近权重越大,所以当计算标签为7的时候,选择权重0.71,即yi+1-y;而计算标签为8的时候,选择权重0.21,即y-yi+1。

    前面也提到了损失函数使用Cross_entropy损失,和BCE损失有两点区别,第一是把网络预测的每个正样本的LTRB值都需要进行SoftMax(),使得∑value=1,这和预测的时候是一样的;第二是只选取标签值对应的值作为损失,比如在该正样本预测Top的损失计算中,有7和8两个标签值,那么7对应的损失值为1.4676,即-Log(Si),8对应的损失值为1.6825,即-log(Si+1)。最后该正样本的Loss_Top为1.53,该正样本的总损失为(Loss_Left+ Loss_Top+ Loss_Right+ Loss_Bottom)/4.

        训练过程的原理会稍微复杂,先整理成这样子,后面我再优化下表达,争取每个人都可以看得懂。

         

       

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/10613.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Map中KEY去除下划线并首字母转换为大写工具类

在运维旧项目时候&#xff0c;碰上sql查询结果只能返回List<Map>&#xff0c;key为表单字段名&#xff0c;value为获取到的结果数据。 懒得一个一个敲出来&#xff0c;就直接写个方法转换&#xff0c;并赋值到相应实体对象里去。 Map中KEY去除下划线并首字母转换为大写&…

算法提高之矩阵距离

算法提高之矩阵距离 核心思想&#xff1a;多源bfs 从多个源头做bfs&#xff0c;求距离 先把所有1的坐标存入队列 再把所有1连接的位置存入 一层一层求 #include <iostream>#include <cstring>#include <algorithm>using namespace std;const int N 1…

Kafka 面试题(八)

1. Kafka&#xff1a;硬件配置选择和调优的建议 &#xff1f; Kafka的硬件配置选择和调优是确保Kafka集群高效稳定运行的关键环节。以下是一些建议&#xff1a; 硬件配置选择&#xff1a; 内存&#xff08;RAM&#xff09;&#xff1a;建议至少使用32GB内存的服务器。为Kafk…

Web3Tools - 助记词生成

Web3Tools - 助记词生成工具 本文介绍了一个简单的助记词生成工具&#xff0c;使用 React 和 Material-UI 构建。用户可以选择助记词的语言和长度&#xff0c;然后生成随机的助记词并显示在页面上 功能介绍 选择语言和长度&#xff1a; 用户可以在下拉菜单中选择助记词的语言&…

uniapp 图片添加水印代码封装(优化版、图片上传压缩、生成文字根据页面自适应比例、增加文字背景色

uniapp 图片添加水印代码封装(优化版、图片上传压缩、生成文字根据页面自适应比例、增加文字背景色 多张照片上传封装 <template><view class"image-picker"><uni-file-picker v-model"imageValue" :auto-upload"false" :title…

关于服务端接口知识的汇总

大家好&#xff0c;今天给大家分享一下之前整理的关于接口知识的汇总&#xff0c;对于测试人员来说&#xff0c;深入了解接口知识能带来诸多显著的好处。 一、为什么要了解接口知识&#xff1f; 接口是系统不同模块之间交互的关键通道。只有充分掌握接口知识&#xff0c;才能…

http-server实现本地服务器

要实现一个本地服务器&#xff0c;你可以使用Node.js的http-server模块。首先&#xff0c;确保你已经安装了Node.js和npm。然后&#xff0c;按照以下步骤操作&#xff1a; 打开终端或命令提示符&#xff0c;进入你想要作为服务器根目录的文件夹&#xff1b;运行以下命令安装ht…

Axure PR 10 制作顶部下拉三级菜单和侧边三级菜单教程和源码

在线预览地址&#xff1a;Untitled Document 2.侧边三级下拉菜单 在线预览地址&#xff1a;Untitled Document 文件包和教程下载地址&#xff1a;https://pan.quark.cn/s/77e55945bfa4 程序员必备资源网站&#xff1a;天梦星服务平台 (tmxkj.top)

Linux x86_64 dump_stack()函数基于FP栈回溯

文章目录 前言一、dump_stack函数使用二、dump_stack函数源码解析2.1 show_stack2.2 show_stack_log_lvl2.3 show_trace_log_lvl2.4 dump_trace2.5 print_context_stack 参考资料 前言 Linux x86_64 centos7 Linux&#xff1a;3.10.0 一、dump_stack函数使用 dump_stack函数…

Unity开发中导弹路径散射的原理与实现

Unity开发中导弹路径散射的原理与实现 前言逻辑原理代码实现导弹自身脚本外部控制脚本 应用效果结语 前言 前面我们学习了导弹的追踪的效果&#xff0c;但是在动画或游戏中&#xff0c;我们经常可以看到导弹发射后的弹道是不规则的&#xff0c;扭扭曲曲的飞行&#xff0c;然后击…

数字生态系统的演进与企业API管理的关键之路

数字生态系统的演进与企业API管理的关键之路 在数字化时代&#xff0c;企业正经历着一场转型的浪潮&#xff0c;而API&#xff08;应用程序编程接口&#xff09;扮演着至关重要的角色。API如同一座桥梁&#xff0c;将组织内部的价值转化为可市场化的产品&#xff0c;从而增强企…

韩国站群服务器在全球网络架构中的重要作用?

韩国站群服务器在全球网络架构中的重要作用? 在全球互联网的蓬勃发展中&#xff0c;站群服务器作为网络架构的核心组成部分之一&#xff0c;扮演着至关重要的角色。韩国站群服务器以其卓越的技术实力、优越的地理位置、稳定的网络基础设施和强大的安全保障能力&#xff0c;成…

LeetCode 题目 118:杨辉三角

题目描述 给定一个非负整数 numRows&#xff0c;生成杨辉三角的前 numRows 行。在杨辉三角中&#xff0c;每个数是它左上方和右上方的数的和。 杨辉三角解析 在这个详解中&#xff0c;我们将使用 ASCII 图形来说明杨辉三角的构建过程&#xff0c;包括逐行添加新的行的过程。…

250 基于matlab的5种时频分析方法((短时傅里叶变换)STFT

基于matlab的5种时频分析方法&#xff08;(短时傅里叶变换)STFT,Gabor展开和小波变换,Wigner-Ville&#xff08;WVD&#xff09;,伪Wigner-Ville分布(PWVD),平滑伪Wigner-Ville分布&#xff08;SPWVD&#xff09;,每条程序都有详细的说明&#xff0c;设置仿真信号进行时频输出。…

Parted分区大容量磁盘

创建了新的虚拟磁盘10T , 挂载后分区格式化一.fdisk无法创建大容量的分区 Fileserver:~ # fdisk /dev/sdb Welcome to fdisk (util-linux 2.29.2). Changes will remain in memory only, until you decide to write them. Be careful before using the write command. Device …

使用html和css实现个人简历表单的制作

根据下列要求&#xff0c;做出下图所示的个人简历&#xff08;表单&#xff09; 表单要求 Ⅰ、表格整体的边框为1像素&#xff0c;单元格间距为0&#xff0c;表格中前六列列宽均为100像素&#xff0c;第七列 为200像素&#xff0c;表格整体在页面上居中显示&#xff1b; Ⅱ、前…

git提交代码异常报错error:bad signature 0x00000000

报错信息 error:bad signature 0x00000000 异常原因 git 提交过程中异常关机或重启&#xff0c;造成当前项目工程中的.git/index 文件损坏&#xff0c;无法提交 解决步骤 删除.git/index文件 rm -f .git/index 重启git git reset

Java 【数据结构】 哈希(Hash超详解)HashSetHashMap【神装】

登神长阶 第十神装 HashSet 第十一神装 HashMap 目录 &#x1f454;一.哈希 &#x1f9e5;1.概念 &#x1fa73;2.Object类的hashCode()方法: &#x1f45a;3.String类的哈希码: &#x1f460;4.注意事项: &#x1f3b7;二.哈希桶 &#x1fa97;1.哈希桶原理 &#x…

Bert基础(二十二)--Bert实战:对话机器人

一 、概念简介 1.1 生成式对话机器人 1.1.1什么是生成式对话机器人? 生成式对话机器人是一种能够通过自然语言交互来理解和生成响应的人工智能系统。它们能够进行开放域的对话,即在对话过程中,机器人可以根据用户的需求和上下文信息,自主地生成新的、连贯的回复,而不仅…

如何使用CertCrunchy从SSL证书中发现和识别潜在的主机名称

关于CertCrunchy CertCrunchy是一款功能强大的网络侦查工具&#xff0c;该工具基于纯Python开发&#xff0c;广大研究人员可以利用该工具轻松从SSL证书中发现和识别潜在的主机信息。 支持的在线源 该工具支持从在线源或给定IP地址范围获取SSL证书的相关数据&#xff0c;并检索…