本篇文章首先介绍目标检测任务中的评价指标
混淆矩阵
的概念,然后介绍其在yolo源码
中的实现方法。
目标检测中的评价指标:
mAP概念及其计算方法(yolo源码/pycocotools)
混淆矩阵概念及其计算方法(yolo源码)
本文目录
- 1 概念
- 2 计算方法
1 概念
在分类任务中,混淆矩阵(Confusion Matrix)
是一种可视化工具,主要用于评价模型精度,将模型的分类结果显示在一个矩阵中。多分类任务的混淆矩阵
结构如图1所示,其中横轴表示模型预测结果,纵轴表示实际结果,图中的各类指标以cls_1的预测结果为例,其含义如下:
True Positive(TP)
:预测为正样本(cls_1
),且实际为正样本(cls_1
)各类别TP:混淆矩阵对角线的值
False Positive(FP)
:预测为正样本(cls_1
),但实际为负样本(cls_other
)各类别FP:混淆矩阵每列的和减去对应的TP
False Negative(FN)
:预测为负样本(cls_other
),但实际为正样本(cls_1
)各类别(FN:混淆矩阵每行的和减去对应的TP
True Negative(TN)
: 预测为负样本(cls_other
),且实际为负样本(cls_other
)各类别FN:混淆矩阵的和减去对应的TP、FP、FN
目标检测的任务为对目标进行分类
与定位
,模型的预测结果p为(cls, conf, pos),其中cls为目标的类别,conf为目标属于该类别的置信度,pos为目标的预测边框。目标检测任务综合类别预测结果
和预测边框与实际边框IoU
,对模型进行评价,其混淆矩阵结构如图2所示,图中的各类指标以cls_1的预测结果为例,其含义如下:
- 样本匹配(每一张图片):预测结果
gt
与实际结果dt
匹配IoU > IoU_thres
- 同一个
gt
至多匹配一个p
(若一个gt
匹配到多个p
,则选择IoU
最高的p
作为匹配结果) - 同一个
gt
至多匹配一个p
(若一个p
匹配到多个gt
,则选择IoU
最高的gt
作为匹配结果)
background
: 未成功匹配的gt
或dt
True Positive(TP)
:匹配结果为正样本(cls_1
),且实际为正样本(cls_1
)False Positive(FP)
:匹配结果正样本(cls_1
),但实际为负样本(cls_1 or background
)False Negative(FN)
:匹配结果为负样本(cls_other or backgroun
),但实际为正样本(cls_1
)True Negative(TN)
:匹配结果为负样本(cls_other or backgroun
),且实际为负样本(cls_other or backgroun
)
目标检测任务中的混淆矩阵
计算方法如图3所示。
2 计算方法
基于YOLO源码实现
混淆矩阵
计算(ConfusionMatrix
)
- 函数
- process_batch:实现预测结果与真实结果的匹配,混淆矩阵计算
- plot:混淆矩阵绘制
- tp_fp:根据混淆矩阵计算
TP/FP
class ConfusionMatrix:# Updated version of https://github.com/kaanakan/object_detection_confusion_matrixdef __init__(self, nc, conf=0.25, iou_thres=0.5):self.matrix = np.zeros((nc + 1, nc + 1))self.nc = nc # number of classesself.conf = conf # 类别置信度self.iou_thres = iou_thres # IoU置信度def process_batch(self, detections, labels):"""Return intersection-ove-unionr (Jaccard index) of boxes.Both sets of boxes are expected to be in (x1, y1, x2, y2) format.Arguments:detections (Array[N, 6]), x1, y1, x2, y2, conf, classlabels (Array[M, 5]), class, x1, y1, x2, y2Returns:None, updates confusion matrix accordingly"""if detections is None:gt_classes = labels.int()for gc in gt_classes:self.matrix[self.nc, gc] += 1 # 预测为背景,但实际为目标returndetections = detections[detections[:, 4] > self.conf] # 小于该conf认为为背景gt_classes = labels[:, 0].int() # 实际类别detection_classes = detections[:, 5].int() # 预测类别iou = box_iou(labels[:, 1:], detections[:, :4]) # 计算所有结果的IoUx = torch.where(iou > self.iou_thres) # 根据IoU匹配结果,返回满足条件的索引 x(dim0), (dim1)if x[0].shape[0]: # x[0]:存在为True的索引(gt索引), x[1]当前所有下True的索引(dt索引)# shape:[n, 3] 3->[label, detect, iou]matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()if x[0].shape[0] > 1:matches = matches[matches[:, 2].argsort()[::-1]] # 根据IoU从大到小排序matches = matches[np.unique(matches[:, 1], return_index=True)[1]] # 若一个dt匹配多个gt,保留IoU最高的gt匹配结果matches = matches[matches[:, 2].argsort()[::-1]] # 根据IoU从大到小排序matches = matches[np.unique(matches[:, 0], return_index=True)[1]] # 若一个gt匹配多个dt,保留IoU最高的dt匹配结果else:matches = np.zeros((0, 3))n = matches.shape[0] > 0 # 是否存在和gt匹配成功的dtm0, m1, _ = matches.transpose().astype(int) # m0:gt索引 m1:dt索引for i, gc in enumerate(gt_classes): # 实际的结果j = m0 == i # 预测为该目标的预测结果序号if n and sum(j) == 1: # 该实际结果预测成功self.matrix[detection_classes[m1[j]], gc] += 1 # 预测为目标,且实际为目标else: # 该实际结果预测失败self.matrix[self.nc, gc] += 1 # 预测为背景,但实际为目标if n:for i, dc in enumerate(detection_classes): # 对预测结果处理if not any(m1 == i): # 若该预测结果没有和实际结果匹配self.matrix[dc, self.nc] += 1 # 预测为目标,但实际为背景def tp_fp(self):tp = self.matrix.diagonal() # true positivesfp = self.matrix.sum(1) - tp # false positives# fn = self.matrix.sum(0) - tp # false negatives (missed detections)return tp[:-1], fp[:-1] # remove background class@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')def plot(self, normalize=True, save_dir='', names=()):import seaborn as snplt.rc('font', family='Times New Roman', size=15)array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columnsarray[array < 0.005] = 0.00 # don't annotate (would appear as 0.00)fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)nc, nn = self.nc, len(names) # number of classes, namessn.set(font_scale=1.0 if nc < 50 else 0.8) # for label sizelabels = (0 < nn < 99) and (nn == nc) # apply names to ticklabelsticklabels = (names + ['background']) if labels else 'auto'with warnings.catch_warnings():warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encounteredh = sn.heatmap(array,ax=ax,annot=nc < 30,annot_kws={'size': 20},cmap='Reds',fmt='.2f',linewidths=2,square=True,vmin=0.0,xticklabels=ticklabels,yticklabels=ticklabels,)h.set_facecolor((1, 1, 1))cb = h.collections[0].colorbar # 显示colorbarcb.ax.tick_params(labelsize=20) # 设置colorbar刻度字体大小。plt.xticks(fontsize=20)plt.yticks(fontsize=20)plt.rcParams["font.sans-serif"] = ["SimSun"]plt.rcParams["axes.unicode_minus"] = Falseax.set_xlabel('实际值')ax.set_ylabel('预测值')# ax.set_title('Confusion Matrix', fontsize=20)fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=100)plt.close(fig)def print(self):for i in range(self.nc + 1):print(' '.join(map(str, self.matrix[i])))