目录
一 回归损失函数(Bounding Box Regression Loss)
1 Inner-IoU
2 Focaler-IoU:更聚焦的IoU损失
二 改进YOLOv8的损失函数
1 总体修改
① ultralytics/utils/metrics.py文件
② ultralytics/utils/loss.py文件
③ ultralytics/utils/tal.py文件
2 各种机制的使用
其他
一 回归损失函数(Bounding Box Regression Loss)
1 Inner-IoU
官方论文地址:官方论文地址 点击即可跳转
官方代码地址:官方代码地址 点击即可跳转
论文中分析了边界框的回归过程,指出了IoU损失的局限性,它对不同的检测任务没有很强的泛化能力。基于边界框回归问题的固有特点,提出了一种基于辅助边界框的边界框回归损失Inner-IoU。通过比例因子比率(scale factor ratio)控制辅助边界框的生成,计算损失,加速训练的收敛。它可以集成到现有的基于IoU的损失函数中。通过一系列的模拟和烧蚀消融实验验证,该方法优于现有方法。本文提出的方法不仅适用于一般的检测任务,而且对于非常小目标的检测任务也表现良好,证实了该方法的泛化性。
官方的代码给出了2种结合方式,文件如下图:
Inner-IoU的描述见下图:
Inner-IoU的实验效果
CIoU 方法, Inner-CIoU (ratio=0.7), Inner-CIoU (ratio=0.75) and Inner-CIoU (ratio=0.8)的检测效果如下图所示:
SIoU 方法, Inner-SIoU (ratio=0.7), Inner-SIoU (ratio=0.75) and Inner-SIoU (ratio=0.8)的检测效果如下图所示:
2 Focaler-IoU:更聚焦的IoU损失
官方论文地址:官方论文地址 点击即可跳转
官方代码地址:官方代码地址 点击即可跳转
论文中分析了难易样本的分布对目标检测的影响。当困难样品占主导地位时,需要关注困难样品以提高检测性能。当简单样本的比例较大时,则相反。论文中提出了Focaler-IoU方法,通过线性区间映射重建原始IoU损失,达到聚焦难易样本的目的。最后通过对比实验证明,该方法能够有效提高检测性能。
为了在不同的回归样本中关注不同的检测任务,使用线性间隔映射方法重构IoU损失,这有助于提高边缘回归。具体的公式如下所示:
将Focaler-IoU应用于现有的基于IoU的边界框回归损失函数中,如下所示:
实验结果如下:
GIoU、DIoU、CIoU、EIoU和MPDIou等的概述见使用MPDIou回归损失函数帮助YOLOv9模型更优秀 点击此处即可跳转
二 改进YOLOv8的损失函数
1 总体修改
首先,我们现将后续会使用到的损失函数集成到项目中。
① ultralytics/utils/metrics.py文件
在utils/metrics.py文件中,使用下述代码(替换后的部分)替换掉bbox_iou()函数,即将被替换的bbox_iou()函数如下图所示:
使用下述的替换代码替换掉下述原始代码。
- a 原始代码
# before
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):"""Calculate Intersection over Union (IoU) of box1(1, 4) to box2(n, 4).Args:box1 (torch.Tensor): A tensor representing a single bounding box with shape (1, 4).box2 (torch.Tensor): A tensor representing n bounding boxes with shape (n, 4).xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in(x1, y1, x2, y2) format. Defaults to True.GIoU (bool, optional): If True, calculate Generalized IoU. Defaults to False.DIoU (bool, optional): If True, calculate Distance IoU. Defaults to False.CIoU (bool, optional): If True, calculate Complete IoU. Defaults to False.eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.Returns:(torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags."""# Get the coordinates of bounding boxesif xywh: # transform from xywh to xyxy(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_else: # x1, y1, x2, y2 = box1b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + epsw2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps# Intersection areainter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp_(0)# Union Areaunion = w1 * h1 + w2 * h2 - inter + eps# IoUiou = inter / unionif CIoU or DIoU or GIoU:cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) widthch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex heightif CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1c2 = cw.pow(2) + ch.pow(2) + eps # convex diagonal squaredrho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2).pow(2) + (b2_y1 + b2_y2 - b1_y1 - b1_y2).pow(2)) / 4 # center dist**2if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2)with torch.no_grad():alpha = v / (v - iou + (1 + eps))return iou - (rho2 / c2 + v * alpha) # CIoUreturn iou - rho2 / c2 # DIoUc_area = cw * ch + eps # convex areareturn iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdfreturn iou # IoU
- b 替换代码
# after
class WIoU_Scale:''' monotonous: {None: origin v1True: monotonic FM v2False: non-monotonic FM v3}'''iou_mean = 1.monotonous = False_momentum = 1 - 0.5 ** (1 / 7000)_is_train = Truedef __init__(self, iou):self.iou = iouself._update(self)@classmethoddef _update(cls, self):if cls._is_train: cls.iou_mean = (1 - cls._momentum) * cls.iou_mean + \cls._momentum * self.iou.detach().mean().item()@classmethoddef _scaled_loss(cls, self, gamma=1.9, delta=3):if isinstance(self.monotonous, bool):if self.monotonous:return (self.iou.detach() / self.iou_mean).sqrt()else:beta = self.iou.detach() / self.iou_meanalpha = delta * torch.pow(gamma, beta - delta)return beta / alphareturn 1def bbox_iou(box1, box2, xywh=True, ratio=1, GIoU=False, DIoU=False, CIoU=False,SIoU=False, EIoU=False, WIoU=False, MPDIoU=False, LMPDIoU=False,Inner=False, Focal=False, alpha=1, gamma=0.5, scale=False, eps=1e-7):# 计算box1与box2之间的Intersection over Union(IoU)# 获取bounding box的坐标if Inner:(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_ * ratio, x1 + w1_ * ratio, \y1 - h1_ * ratio, y1 + h1_ * ratiob2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_ * ratio, x2 + w2_ * ratio, \y2 - h2_ * ratio, y2 + h2_ * ratio# 计算交集面积inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)# 计算并集面积union = w1 * ratio * h1 * ratio + w2 * ratio * h2 * ratio - inter + epsiou = inter / union # inner_iouelse:# Returns the IoU of box1 to box2. box1 is 4, box2 is nx4if xywh: # xywh转换为xyxy格式(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_else: # x1, y1, x2, y2 = box1b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + epsw2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps# 计算交集面积inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * \(b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp_(0)# 计算并集面积union = w1 * h1 + w2 * h2 - inter + eps# 计算IoU值iou = inter / unionif CIoU or DIoU or GIoU or EIoU or SIoU or WIoU or MPDIoU or LMPDIoU:cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # 计算最小外接矩形的宽度ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # 计算最小外接矩形的高度if CIoU or DIoU or EIoU or SIoU or WIoU or MPDIoU or LMPDIoU: # Distance or Complete IoUc2 = (cw ** 2 + ch ** 2) ** alpha + eps # convex diagonal squaredrho2 = (((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4) ** alpha # 中心点距离的平方if CIoU:v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)with torch.no_grad():alpha_ciou = v / (v - iou + (1 + eps))if Focal:return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)), torch.pow(inter / (union + eps), gamma) # Focal_CIoU的计算else:return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)) # CIoUelif MPDIoU:d1 = (b2_x1 - b1_x1) ** 2 + (b2_y1 - b1_y1) ** 2d2 = (b2_x2 - b1_x2) ** 2 + (b2_y2 - b1_y2) ** 2w = (b2_x2 - b2_x1) # x2 - x1h = (b2_y2 - b2_y1) # y2 - y1if Focal:return iou - ((d1 + d2) / (w ** 2 + h ** 2)), torch.pow(inter / (union + eps), gamma)# Focal_MPDIoUelse:return iou - (d1 + d2) / (w ** 2 + h ** 2)elif LMPDIoU:d1 = (b2_x1 - b1_x1) ** 2 + (b2_y1 - b1_y1) ** 2d2 = (b2_x2 - b1_x2) ** 2 + (b2_y2 - b1_y2) ** 2w = (b2_x2 - b2_x1) # x2 - x1h = (b2_y2 - b2_y1) # y2 - y1if Focal:return 1 - (iou - (d1 + d2) / (w ** 2 + h ** 2)), torch.pow(inter / (union + eps), gamma)# Focal_MPDIo # MPDIoUelse:return 1 - iou + d1 / (w ** 2 + h ** 2) + d2 / (w ** 2 + h ** 2)elif EIoU:rho_w2 = ((b2_x2 - b2_x1) - (b1_x2 - b1_x1)) ** 2rho_h2 = ((b2_y2 - b2_y1) - (b1_y2 - b1_y1)) ** 2cw2 = torch.pow(cw ** 2 + eps, alpha)ch2 = torch.pow(ch ** 2 + eps, alpha)if Focal:return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2), torch.pow(inter / (union + eps),gamma) # Focal_EIouelse:return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2) # EIouelif SIoU:# SIoUs_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + epss_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + epssigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)sin_alpha_1 = torch.abs(s_cw) / sigmasin_alpha_2 = torch.abs(s_ch) / sigmathreshold = pow(2, 0.5) / 2sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)rho_x = (s_cw / cw) ** 2rho_y = (s_ch / ch) ** 2gamma = angle_cost - 2distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)if Focal:return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha), torch.pow(inter / (union + eps), gamma) # Focal_SIou的计算else:return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha) # SIouelif WIoU:self = WIoU_Scale(1 - (inter / union))dist = getattr(WIoU_Scale, '_scaled_loss')(self)return iou * dist # WIoUif Focal:return iou - rho2 / c2, torch.pow(inter / (union + eps), gamma) # Focal_DIoUelse:return iou - rho2 / c2 # DIoUc_area = cw * ch + eps # convex areaif Focal:return iou - torch.pow((c_area - union) / c_area + eps, alpha), torch.pow(inter / (union + eps), gamma)# Focal_GIoUelse:return iou - torch.pow((c_area - union) / c_area + eps, alpha) # GIoUif Focal:return iou, torch.pow(inter / (union + eps), gamma) # Focal_IoUelse:return iou # IoU的值
② ultralytics/utils/loss.py文件
接下来,需要修改loss.py文件中的内容。
before |
after |
③ ultralytics/utils/tal.py文件
before |
after |
2 各种机制的使用
与上述内容类比,如果将对应机制设置为True则开启,否则关闭。之后,可以尝试多种组合方式去训练模型。
那么。接下来开始训练模型吧!!!🌺🌺🌺
【YOLOv8】使用自己的数据集训练模型 点击即可跳转
其他
如果觉得替换部分内容不方便的话,可以直接复制下述文件对应替换原始py文件的内容:
- ① ultralytics/utils/metrics.py
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""Model validation metrics."""import math
import warnings
from pathlib import Pathimport matplotlib.pyplot as plt
import numpy as np
import torchfrom ultralytics.utils import LOGGER, SimpleClass, TryExcept, plt_settingsOKS_SIGMA = (np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89])/ 10.0
)def bbox_ioa(box1, box2, iou=False, eps=1e-7):"""Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format.Args:box1 (np.ndarray): A numpy array of shape (n, 4) representing n bounding boxes.box2 (np.ndarray): A numpy array of shape (m, 4) representing m bounding boxes.iou (bool): Calculate the standard IoU if True else return inter_area/box2_area.eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.Returns:(np.ndarray): A numpy array of shape (n, m) representing the intersection over box2 area."""# Get the coordinates of bounding boxesb1_x1, b1_y1, b1_x2, b1_y2 = box1.Tb2_x1, b2_y1, b2_x2, b2_y2 = box2.T# Intersection areainter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * (np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)).clip(0)# Box2 areaarea = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)if iou:box1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)area = area + box1_area[:, None] - inter_area# Intersection over box2 areareturn inter_area / (area + eps)def box_iou(box1, box2, eps=1e-7):"""Calculate intersection-over-union (IoU) of boxes. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.pyArgs:box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes.eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.Returns:(torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2."""# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)(a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp_(0).prod(2)# IoU = inter / (area1 + area2 - inter)return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)class WIoU_Scale:''' monotonous: {None: origin v1True: monotonic FM v2False: non-monotonic FM v3}'''iou_mean = 1.monotonous = False_momentum = 1 - 0.5 ** (1 / 7000)_is_train = Truedef __init__(self, iou):self.iou = iouself._update(self)@classmethoddef _update(cls, self):if cls._is_train: cls.iou_mean = (1 - cls._momentum) * cls.iou_mean + \cls._momentum * self.iou.detach().mean().item()@classmethoddef _scaled_loss(cls, self, gamma=1.9, delta=3):if isinstance(self.monotonous, bool):if self.monotonous:return (self.iou.detach() / self.iou_mean).sqrt()else:beta = self.iou.detach() / self.iou_meanalpha = delta * torch.pow(gamma, beta - delta)return beta / alphareturn 1def bbox_iou(box1, box2, xywh=True, ratio=1, GIoU=False, DIoU=False, CIoU=False,SIoU=False, EIoU=False, WIoU=False, MPDIoU=False, LMPDIoU=False,Inner=False, Focal=False, alpha=1, gamma=0.5, scale=False, eps=1e-7):# 计算box1与box2之间的Intersection over Union(IoU)# 获取bounding box的坐标if Inner:(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_ * ratio, x1 + w1_ * ratio, \y1 - h1_ * ratio, y1 + h1_ * ratiob2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_ * ratio, x2 + w2_ * ratio, \y2 - h2_ * ratio, y2 + h2_ * ratio# 计算交集面积inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)# 计算并集面积union = w1 * ratio * h1 * ratio + w2 * ratio * h2 * ratio - inter + epsiou = inter / union # inner_iouelse:# Returns the IoU of box1 to box2. box1 is 4, box2 is nx4if xywh: # xywh转换为xyxy格式(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_else: # x1, y1, x2, y2 = box1b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + epsw2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps# 计算交集面积inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * \(b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp_(0)# 计算并集面积union = w1 * h1 + w2 * h2 - inter + eps# 计算IoU值iou = inter / unionif CIoU or DIoU or GIoU or EIoU or SIoU or WIoU or MPDIoU or LMPDIoU:cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # 计算最小外接矩形的宽度ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # 计算最小外接矩形的高度if CIoU or DIoU or EIoU or SIoU or WIoU or MPDIoU or LMPDIoU: # Distance or Complete IoUc2 = (cw ** 2 + ch ** 2) ** alpha + eps # convex diagonal squaredrho2 = (((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4) ** alpha # 中心点距离的平方if CIoU:v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)with torch.no_grad():alpha_ciou = v / (v - iou + (1 + eps))if Focal:return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)), torch.pow(inter / (union + eps), gamma) # Focal_CIoU的计算else:return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)) # CIoUelif MPDIoU:d1 = (b2_x1 - b1_x1) ** 2 + (b2_y1 - b1_y1) ** 2d2 = (b2_x2 - b1_x2) ** 2 + (b2_y2 - b1_y2) ** 2w = (b2_x2 - b2_x1) # x2 - x1h = (b2_y2 - b2_y1) # y2 - y1if Focal:return iou - ((d1 + d2) / (w ** 2 + h ** 2)), torch.pow(inter / (union + eps), gamma)# Focal_MPDIoUelse:return iou - (d1 + d2) / (w ** 2 + h ** 2)elif LMPDIoU:d1 = (b2_x1 - b1_x1) ** 2 + (b2_y1 - b1_y1) ** 2d2 = (b2_x2 - b1_x2) ** 2 + (b2_y2 - b1_y2) ** 2w = (b2_x2 - b2_x1) # x2 - x1h = (b2_y2 - b2_y1) # y2 - y1if Focal:return 1 - (iou - (d1 + d2) / (w ** 2 + h ** 2)), torch.pow(inter / (union + eps), gamma)# Focal_MPDIo # MPDIoUelse:return 1 - iou + d1 / (w ** 2 + h ** 2) + d2 / (w ** 2 + h ** 2)elif EIoU:rho_w2 = ((b2_x2 - b2_x1) - (b1_x2 - b1_x1)) ** 2rho_h2 = ((b2_y2 - b2_y1) - (b1_y2 - b1_y1)) ** 2cw2 = torch.pow(cw ** 2 + eps, alpha)ch2 = torch.pow(ch ** 2 + eps, alpha)if Focal:return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2), torch.pow(inter / (union + eps),gamma) # Focal_EIouelse:return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2) # EIouelif SIoU:# SIoUs_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + epss_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + epssigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)sin_alpha_1 = torch.abs(s_cw) / sigmasin_alpha_2 = torch.abs(s_ch) / sigmathreshold = pow(2, 0.5) / 2sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)rho_x = (s_cw / cw) ** 2rho_y = (s_ch / ch) ** 2gamma = angle_cost - 2distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)if Focal:return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha), torch.pow(inter / (union + eps), gamma) # Focal_SIou的计算else:return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha) # SIouelif WIoU:self = WIoU_Scale(1 - (inter / union))dist = getattr(WIoU_Scale, '_scaled_loss')(self)return iou * dist # WIoUif Focal:return iou - rho2 / c2, torch.pow(inter / (union + eps), gamma) # Focal_DIoUelse:return iou - rho2 / c2 # DIoUc_area = cw * ch + eps # convex areaif Focal:return iou - torch.pow((c_area - union) / c_area + eps, alpha), torch.pow(inter / (union + eps), gamma)# Focal_GIoUelse:return iou - torch.pow((c_area - union) / c_area + eps, alpha) # GIoUif Focal:return iou, torch.pow(inter / (union + eps), gamma) # Focal_IoUelse:return iou # IoU的值def mask_iou(mask1, mask2, eps=1e-7):"""Calculate masks IoU.Args:mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is theproduct of image width and height.mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is theproduct of image width and height.eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.Returns:(torch.Tensor): A tensor of shape (N, M) representing masks IoU."""intersection = torch.matmul(mask1, mask2.T).clamp_(0)union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersectionreturn intersection / (union + eps)def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):"""Calculate Object Keypoint Similarity (OKS).Args:kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints.kpt2 (torch.Tensor): A tensor of shape (M, 17, 3) representing predicted keypoints.area (torch.Tensor): A tensor of shape (N,) representing areas from ground truth.sigma (list): A list containing 17 values representing keypoint scales.eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.Returns:(torch.Tensor): A tensor of shape (N, M) representing keypoint similarities."""d = (kpt1[:, None, :, 0] - kpt2[..., 0]).pow(2) + (kpt1[:, None, :, 1] - kpt2[..., 1]).pow(2) # (N, M, 17)sigma = torch.tensor(sigma, device=kpt1.device, dtype=kpt1.dtype) # (17, )kpt_mask = kpt1[..., 2] != 0 # (N, 17)e = d / ((2 * sigma).pow(2) * (area[:, None, None] + eps) * 2) # from cocoeval# e = d / ((area[None, :, None] + eps) * sigma) ** 2 / 2 # from formulareturn ((-e).exp() * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps)def _get_covariance_matrix(boxes):"""Generating covariance matrix from obbs.Args:boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.Returns:(torch.Tensor): Covariance metrixs corresponding to original rotated bounding boxes."""# Gaussian bounding boxes, ignore the center points (the first two columns) because they are not needed here.gbbs = torch.cat((boxes[:, 2:4].pow(2) / 12, boxes[:, 4:]), dim=-1)a, b, c = gbbs.split(1, dim=-1)cos = c.cos()sin = c.sin()cos2 = cos.pow(2)sin2 = sin.pow(2)return a * cos2 + b * sin2, a * sin2 + b * cos2, (a - b) * cos * sindef probiou(obb1, obb2, CIoU=False, eps=1e-7):"""Calculate the prob IoU between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.Args:obb1 (torch.Tensor): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.obb2 (torch.Tensor): A tensor of shape (N, 5) representing predicted obbs, with xywhr format.eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.Returns:(torch.Tensor): A tensor of shape (N, ) representing obb similarities."""x1, y1 = obb1[..., :2].split(1, dim=-1)x2, y2 = obb2[..., :2].split(1, dim=-1)a1, b1, c1 = _get_covariance_matrix(obb1)a2, b2, c2 = _get_covariance_matrix(obb2)t1 = (((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.25t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5t3 = (((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2))/ (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps)+ eps).log() * 0.5bd = (t1 + t2 + t3).clamp(eps, 100.0)hd = (1.0 - (-bd).exp() + eps).sqrt()iou = 1 - hdif CIoU: # only include the wh aspect ratio partw1, h1 = obb1[..., 2:4].split(1, dim=-1)w2, h2 = obb2[..., 2:4].split(1, dim=-1)v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2)with torch.no_grad():alpha = v / (v - iou + (1 + eps))return iou - v * alpha # CIoUreturn ioudef batch_probiou(obb1, obb2, eps=1e-7):"""Calculate the prob IoU between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.Args:obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.Returns:(torch.Tensor): A tensor of shape (N, M) representing obb similarities."""obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2x1, y1 = obb1[..., :2].split(1, dim=-1)x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1))a1, b1, c1 = _get_covariance_matrix(obb1)a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2))t1 = (((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.25t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5t3 = (((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2))/ (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps)+ eps).log() * 0.5bd = (t1 + t2 + t3).clamp(eps, 100.0)hd = (1.0 - (-bd).exp() + eps).sqrt()return 1 - hddef smooth_BCE(eps=0.1):"""Computes smoothed positive and negative Binary Cross-Entropy targets.This function calculates positive and negative label smoothing BCE targets based on a given epsilon value.For implementation details, refer to https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441.Args:eps (float, optional): The epsilon value for label smoothing. Defaults to 0.1.Returns:(tuple): A tuple containing the positive and negative label smoothing BCE targets."""return 1.0 - 0.5 * eps, 0.5 * epsclass ConfusionMatrix:"""A class for calculating and updating a confusion matrix for object detection and classification tasks.Attributes:task (str): The type of task, either 'detect' or 'classify'.matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.nc (int): The number of classes.conf (float): The confidence threshold for detections.iou_thres (float): The Intersection over Union threshold."""def __init__(self, nc, conf=0.25, iou_thres=0.45, task="detect"):"""Initialize attributes for the YOLO model."""self.task = taskself.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))self.nc = nc # number of classesself.conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passedself.iou_thres = iou_thresdef process_cls_preds(self, preds, targets):"""Update confusion matrix for classification task.Args:preds (Array[N, min(nc,5)]): Predicted class labels.targets (Array[N, 1]): Ground truth class labels."""preds, targets = torch.cat(preds)[:, 0], torch.cat(targets)for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):self.matrix[p][t] += 1def process_batch(self, detections, gt_bboxes, gt_cls):"""Update confusion matrix for object detection task.Args:detections (Array[N, 6] | Array[N, 7]): Detected bounding boxes and their associated information.Each row should contain (x1, y1, x2, y2, conf, class)or with an additional element `angle` when it's obb.gt_bboxes (Array[M, 4]| Array[N, 5]): Ground truth bounding boxes with xyxy/xyxyr format.gt_cls (Array[M]): The class labels."""if gt_cls.shape[0] == 0: # Check if labels is emptyif detections is not None:detections = detections[detections[:, 4] > self.conf]detection_classes = detections[:, 5].int()for dc in detection_classes:self.matrix[dc, self.nc] += 1 # false positivesreturnif detections is None:gt_classes = gt_cls.int()for gc in gt_classes:self.matrix[self.nc, gc] += 1 # background FNreturndetections = detections[detections[:, 4] > self.conf]gt_classes = gt_cls.int()detection_classes = detections[:, 5].int()is_obb = detections.shape[1] == 7 and gt_bboxes.shape[1] == 5 # with additional `angle` dimensioniou = (batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))if is_obbelse box_iou(gt_bboxes, detections[:, :4]))x = torch.where(iou > self.iou_thres)if x[0].shape[0]:matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()if x[0].shape[0] > 1:matches = matches[matches[:, 2].argsort()[::-1]]matches = matches[np.unique(matches[:, 1], return_index=True)[1]]matches = matches[matches[:, 2].argsort()[::-1]]matches = matches[np.unique(matches[:, 0], return_index=True)[1]]else:matches = np.zeros((0, 3))n = matches.shape[0] > 0m0, m1, _ = matches.transpose().astype(int)for i, gc in enumerate(gt_classes):j = m0 == iif n and sum(j) == 1:self.matrix[detection_classes[m1[j]], gc] += 1 # correctelse:self.matrix[self.nc, gc] += 1 # true backgroundif n:for i, dc in enumerate(detection_classes):if not any(m1 == i):self.matrix[dc, self.nc] += 1 # predicted backgrounddef matrix(self):"""Returns the confusion matrix."""return self.matrixdef tp_fp(self):"""Returns true positives and false positives."""tp = self.matrix.diagonal() # true positivesfp = self.matrix.sum(1) - tp # false positives# fn = self.matrix.sum(0) - tp # false negatives (missed detections)return (tp[:-1], fp[:-1]) if self.task == "detect" else (tp, fp) # remove background class if task=detect@TryExcept("WARNING ⚠️ ConfusionMatrix plot failure")@plt_settings()def plot(self, normalize=True, save_dir="", names=(), on_plot=None):"""Plot the confusion matrix using seaborn and save it to a file.Args:normalize (bool): Whether to normalize the confusion matrix.save_dir (str): Directory where the plot will be saved.names (tuple): Names of classes, used as labels on the plot.on_plot (func): An optional callback to pass plots path and data when they are rendered."""import seaborn # scope for faster 'import ultralytics'array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columnsarray[array < 0.005] = np.nan # don't annotate (would appear as 0.00)fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)nc, nn = self.nc, len(names) # number of classes, namesseaborn.set_theme(font_scale=1.0 if nc < 50 else 0.8) # for label sizelabels = (0 < nn < 99) and (nn == nc) # apply names to ticklabelsticklabels = (list(names) + ["background"]) if labels else "auto"with warnings.catch_warnings():warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encounteredseaborn.heatmap(array,ax=ax,annot=nc < 30,annot_kws={"size": 8},cmap="Blues",fmt=".2f" if normalize else ".0f",square=True,vmin=0.0,xticklabels=ticklabels,yticklabels=ticklabels,).set_facecolor((1, 1, 1))title = "Confusion Matrix" + " Normalized" * normalizeax.set_xlabel("True")ax.set_ylabel("Predicted")ax.set_title(title)plot_fname = Path(save_dir) / f'{title.lower().replace(" ", "_")}.png'fig.savefig(plot_fname, dpi=250)plt.close(fig)if on_plot:on_plot(plot_fname)def print(self):"""Print the confusion matrix to the console."""for i in range(self.nc + 1):LOGGER.info(" ".join(map(str, self.matrix[i])))def smooth(y, f=0.05):"""Box filter of fraction f."""nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)p = np.ones(nf // 2) # ones paddingyp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y paddedreturn np.convolve(yp, np.ones(nf) / nf, mode="valid") # y-smoothed@plt_settings()
def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=(), on_plot=None):"""Plots a precision-recall curve."""fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)py = np.stack(py, axis=1)if 0 < len(names) < 21: # display per-class legend if < 21 classesfor i, y in enumerate(py.T):ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision)else:ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)ax.plot(px, py.mean(1), linewidth=3, color="blue", label="all classes %.3f mAP@0.5" % ap[:, 0].mean())ax.set_xlabel("Recall")ax.set_ylabel("Precision")ax.set_xlim(0, 1)ax.set_ylim(0, 1)ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")ax.set_title("Precision-Recall Curve")fig.savefig(save_dir, dpi=250)plt.close(fig)if on_plot:on_plot(save_dir)@plt_settings()
def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names=(), xlabel="Confidence", ylabel="Metric", on_plot=None):"""Plots a metric-confidence curve."""fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)if 0 < len(names) < 21: # display per-class legend if < 21 classesfor i, y in enumerate(py):ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric)else:ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric)y = smooth(py.mean(0), 0.05)ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.2f} at {px[y.argmax()]:.3f}")ax.set_xlabel(xlabel)ax.set_ylabel(ylabel)ax.set_xlim(0, 1)ax.set_ylim(0, 1)ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")ax.set_title(f"{ylabel}-Confidence Curve")fig.savefig(save_dir, dpi=250)plt.close(fig)if on_plot:on_plot(save_dir)def compute_ap(recall, precision):"""Compute the average precision (AP) given the recall and precision curves.Args:recall (list): The recall curve.precision (list): The precision curve.Returns:(float): Average precision.(np.ndarray): Precision envelope curve.(np.ndarray): Modified recall curve with sentinel values added at the beginning and end."""# Append sentinel values to beginning and endmrec = np.concatenate(([0.0], recall, [1.0]))mpre = np.concatenate(([1.0], precision, [0.0]))# Compute the precision envelopempre = np.flip(np.maximum.accumulate(np.flip(mpre)))# Integrate area under curvemethod = "interp" # methods: 'continuous', 'interp'if method == "interp":x = np.linspace(0, 1, 101) # 101-point interp (COCO)ap = np.trapz(np.interp(x, mrec, mpre), x) # integrateelse: # 'continuous'i = np.where(mrec[1:] != mrec[:-1])[0] # points where x-axis (recall) changesap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curvereturn ap, mpre, mrecdef ap_per_class(tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names=(), eps=1e-16, prefix=""
):"""Computes the average precision per class for object detection evaluation.Args:tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).conf (np.ndarray): Array of confidence scores of the detections.pred_cls (np.ndarray): Array of predicted classes of the detections.target_cls (np.ndarray): Array of true classes of the detections.plot (bool, optional): Whether to plot PR curves or not. Defaults to False.on_plot (func, optional): A callback to pass plots path and data when they are rendered. Defaults to None.save_dir (Path, optional): Directory to save the PR curves. Defaults to an empty path.names (tuple, optional): Tuple of class names to plot PR curves. Defaults to an empty tuple.eps (float, optional): A small value to avoid division by zero. Defaults to 1e-16.prefix (str, optional): A prefix string for saving the plot files. Defaults to an empty string.Returns:(tuple): A tuple of six arrays and one array of unique classes, where:tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.Shape: (nc,).fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class. Shape: (nc,).p (np.ndarray): Precision values at threshold given by max F1 metric for each class. Shape: (nc,).r (np.ndarray): Recall values at threshold given by max F1 metric for each class. Shape: (nc,).f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class. Shape: (nc,).ap (np.ndarray): Average precision for each class at different IoU thresholds. Shape: (nc, 10).unique_classes (np.ndarray): An array of unique classes that have data. Shape: (nc,).p_curve (np.ndarray): Precision curves for each class. Shape: (nc, 1000).r_curve (np.ndarray): Recall curves for each class. Shape: (nc, 1000).f1_curve (np.ndarray): F1-score curves for each class. Shape: (nc, 1000).x (np.ndarray): X-axis values for the curves. Shape: (1000,).prec_values: Precision values at mAP@0.5 for each class. Shape: (nc, 1000)."""# Sort by objectnessi = np.argsort(-conf)tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]# Find unique classesunique_classes, nt = np.unique(target_cls, return_counts=True)nc = unique_classes.shape[0] # number of classes, number of detections# Create Precision-Recall curve and compute AP for each classx, prec_values = np.linspace(0, 1, 1000), []# Average precision, precision and recall curvesap, p_curve, r_curve = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))for ci, c in enumerate(unique_classes):i = pred_cls == cn_l = nt[ci] # number of labelsn_p = i.sum() # number of predictionsif n_p == 0 or n_l == 0:continue# Accumulate FPs and TPsfpc = (1 - tp[i]).cumsum(0)tpc = tp[i].cumsum(0)# Recallrecall = tpc / (n_l + eps) # recall curver_curve[ci] = np.interp(-x, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases# Precisionprecision = tpc / (tpc + fpc) # precision curvep_curve[ci] = np.interp(-x, -conf[i], precision[:, 0], left=1) # p at pr_score# AP from recall-precision curvefor j in range(tp.shape[1]):ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])if plot and j == 0:prec_values.append(np.interp(x, mrec, mpre)) # precision at mAP@0.5prec_values = np.array(prec_values) # (nc, 1000)# Compute F1 (harmonic mean of precision and recall)f1_curve = 2 * p_curve * r_curve / (p_curve + r_curve + eps)names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have datanames = dict(enumerate(names)) # to dictif plot:plot_pr_curve(x, prec_values, ap, save_dir / f"{prefix}PR_curve.png", names, on_plot=on_plot)plot_mc_curve(x, f1_curve, save_dir / f"{prefix}F1_curve.png", names, ylabel="F1", on_plot=on_plot)plot_mc_curve(x, p_curve, save_dir / f"{prefix}P_curve.png", names, ylabel="Precision", on_plot=on_plot)plot_mc_curve(x, r_curve, save_dir / f"{prefix}R_curve.png", names, ylabel="Recall", on_plot=on_plot)i = smooth(f1_curve.mean(0), 0.1).argmax() # max F1 indexp, r, f1 = p_curve[:, i], r_curve[:, i], f1_curve[:, i] # max-F1 precision, recall, F1 valuestp = (r * nt).round() # true positivesfp = (tp / (p + eps) - tp).round() # false positivesreturn tp, fp, p, r, f1, ap, unique_classes.astype(int), p_curve, r_curve, f1_curve, x, prec_valuesclass Metric(SimpleClass):"""Class for computing evaluation metrics for YOLOv8 model.Attributes:p (list): Precision for each class. Shape: (nc,).r (list): Recall for each class. Shape: (nc,).f1 (list): F1 score for each class. Shape: (nc,).all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10).ap_class_index (list): Index of class for each AP score. Shape: (nc,).nc (int): Number of classes.Methods:ap50(): AP at IoU threshold of 0.5 for all classes. Returns: List of AP scores. Shape: (nc,) or [].ap(): AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: List of AP scores. Shape: (nc,) or [].mp(): Mean precision of all classes. Returns: Float.mr(): Mean recall of all classes. Returns: Float.map50(): Mean AP at IoU threshold of 0.5 for all classes. Returns: Float.map75(): Mean AP at IoU threshold of 0.75 for all classes. Returns: Float.map(): Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: Float.mean_results(): Mean of results, returns mp, mr, map50, map.class_result(i): Class-aware result, returns p[i], r[i], ap50[i], ap[i].maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,).fitness(): Model fitness as a weighted combination of metrics. Returns: Float.update(results): Update metric attributes with new evaluation results."""def __init__(self) -> None:"""Initializes a Metric instance for computing evaluation metrics for the YOLOv8 model."""self.p = [] # (nc, )self.r = [] # (nc, )self.f1 = [] # (nc, )self.all_ap = [] # (nc, 10)self.ap_class_index = [] # (nc, )self.nc = 0@propertydef ap50(self):"""Returns the Average Precision (AP) at an IoU threshold of 0.5 for all classes.Returns:(np.ndarray, list): Array of shape (nc,) with AP50 values per class, or an empty list if not available."""return self.all_ap[:, 0] if len(self.all_ap) else []@propertydef ap(self):"""Returns the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.Returns:(np.ndarray, list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available."""return self.all_ap.mean(1) if len(self.all_ap) else []@propertydef mp(self):"""Returns the Mean Precision of all classes.Returns:(float): The mean precision of all classes."""return self.p.mean() if len(self.p) else 0.0@propertydef mr(self):"""Returns the Mean Recall of all classes.Returns:(float): The mean recall of all classes."""return self.r.mean() if len(self.r) else 0.0@propertydef map50(self):"""Returns the mean Average Precision (mAP) at an IoU threshold of 0.5.Returns:(float): The mAP at an IoU threshold of 0.5."""return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0@propertydef map75(self):"""Returns the mean Average Precision (mAP) at an IoU threshold of 0.75.Returns:(float): The mAP at an IoU threshold of 0.75."""return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0@propertydef map(self):"""Returns the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.Returns:(float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05."""return self.all_ap.mean() if len(self.all_ap) else 0.0def mean_results(self):"""Mean of results, return mp, mr, map50, map."""return [self.mp, self.mr, self.map50, self.map]def class_result(self, i):"""Class-aware result, return p[i], r[i], ap50[i], ap[i]."""return self.p[i], self.r[i], self.ap50[i], self.ap[i]@propertydef maps(self):"""MAP of each class."""maps = np.zeros(self.nc) + self.mapfor i, c in enumerate(self.ap_class_index):maps[c] = self.ap[i]return mapsdef fitness(self):"""Model fitness as a weighted combination of metrics."""w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]return (np.array(self.mean_results()) * w).sum()def update(self, results):"""Updates the evaluation metrics of the model with a new set of results.Args:results (tuple): A tuple containing the following evaluation metrics:- p (list): Precision for each class. Shape: (nc,).- r (list): Recall for each class. Shape: (nc,).- f1 (list): F1 score for each class. Shape: (nc,).- all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10).- ap_class_index (list): Index of class for each AP score. Shape: (nc,).Side Effects:Updates the class attributes `self.p`, `self.r`, `self.f1`, `self.all_ap`, and `self.ap_class_index` basedon the values provided in the `results` tuple."""(self.p,self.r,self.f1,self.all_ap,self.ap_class_index,self.p_curve,self.r_curve,self.f1_curve,self.px,self.prec_values,) = results@propertydef curves(self):"""Returns a list of curves for accessing specific metrics curves."""return []@propertydef curves_results(self):"""Returns a list of curves for accessing specific metrics curves."""return [[self.px, self.prec_values, "Recall", "Precision"],[self.px, self.f1_curve, "Confidence", "F1"],[self.px, self.p_curve, "Confidence", "Precision"],[self.px, self.r_curve, "Confidence", "Recall"],]class DetMetrics(SimpleClass):"""This class is a utility class for computing detection metrics such as precision, recall, and mean average precision(mAP) of an object detection model.Args:save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False.on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.names (tuple of str): A tuple of strings that represents the names of the classes. Defaults to an empty tuple.Attributes:save_dir (Path): A path to the directory where the output plots will be saved.plot (bool): A flag that indicates whether to plot the precision-recall curves for each class.on_plot (func): An optional callback to pass plots path and data when they are rendered.names (tuple of str): A tuple of strings that represents the names of the classes.box (Metric): An instance of the Metric class for storing the results of the detection metrics.speed (dict): A dictionary for storing the execution time of different parts of the detection process.Methods:process(tp, conf, pred_cls, target_cls): Updates the metric results with the latest batch of predictions.keys: Returns a list of keys for accessing the computed detection metrics.mean_results: Returns a list of mean values for the computed detection metrics.class_result(i): Returns a list of values for the computed detection metrics for a specific class.maps: Returns a dictionary of mean average precision (mAP) values for different IoU thresholds.fitness: Computes the fitness score based on the computed detection metrics.ap_class_index: Returns a list of class indices sorted by their average precision (AP) values.results_dict: Returns a dictionary that maps detection metric keys to their computed values.curves: TODOcurves_results: TODO"""def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:"""Initialize a DetMetrics instance with a save directory, plot flag, callback function, and class names."""self.save_dir = save_dirself.plot = plotself.on_plot = on_plotself.names = namesself.box = Metric()self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}self.task = "detect"def process(self, tp, conf, pred_cls, target_cls):"""Process predicted results for object detection and update metrics."""results = ap_per_class(tp,conf,pred_cls,target_cls,plot=self.plot,save_dir=self.save_dir,names=self.names,on_plot=self.on_plot,)[2:]self.box.nc = len(self.names)self.box.update(results)@propertydef keys(self):"""Returns a list of keys for accessing specific metrics."""return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]def mean_results(self):"""Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""return self.box.mean_results()def class_result(self, i):"""Return the result of evaluating the performance of an object detection model on a specific class."""return self.box.class_result(i)@propertydef maps(self):"""Returns mean Average Precision (mAP) scores per class."""return self.box.maps@propertydef fitness(self):"""Returns the fitness of box object."""return self.box.fitness()@propertydef ap_class_index(self):"""Returns the average precision index per class."""return self.box.ap_class_index@propertydef results_dict(self):"""Returns dictionary of computed performance metrics and statistics."""return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))@propertydef curves(self):"""Returns a list of curves for accessing specific metrics curves."""return ["Precision-Recall(B)", "F1-Confidence(B)", "Precision-Confidence(B)", "Recall-Confidence(B)"]@propertydef curves_results(self):"""Returns dictionary of computed performance metrics and statistics."""return self.box.curves_resultsclass SegmentMetrics(SimpleClass):"""Calculates and aggregates detection and segmentation metrics over a given set of classes.Args:save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.plot (bool): Whether to save the detection and segmentation plots. Default is False.on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.names (list): List of class names. Default is an empty list.Attributes:save_dir (Path): Path to the directory where the output plots should be saved.plot (bool): Whether to save the detection and segmentation plots.on_plot (func): An optional callback to pass plots path and data when they are rendered.names (list): List of class names.box (Metric): An instance of the Metric class to calculate box detection metrics.seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.speed (dict): Dictionary to store the time taken in different phases of inference.Methods:process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions.mean_results(): Returns the mean of the detection and segmentation metrics over all the classes.class_result(i): Returns the detection and segmentation metrics of class `i`.maps: Returns the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95.fitness: Returns the fitness scores, which are a single weighted combination of metrics.ap_class_index: Returns the list of indices of classes used to compute Average Precision (AP).results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score."""def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:"""Initialize a SegmentMetrics instance with a save directory, plot flag, callback function, and class names."""self.save_dir = save_dirself.plot = plotself.on_plot = on_plotself.names = namesself.box = Metric()self.seg = Metric()self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}self.task = "segment"def process(self, tp, tp_m, conf, pred_cls, target_cls):"""Processes the detection and segmentation metrics over the given set of predictions.Args:tp (list): List of True Positive boxes.tp_m (list): List of True Positive masks.conf (list): List of confidence scores.pred_cls (list): List of predicted classes.target_cls (list): List of target classes."""results_mask = ap_per_class(tp_m,conf,pred_cls,target_cls,plot=self.plot,on_plot=self.on_plot,save_dir=self.save_dir,names=self.names,prefix="Mask",)[2:]self.seg.nc = len(self.names)self.seg.update(results_mask)results_box = ap_per_class(tp,conf,pred_cls,target_cls,plot=self.plot,on_plot=self.on_plot,save_dir=self.save_dir,names=self.names,prefix="Box",)[2:]self.box.nc = len(self.names)self.box.update(results_box)@propertydef keys(self):"""Returns a list of keys for accessing metrics."""return ["metrics/precision(B)","metrics/recall(B)","metrics/mAP50(B)","metrics/mAP50-95(B)","metrics/precision(M)","metrics/recall(M)","metrics/mAP50(M)","metrics/mAP50-95(M)",]def mean_results(self):"""Return the mean metrics for bounding box and segmentation results."""return self.box.mean_results() + self.seg.mean_results()def class_result(self, i):"""Returns classification results for a specified class index."""return self.box.class_result(i) + self.seg.class_result(i)@propertydef maps(self):"""Returns mAP scores for object detection and semantic segmentation models."""return self.box.maps + self.seg.maps@propertydef fitness(self):"""Get the fitness score for both segmentation and bounding box models."""return self.seg.fitness() + self.box.fitness()@propertydef ap_class_index(self):"""Boxes and masks have the same ap_class_index."""return self.box.ap_class_index@propertydef results_dict(self):"""Returns results of object detection model for evaluation."""return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))@propertydef curves(self):"""Returns a list of curves for accessing specific metrics curves."""return ["Precision-Recall(B)","F1-Confidence(B)","Precision-Confidence(B)","Recall-Confidence(B)","Precision-Recall(M)","F1-Confidence(M)","Precision-Confidence(M)","Recall-Confidence(M)",]@propertydef curves_results(self):"""Returns dictionary of computed performance metrics and statistics."""return self.box.curves_results + self.seg.curves_resultsclass PoseMetrics(SegmentMetrics):"""Calculates and aggregates detection and pose metrics over a given set of classes.Args:save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.plot (bool): Whether to save the detection and segmentation plots. Default is False.on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.names (list): List of class names. Default is an empty list.Attributes:save_dir (Path): Path to the directory where the output plots should be saved.plot (bool): Whether to save the detection and segmentation plots.on_plot (func): An optional callback to pass plots path and data when they are rendered.names (list): List of class names.box (Metric): An instance of the Metric class to calculate box detection metrics.pose (Metric): An instance of the Metric class to calculate mask segmentation metrics.speed (dict): Dictionary to store the time taken in different phases of inference.Methods:process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions.mean_results(): Returns the mean of the detection and segmentation metrics over all the classes.class_result(i): Returns the detection and segmentation metrics of class `i`.maps: Returns the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95.fitness: Returns the fitness scores, which are a single weighted combination of metrics.ap_class_index: Returns the list of indices of classes used to compute Average Precision (AP).results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score."""def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:"""Initialize the PoseMetrics class with directory path, class names, and plotting options."""super().__init__(save_dir, plot, names)self.save_dir = save_dirself.plot = plotself.on_plot = on_plotself.names = namesself.box = Metric()self.pose = Metric()self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}self.task = "pose"def process(self, tp, tp_p, conf, pred_cls, target_cls):"""Processes the detection and pose metrics over the given set of predictions.Args:tp (list): List of True Positive boxes.tp_p (list): List of True Positive keypoints.conf (list): List of confidence scores.pred_cls (list): List of predicted classes.target_cls (list): List of target classes."""results_pose = ap_per_class(tp_p,conf,pred_cls,target_cls,plot=self.plot,on_plot=self.on_plot,save_dir=self.save_dir,names=self.names,prefix="Pose",)[2:]self.pose.nc = len(self.names)self.pose.update(results_pose)results_box = ap_per_class(tp,conf,pred_cls,target_cls,plot=self.plot,on_plot=self.on_plot,save_dir=self.save_dir,names=self.names,prefix="Box",)[2:]self.box.nc = len(self.names)self.box.update(results_box)@propertydef keys(self):"""Returns list of evaluation metric keys."""return ["metrics/precision(B)","metrics/recall(B)","metrics/mAP50(B)","metrics/mAP50-95(B)","metrics/precision(P)","metrics/recall(P)","metrics/mAP50(P)","metrics/mAP50-95(P)",]def mean_results(self):"""Return the mean results of box and pose."""return self.box.mean_results() + self.pose.mean_results()def class_result(self, i):"""Return the class-wise detection results for a specific class i."""return self.box.class_result(i) + self.pose.class_result(i)@propertydef maps(self):"""Returns the mean average precision (mAP) per class for both box and pose detections."""return self.box.maps + self.pose.maps@propertydef fitness(self):"""Computes classification metrics and speed using the `targets` and `pred` inputs."""return self.pose.fitness() + self.box.fitness()@propertydef curves(self):"""Returns a list of curves for accessing specific metrics curves."""return ["Precision-Recall(B)","F1-Confidence(B)","Precision-Confidence(B)","Recall-Confidence(B)","Precision-Recall(P)","F1-Confidence(P)","Precision-Confidence(P)","Recall-Confidence(P)",]@propertydef curves_results(self):"""Returns dictionary of computed performance metrics and statistics."""return self.box.curves_results + self.pose.curves_resultsclass ClassifyMetrics(SimpleClass):"""Class for computing classification metrics including top-1 and top-5 accuracy.Attributes:top1 (float): The top-1 accuracy.top5 (float): The top-5 accuracy.speed (Dict[str, float]): A dictionary containing the time taken for each step in the pipeline.Properties:fitness (float): The fitness of the model, which is equal to top-5 accuracy.results_dict (Dict[str, Union[float, str]]): A dictionary containing the classification metrics and fitness.keys (List[str]): A list of keys for the results_dict.Methods:process(targets, pred): Processes the targets and predictions to compute classification metrics."""def __init__(self) -> None:"""Initialize a ClassifyMetrics instance."""self.top1 = 0self.top5 = 0self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}self.task = "classify"def process(self, targets, pred):"""Target classes and predicted classes."""pred, targets = torch.cat(pred), torch.cat(targets)correct = (targets[:, None] == pred).float()acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracyself.top1, self.top5 = acc.mean(0).tolist()@propertydef fitness(self):"""Returns mean of top-1 and top-5 accuracies as fitness score."""return (self.top1 + self.top5) / 2@propertydef results_dict(self):"""Returns a dictionary with model's performance metrics and fitness score."""return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness]))@propertydef keys(self):"""Returns a list of keys for the results_dict property."""return ["metrics/accuracy_top1", "metrics/accuracy_top5"]@propertydef curves(self):"""Returns a list of curves for accessing specific metrics curves."""return []@propertydef curves_results(self):"""Returns a list of curves for accessing specific metrics curves."""return []class OBBMetrics(SimpleClass):def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:self.save_dir = save_dirself.plot = plotself.on_plot = on_plotself.names = namesself.box = Metric()self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}def process(self, tp, conf, pred_cls, target_cls):"""Process predicted results for object detection and update metrics."""results = ap_per_class(tp,conf,pred_cls,target_cls,plot=self.plot,save_dir=self.save_dir,names=self.names,on_plot=self.on_plot,)[2:]self.box.nc = len(self.names)self.box.update(results)@propertydef keys(self):"""Returns a list of keys for accessing specific metrics."""return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]def mean_results(self):"""Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""return self.box.mean_results()def class_result(self, i):"""Return the result of evaluating the performance of an object detection model on a specific class."""return self.box.class_result(i)@propertydef maps(self):"""Returns mean Average Precision (mAP) scores per class."""return self.box.maps@propertydef fitness(self):"""Returns the fitness of box object."""return self.box.fitness()@propertydef ap_class_index(self):"""Returns the average precision index per class."""return self.box.ap_class_index@propertydef results_dict(self):"""Returns dictionary of computed performance metrics and statistics."""return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))@propertydef curves(self):"""Returns a list of curves for accessing specific metrics curves."""return []@propertydef curves_results(self):"""Returns a list of curves for accessing specific metrics curves."""return []
- ② ultralytics/utils/loss.py
# Ultralytics YOLO 🚀, AGPL-3.0 licenseimport torch
import torch.nn as nn
import torch.nn.functional as Ffrom ultralytics.utils.metrics import OKS_SIGMA
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
from .metrics import bbox_iou, probiou
from .tal import bbox2distclass VarifocalLoss(nn.Module):"""Varifocal loss by Zhang et al.https://arxiv.org/abs/2008.13367."""def __init__(self):"""Initialize the VarifocalLoss class."""super().__init__()@staticmethoddef forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0):"""Computes varfocal loss."""weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * labelwith torch.cuda.amp.autocast(enabled=False):loss = ((F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight).mean(1).sum())return lossclass FocalLoss(nn.Module):"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""def __init__(self):"""Initializer for FocalLoss class with no parameters."""super().__init__()@staticmethoddef forward(pred, label, gamma=1.5, alpha=0.25):"""Calculates and updates confusion matrix for object detection/classification tasks."""loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")# p_t = torch.exp(-loss)# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.pypred_prob = pred.sigmoid() # prob from logitsp_t = label * pred_prob + (1 - label) * (1 - pred_prob)modulating_factor = (1.0 - p_t) ** gammaloss *= modulating_factorif alpha > 0:alpha_factor = label * alpha + (1 - label) * (1 - alpha)loss *= alpha_factorreturn loss.mean(1).sum()class BboxLoss(nn.Module):"""Criterion class for computing training losses during training."""def __init__(self, reg_max, use_dfl=False):"""Initialize the BboxLoss module with regularization maximum and DFL settings."""super().__init__()self.reg_max = reg_maxself.use_dfl = use_dfldef forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):"""IoU loss."""weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, GIoU=False, DIoU=False, CIoU=True, SIoU=False, EIoU=False, WIoU=False, MPDIoU=False, LMPDIoU=False, Inner=True, Focal=True)# loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sumif type(iou) is tuple:if len(iou) == 2:# 针对Focus Loss的特殊处理,得到的元组类型进行额外处理loss_iou = ((1.0 - iou[0]) * iou[1].detach() * weight).sum() / target_scores_sumelse:loss_iou = (iou[0] * iou[1] * weight).sum() / target_scores_sumelse:loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum# DFL lossif self.use_dfl:target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weightloss_dfl = loss_dfl.sum() / target_scores_sumelse:loss_dfl = torch.tensor(0.0).to(pred_dist.device)return loss_iou, loss_dfl@staticmethoddef _df_loss(pred_dist, target):"""Return sum of left and right DFL losses.Distribution Focal Loss (DFL) proposed in Generalized Focal Losshttps://ieeexplore.ieee.org/document/9792391"""tl = target.long() # target lefttr = tl + 1 # target rightwl = tr - target # weight leftwr = 1 - wl # weight rightreturn (F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl+ F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr).mean(-1, keepdim=True)class RotatedBboxLoss(BboxLoss):"""Criterion class for computing training losses during training."""def __init__(self, reg_max, use_dfl=False):"""Initialize the BboxLoss module with regularization maximum and DFL settings."""super().__init__(reg_max, use_dfl)def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):"""IoU loss."""weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum# DFL lossif self.use_dfl:target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.reg_max)loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weightloss_dfl = loss_dfl.sum() / target_scores_sumelse:loss_dfl = torch.tensor(0.0).to(pred_dist.device)return loss_iou, loss_dflclass KeypointLoss(nn.Module):"""Criterion class for computing training losses."""def __init__(self, sigmas) -> None:"""Initialize the KeypointLoss class."""super().__init__()self.sigmas = sigmasdef forward(self, pred_kpts, gt_kpts, kpt_mask, area):"""Calculates keypoint loss factor and Euclidean distance loss for predicted and actual keypoints."""d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)# e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formulae = d / ((2 * self.sigmas).pow(2) * (area + 1e-9) * 2) # from cocoevalreturn (kpt_loss_factor.view(-1, 1) * ((1 - torch.exp(-e)) * kpt_mask)).mean()class v8DetectionLoss:"""Criterion class for computing training losses."""def __init__(self, model): # model must be de-paralleled"""Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function."""device = next(model.parameters()).device # get model deviceh = model.args # hyperparametersm = model.model[-1] # Detect() moduleself.bce = nn.BCEWithLogitsLoss(reduction="none")self.hyp = hself.stride = m.stride # model stridesself.nc = m.nc # number of classesself.no = m.nc + m.reg_max * 4self.reg_max = m.reg_maxself.device = deviceself.use_dfl = m.reg_max > 1self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device)self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)def preprocess(self, targets, batch_size, scale_tensor):"""Preprocesses the target counts and matches with the input batch size to output a tensor."""if targets.shape[0] == 0:out = torch.zeros(batch_size, 0, 5, device=self.device)else:i = targets[:, 0] # image index_, counts = i.unique(return_counts=True)counts = counts.to(dtype=torch.int32)out = torch.zeros(batch_size, counts.max(), 5, device=self.device)for j in range(batch_size):matches = i == jn = matches.sum()if n:out[j, :n] = targets[matches, 1:]out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))return outdef bbox_decode(self, anchor_points, pred_dist):"""Decode predicted object bounding box coordinates from anchor points and distribution."""if self.use_dfl:b, a, c = pred_dist.shape # batch, anchors, channelspred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))# pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)return dist2bbox(pred_dist, anchor_points, xywh=False)def __call__(self, preds, batch):"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""loss = torch.zeros(3, device=self.device) # box, cls, dflfeats = preds[1] if isinstance(preds, tuple) else predspred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split((self.reg_max * 4, self.nc), 1)pred_scores = pred_scores.permute(0, 2, 1).contiguous()pred_distri = pred_distri.permute(0, 2, 1).contiguous()dtype = pred_scores.dtypebatch_size = pred_scores.shape[0]imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)# Targetstargets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxymask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)# Pboxespred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)_, target_bboxes, target_scores, fg_mask, _ = self.assigner(pred_scores.detach().sigmoid(),(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),anchor_points * stride_tensor,gt_labels,gt_bboxes,mask_gt,)target_scores_sum = max(target_scores.sum(), 1)# Cls loss# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL wayloss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE# Bbox lossif fg_mask.sum():target_bboxes /= stride_tensorloss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask)loss[0] *= self.hyp.box # box gainloss[1] *= self.hyp.cls # cls gainloss[2] *= self.hyp.dfl # dfl gainreturn loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)class v8SegmentationLoss(v8DetectionLoss):"""Criterion class for computing training losses."""def __init__(self, model): # model must be de-paralleled"""Initializes the v8SegmentationLoss class, taking a de-paralleled model as argument."""super().__init__(model)self.overlap = model.args.overlap_maskdef __call__(self, preds, batch):"""Calculate and return the loss for the YOLO model."""loss = torch.zeros(4, device=self.device) # box, cls, dflfeats, pred_masks, proto = preds if len(preds) == 3 else preds[1]batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask widthpred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split((self.reg_max * 4, self.nc), 1)# B, grids, ..pred_scores = pred_scores.permute(0, 2, 1).contiguous()pred_distri = pred_distri.permute(0, 2, 1).contiguous()pred_masks = pred_masks.permute(0, 2, 1).contiguous()dtype = pred_scores.dtypeimgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)# Targetstry:batch_idx = batch["batch_idx"].view(-1, 1)targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxymask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)except RuntimeError as e:raise TypeError("ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n""This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, ""i.e. 'yolo train model=yolov8n-seg.pt data=coco8.yaml'.\nVerify your dataset is a ""correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' ""as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help.") from e# Pboxespred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(pred_scores.detach().sigmoid(),(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),anchor_points * stride_tensor,gt_labels,gt_bboxes,mask_gt,)target_scores_sum = max(target_scores.sum(), 1)# Cls loss# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL wayloss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCEif fg_mask.sum():# Bbox lossloss[0], loss[3] = self.bbox_loss(pred_distri,pred_bboxes,anchor_points,target_bboxes / stride_tensor,target_scores,target_scores_sum,fg_mask,)# Masks lossmasks = batch["masks"].to(self.device).float()if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsamplemasks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]loss[1] = self.calculate_segmentation_loss(fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap)# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not removeelse:loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan lossloss[0] *= self.hyp.box # box gainloss[1] *= self.hyp.box # seg gainloss[2] *= self.hyp.cls # cls gainloss[3] *= self.hyp.dfl # dfl gainreturn loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)@staticmethoddef single_mask_loss(gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor) -> torch.Tensor:"""Compute the instance segmentation loss for a single image.Args:gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects.pred (torch.Tensor): Predicted mask coefficients of shape (n, 32).proto (torch.Tensor): Prototype masks of shape (32, H, W).xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4).area (torch.Tensor): Area of each ground truth bounding box of shape (n,).Returns:(torch.Tensor): The calculated mask loss for a single image.Notes:The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce thepredicted masks from the prototype masks and predicted mask coefficients."""pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()def calculate_segmentation_loss(self,fg_mask: torch.Tensor,masks: torch.Tensor,target_gt_idx: torch.Tensor,target_bboxes: torch.Tensor,batch_idx: torch.Tensor,proto: torch.Tensor,pred_masks: torch.Tensor,imgsz: torch.Tensor,overlap: bool,) -> torch.Tensor:"""Calculate the loss for instance segmentation.Args:fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W).target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors).target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4).batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1).proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).overlap (bool): Whether the masks in `masks` tensor overlap.Returns:(torch.Tensor): The calculated loss for instance segmentation.Notes:The batch loss can be computed for improved speed at higher memory usage.For example, pred_mask can be computed as follows:pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160)"""_, _, mask_h, mask_w = proto.shapeloss = 0# Normalize to 0-1target_bboxes_normalized = target_bboxes / imgsz[[1, 0, 1, 0]]# Areas of target bboxesmarea = xyxy2xywh(target_bboxes_normalized)[..., 2:].prod(2)# Normalize to mask sizemxyxy = target_bboxes_normalized * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=proto.device)for i, single_i in enumerate(zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)):fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_iif fg_mask_i.any():mask_idx = target_gt_idx_i[fg_mask_i]if overlap:gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)gt_mask = gt_mask.float()else:gt_mask = masks[batch_idx.view(-1) == i][mask_idx]loss += self.single_mask_loss(gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i])# WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not removeelse:loss += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan lossreturn loss / fg_mask.sum()class v8PoseLoss(v8DetectionLoss):"""Criterion class for computing training losses."""def __init__(self, model): # model must be de-paralleled"""Initializes v8PoseLoss with model, sets keypoint variables and declares a keypoint loss instance."""super().__init__(model)self.kpt_shape = model.model[-1].kpt_shapeself.bce_pose = nn.BCEWithLogitsLoss()is_pose = self.kpt_shape == [17, 3]nkpt = self.kpt_shape[0] # number of keypointssigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkptself.keypoint_loss = KeypointLoss(sigmas=sigmas)def __call__(self, preds, batch):"""Calculate the total loss and detach it."""loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibilityfeats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split((self.reg_max * 4, self.nc), 1)# B, grids, ..pred_scores = pred_scores.permute(0, 2, 1).contiguous()pred_distri = pred_distri.permute(0, 2, 1).contiguous()pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()dtype = pred_scores.dtypeimgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)# Targetsbatch_size = pred_scores.shape[0]batch_idx = batch["batch_idx"].view(-1, 1)targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxymask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)# Pboxespred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(pred_scores.detach().sigmoid(),(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),anchor_points * stride_tensor,gt_labels,gt_bboxes,mask_gt,)target_scores_sum = max(target_scores.sum(), 1)# Cls loss# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL wayloss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE# Bbox lossif fg_mask.sum():target_bboxes /= stride_tensorloss[0], loss[4] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask)keypoints = batch["keypoints"].to(self.device).float().clone()keypoints[..., 0] *= imgsz[1]keypoints[..., 1] *= imgsz[0]loss[1], loss[2] = self.calculate_keypoints_loss(fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts)loss[0] *= self.hyp.box # box gainloss[1] *= self.hyp.pose # pose gainloss[2] *= self.hyp.kobj # kobj gainloss[3] *= self.hyp.cls # cls gainloss[4] *= self.hyp.dfl # dfl gainreturn loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)@staticmethoddef kpts_decode(anchor_points, pred_kpts):"""Decodes predicted keypoints to image coordinates."""y = pred_kpts.clone()y[..., :2] *= 2.0y[..., 0] += anchor_points[:, [0]] - 0.5y[..., 1] += anchor_points[:, [1]] - 0.5return ydef calculate_keypoints_loss(self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts):"""Calculate the keypoints loss for the model.This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss isbased on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss isa binary classification loss that classifies whether a keypoint is present or not.Args:masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).Returns:(tuple): Returns a tuple containing:- kpts_loss (torch.Tensor): The keypoints loss.- kpts_obj_loss (torch.Tensor): The keypoints object loss."""batch_idx = batch_idx.flatten()batch_size = len(masks)# Find the maximum number of keypoints in a single imagemax_kpts = torch.unique(batch_idx, return_counts=True)[1].max()# Create a tensor to hold batched keypointsbatched_keypoints = torch.zeros((batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device)# TODO: any idea how to vectorize this?# Fill batched_keypoints with keypoints based on batch_idxfor i in range(batch_size):keypoints_i = keypoints[batch_idx == i]batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i# Expand dimensions of target_gt_idx to match the shape of batched_keypointstarget_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1)# Use target_gt_idx_expanded to select keypoints from batched_keypointsselected_keypoints = batched_keypoints.gather(1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2]))# Divide coordinates by strideselected_keypoints /= stride_tensor.view(1, -1, 1, 1)kpts_loss = 0kpts_obj_loss = 0if masks.any():gt_kpt = selected_keypoints[masks]area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)pred_kpt = pred_kpts[masks]kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose lossif pred_kpt.shape[-1] == 3:kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj lossreturn kpts_loss, kpts_obj_lossclass v8ClassificationLoss:"""Criterion class for computing training losses."""def __call__(self, preds, batch):"""Compute the classification loss between predictions and true labels."""loss = torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean")loss_items = loss.detach()return loss, loss_itemsclass v8OBBLoss(v8DetectionLoss):def __init__(self, model):"""Initializes v8OBBLoss with model, assigner, and rotated bbox loss.Note model must be de-paralleled."""super().__init__(model)self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)self.bbox_loss = RotatedBboxLoss(self.reg_max - 1, use_dfl=self.use_dfl).to(self.device)def preprocess(self, targets, batch_size, scale_tensor):"""Preprocesses the target counts and matches with the input batch size to output a tensor."""if targets.shape[0] == 0:out = torch.zeros(batch_size, 0, 6, device=self.device)else:i = targets[:, 0] # image index_, counts = i.unique(return_counts=True)counts = counts.to(dtype=torch.int32)out = torch.zeros(batch_size, counts.max(), 6, device=self.device)for j in range(batch_size):matches = i == jn = matches.sum()if n:bboxes = targets[matches, 2:]bboxes[..., :4].mul_(scale_tensor)out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)return outdef __call__(self, preds, batch):"""Calculate and return the loss for the YOLO model."""loss = torch.zeros(3, device=self.device) # box, cls, dflfeats, pred_angle = preds if isinstance(preds[0], list) else preds[1]batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask widthpred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split((self.reg_max * 4, self.nc), 1)# b, grids, ..pred_scores = pred_scores.permute(0, 2, 1).contiguous()pred_distri = pred_distri.permute(0, 2, 1).contiguous()pred_angle = pred_angle.permute(0, 2, 1).contiguous()dtype = pred_scores.dtypeimgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)# targetstry:batch_idx = batch["batch_idx"].view(-1, 1)targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize trainingtargets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhrmask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)except RuntimeError as e:raise TypeError("ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n""This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, ""i.e. 'yolo train model=yolov8n-obb.pt data=dota8.yaml'.\nVerify your dataset is a ""correctly formatted 'OBB' dataset using 'data=dota8.yaml' ""as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help.") from e# Pboxespred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4)bboxes_for_assigner = pred_bboxes.clone().detach()# Only the first four elements need to be scaledbboxes_for_assigner[..., :4] *= stride_tensor_, target_bboxes, target_scores, fg_mask, _ = self.assigner(pred_scores.detach().sigmoid(),bboxes_for_assigner.type(gt_bboxes.dtype),anchor_points * stride_tensor,gt_labels,gt_bboxes,mask_gt,)target_scores_sum = max(target_scores.sum(), 1)# Cls loss# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL wayloss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE# Bbox lossif fg_mask.sum():target_bboxes[..., :4] /= stride_tensorloss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask)else:loss[0] += (pred_angle * 0).sum()loss[0] *= self.hyp.box # box gainloss[1] *= self.hyp.cls # cls gainloss[2] *= self.hyp.dfl # dfl gainreturn loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)def bbox_decode(self, anchor_points, pred_dist, pred_angle):"""Decode predicted object bounding box coordinates from anchor points and distribution.Args:anchor_points (torch.Tensor): Anchor points, (h*w, 2).pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).Returns:(torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5)."""if self.use_dfl:b, a, c = pred_dist.shape # batch, anchors, channelspred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
- ③ ultralytics/utils/tal.py
# Ultralytics YOLO 🚀, AGPL-3.0 licenseimport torch
import torch.nn as nnfrom .checks import check_version
from .metrics import bbox_iou, probiou
from .ops import xywhr2xyxyxyxyTORCH_1_10 = check_version(torch.__version__, "1.10.0")class TaskAlignedAssigner(nn.Module):"""A task-aligned assigner for object detection.This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines bothclassification and localization information.Attributes:topk (int): The number of top candidates to consider.num_classes (int): The number of object classes.alpha (float): The alpha parameter for the classification component of the task-aligned metric.beta (float): The beta parameter for the localization component of the task-aligned metric.eps (float): A small value to prevent division by zero."""def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):"""Initialize a TaskAlignedAssigner object with customizable hyperparameters."""super().__init__()self.topk = topkself.num_classes = num_classesself.bg_idx = num_classesself.alpha = alphaself.beta = betaself.eps = eps@torch.no_grad()def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):"""Compute the task-aligned assignment. Reference code is available athttps://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.Args:pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)anc_points (Tensor): shape(num_total_anchors, 2)gt_labels (Tensor): shape(bs, n_max_boxes, 1)gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)mask_gt (Tensor): shape(bs, n_max_boxes, 1)Returns:target_labels (Tensor): shape(bs, num_total_anchors)target_bboxes (Tensor): shape(bs, num_total_anchors, 4)target_scores (Tensor): shape(bs, num_total_anchors, num_classes)fg_mask (Tensor): shape(bs, num_total_anchors)target_gt_idx (Tensor): shape(bs, num_total_anchors)"""self.bs = pd_scores.shape[0]self.n_max_boxes = gt_bboxes.shape[1]if self.n_max_boxes == 0:device = gt_bboxes.devicereturn (torch.full_like(pd_scores[..., 0], self.bg_idx).to(device),torch.zeros_like(pd_bboxes).to(device),torch.zeros_like(pd_scores).to(device),torch.zeros_like(pd_scores[..., 0]).to(device),torch.zeros_like(pd_scores[..., 0]).to(device),)mask_pos, align_metric, overlaps = self.get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt)target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)# Assigned targettarget_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)# Normalizealign_metric *= mask_pospos_align_metrics = align_metric.amax(dim=-1, keepdim=True) # b, max_num_objpos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True) # b, max_num_objnorm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)target_scores = target_scores * norm_align_metricreturn target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idxdef get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):"""Get in_gts mask, (b, max_num_obj, h*w)."""mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)# Get anchor_align metric, (b, max_num_obj, h*w)align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)# Get topk_metric mask, (b, max_num_obj, h*w)mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())# Merge all mask to a final mask, (b, max_num_obj, h*w)mask_pos = mask_topk * mask_in_gts * mask_gtreturn mask_pos, align_metric, overlapsdef get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):"""Compute alignment metric given predicted and ground truth bounding boxes."""na = pd_bboxes.shape[-2]mask_gt = mask_gt.bool() # b, max_num_obj, h*woverlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_objind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes) # b, max_num_objind[1] = gt_labels.squeeze(-1) # b, max_num_obj# Get the scores of each grid for each gt clsbbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] # b, max_num_obj, h*w# (b, max_num_obj, 1, 4), (b, 1, h*w, 4)pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, GIoU=False, DIoU=False, CIoU=True, SIoU=False, EIoU=False, WIoU=False, MPDIoU=False, LMPDIoU=False, Inner=True, Focal=False).squeeze(-1).clamp_(0)align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)return align_metric, overlapsdef iou_calculation(self, gt_bboxes, pd_bboxes):"""IoU calculation for horizontal bounding boxes."""return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)def select_topk_candidates(self, metrics, largest=True, topk_mask=None):"""Select the top-k candidates based on the given metrics.Args:metrics (Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,max_num_obj is the maximum number of objects, and h*w represents thetotal number of anchor points.largest (bool): If True, select the largest values; otherwise, select the smallest values.topk_mask (Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), wheretopk is the number of top candidates to consider. If not provided,the top-k values are automatically computed based on the given metrics.Returns:(Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates."""# (b, max_num_obj, topk)topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)if topk_mask is None:topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)# (b, max_num_obj, topk)topk_idxs.masked_fill_(~topk_mask, 0)# (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)for k in range(self.topk):# Expand topk_idxs for each value of k and add 1 at the specified positionscount_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)# count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))# Filter invalid bboxescount_tensor.masked_fill_(count_tensor > 1, 0)return count_tensor.to(metrics.dtype)def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):"""Compute target labels, target bounding boxes, and target scores for the positive anchor points.Args:gt_labels (Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is thebatch size and max_num_obj is the maximum number of objects.gt_bboxes (Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).target_gt_idx (Tensor): Indices of the assigned ground truth objects for positiveanchor points, with shape (b, h*w), where h*w is the totalnumber of anchor points.fg_mask (Tensor): A boolean tensor of shape (b, h*w) indicating the positive(foreground) anchor points.Returns:(Tuple[Tensor, Tensor, Tensor]): A tuple containing the following tensors:- target_labels (Tensor): Shape (b, h*w), containing the target labels forpositive anchor points.- target_bboxes (Tensor): Shape (b, h*w, 4), containing the target bounding boxesfor positive anchor points.- target_scores (Tensor): Shape (b, h*w, num_classes), containing the target scoresfor positive anchor points, where num_classes is the numberof object classes."""# Assigned target labels, (b, 1)batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)# Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]# Assigned target scorestarget_labels.clamp_(0)# 10x faster than F.one_hot()target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes),dtype=torch.int64,device=target_labels.device,) # (b, h*w, 80)target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)return target_labels, target_bboxes, target_scores@staticmethoddef select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):"""Select the positive anchor center in gt.Args:xy_centers (Tensor): shape(h*w, 2)gt_bboxes (Tensor): shape(b, n_boxes, 4)Returns:(Tensor): shape(b, n_boxes, h*w)"""n_anchors = xy_centers.shape[0]bs, n_boxes, _ = gt_bboxes.shapelt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottombbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)# return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)return bbox_deltas.amin(3).gt_(eps)@staticmethoddef select_highest_overlaps(mask_pos, overlaps, n_max_boxes):"""If an anchor box is assigned to multiple gts, the one with the highest IoU will be selected.Args:mask_pos (Tensor): shape(b, n_max_boxes, h*w)overlaps (Tensor): shape(b, n_max_boxes, h*w)Returns:target_gt_idx (Tensor): shape(b, h*w)fg_mask (Tensor): shape(b, h*w)mask_pos (Tensor): shape(b, n_max_boxes, h*w)"""# (b, n_max_boxes, h*w) -> (b, h*w)fg_mask = mask_pos.sum(-2)if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxesmask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)max_overlaps_idx = overlaps.argmax(1) # (b, h*w)is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)fg_mask = mask_pos.sum(-2)# Find each grid serve which gt(index)target_gt_idx = mask_pos.argmax(-2) # (b, h*w)return target_gt_idx, fg_mask, mask_posclass RotatedTaskAlignedAssigner(TaskAlignedAssigner):def iou_calculation(self, gt_bboxes, pd_bboxes):"""IoU calculation for rotated bounding boxes."""return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)@staticmethoddef select_candidates_in_gts(xy_centers, gt_bboxes):"""Select the positive anchor center in gt for rotated bounding boxes.Args:xy_centers (Tensor): shape(h*w, 2)gt_bboxes (Tensor): shape(b, n_boxes, 5)Returns:(Tensor): shape(b, n_boxes, h*w)"""# (b, n_boxes, 5) --> (b, n_boxes, 4, 2)corners = xywhr2xyxyxyxy(gt_bboxes)# (b, n_boxes, 1, 2)a, b, _, d = corners.split(1, dim=-2)ab = b - aad = d - a# (b, n_boxes, h*w, 2)ap = xy_centers - anorm_ab = (ab * ab).sum(dim=-1)norm_ad = (ad * ad).sum(dim=-1)ap_dot_ab = (ap * ab).sum(dim=-1)ap_dot_ad = (ap * ad).sum(dim=-1)return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad) # is_in_boxdef make_anchors(feats, strides, grid_cell_offset=0.5):"""Generate anchors from features."""anchor_points, stride_tensor = [], []assert feats is not Nonedtype, device = feats[0].dtype, feats[0].devicefor i, stride in enumerate(strides):_, _, h, w = feats[i].shapesx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift xsy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift ysy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))return torch.cat(anchor_points), torch.cat(stride_tensor)def dist2bbox(distance, anchor_points, xywh=True, dim=-1):"""Transform distance(ltrb) to box(xywh or xyxy)."""lt, rb = distance.chunk(2, dim)x1y1 = anchor_points - ltx2y2 = anchor_points + rbif xywh:c_xy = (x1y1 + x2y2) / 2wh = x2y2 - x1y1return torch.cat((c_xy, wh), dim) # xywh bboxreturn torch.cat((x1y1, x2y2), dim) # xyxy bboxdef bbox2dist(anchor_points, bbox, reg_max):"""Transform bbox(xyxy) to dist(ltrb)."""x1y1, x2y2 = bbox.chunk(2, -1)return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb)def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):"""Decode predicted object bounding box coordinates from anchor points and distribution.Args:pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).anchor_points (torch.Tensor): Anchor points, (h*w, 2).Returns:(torch.Tensor): Predicted rotated bounding boxes, (bs, h*w, 4)."""lt, rb = pred_dist.split(2, dim=dim)cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)# (bs, h*w, 1)xf, yf = ((rb - lt) / 2).split(1, dim=dim)x, y = xf * cos - yf * sin, xf * sin + yf * cosxy = torch.cat([x, y], dim=dim) + anchor_pointsreturn torch.cat([xy, lt + rb], dim=dim)
到此,本文分享的内容就结束啦!遇见便是缘,感恩遇见!!!💛 💙 💜 ❤️ 💚