YOLOV8 |搞懂检测头

代码:

yaml结构的最后一层,接了前面三个层的,有3个检测头:
 

# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 6], 1, Concat, [1]] # cat backbone P4- [-1, 3, C2f, [512]] # 12- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 4], 1, Concat, [1]] # cat backbone P3- [-1, 3, C2f, [256]] # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]] # cat head P4- [-1, 3, C2f, [512]] # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]] # cat head P5- [-1, 3, C2f, [1024]] # 21 (P5/32-large)  - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)

检测头代码部分的关键点:

1. 初始化 (__init__ 方法)

初始化方法设置了模型的参数和网络结构。

def __init__(self, nc=80, ch=()):"""Initializes the YOLOv8 detection layer with specified number of classes and channels."""super().__init__()self.nc = nc  # number of classesself.nl = len(ch)  # number of detection layersself.reg_max = 16  # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)self.no = nc + self.reg_max * 4  # number of outputs per anchorself.stride = torch.zeros(self.nl)  # strides computed during buildc2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100))  # channelsself.cv2 = nn.ModuleList(nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()if self.end2end:self.one2one_cv2 = copy.deepcopy(self.cv2)self.one2one_cv3 = copy.deepcopy(self.cv3)

关键点:

  • self.nc: 模型需要预测的类别数量。
  • self.nl: 检测层的数量,通常对应于特征图的数量。
  • self.reg_max: 用于提高边界框回归精度的 DFL 通道数。
  • self.no: 每个锚点的输出数量(类别数 + 回归通道数)。
  • c2 和 c3: 中间层的通道数,分别用于位置回归和类别分类。
  • self.cv2 和 self.cv3: 位置回归和类别分类的卷积层序列。
  • self.dfl: DFL 模块,用于将多通道表示转换为实际坐标。
  • 如果是端到端模式(end2end),复制 cv2 和 cv3 以创建 one2one_cv2 和 one2one_cv3

2. 前向传播 (forward 方法)

前向传播方法负责处理输入数据并生成最终的检测结果。

def forward(self, x):"""Concatenates and returns predicted bounding boxes and class probabilities."""if self.end2end:return self.forward_end2end(x)for i in range(self.nl):x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)if self.training:  # Training pathreturn xy = self._inference(x)return y if self.export else (y, x)

返回的目的

这些拼接后的特征图 x 会被传递给损失函数计算模块,用于计算损失并反向传播梯度以更新网络参数。具体步骤如下:

  • 关键点:
    • 如果是端到端模式(end2end),调用 forward_end2end
    • 否则,对每个检测层进行前向传播,将位置回归和类别分类的结果拼接在一起。
    • 在训练模式下,直接返回拼接后的结果。
    • 在推理模式下,调用 _inference 方法进行后处理,并返回最终的检测结果。
    • 对于每个检测层(self.nl 个),将 cv2 和 cv3 的输出拼接在一起。
    • cv2[i] 是位置回归的卷积层序列,cv3[i] 是类别分类的卷积层序列。
    • 拼接的结果是一个包含边界框回归和类别分类信息的特征图。
  • 训练模式下的返回

    在训练模式下,直接返回拼接后的结果 x。这些结果是 cv2cv3 的输出拼接在一起的特征图。具体来说:

  • cv2[i] 的输出是边界框回归的特征图。
  • cv3[i] 的输出是类别分类的特征图。
  • 这两个特征图在通道维度上拼接在一起,形成一个包含边界框和类别信息的特征图。
  • 损失计算:

    • 拼接后的特征图 x 包含了预测的边界框和类别信息。
    • 使用真实标签(ground truth)与这些预测结果计算损失。通常包括边界框回归损失(如 DFL 损失)和分类损失(如交叉熵损失)。
  • 反向传播:

    • 根据计算出的损失,通过反向传播算法计算每个参数的梯度。
    • 反向传播从损失函数开始,逐层计算每一层的梯度。
  • 参数更新:

    • 使用优化器(如 SGD、Adam 等)根据计算出的梯度更新网络中的权重和偏置。
    • 优化器使用学习率来控制每次更新时参数变化的步长

3. 端到端前向传播 (forward_end2end 方法)

这个方法在端到端模式下使用,生成两个检测结果:one2manyone2one

def forward_end2end(self, x):"""Performs forward pass of the v10Detect module.Args:x (tensor): Input tensor.Returns:(dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections.If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately."""x_detach = [xi.detach() for xi in x]one2one = [torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)]for i in range(self.nl):x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)if self.training:  # Training pathreturn {"one2many": x, "one2one": one2one}y = self._inference(one2one)y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)return y if self.export else (y, {"one2many": x, "one2one": one2one})
  • 关键点:
    • 对输入特征图进行分离,生成 one2one 和 one2many 的检测结果。
    • 在训练模式下,返回两个检测结果。
    • 在推理模式下,对 one2one 的结果进行后处理,并返回最终的检测结果。

4. 推理路径 (_inference 方法)

推理路径方法解码预测的边界框和类别概率。

def _inference(self, x):"""Decode predicted bounding boxes and class probabilities based on multiple-level feature maps."""shape = x[0].shape  # BCHWx_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)if self.dynamic or self.shape != shape:self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))self.shape = shapeif self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}:  # avoid TF FlexSplitV opsbox = x_cat[:, : self.reg_max * 4]cls = x_cat[:, self.reg_max * 4 :]else:box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)if self.export and self.format in {"tflite", "edgetpu"}:grid_h = shape[2]grid_w = shape[3]grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)norm = self.strides / (self.stride[0] * grid_size)dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])else:dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.stridesreturn torch.cat((dbox, cls.sigmoid()), 1)
  • 关键点:
    • 将多个特征图的预测结果拼接在一起。
    • 计算并更新锚点和步长。
    • 分割出边界框和类别概率。
    • 使用 DFL 将多通道表示转换为实际坐标。
    • 返回解码后的边界框和类别的 sigmoid 概率.

5. 解码边界框 (decode_bboxes 方法)

这个方法将编码的边界框转换为实际的边界框坐标。

def decode_bboxes(self, bboxes, anchors):"""Decode bounding boxes."""return dist2bbox(bboxes, anchors, xywh=not self.end2end, dim=1)
  • 关键点:
    • 使用 dist2bbox 函数将编码的边界框转换为实际的边界框坐标。

6. 后处理 (postprocess 方法)

后处理方法选择最高分的边界框并返回最终检测结果。

@staticmethod
def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):"""Post-processes the predictions obtained from a YOLOv10 model.Args:preds (torch.Tensor): The predictions obtained from the model. It should have a shape of (batch_size, num_boxes, 4 + num_classes).max_det (int): The maximum number of detections to keep.nc (int, optional): The number of classes. Defaults to 80.Returns:(torch.Tensor): The post-processed predictions with shape (batch_size, max_det, 6),including bounding boxes, scores and cls."""assert 4 + nc == preds.shape[-1]boxes, scores = preds.split([4, nc], dim=-1)max_scores, index = torch.topk(scores.amax(dim=-1), min(max_det, scores.shape[1]), axis=-1)index = index.unsqueeze(-1)boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1]))scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1]))scores, index = torch.topk(scores.flatten(1), max_det, axis=-1)labels = index % ncindex = index // ncboxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1)
  • 关键点:
    • 将预测结果分割成边界框和类别分数。
    • 选择最高分的边界框。
    • 返回包含边界框、得分和类别的最终检测结果。

总结

  • 初始化:设置模型参数和网络结构。
  • 前向传播:根据不同的模式(普通或端到端)进行前向传播。
  • 推理路径:解码边界框和类别概率。
  • 解码边界框:将编码的边界框转换为实际坐标。
  • 后处理:选择最高分的边界框并返回最终检测结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/58421.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

池化层笔记

池化层 文章目录 池化层二维池化层超参数池化层的分类代码实现填充和步幅 多个通道 总结 卷积对位置敏感,可以检测垂直边缘。需要有一定程度的平移不变性,而在平时图片的拍摄,会因为图片的照明,物体位置,比例&#xff…

大数据-191 Elasticsearch - ES 集群模式 配置启动 规划调优

点一下关注吧!!!非常感谢!!持续更新!!! 目前已经更新到了: Hadoop(已更完)HDFS(已更完)MapReduce(已更完&am…

mysql 5.7实现组内排序(连续xx天数)

需求:查询出连续登录的用户及其连续登录的天数 我先说一下思路:要实现连续登录的判断,可以找一下他们之间的规律。这里我拿一个用户来说,如果这个用户在1、2、3号都有登录记录,可以对这个用户的数据按照时间排序&…

J3学习打卡

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 DensNet模型 import matplotlib.pyplot as plt import tensorflow as tf from tensorflow.keras import layers, models, initializersclass DenseLayer(lay…

基于微信小程序的小区管理系统设计与实现(lw+演示+源码+运行)

摘 要 社会发展日新月异,用计算机应用实现数据管理功能已经算是很完善的了,但是随着移动互联网的到来,处理信息不再受制于地理位置的限制,处理信息及时高效,备受人们的喜爱。所以各大互联网厂商都瞄准移动互联网这个潮…

随机变量、取值、样本和统计量之间的关系

1. 随机变量 (Random Variable) 随机变量是用来量化随机现象结果的一种数学工具。随机变量是一个函数,它将实验结果映射到数值。随机变量可以是离散的或连续的。 离散随机变量:取有限或可数无限个值。例如,掷骰子的结果。连续随机变量&…

Matlab实现蚁群算法求解旅行商优化问题(TSP)(理论+例子+程序)

一、蚁群算法 蚁群算法由意大利学者Dorigo M等根据自然界蚂蚁觅食行为提岀。蚂蚁觅食行为表示大量蚂蚁组成的群体构成一个信息正反馈机制,在同一时间内路径越短蚂蚁分泌的信息就越多,蚂蚁选择该路径的概率就更大。 蚁群算法的思想来源于自然界蚂蚁觅食&a…

给哔哩哔哩bilibili电脑版做个手机遥控器

前言 bilibili电脑版可以在电脑屏幕上观看bilibili视频。然而,电脑版的bilibili不能通过手机控制视频翻页和调节音量,这意味着观看视频时需要一直坐在电脑旁边。那么,有没有办法制作一个手机遥控器来控制bilibili电脑版呢? 首先…

JavaEE初阶---网络原理之TCP篇(二)

文章目录 1.断开连接--四次挥手1.1 TCP状态1.2四次挥手的过程1.3time_wait等待1.4三次四次的总结 2.前段时间总结3.滑动窗口---传输效率机制3.1原理分析3.2丢包的处理3.3快速重传 4.流量控制---接收方安全机制4.1流量控制思路4.2剩余空间大小4.3探测包的机制 5.拥塞控制---考虑…

【C语言刷力扣】3216.交换后字典序最小的字符串

题目: 解题思路: 字典序最小的字符串:是指按照字母表顺序排列最前的字符串。即字符串在更靠前的位置出现比原字符串对应字符在字母表更早出现的字符。 枚举数组元素,尽早将较小的同奇偶的相邻字符交换。 char* getSmallestString…

Java:Map和Set练习

目录 查找字母出现的次数 只出现一次的数字 坏键盘打字 查找字母出现的次数 这道题的思路在后面的题目过程中能用到,所以先把这题给写出来 题目要求:给出一个字符串数组,要求输出结果为其中每个字符串及其出现次数。 思路:我…

【宠粉赠书】大模型项目实战:多领域智能应用开发

在当今的人工智能与自然语言处理领域,大型语言模型(LLM)凭借其强大的生成与理解能力,正在广泛应用于多个实际场景中。《大模型项目实战:多领域智能应用开发》为大家提供了全面的应用技巧和案例,帮助开发者深…

【商汤科技-注册/登录安全分析报告】

前言 由于网站注册入口容易被黑客攻击,存在如下安全问题: 暴力破解密码,造成用户信息泄露短信盗刷的安全问题,影响业务及导致用户投诉带来经济损失,尤其是后付费客户,风险巨大,造成亏损无底洞…

Nginx防盗链配置

1. 什么是盗链? 盗链是指服务提供商自己不提供服务的内容,通过技术手段绕过其它有利益的最终用户界面(如广告),直接在自己的网站上向最终用户提供其它服务提供商的服务内容,骗取最终用户的浏览和点击率。受益者不提供…

Oracle+11g+笔记(8)-备份与恢复机制

Oracle11g笔记(8)-备份与恢复机制 8、备份与恢复机制 8.1 备份与恢复的方法 数据库的备份是对数据库信息的一种操作系统备份。这些信息可能是数据库的物理结构文件,也可能是某一部分数 据。在数据库正常运行时,就应该考虑到数据库可能出现故障&#…

基于Multisim的篮球比赛电子记分牌设计与仿真

一、设计任务与要求 设计一个符合篮球比赛规则的记分系统。 (1)有得1分、2分和3分的情况,电路要具有加、减分及显示的功能。 (2)有倒计时时钟显示,在“暂停时间到”和“比赛时间到”时,发出声光…

易友BOM管理软件

易友BOM管理软件介绍 易友BOM管理软件是一款功能齐全、操作简便、安全可靠的BOM管理系统。它为企业提供了多方面的BOM管理解决方案,帮助企业提高生产效率、降低成本、增强灵活性并提升竞争力。制造企业,都可以通过易友BOM管理软件来实现BOM管理的优化和…

【模型学习之路】手写+分析bert

手写分析bert 目录 前言 架构 embeddings Bertmodel 预训练任务 MLM NSP Bert 后话 netron可视化 code2flow可视化 fine tuning 前言 Attention is all you need! 读本文前,建议至少看懂【模型学习之路】手写分析Transformer-CSDN博客。 毕竟Bert是tr…

不用求人,4个方法快速恢复小米手机删除短信

手机短信作为我们日常办理事情的重要验收通道,往往承载着许多重要的信息。然而,由于各种原因,我们可能会不小心删除了重要的短信。那么,小米手机用户如何恢复这些被删除的短信呢?接下来,我们将分点为您详细…

爆肝整理14天AI工具宝藏合集(三)

🛠️以下是我为大家整理的AI工具宝藏合集(三): 💡AI搜索 1️⃣ 天工AI搜索 2️⃣ 秘塔AI搜索 3️⃣ 夸克AI搜索 4️⃣ 开搜AI搜索 💡 AI视频 1️⃣ 可灵AI 2️⃣ 即梦AI 3️⃣ Vidu 4️⃣ Stable Video …