动态标签分配 - 以 Nanodet-plus 中的代码为例

标签分配

部分内容参考自:https://www.bilibili.com/video/BV1ge41117va

简单介绍一些特点,主要结合动态标签分配的一个实例来看

从更高抽象的层面理解 assign

所有用于最终检测的特征图上的所有 point 都具备学习并预测目标的能力,在给定一幅图像及其目标 gt bbox 的情况下,为每个目标 gt bbox 选择恰当的特征图 point 进行学习预测的过程就是分配。

个人理解就是:

每个 anchor 锚点都有预测 bbox 的能力,对一张图像来说,将先验框 gt_bbox 与合适的锚点 points 进行匹配,训练 points 来预测。

一个锚点 point 分配给一个 gt_bbox (即标注框),但是一个 gt_bbox 可以和多个锚点 points 进行匹配

前提:

仅当感受野中心命中 gt bbox 的 point 才有可能被选用来预测这个 bbox

个人理解是:anchor 锚点要位于 gt_bbox 中才能被用来预测

两种类型的匹配机制:

基于规则的分配、自动分配(与网络的输出有关)

目标匹配是 One-stage Anchor-free 检测器核心中的核心!!!

实例分析:

NanoDet-Plus 使用的 DynamicSoftLabelAssigner 来分析,这里主要分析动态分配:

class DynamicSoftLabelAssigner(BaseAssigner):"""Computes matching between predictions and ground truth withdynamic soft label assignment.使用动态软标签分配计算预测与真实值之间的匹配Args:topk (int): Select top-k predictions to calculate dynamic kbest matchs for each gt. Default 13.为每个 gt 选择 k 个最佳预测来计算动态 k 最佳匹配。默认值为13iou_factor (float): The scale factor of iou cost. Default 3.0.IoU 代价的缩放因子。默认值为3.0ignore_iof_thr (int): whether ignore max overlaps or not.Default -1 (1 or -1).是否忽略最大重叠"""def __init__(self, topk=13, iou_factor=3.0, ignore_iof_thr=-1):self.topk = topkself.iou_factor = iou_factorself.ignore_iof_thr = ignore_iof_thr

以 gt 开头的变量为真实标注信息

num_priors 即锚点 point 的数量,num_gts 即一幅图中真实标注的数量,其中 decoded_bboxes 为根据 preds 信息预测的 bboxes

将 point 和 gt 进行匹配,还是刚才所说的,一个 point 对应一个 gt,一个 gt 可以对应多个 point

简单总结一下过程(具体详细的内容看代码及注释):

以下用()包住的内容为张量尺寸大小,对于理解也十分有帮助


重点部分

首先,初步选出可能匹配的锚点,在所有锚点(num_priors)中选出 gt_bboxes 包住的锚点,即初步的有效锚点 (num_valid)

然后,计算代价矩阵 cost_matrix,以及 IoU 矩阵 pairwise_ious,均为**(num_valid, num_gts)**大小,即初步有效的锚点与真实标注的交叉矩阵

调用 dynamic_k_matching,根据 iou 排序,选出一个 gt_bbox 对应的 topk 个锚点,计算 iou 的和,将其作为 dynamic_k,将其作为该 gt_bbox 匹配的锚点个数(规定下限为1个,cost 最小的前 dynamic_k 个锚点),对每个 gt_bbox 均为同样的操作,如果存在一个锚点与多个 gt_bbox 匹配,则只保留代价最小的那一个 gt_bbox,并更新有效锚点为匹配了 gt_bbox 的锚点

最终得到锚点与 gt_bbox 的匹配,一个或多个有效锚点 priors 匹配一个 gt_bbox


更多的细节查看下方提供的代码即注释:

   def assign(self,pred_scores,		# [num_priors, num_classes]priors,				# [num_priors, 4]	 	[cx, cy, stride_x, stride_y]decoded_bboxes,		# [num_priors, 4]	 	[tl_x, tl_y, br_x, br_y]gt_bboxes,			# [num_gts, 4]		 	[tl_x, tl_y, br_x, br_y]gt_labels,			# [num_gts]gt_bboxes_ignore=None,):INF = 100000000num_gt = gt_bboxes.size(0)num_bboxes = decoded_bboxes.size(0)# assign 0 by default# 创建一个与 decoded_bboxes 在同一设备上# 长度为 num_bboxes, 类型为 torch.long 的一维向量assigned_gt_inds = decoded_bboxes.new_full((num_bboxes,), 0, dtype=torch.long)# 锚点中心 (N, 2)prior_center = priors[:, :2]# (N, M, 2) <= (N, 1, 2) - (M, 2)   广播规则lt_ = prior_center[:, None] - gt_bboxes[:, :2]# 同上 (N, M, 2)rb_ = gt_bboxes[:, 2:] - prior_center[:, None]# 合并左上角和右下角的相对位置信息 (N, M, 4)deltas = torch.cat([lt_, rb_], dim=-1)# 判断 N个 锚点是否在 M个 gt_bboxes 内部, 得到 (N, M) 尺寸的向量is_in_gts = deltas.min(dim=-1).values > 0# (N, M) => (N, ), 对每个锚点, 判断是否有一个 gt_bboxes 包含它# valid_mask 表示了用于预测的锚点是否在 gt_bboxes 内部	(num_priors)valid_mask = is_in_gts.sum(dim=1) > 0# 获取有效锚点 (在gt_bboxes内部) 的对应的 preds 	(label以及bbox)# (num_valid, 4)valid_decoded_bbox = decoded_bboxes[valid_mask]# (num_valid, num_classes)valid_pred_scores = pred_scores[valid_mask]			# 被 gt_bboxes 包含的锚点数量 num_validnum_valid = valid_decoded_bbox.size(0)# 如果没有 gt_bboxes, 没有预测框或者没有有效匹配, 则直接返回空的分配结果if num_gt == 0 or num_bboxes == 0 or num_valid == 0:# No ground truth or boxes, return empty assignmentmax_overlaps = decoded_bboxes.new_zeros((num_bboxes,))	# 0if num_gt == 0:# No truth, assign everything to backgroundassigned_gt_inds[:] = 0if gt_labels is None:assigned_labels = Noneelse:assigned_labels = decoded_bboxes.new_full((num_bboxes,), -1, dtype=torch.long)return AssignResult(num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)# 计算有效匹配的锚点预测的 bbox 与 gt_bboxes 之间的 IoU 矩阵# (num_valid, num_gts)  <- 	(num_valid, 4), (num_gts, 4)pairwise_ious = bbox_overlaps(valid_decoded_bbox, gt_bboxes)# 计算 IoU 的代价iou_cost = -torch.log(pairwise_ious + 1e-7)# 将真实的类别转换成 onehot 编码 (num_valid, num_gts, num_classes)gt_onehot_label = (F.one_hot(gt_labels.to(torch.int64), pred_scores.shape[-1]).float().unsqueeze(0)				# 在第一个维度上增加一个维度.repeat(num_valid, 1, 1)	# 第一个维度重复 num_valid 次)# 赋值有效类别的分数  (num_valid, num_classes) -># (num_valid, 1, num_classes) -> (num_valid, num_gts, num_classes)valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1)# 生成软标签, 考虑 IoU 权重 (num_valid, num_gts, num_classes)soft_label = gt_onehot_label * pairwise_ious[..., None]# 软标签(真实标签 * IoU) - 预测得分scale_factor = soft_label - valid_pred_scores.sigmoid()# 使用二元交叉熵损失计算分类损失 (num_valid, num_gts, num_classes)cls_cost = F.binary_cross_entropy_with_logits(valid_pred_scores, soft_label, reduction="none") * scale_factor.abs().pow(2.0)# (num_valid, num_gts)cls_cost = cls_cost.sum(dim=-1)# 计算总代价, cls 代价 + bbox 代价 (num_valid, num_gts)cost_matrix = cls_cost + iou_cost * self.iou_factor# 时刻记着: valid 指的是在 gt_bboxes 内的锚点 point 的索引# 根据代价矩阵, iou矩阵, 均为 (num_valid, num_gts)# 进行动态 K-matching, 得到匹配的部分锚点, 这些锚点每个都对应一个 gt_bbox# 每个锚点分配给一个 gt_bbox, 一个 gt_bbox 可以对应多个锚点matched_pred_ious, matched_gt_inds = self.dynamic_k_matching(cost_matrix, pairwise_ious, num_gt, valid_mask)# convert to AssignResult format		# matched_pred_ious 为锚点预测的 bbox 与匹配的 gt_bbox 的 iou# matched_gt_inds 为锚点匹配的 gt_bbox 的索引# 分配的 gt_bbox 的索引, 未分配的为 0(初始值)assigned_gt_inds[valid_mask] = matched_gt_inds + 1# 分配的标签		(num_priors)assigned_labels = assigned_gt_inds.new_full((num_bboxes,), -1)# 得到分配的类别 根据 gt 的索引确认对应的类别 	(num_priors)assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long()# 最大 IoU 	(num_priors)max_overlaps = assigned_gt_inds.new_full((num_bboxes,), -INF, dtype=torch.float32)# 填入有效锚点对应的 IoUmax_overlaps[valid_mask] = matched_pred_ious# 这里的判断默认情况下不会为 Trueif (self.ignore_iof_thr > 0					# 默认 -1 > 0and gt_bboxes_ignore is not Noneand gt_bboxes_ignore.numel() > 0and num_bboxes > 0):ignore_overlaps = bbox_overlaps(valid_decoded_bbox, gt_bboxes_ignore, mode="iof")ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)ignore_idxs = ignore_max_overlaps > self.ignore_iof_thrassigned_gt_inds[ignore_idxs] = -1# 返回 num_gts, 锚点 priors 分配的 gt 索引以及对应的 IoU, 匹配的 gt 对应的标签return AssignResult(num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)  
# 根据预测框与真实框之间 IoU 以及损失矩阵来进行匹配def dynamic_k_matching(self, cost, pairwise_ious, num_gt, valid_mask):"""Use sum of topk pred iou as dynamic k. Refer from OTA and YOLOX.Args:cost (Tensor): Cost matrix.pairwise_ious (Tensor): Pairwise iou matrix.num_gt (int): Number of gt.valid_mask (Tensor): Mask for valid bboxes."""# 初始化一个与 cost 同形状的匹配矩阵 (num_valid, num_gts)matching_matrix = torch.zeros_like(cost)# select candidate topk ious for dynamic-k calculationcandidate_topk = min(self.topk, pairwise_ious.size(0))# 选取每个真实框的前 topk 个最高 IoU 值     (candidate_topk, num_gt)topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0)# calculate dynamic k for each gt# 计算每个 gt 的包含的锚点 points 的 IoU 最高的前 topk 个值的和, 作为动态 k# 即这里的 k 会根据前 topk 个 iou的值变化   (num_gts)dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)# 进行动态匹配, 遍历 gt_bboxes# 对于每个 gt 挑选其中的 dymamic_k 个锚点进行匹配for gt_idx in range(num_gt):# 选取当前 gt_bbox 对应的损失矩阵中前 k 个最小损失值的索引, 即对应的锚点的索引# cost 维度为	 (num_priors, num_gts)_, pos_idx = torch.topk(cost[:, gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)# 将对应的匹配矩阵中的值置为 1, 对应的 dynamic_k 个锚点和该 gt 匹配 matching_matrix[:, gt_idx][pos_idx] = 1.0del topk_ious, dynamic_ks, pos_idx# matching_matrix 尺寸为 (num_priors, num_gt), 挑选出匹配的锚点    (num_priors)# 一个锚点与两个或更多个 gt_bbox 匹配	(num_priors)prior_match_gt_mask = matching_matrix.sum(1) > 1# 如果存在一个锚点和多个 gt_bboxes 匹配, 那么则选择代价最小的那一个if prior_match_gt_mask.sum() > 0:# 对于匹配多个 gt_bbox 的锚点, 选择代价最小的 gt_bbox 进行匹配		(num_priors) cost_min, cost_argmin = torch.min(cost[prior_match_gt_mask, :], dim=1)# 将匹配多个 gt_bbox 的锚点的匹配清空, 选择代价最小的那一个 gt_bboxmatching_matrix[prior_match_gt_mask, :] *= 0.0matching_matrix[prior_match_gt_mask, cost_argmin] = 1.0# 匹配了 gt_bbox 的锚点矩阵  (num_priors) 	# get foreground mask inside box and center priorfg_mask_inboxes = matching_matrix.sum(1) > 0.0# 更新有效 mask, valid_mask 表示的为匹配了 gt_bbox 的锚点 (num_priors)valid_mask[valid_mask.clone()] = fg_mask_inboxes# 获取有效匹配的每个预测框对应的 gt_bbox 的索引 maching_matrix     (num_priors, num_gts)# argmax 获取最大的那一个值的索引      (num_valid_priros, num_gts) -> (num_valid_priors)matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)# 计算有效匹配的每个预测框与 gt_bbox 之间的 IoU    # (num_valid, num_gts) -> (num_valid) -> (num_valid_priors)matched_pred_ious = (matching_matrix * pairwise_ious).sum(1)[fg_mask_inboxes]# 每个 prior 与一个 gt_bbox 对应 (一个 gt_bbox 可以对应多个预测框)return matched_pred_ious, matched_gt_inds

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/671263.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言基础语法..

1.函数的基本语法 函数的格式为&#xff1a; 返回值类型 函数名(参数列表){ 函数体(包括返回值语句) } 利用上述的格式 我们可以自己整一个实现加法功能的函数 int add(int a, int b){return a b; } int main(){int c add(10, 20);printf("%d", c);// 30return …

基于Vue2用keydown、setTimeout事件实现连续按键(连击)任意键(或组合键)3秒触发自定义事件(以F1键为例)

核心代码 <template></template> <script> export default {created() {//监听弹起快捷键addEventListener("keyup", this.keyup);},destroyed(d) {//移除监听弹起快捷键removeEventListener("keyup", this.keyup);},methods: {keyup(…

golang开发window环境搭建

1.本人开发环境&#xff1a;window10,idea2020.1.3 2.Go语言环境版本1.5.1 2.1. go语言插件 下载地址 csdn - 安全中心 2.1.1 go的各个版本官网Other Versions - GoLand 2.2下载安装 3.idea配置go环境 4.创建go项目 、5.运行

Linux 性能调优之虚拟化(Virtualization tuned)调优

写在前面 考试整理相关笔记博文内容涉及Linux 虚拟化常见管理操作以及部分调优配置理解不足小伙伴帮忙指正 不必太纠结于当下&#xff0c;也不必太忧虑未来&#xff0c;当你经历过一些事情的时候&#xff0c;眼前的风景已经和从前不一样了。——村上春树 使用工具进行调优 可以…

BLEUScore AttributeError: ‘list‘ object has no attribute ‘split‘——问题解决

目录 问题解决 问题 出现错误&#xff1a; BLEUScore AttributeError: ‘list’ object has no attribute ‘split’ 解决 应该是torchmetrics版本对torch的要求&#xff0c;需要对应版本: pip install torchmetrics0.6.2具体需要根据自己版本去降低&#xff0c;一般是往低…

Unet 实战分割项目、多尺度训练、多类别分割

1. 介绍 之前写了篇二值图像分割的项目&#xff0c;支持多尺度训练&#xff0c;网络采用backbone为vgg的unet网络。缺点就是没法实现多类别的分割&#xff0c;具体可以参考&#xff1a;二值图像分割统一项目 本章只对增加的代码进行介绍&#xff0c;其余的参考上述链接博文 本…

在本地运行大型语言模型 (LLM) 的六种方法(2024 年 1 月)

一、说明 &#xff08;开放&#xff09;本地大型语言模型&#xff08;LLM&#xff09;&#xff0c;特别是在 Meta 发布LLaMA和后Llama 2&#xff0c;变得越来越好&#xff0c;并且被越来越广泛地采用。 在本文中&#xff0c;我想演示在本地&#xff08;即在您的计算机上&#x…

DataX详解和架构介绍

系列文章目录 一、 DataX详解和架构介绍 二、 DataX源码分析 JobContainer 三、DataX源码分析 TaskGroupContainer 四、DataX源码分析 TaskExecutor 五、DataX源码分析 reader 六、DataX源码分析 writer 七、DataX源码分析 Channel 文章目录 系列文章目录DataX是什么&#xff…

【QT】VS-code报错:LNK2019: 无法解析的外部符号

目录 0.环境 1.问题简述 2.分析报错原因 3.解决方法 1&#xff09;set() 相关语句 2&#xff09;target_link_libraries() 相关语句 4.参考 0.环境 windows11 、 vs-code 、 qt 、 c、编译器为vs2019-x86_amd64 1.问题简述 项目编译release版本时会报错&#xff1a;报错…

页面单跳转换率统计案例分析

需求说明 页面单跳转化率 计算页面单跳转化率&#xff0c;什么是页面单跳转换率&#xff0c;比如一个用户在一次 Session 过程中访问的页面路径 3,5,7,9,10,21&#xff0c;那么页面 3 跳到页面 5 叫一次单跳&#xff0c;7-9 也叫一次单跳&#xff0c; 那么单跳转化率就是要统计…

c语言--指针的传值调用和传址调用

目录 一、前言二、传值调用。三、传址调用四、总结 一、前言 学习指针的目的是使用指针解决问题&#xff0c;那什么问题&#xff0c;非指针不可呢&#xff1f; 二、传值调用。 写个函数&#xff0c;交换两个整数的内容。 #include<stdio.h> void Swap1(int x, int y)…

在 Java 中处理整数上溢和下溢

本文介绍整数数据类型的上溢和下溢以及该问题的处理。 Java 中整数上溢和下溢概述 如果您使用整数值&#xff0c;则可能会遇到上溢或下溢错误。 当我们错误地声明变量时&#xff0c;就会发生这种情况&#xff0c;例如分配的值超出了声明的数据类型的范围。 众所周知&#xff…

LabVIEW双光子荧光显微成像系统开发

双光子显微成像是一种高级荧光显微技术&#xff0c;广泛用于生物学和医学研究&#xff0c;尤其是用于活体组织的深层成像。在双光子成像过程中&#xff0c;振镜&#xff08;Galvo镜&#xff09;扮演了非常关键的角色&#xff0c;它负责精确控制激光束在样本上的扫描路径。以下是…

读分布式稳定性建设指南文档

最近还是在做一些和稳定性建设相关的事情&#xff0c;找到一份《分布式稳定性建设指南》文档&#xff0c;摘抄了其中的重点&#xff0c;以便后续回顾方便&#xff0c;一直没上传好资源&#xff0c;我之后再试试&#xff0c;原文内容质量非常高。 大家可以先看一级目录即可&…

掌握Web服务器之王:Nginx 学习网站全攻略!

介绍&#xff1a;Nginx是一款高性能的Web服务器&#xff0c;同时也是一个反向代理、负载均衡和HTTP缓存服务器。具体介绍如下&#xff1a; 轻量级设计&#xff1a;Nginx的设计理念是轻量级&#xff0c;这意味着它在占用最少的系统资源的同时提供高效的服务。 高并发能力&#x…

go 内存二进制数据操作

go 内存二进制数据操作 go 内存二进制数据直接操作 以数字类型为例 int(linux/macos 为int32,windows 为int64). 如果不清楚可以使用unsafe.Sizeof函数来查看(函数出来的值*8就是int位数) 若不使用内存二进制数据操作&#xff0c;你需要在每次获取数字内容时调用binary.Big…

五、机器学习模型及其实现1

1_机器学习 1&#xff09;基础要求&#xff1a;所有的数据全部变为了特征&#xff0c;而不是eeg信号了 python基础已经实现了特征提取、特征选择&#xff08;可选&#xff09;进行了数据预处理.预处理指对数据进行清洗、转换等处理&#xff0c;使数据更适合机器学习的工具。S…

完全背包总结二

1.完全背包和0/1背包的区别&#xff1f; 完全背包的物体有无限个&#xff0c;可以多次放入 0/1背包的物体只有一个&#xff0c;只能放入一次 2.关于物品遍历顺序 在0/1背包中为了防止物品被重复放入&#xff0c;所以选择倒序遍历背包 而完全背包中&#xff0c;可以重复放入…

Datax3.0+DataX-Web部署分布式可视化ETL系统

一、DataX 简介 DataX 是阿里云 DataWorks 数据集成的开源版本&#xff0c;主要就是用于实现数据间的离线同步。DataX 致力于实现包括关系型数据库&#xff08;MySQL、Oracle 等&#xff09;、HDFS、Hive、ODPS、HBase、FTP 等各种异构数据源&#xff08;即不同的数据库&#x…

找单身狗(C语言)

题目叙述&#xff1a; 一个数组中只有两个数字是出现一次&#xff0c;其他所有数字都出现了两次。 编写一个函数找出这两个只出现一次的数字。 例如&#xff1a; 数组的元素是&#xff1a;1&#xff0c;2&#xff0c;3&#xff0c;4&#xff0c;5&#xff0c;1&#xff0c;…