yolov8源码解读Detect层

yolov8源码解读Detect层

  • Detect层解读
  • 网络各层解读及detect层后的处理

关于网络的backbone,head,以及detect层后处理,可以参考文章结尾博主的文章。

Detect层解读

先贴一下全部代码,下面一一解读。

class Detect(nn.Module):"""YOLOv8 Detect head for detection models."""dynamic = False  # force grid reconstructionexport = False  # export modeshape = Noneanchors = torch.empty(0)  # initstrides = torch.empty(0)  # initdef __init__(self, nc=80, ch=()):"""Initializes the YOLOv8 detection layer with specified number of classes and channels."""super().__init__()self.nc = nc  # number of classesself.nl = len(ch)  # number of detection layersself.reg_max = 16  # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)self.no = nc + self.reg_max * 4  # number of outputs per anchorself.stride = torch.zeros(self.nl)  # strides computed during buildc2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100))  # channelsself.cv2 = nn.ModuleList(nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()def forward(self, x):"""Concatenates and returns predicted bounding boxes and class probabilities."""shape = x[0].shape  # BCHW# print(">>>>", x[0].shape)# print(">>>>", x[1].shape)# print(">>>>", x[2].shape)for i in range(self.nl):x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)if self.training:return xelif self.dynamic or self.shape != shape:self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))self.shape = shapex_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'):  # avoid TF FlexSplitV opsbox = x_cat[:, :self.reg_max * 4]cls = x_cat[:, self.reg_max * 4:]else:box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.stridesif self.export and self.format in ('tflite', 'edgetpu'):# Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5:# https://github.com/ultralytics/yolov5/blob/0c8de3fca4a702f8ff5c435e67f378d1fce70243/models/tf.py#L307-L309# See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695img_h = shape[2] * self.stride[0]img_w = shape[3] * self.stride[0]img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1)dbox /= img_size# print(cls.shape)y = torch.cat((dbox, cls.sigmoid()), 1)# print(y.shape)return y if self.export else (y, x)
	dynamic = False #这个属性指示网格(通常是特征图上的锚框网格)是否需要动态地重建export = False  #这个属性用于指示模型是否处于导出模式。shape = None # 用于存储输入图像或特征图的尺寸。anchors = torch.empty(0)  # 创建了一个空的PyTorch张量strides = torch.empty(0)

步长(strides)是卷积神经网络中特征图相对于输入图像的缩小比例。
例如,如果步长是32,那么一个32x32像素的区域在特征图上就对应一个单元。
和anchors一样,这里的torch.empty(0)表示步长尚未初始化。

    def __init__(self, nc=80, ch=()):"""Initializes the YOLOv8 detection layer with specified number of classes and channels."""super().__init__()self.nc = nc  # number of classesself.nl = len(ch)  # number of detection layersself.reg_max = 16  # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)self.no = nc + self.reg_max * 4  # number of outputs per anchorself.stride = torch.zeros(self.nl)  # strides computed during buildc2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100))  # channelsself.cv2 = nn.ModuleList(nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()

nc:类别数
nl:检测层的数量,目标检测中为3。
ch:传入的图片通道尺寸,在yolov8n,图片大小为640*640时。这里的ch为(256,128,64)
no:两个卷积再拼接后输出通道数,为4×reg_max+nc
c2,c3:计算卷积层的通道数。
cv2,cv3:定义的卷积操作,以输出有关类别和选框的特征图。
dfl:通过将分布式的概率分布转化为单一的预测值

class DFL(nn.Module):def __init__(self, c1=16):"""Initialize a convolutional layer with a given number of input channels."""super().__init__()self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)x = torch.arange(c1, dtype=torch.float)self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))self.c1 = c1def forward(self, x):"""Applies a transformer layer on input tensor 'x' and returns a tensor."""b, c, a = x.shape  # batch, channels, anchorsreturn self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)

self.conv:创建了一个输入通道为16,输出为1,没有偏置项,不需要进行梯度更新的卷积层。
这样的权重设置实际上模拟了一个积分过程,将卷积操作变成了加权求和的形式。
x:1到15的整数。
self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1)):
这里使用nn.Parameter将重塑后的张量设置为模型的参数,并且参数不会被更新。
假设前向传播中,x的形状为(1, 64, 8400),下面解释下forword中的变化。
1,x.view(b, 4, self.c1, a): 这个操作是对x的形状进行重塑。self.c1是16(因为输入通道数是64,即4*self.c1),那么a是8400(代表了所有锚点的数量)。b是批次大小,这里为1。所以x.view(b, 4, self.c1, a)将x从(1, 64, 8400)重塑为(1, 4, 16, 8400)。在这个形状中,我们得到了每个锚点的每个坐标轴(x, y, 宽度, 高度)上的16个预测值(可能代表某种概率分布)。
2,transpose(2, 1): 这个操作交换第二维和第三维。在应用transpose之后,张量的形状变为(1, 16, 4, 8400)。这样做的目的是让每组概率分布的16个预测值连续地排列在一起,为后面的softmax运算做准备。
3,softmax(1): softmax函数应用于第一维(现在是16个预测值的这一维)。softmax确保了这16个值之和为1,转换为一个有效的概率分布,表示每个预测值的可能性。
4,self.conv(…): 这个操作将配置好的卷积层应用在进行了softmax操作的张量上。由于卷积层的权重已被设置为从0到15的整数,并且不更新权重(不进行梯度下降优化),这个步骤实际上是在计算期望值。卷积层将每个离散的概率值乘以其相应的索引(也就是权重),然后对结果进行求和,得到该坐标的预测值。
5,view(b, 4, a): 最后一步是将张量的形状从卷积操作后的(1, 1, 4, 8400)转换回(1, 4, 8400)。这样确保了最终的输出张量与每个坐标轴的预测值(x, y, 宽度, 高度)和所有锚点的数量对齐。
总的来说,dfl层就是对预测的坐标求加权期望值。将(1,64,8400)先变为(1,16,4,8400),然后对这16个通道求加权期望,变为(1,4,8400)即这8400个锚点中的每一个锚点,x,y,width,hight的加权平均值。
接下来是前向传播的过程。打印传入的x形状,发现通道数是64,128,256。

在这里插入图片描述
在这里插入图片描述
原因:Detect层接受15,18,21层的输入。原本通道数是1024,512,256。但是yolov8n还需要乘0.25。
在这里插入图片描述
在这里插入图片描述

经过cv2,通道数变为64,经过cv3通道数变为nc,我这里nc为2(二分类)。在经过cat拼接,在通道维度上拼接,所以x[i]的通道数变为66。

如果处于训练模式,就直接返回x。
否则执行下面的代码,将特征图列表x(1×66×40×40,1×66×80×80,1×66×20×20)传递给make_anchors()函数。make_anchors函数用于生成锚点(anchors),它通常用在目标检测网络中。每个锚点代表了特征图上的一个点,可以用来预测相对于该点的边界框。strides是这些特征图相对于原始图像的下采样步长。简单来说,生成了8400个锚点(40×40+80×80+20×20),变量为anchors,形状为1×2×8400)。同时生成了8400个步长,变量为strides,形状为1×8400。参数0.5表示每个锚点处于每个像素块的中央。

        if self.training:return xelif self.dynamic or self.shape != shape:self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))self.shape = shape

将xi按照2维度进行拼接,xi分别为1×66×40×40,1×66×80×80,1×66×20×20。拼接后的x_cat为1×66×8400

x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)

这段代码就是把x_cat进行拆分。box形状为1×64×8400,包含每个边界框的回归参数。cls形状为1×2×8400,会包含类别预测,2是因为我这里类别为2。

        if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'):  # avoid TF FlexSplitV opsbox = x_cat[:, :self.reg_max * 4]cls = x_cat[:, self.reg_max * 4:]else:box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)

dfl层就是对预测的坐标求加权期望值。将(1,64,8400)先变为(1,16,4,8400),然后对这16个通道求加权期望,变为(1,4,8400)即这8400个锚点中的每一个锚点,x1,y1,x2,y2的加权平均值。dist2bbox()函数的作用是将锚点x1,y1,x2,y2转换为x,y,width,hight的形式。最后在乘以步长,还原到原图的大小比例。

dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides

在这里插入图片描述

此代码片段的作用是在模型导出为 Tensorflow Lite (tflite) 或 Edge TPU 兼容格式时,对预测框 (dbox) 进行归一化处理。

 if self.export and self.format in ('tflite', 'edgetpu'):# Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5:# https://github.com/ultralytics/yolov5/blob/0c8de3fca4a702f8ff5c435e67f378d1fce70243/models/tf.py#L307-L309# See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695img_h = shape[2] * self.stride[0]img_w = shape[3] * self.stride[0]img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1)dbox /= img_size

此时y的形状为1×66×8400。代表有8400个锚点,每个锚点包含坐标框的x,y,width,hight,以及类别得分信息。

y = torch.cat((dbox, cls.sigmoid()), 1)

返回值,至此detect层结束。完整的预测,后续还需要进行一些处理。如进行非极大抑制,对这8400个锚点进行筛选

return y if self.export else (y, x)

另外,最后的bias_init()函数用于初始化一个目标检测模型中的Detect层的偏置。确保在训练开始时偏置值是基于合理假设的。这种方法的目标是为模型提供一个好的起点,并有助于加速训练过程中的收敛。

    def bias_init(self):"""Initialize Detect() biases, WARNING: requires stride availability."""m = self  # self.model[-1]  # Detect() module# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum())  # nominal class frequencyfor a, b, s in zip(m.cv2, m.cv3, m.stride):  # froma[-1].bias.data[:] = 1.0  # boxb[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2)  # cls (.01 objects, 80 classes, 640 img)

网络各层解读及detect层后的处理

关于backbone,head层,以及detect层参考下面博主的文章,讲的非常好。

链接: Yolov 8源码超详细逐行解读+ 网络结构细讲(自我用的小白笔记)
链接: 最细致讲解yolov8模型推理完整代码–(前处理,后处理)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/690011.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【EndNote20】Endnote20和word的一些操作

文章目录 前言一、如何导入参考文献到EndNote201.1.在谷歌学术或知网上下载文献1.2.将下载好的文件导入EndNote20(可批量导入)1.3.书籍如何导入 二、Word中加入参考文献 前言 做毕设时学习了EndNote20的一些使用方法,并在此慢慢做汇总。 一、如何导入参考文献到End…

【python】python入门之输入与进入交互模式的方法

目录 一次性输入(进入交互模式): 输入函数: 一次性输入(进入交互模式): 交互模式介绍:写一行读一行,不用print也可以显示出来 (当进行某些一次性计算或者纠错…

在石家庄有哪家券商证券公司可以免费开量化软件Ptrade、QMT

在石家庄有少数证券公司可以免费开量化软件QMT、Ptrade,如国金证券、广发证券等,之前要100万才可开通,现在只需满足资金50万的条件即可免费办理使用,详情找客户经理孙经理咨询。 证券佣金低价标准是“成本价”,默认佣金…

多线程——

目录 一、为什么要有多线程? 1、线程与进程 2、多线程的应用场景 3、小结​编辑 二、多线程中的两个概念(并发和并行) 1、并发 2、并行 3、小结 三、多线程的三种实现方式 1、继承Thread类的方式进行实现 2、实现Runnable接口的方…

为什么有的代理IP速度比较慢?

“为什么有的IP代理速度比较慢?”随着数字化时代的不断发展,代理服务成为了许多网络操作的关键环节。然而,有时我们可能会遇到IP代理速度慢的问题,这可能会对我们的网络操作产生影响。让我们一起揭开这个谜团,探寻其中…

EMQX Enterprise 5.4 发布:OpenTelemetry 分布式追踪、OCPP 网关、Confluent 集成支持

EMQX Enterprise 5.4.0 版本已正式发布! 新版本提供 OpenTelemetry 分布式追踪与日志集成功能,新增了开放充电协议 OCPP 协议接入能力,并为数据集成添加了 Confluent 支持。此外,新版本还进行了多项改进以及 BUG 修复&#xff0c…

AI提示工程实战:从零开始利用提示工程学习应用大语言模型【文末送书-19】

文章目录 背景什么是提示工程?从零开始:准备工作设计提示调用大语言模型 实际应用示例文字创作助手代码生成持续优化与迭代数据隐私与安全性可解释性与透明度总结 AI提示工程实战:从零开始利用提示工程学习应用大语言模型【文末送书-19】⛳粉…

SpringMVC 的参数绑定之list集合、Map

标签中name属性的值就是pojo类的属性名 参数绑定4 list [对象] <form action"teaupd.do" method"post"> <c:forEach items"${list}" var"tea" varStatus "status"> 教师编号&#xff1a;<input…

希尔排序算法

目录 ShellSort希尔排序 整体思路 图解分析 【1】预排序 单组排序 多组并排 【2】直接插入排序 关于gap取值 总代码实现 时间复杂度 ShellSort希尔排序 希尔排序法又称缩小增量法。 希尔排序法的基本思想是&#xff1a;先选定一个整数&#xff0c;把待排序文件中所有…

Vue3快速上手(七) ref和reactive对比

一、ref和reactive对比 表格形式更加直观吧&#xff1a; 项目refreactive是否支持基本类型支持不支持是否支持对象类型支持支持对象类型是否支持属性直接赋值不支持&#xff0c;需要.value支持是否支持直接重新分配对象支持&#xff0c;因为操作的.value不支持&#xff0c;需…

120 Linux C++ 通讯架构实战 nginx整体结构,nginx进程模型,nginx调整worker进程数量,nginx重载配置文件,热升级,关闭

一 nginx整体结构 1.1 master进程和worker进程概览&#xff08;父子关系&#xff09; 启动nginx&#xff0c;看到了master进程和 worker 进程。 ps -ef | grep nginx 第一列&#xff1a;进程所属的用户id 第二列&#xff1a;进程ID&#xff0c;也叫做PID&#xff0c;用来唯…

@arco.design Modal renderContent 增加样式

方式一&#xff1a;通过 h 函数 import { h } from vueMessage.error({content: () > {return h(div, {}, [手机号 , h(span, { style: { color: red } }, staffPhone), 已存在])}, })方式二&#xff1a;通过 jsx 方式 注意&#xff1a;lang 需要改为 jsx 或者 tsx <s…

OSQP文档学习

OSQP官方文档 1 QSQP简介 OSQP求解形式为的凸二次规划&#xff1a; x ∈ R n x∈R^n x∈Rn&#xff1a;优化变量 P ∈ S n P∈S^n_ P∈Sn​&#xff1a;半正定矩阵 特征 &#xff08;1&#xff09;高效&#xff1a;使用了一种自定义的基于ADMM的一阶方法&#xff0c;只需…

关于Sora的一些紧迫问题...

OpenAI Sora 概述 OpenAI最新的创新&#xff0c;Sora&#xff0c;在人工智能领域开辟了新的天地。Sora是一个文本到视频的扩散模型&#xff0c;可以将文本描述转化为逼真的视频内容。它解决了一个重大的技术挑战&#xff0c;即在视频中保持主体的一致性&#xff0c;即使它们暂…

Java 线程池的基本操作

Java 线程池的基本操作 package com.zhong.thread.threadpool;import java.util.concurrent.*;/*** ClassName : ThreadPool* Description : 线程池的基本操作* Author : zhx* Date: 2024-02-19 18:03*/ public class ThreadPool {public static void main(String[] args) {// …

C语言每日一题(59)左叶子之和

题目链接 力扣网404 左叶子之和 题目描述 给定二叉树的根节点 root &#xff0c;返回所有左叶子之和。 示例 1&#xff1a; 输入: root [3,9,20,null,null,15,7] 输出: 24 解释: 在这个二叉树中&#xff0c;有两个左叶子&#xff0c;分别是 9 和 15&#xff0c;所以返回 2…

基于SpringBoot的高校竞赛管理系统

基于SpringBoot的高校竞赛管理系统的设计与实现~ 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringBootMyBatis工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 系统展示 主页 个人中心 管理员界面 老师界面 摘要 高校竞赛管理系统是为了有效管理学校…

K8s进阶之路-命名空间级-服务发现 :

服务发现&#xff1a; Service&#xff08;东西流量&#xff09;&#xff1a;集群内网络通信、负载均衡&#xff08;四层负载&#xff09;内部跨节点&#xff0c;节点与节点之间的通信&#xff0c;以及pod与pod之间的通信&#xff0c;用Service暴露端口即可实现 Ingress&#…

Vscode python pyside6 制作视频播放器

一、界面如下 包含控件 qcombox、qtablewidget、qpushbotton、qverticalslider 二、运行代码 media_player.py import sysfrom PySide6 import QtWidgets from PySide6.QtWidgets import * from PySide6.QtMultimedia import * from PySide6.QtMultimediaWidgets import QVi…

林浩然与杨凌芸的Java List大冒险

林浩然与杨凌芸的Java List大冒险 Lin Haoran and Yang Lingyun’s Java List Adventure 在一个阳光明媚的日子&#xff0c;程序员界的“侠客”林浩然和他那聪明伶俐的同事兼好友杨凌芸正在Java王国里进行一场别开生面的大冒险。这次他们的目标是征服两个强大的List家族成员——…