ByteTrack多目标跟踪——yolox_model代码详解

文章目录

  • yolox_model
  • YOLOPAFPN
  • YOLOXHead
    • model
    • 损失计算
      • 初步筛选
      • SimOTA 求解
  • 附:网络结构
    • Cls head
      • Cls_convs
      • Cls_preds
    • Reg head
      • Reg_convs
      • Reg_preds
    • Obj head
      • Obj_preds

yolox_model

yolox_model主要包括以下几个文件:yolox.pyyolo_pafpn.py以及yolo_head.py

train时将图像以及label传入模型:

self.model(inps, targets) 

传入模型之后进入 yolox.py,并且依次经过 YOLOXPAFPN 以及 YOLOXHead

YOLOPAFPN

fpn_outs = self.backbone(x)

进入 yolo_pafpn,py
这里可以参考我的另一篇文章ByteTrack多目标跟踪——YOLOX详解

举例:输入为 torch.Size ([1, 3, 896, 1600])
模型输出为:
{‘dark 3’: (1,320,112,200),
‘dark 4’: (1,640,56,100),
‘dark 5’: (1,1280,28,50)}
通过 PAFPN 进行融合后得到:
{‘0’: (1,320,112,200),
‘1’: (1,640,56,100),
‘2’: (1,1280,28,50)}

YOLOXHead

if self.training:assert targets is not Noneloss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg, settings = self.head(       fpn_outs, targets, x)outputs = {"total_loss": loss,"iou_loss": iou_loss,"l1_loss": l1_loss,"conf_loss": conf_loss,"cls_loss": cls_loss,     "num_fg": num_fg,"settings": settings            }
else:outputs = self.head(fpn_outs)

进入 yolo_head

model

总共有四个分支:

  • Cls_output:目标框的类别,预测分数。因为只有行人一个类别,所以大小为 1。
  • Obj_output:判断目标框是前景还是背景,大小也为 1。
  • Reg_output:对目标框的坐标信息 (x,y,w,h)进行预测,大小为 4。

具体网络结构参见附录部分

将三个分支输出结果合并

output = torch.cat([reg_output, obj_output, cls_output], 1) 

进入 self.get_output_and_grid 函数中

output, grid = self.get_output_and_grid(output, k, stride_this_level, xin[0].type())

主要是创建特征图网络坐标点,并把神经网络前向推理的 bbox 投影输入图像的尺寸上

损失计算

输入:

  • Imgs:一个 batch 的图像
  • X_shifts、y_shifts: 特征图每个网格 grid 的 xy 坐标
  • Expanded_strides: 不同尺寸的特征输出与输入图像之间缩小的倍数
  • Labels: ground_truth 的类别号与 bbox (一个 batch 图像中的人工标注框与类别)
  • Torch.Cat (outputs, 1): 对三个尺度的输出进行合并
if self.training:return self.get_losses(imgs,x_shifts,y_shifts,expanded_strides,labels,torch.cat(outputs, 1),origin_preds,dtype=xin[0].dtype,darkfpn=darkfpn,)

获取 gt 数量

nlabel = (label_cut.sum(dim=2) > 0).sum(dim=1) 

在计算损失时,yolox 需要做标签分配,这是 yolox 的重要思想。其中涉及的函数为 self.get_assignments

输入:

  • Batch_idx: 批图像的索引
  • Num_gt: 一幅图像存在的目标数目;
  • Total_num_anchors: 总的 anchor 数目,yolox 提取的最后特征,每个方格表示一个 anchor
  • Gt_bboxes_per_image: 一幅图像人工标注的框 box 坐标;
  • Gt_classes: 一幅图像的标注框类别编号
  • Bboxes_preds_per_image: 一幅图像预测的 bbox
  • Expanded_strides: 三个尺度的每个特征方格相对于输入图的缩放像素[8,…],[16,…],[32,…]
  • X_shifts, y_shifts: 每个特征方格位置偏移量组成的向量 (一个 batch)
  • Cls_preds: 类别预测概率, 一个 batch 数据。[batch_num, anchors_all, num_cls]
  • Bbox_preds:目标框的预测
  • Obj_preds: 目标置信度概率
  • Labels: yolo 人工标注框,一个 batch 数据。
  • Imgs: 一个 batch 的图像数据
  • Gt_ids:gt 中 ID

输出:

  • Gt_matched_classes: 标签分配后,每列候选框预测目标的编号
  • Fg_mask: 初步筛选中,in_boxes 与 in_center 的并集[29400]
  • Pred_ious_this_matching: 由标签分配的 mask, 筛选真实框与预测框构成的 IoU 矩阵对应的 IoU 值
  • Matched_gt_inds: matrix_matching 矩阵,存在候选框的位置 idx
  • Num_fg_img: 标签分配完成后,总共存在的候选框个数 (matrix_matching 每列保证一个候选框)
(gt_matched_classes,                 # [matched_anchor], class of matched anchorsfg_mask,                            # [n_anchors], .sum()=matched_anchor, to mask out unmatched anchorspred_ious_this_matching,            # [matched_anchor], IoU of matched anchorsmatched_gt_inds,                    # [matched_anchor], index of gts for each matched anchornum_fg_img,                         # [1], matched_anchor
) = self.get_assignments(  # noqabatch_idx,num_gt,total_num_anchors,gt_bboxes_per_image,gt_classes,bboxes_preds_per_image,expanded_strides,x_shifts,y_shifts,cls_preds,bbox_preds,obj_preds,labels,imgs,gt_ids,)

初步筛选

筛选方式参考我的另一篇文章ByteTrack多目标跟踪——YOLOX详解

通过 get_in_boxes_info() 得到:
Fg_mask: in_boxes 与 in_center 的并集,
Is_in_boxes_and_center 为交集

fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt)

单幅图像,根据正样本锚框的初步筛选

bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]  # [3717,4]
cls_preds_ = cls_preds[batch_idx][fg_mask] #[3717,1]
obj_preds_ = obj_preds[batch_idx][fg_mask] #[3717,1]
num_in_boxes_anchor = bboxes_preds_per_image.shape[0]  
# 正样本锚框筛选的个数

即此时正样本框数量为 num_in_boxes_anchor个。

SimOTA 求解

先计算 bbox 的边界框损失与类别损失

pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)     # [gt_num, matched_anchor]
gt_cls_per_image = (                F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1)
)
pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)   
with torch.cuda.amp.autocast(enabled=False):cls_preds_ = (      # [gt_num, matched_anchor, 1]cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()         * obj_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()       )pair_wise_cls_loss = F.binary_cross_entropy(        # [gt_num, matched_anchor]cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1)

然后计算 Cost 成本

# lambda=3.0, 设置anchor box的中心,不在以中心点构建框与目标框中的cost=100000
cost = (pair_wise_cls_loss + 3.0 * pair_wise_ious_loss + 100000.0 * (~is_in_boxes_and_center)) 

最后进行 SimOTA 求解

### SimOTA, 求近似最优解 ###
# 输入:
#      cost: 通过回归损失和类别损失计算得到的cost
#      pair_wise_ious: size为[num_gt,num_in_boxes_anchor]的IoU计算,即所有真实框与预测框的IoU
#      gt_classes: 一幅图像ground truth标注框的类别编号向量
#      num_gt: 一幅图像的标注框个数
#      fg_mask: 根据中心点与目标框初步筛选并集掩码# 输出:
#      num_fg: 标签分配完成后,总共存在的候选框个数(matrix_matching每列保证一个候选框)
#      gt_matched_classes: 标签分配后,每列候选框预测目标的编号
#      gt_matched_ids
#      pred_ious_this_matching: 由标签分配的mask, 筛选真实框与预测框构成的IoU矩阵对应的IoU值
#      matched_gt_inds: matrix_matching矩阵,存在候选框的位置idx
(num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds) = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)

附:网络结构

Cls head

Cls_convs

ModuleList((0): Sequential((0): BaseConv((conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): BaseConv((conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True)))(1): Sequential((0): BaseConv((conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): BaseConv((conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True)))(2): Sequential((0): BaseConv((conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): BaseConv((conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True)))
)

Cls_preds

ModuleList((0): Conv2d(320, 1, kernel_size=(1, 1), stride=(1, 1))(1): Conv2d(320, 1, kernel_size=(1, 1), stride=(1, 1))(2): Conv2d(320, 1, kernel_size=(1, 1), stride=(1, 1))
)

Reg head

Reg_convs

ModuleList((0): Sequential((0): BaseConv((conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): BaseConv((conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True)))(1): Sequential((0): BaseConv((conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): BaseConv((conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True)))(2): Sequential((0): BaseConv((conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): BaseConv((conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True)))
)

Reg_preds

ModuleList((0): Conv2d(320, 4, kernel_size=(1, 1), stride=(1, 1))(1): Conv2d(320, 4, kernel_size=(1, 1), stride=(1, 1))(2): Conv2d(320, 4, kernel_size=(1, 1), stride=(1, 1))
)

Obj head

Obj_preds

ModuleList((0): Conv2d(320, 1, kernel_size=(1, 1), stride=(1, 1))(1): Conv2d(320, 1, kernel_size=(1, 1), stride=(1, 1))(2): Conv2d(320, 1, kernel_size=(1, 1), stride=(1, 1))
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/776397.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[AIGC] MySQL存储引擎详解

MySQL 是一种颇受欢迎的开源关系型数据库系统,它的强大功能、灵活性和开放性赢得了用户们的广泛赞誉。在 MySQL 中,有一项特别重要的技术就是存储引擎。在本文中,我们将详细介绍什么是存储引擎,以及MySQL中常见的一些存储引擎。 文…

申请GeoTrust数字证书

GeoTrust介绍: 大家应该都不陌生,作为最老资格的一批国际大牌证书,GeoTrust的品牌效益和使用群体非常庞大。在数字证书领域也是当之无愧的龙头地位,作为Symantec和Digicert的子品牌,证书安全性能方面毋庸置疑&#xf…

IP SSL证书注册流程

使用IP地址申请SSL证书,需要用公网IP地址申请,申请之前确保直接的IP地址可以开放80或者443端口两者选择1个就好,端口不需要一直开放,只要认证的几分钟内开放就可以了,然后IP地址根目录可以上传txt文件。 IP SSL证书认…

Codeforces Round 800 (Div. 1)C. Keshi in Search of AmShZ 反向dijkstra,并附带权值

Problem - C - Codeforces 目录 题意: 思路: 答疑: 1.为什么反向做呢? 2.为什么是到达点的剩余度数呢? 3.相同路是否可以去重,用个set? 4.如果有多条路相同呢? 参考代码&am…

【SecretFlow——SPU基础】

1.SPU基础 SPU设备在SecretFlow中负责执行MPC计算。 2.代码解读 2.1 创建设备 import secretflow as sf # 如果存在secretflow,先关闭已经存在的环境 sf.shutdown() # 初始化四个参与方 sf.init([alice, bob, carol, dave], addresslocal) # 寻找未占用的端口来…

【YOLOV5 入门】——detect.py简单解析模型检测基于torch.hub的检测方法

声明:笔记是毕设时根据B站博主视频学习时自己编写,请勿随意转载! 一、打开detect.py(文件解析) 打开上节桌面创建的yolov5-7.0文件夹里的detect.py文件(up主使用的是VScode,我这里使用pycharm…

NLP深入学习:结合源码详解 BERT 模型(三)

文章目录 1. 前言2. 预训练2.1 modeling.BertModel2.1.1 embedding_lookup2.1.2 embedding_postprocessor2.1.3 transformer_model 2.2 get_masked_lm_output2.3 get_next_sentence_output2.4 训练 3. 参考 1. 前言 前情提要: 《NLP深入学习:结合源码详…

PyQt5开发——QCheckBox 复选框用法与代码示例

1. 复选框 QCheckBox 是 Qt 框架中的一个控件,用于在界面中表示一个可以被选中或取消选中的复选框。它通常用于允许用户在多个选项之间进行选择。在 Python 中使用 PyQt 或 PySide 开发 GUI 应用程序时,可以使用 QCheckBox 控件来实现复选框。 2.基本用…

[ Linux ] git工具的基本使用(仓库的构建,提交)

1.安装git yum install -y git 2.打开Gitee,创建你的远程仓库,根据提示初始化本地仓库(这里以我的仓库为例) 新建好仓库之后跟着网页的提示初始化便可以了 3.add、commit、push三板斧 git add . //add仓库新增(变…

企业数字化转型:聊聊数据思维!

笔者曾在《深入聊一聊企业数字化转型这个事儿》 一文中给出了数字化转型的定义,即:通过应用数字化技术来重塑企业的信息化环境和业务过程。本质上来讲,企业数字化转型,不仅是技术方面的升级,更是企业文化、思维方式的转…

【计算机考研】408到底有多难?

你真以为大家是学不会408吗? 不是!单纯是因为时间不够!!! 再准确一些就是不会分配时间 408的知识其实并不难,要说想上130那确实有难度,但是100在时间充裕的情况下还是可以做到的 我本人是双…

非wpf应用程序项目【类库、用户控件库】中使用HandyControl

文章速览 前言参考文章实现方法1、添加HandyControl包;2、添加资源字典3、修改资源字典内容坚持记录实属不易,希望友善多金的码友能够随手点一个赞。 共同创建氛围更加良好的开发者社区! 谢谢~ 前言 wpf应用程序中,在入口项目中存在App.xaml文件,在这个文件中加上对各个…

Linux之进程控制进程终止进程等待进程的程序替换替换函数实现简易shell

文章目录 一、进程创建1.1 fork的使用 二、进程终止2.1 终止是在做什么?2.2 终止的3种情况&&退出码的理解2.3 进程常见退出方法 三、进程等待3.1 为什么要进行进程等待?3.2 取子进程退出信息status3.3 宏WIFEXITED和WEXITSTATUS(获取…

全球首位AI程序员Devin诞生,以此谈谈AI对程序员的影响

一、简介 全球首位 AI 程序员 Devin 是由初创公司 Cognition AI 创造的。这家公司成立仅四个月,却已经引起了广泛关注。 Devin作为人工智能的代表,将展示出人工智能在编程领域的潜力和能力,激发程序员探索和应用人工智能技术的兴趣。这将可…

NanoMQ的安装与部署

本文使用docker进行安装,因此安装之前需要已经安装了docker 拉取镜像 docker pull emqx/nanomq:latest 相关配置及密码认证 创建目录/usr/local/nanomq/conf以及配置文件nanomq.conf、pwd.conf # # # # MQTT Broker # # mqtt {property_size 32max_packet_siz…

6、ChatGLM3-6B 部署实践

一、ChatGLM3-6B介绍与快速入门 ChatGLM3 是智谱AI和清华大学 KEG 实验室在2023年10月27日联合发布的新一代对话预训练模型。ChatGLM3-6B 是 ChatGLM3 系列中的开源模型,免费下载,免费的商业化使用。 该模型在保留了前两代模型对话流畅、部署门槛低等众多…

官网怎么发布新文章,怎么在官方网站上发布新内容

随着企业和组织越来越重视官方网站的建设和更新,发布新内容成为了官方网站管理的重要一环。本文将探讨在官方网站上发布新内容的步骤和方法,以及如何确保发布的内容质量和效果。 1. 确定发布内容 在发布新内容之前,首先需要确定发布的内容。…

精品凉拌菜系列热卤系列课程

这一系列课程涵盖精美凉拌菜和美味热卤菜的制作技巧。学员将学习如何选材、调味和烹饪,打造口感丰富、色香俱佳的菜肴。通过实践训练,掌握独特的烹饪技能,为家庭聚餐或职业厨艺提升增添亮点。 课程大小:6.6G 课程下载&#xff1…

windows安装R4.3.3

官网地址The Comprehensive R Archive Network 下载后得到exe安装,默认安装到了C:\Program Files\R, 因为之前已经安装了4.2.3,所以新建了文件夹为4.3.3,两者互不干扰 安装完毕后,打开rstudio,设置 然后重…

基于springboot+vue+Mysql的酒店管理系统

开发语言:Java框架:springbootJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:…