标签分配
部分内容参考自:https://www.bilibili.com/video/BV1ge41117va
简单介绍一些特点,主要结合动态标签分配的一个实例来看
从更高抽象的层面理解 assign
:
所有用于最终检测的特征图上的所有 point 都具备学习并预测目标的能力,在给定一幅图像及其目标 gt bbox 的情况下,为每个目标 gt bbox 选择恰当的特征图 point 进行学习预测的过程就是分配。
个人理解就是:
每个 anchor 锚点都有预测 bbox 的能力,对一张图像来说,将先验框 gt_bbox 与合适的锚点 points 进行匹配,训练 points 来预测。
一个锚点 point 分配给一个 gt_bbox (即标注框),但是一个 gt_bbox 可以和多个锚点 points 进行匹配
前提:
仅当感受野中心命中 gt bbox 的 point 才有可能被选用来预测这个 bbox
个人理解是:anchor 锚点要位于 gt_bbox 中才能被用来预测
两种类型的匹配机制:
基于规则的分配、自动分配(与网络的输出有关)
目标匹配是 One-stage Anchor-free 检测器核心中的核心!!!
实例分析:
以 NanoDet-Plus
使用的 DynamicSoftLabelAssigner
来分析,这里主要分析动态分配:
class DynamicSoftLabelAssigner(BaseAssigner):"""Computes matching between predictions and ground truth withdynamic soft label assignment.使用动态软标签分配计算预测与真实值之间的匹配Args:topk (int): Select top-k predictions to calculate dynamic kbest matchs for each gt. Default 13.为每个 gt 选择 k 个最佳预测来计算动态 k 最佳匹配。默认值为13iou_factor (float): The scale factor of iou cost. Default 3.0.IoU 代价的缩放因子。默认值为3.0ignore_iof_thr (int): whether ignore max overlaps or not.Default -1 (1 or -1).是否忽略最大重叠"""def __init__(self, topk=13, iou_factor=3.0, ignore_iof_thr=-1):self.topk = topkself.iou_factor = iou_factorself.ignore_iof_thr = ignore_iof_thr
以 gt 开头的变量为真实标注信息
num_priors 即锚点 point 的数量,num_gts 即一幅图中真实标注的数量,其中 decoded_bboxes 为根据 preds 信息预测的 bboxes
将 point 和 gt 进行匹配,还是刚才所说的,一个 point 对应一个 gt,一个 gt 可以对应多个 point
简单总结一下过程(具体详细的内容看代码及注释):
以下用()包住的内容为张量尺寸大小,对于理解也十分有帮助
重点部分
:
首先,初步选出可能匹配的锚点,在所有锚点(num_priors)中选出 gt_bboxes 包住的锚点,即初步的有效锚点 (num_valid)
然后,计算代价矩阵 cost_matrix,以及 IoU 矩阵 pairwise_ious,均为**(num_valid, num_gts)**大小,即初步有效的锚点与真实标注的交叉矩阵
调用 dynamic_k_matching,根据 iou 排序,选出一个 gt_bbox 对应的 topk 个锚点,计算 iou 的和,将其作为 dynamic_k,将其作为该 gt_bbox 匹配的锚点个数(规定下限为1个,cost 最小的前 dynamic_k 个锚点),对每个 gt_bbox 均为同样的操作,如果存在一个锚点与多个 gt_bbox 匹配,则只保留代价最小的那一个 gt_bbox,并更新有效锚点为匹配了 gt_bbox 的锚点
最终得到锚点与 gt_bbox 的匹配,一个或多个有效锚点 priors 匹配一个 gt_bbox
更多的细节查看下方提供的代码即注释:
def assign(self,pred_scores, # [num_priors, num_classes]priors, # [num_priors, 4] [cx, cy, stride_x, stride_y]decoded_bboxes, # [num_priors, 4] [tl_x, tl_y, br_x, br_y]gt_bboxes, # [num_gts, 4] [tl_x, tl_y, br_x, br_y]gt_labels, # [num_gts]gt_bboxes_ignore=None,):INF = 100000000num_gt = gt_bboxes.size(0)num_bboxes = decoded_bboxes.size(0)# assign 0 by default# 创建一个与 decoded_bboxes 在同一设备上# 长度为 num_bboxes, 类型为 torch.long 的一维向量assigned_gt_inds = decoded_bboxes.new_full((num_bboxes,), 0, dtype=torch.long)# 锚点中心 (N, 2)prior_center = priors[:, :2]# (N, M, 2) <= (N, 1, 2) - (M, 2) 广播规则lt_ = prior_center[:, None] - gt_bboxes[:, :2]# 同上 (N, M, 2)rb_ = gt_bboxes[:, 2:] - prior_center[:, None]# 合并左上角和右下角的相对位置信息 (N, M, 4)deltas = torch.cat([lt_, rb_], dim=-1)# 判断 N个 锚点是否在 M个 gt_bboxes 内部, 得到 (N, M) 尺寸的向量is_in_gts = deltas.min(dim=-1).values > 0# (N, M) => (N, ), 对每个锚点, 判断是否有一个 gt_bboxes 包含它# valid_mask 表示了用于预测的锚点是否在 gt_bboxes 内部 (num_priors)valid_mask = is_in_gts.sum(dim=1) > 0# 获取有效锚点 (在gt_bboxes内部) 的对应的 preds (label以及bbox)# (num_valid, 4)valid_decoded_bbox = decoded_bboxes[valid_mask]# (num_valid, num_classes)valid_pred_scores = pred_scores[valid_mask] # 被 gt_bboxes 包含的锚点数量 num_validnum_valid = valid_decoded_bbox.size(0)# 如果没有 gt_bboxes, 没有预测框或者没有有效匹配, 则直接返回空的分配结果if num_gt == 0 or num_bboxes == 0 or num_valid == 0:# No ground truth or boxes, return empty assignmentmax_overlaps = decoded_bboxes.new_zeros((num_bboxes,)) # 0if num_gt == 0:# No truth, assign everything to backgroundassigned_gt_inds[:] = 0if gt_labels is None:assigned_labels = Noneelse:assigned_labels = decoded_bboxes.new_full((num_bboxes,), -1, dtype=torch.long)return AssignResult(num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)# 计算有效匹配的锚点预测的 bbox 与 gt_bboxes 之间的 IoU 矩阵# (num_valid, num_gts) <- (num_valid, 4), (num_gts, 4)pairwise_ious = bbox_overlaps(valid_decoded_bbox, gt_bboxes)# 计算 IoU 的代价iou_cost = -torch.log(pairwise_ious + 1e-7)# 将真实的类别转换成 onehot 编码 (num_valid, num_gts, num_classes)gt_onehot_label = (F.one_hot(gt_labels.to(torch.int64), pred_scores.shape[-1]).float().unsqueeze(0) # 在第一个维度上增加一个维度.repeat(num_valid, 1, 1) # 第一个维度重复 num_valid 次)# 赋值有效类别的分数 (num_valid, num_classes) -># (num_valid, 1, num_classes) -> (num_valid, num_gts, num_classes)valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1)# 生成软标签, 考虑 IoU 权重 (num_valid, num_gts, num_classes)soft_label = gt_onehot_label * pairwise_ious[..., None]# 软标签(真实标签 * IoU) - 预测得分scale_factor = soft_label - valid_pred_scores.sigmoid()# 使用二元交叉熵损失计算分类损失 (num_valid, num_gts, num_classes)cls_cost = F.binary_cross_entropy_with_logits(valid_pred_scores, soft_label, reduction="none") * scale_factor.abs().pow(2.0)# (num_valid, num_gts)cls_cost = cls_cost.sum(dim=-1)# 计算总代价, cls 代价 + bbox 代价 (num_valid, num_gts)cost_matrix = cls_cost + iou_cost * self.iou_factor# 时刻记着: valid 指的是在 gt_bboxes 内的锚点 point 的索引# 根据代价矩阵, iou矩阵, 均为 (num_valid, num_gts)# 进行动态 K-matching, 得到匹配的部分锚点, 这些锚点每个都对应一个 gt_bbox# 每个锚点分配给一个 gt_bbox, 一个 gt_bbox 可以对应多个锚点matched_pred_ious, matched_gt_inds = self.dynamic_k_matching(cost_matrix, pairwise_ious, num_gt, valid_mask)# convert to AssignResult format # matched_pred_ious 为锚点预测的 bbox 与匹配的 gt_bbox 的 iou# matched_gt_inds 为锚点匹配的 gt_bbox 的索引# 分配的 gt_bbox 的索引, 未分配的为 0(初始值)assigned_gt_inds[valid_mask] = matched_gt_inds + 1# 分配的标签 (num_priors)assigned_labels = assigned_gt_inds.new_full((num_bboxes,), -1)# 得到分配的类别 根据 gt 的索引确认对应的类别 (num_priors)assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long()# 最大 IoU (num_priors)max_overlaps = assigned_gt_inds.new_full((num_bboxes,), -INF, dtype=torch.float32)# 填入有效锚点对应的 IoUmax_overlaps[valid_mask] = matched_pred_ious# 这里的判断默认情况下不会为 Trueif (self.ignore_iof_thr > 0 # 默认 -1 > 0and gt_bboxes_ignore is not Noneand gt_bboxes_ignore.numel() > 0and num_bboxes > 0):ignore_overlaps = bbox_overlaps(valid_decoded_bbox, gt_bboxes_ignore, mode="iof")ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)ignore_idxs = ignore_max_overlaps > self.ignore_iof_thrassigned_gt_inds[ignore_idxs] = -1# 返回 num_gts, 锚点 priors 分配的 gt 索引以及对应的 IoU, 匹配的 gt 对应的标签return AssignResult(num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
# 根据预测框与真实框之间 IoU 以及损失矩阵来进行匹配def dynamic_k_matching(self, cost, pairwise_ious, num_gt, valid_mask):"""Use sum of topk pred iou as dynamic k. Refer from OTA and YOLOX.Args:cost (Tensor): Cost matrix.pairwise_ious (Tensor): Pairwise iou matrix.num_gt (int): Number of gt.valid_mask (Tensor): Mask for valid bboxes."""# 初始化一个与 cost 同形状的匹配矩阵 (num_valid, num_gts)matching_matrix = torch.zeros_like(cost)# select candidate topk ious for dynamic-k calculationcandidate_topk = min(self.topk, pairwise_ious.size(0))# 选取每个真实框的前 topk 个最高 IoU 值 (candidate_topk, num_gt)topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0)# calculate dynamic k for each gt# 计算每个 gt 的包含的锚点 points 的 IoU 最高的前 topk 个值的和, 作为动态 k# 即这里的 k 会根据前 topk 个 iou的值变化 (num_gts)dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)# 进行动态匹配, 遍历 gt_bboxes# 对于每个 gt 挑选其中的 dymamic_k 个锚点进行匹配for gt_idx in range(num_gt):# 选取当前 gt_bbox 对应的损失矩阵中前 k 个最小损失值的索引, 即对应的锚点的索引# cost 维度为 (num_priors, num_gts)_, pos_idx = torch.topk(cost[:, gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)# 将对应的匹配矩阵中的值置为 1, 对应的 dynamic_k 个锚点和该 gt 匹配 matching_matrix[:, gt_idx][pos_idx] = 1.0del topk_ious, dynamic_ks, pos_idx# matching_matrix 尺寸为 (num_priors, num_gt), 挑选出匹配的锚点 (num_priors)# 一个锚点与两个或更多个 gt_bbox 匹配 (num_priors)prior_match_gt_mask = matching_matrix.sum(1) > 1# 如果存在一个锚点和多个 gt_bboxes 匹配, 那么则选择代价最小的那一个if prior_match_gt_mask.sum() > 0:# 对于匹配多个 gt_bbox 的锚点, 选择代价最小的 gt_bbox 进行匹配 (num_priors) cost_min, cost_argmin = torch.min(cost[prior_match_gt_mask, :], dim=1)# 将匹配多个 gt_bbox 的锚点的匹配清空, 选择代价最小的那一个 gt_bboxmatching_matrix[prior_match_gt_mask, :] *= 0.0matching_matrix[prior_match_gt_mask, cost_argmin] = 1.0# 匹配了 gt_bbox 的锚点矩阵 (num_priors) # get foreground mask inside box and center priorfg_mask_inboxes = matching_matrix.sum(1) > 0.0# 更新有效 mask, valid_mask 表示的为匹配了 gt_bbox 的锚点 (num_priors)valid_mask[valid_mask.clone()] = fg_mask_inboxes# 获取有效匹配的每个预测框对应的 gt_bbox 的索引 maching_matrix (num_priors, num_gts)# argmax 获取最大的那一个值的索引 (num_valid_priros, num_gts) -> (num_valid_priors)matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)# 计算有效匹配的每个预测框与 gt_bbox 之间的 IoU # (num_valid, num_gts) -> (num_valid) -> (num_valid_priors)matched_pred_ious = (matching_matrix * pairwise_ious).sum(1)[fg_mask_inboxes]# 每个 prior 与一个 gt_bbox 对应 (一个 gt_bbox 可以对应多个预测框)return matched_pred_ious, matched_gt_inds