【YOLO改进】换遍IoU损失函数之EIoU Loss(基于MMYOLO)

EIoU损失函数

设计原理

一、IoU的局限性

IoU(Intersection over Union)是一种常用于评估目标检测模型性能的指标,特别是在计算预测边界框与真实边界框之间的重叠程度时。然而,IoU存在一些局限性,尤其是当两个边界框没有任何交集时,IoU 的值为0,这使得梯度更新停滞,不利于模型的进一步学习和优化。

二、EIoU的引入

为了解决这一问题,引入了EIoU(Enhanced Intersection over Union)损失函数。EIoU 不仅考虑了边界框间的重叠区域,还引入了其他度量,如边界框中心点的距离,以及边界框的宽度和高度的相对差异。这样的设计使得即使两个边界框不重叠,损失函数仍然可以提供有效的梯度,从而促进模型的训练和收敛。

计算步骤

一、计算IoU

  • 计算两个边界框A和B的交集面积I。

  • 计算两个边界框的并集面积U。

  • IoU计算公式为:I/U

二、计算中心点距离的公式

中心点距离是预测框和真实框中心点之间的欧氏距离。设预测框中心为(x_p,y_p),真实框中心为 (x_g,y_g),则中心距离D_c计算为:

D_c = \sqrt{(x_p - x_g)^2 + (y_p - y_g)^2}

三、计算宽高比的差异

宽度差异w_{diff}和高度差异 h_{diff} 分别为预测框和真实框宽度和高度的相对差值。计算方法可以是简单的差值或者比例差等。

四、整合以上度量

EIoU将上述度量整合到一个损失函数中,通常形式为:

\text{EIoU Loss} = 1 - \text{IoU} + \lambda_1 D_c + \lambda_2 (w_{\text{diff}} + h_{\text{diff}})

其中,\lambda_1​ 和\lambda_2是调节中心距离和宽高差异影响的超参数。

使用PyTorch实现EIoU计算的源代码

import torch
import torch.nn.functional as Fdef bbox_iou(boxes1, boxes2):"""计算两组边界框的IoU。boxes1, boxes2: [N, 4] (x1, y1, x2, y2)"""area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])inter_x1 = torch.max(boxes1[:, 0], boxes2[:, 0])inter_y1 = torch.max(boxes1[:, 1], boxes2[:, 1])inter_x2 = torch.min(boxes1[:, 2], boxes2[:, 2])inter_y2 = torch.min(boxes1[:, 3], boxes2[:, 3])inter_area = torch.clamp(inter_x2 - inter_x1, min=0) * torch.clamp(inter_y2 - inter_y1, min=0)union_area = area1 + area2 - inter_areareturn inter_area / union_areadef eiou_loss(pred_boxes, target_boxes, lambda1=1, lambda2=1):"""计算EIoU损失。pred_boxes, target_boxes: [N, 4] (x1, y1, x2, y2)"""iou = bbox_iou(pred_boxes, target_boxes)# 计算中心点center_pred = (pred_boxes[:, :2] + pred_boxes[:, 2:4]) / 2center_target = (target_boxes[:, :2] + target_boxes[:, 2:4]) / 2# 计算中心点距离dc = torch.sqrt(torch.sum((center_pred - center_target) ** 2, dim=1))# 计算宽高差异wh_pred = pred_boxes[:, 2:4] - pred_boxes[:, :2]wh_target = target_boxes[:, 2:4] - target_boxes[:, :2]wh_diff = torch.abs(wh_pred - wh_target).sum(dim=1)# 计算EIoU损失loss = 1 - iou + lambda1 * dc + lambda2 * wh_diffreturn loss.mean()# 示例用法
pred_boxes = torch.tensor([[25, 25, 75, 75], [50, 50, 100, 100]], dtype=torch.float32)
target_boxes = torch.tensor([[30, 30, 70, 70], [70, 70, 120, 120]], dtype=torch.float32)loss = eiou_loss(pred_boxes, target_boxes)
print(f"EIoU Loss: {loss}")

替换EIoU损失函数(基于MMYOLO)

由于MMYOLO中没有实现EIoU损失函数,所以需要在mmyolo/models/iou_loss.py中添加EIoU的计算和对应的iou_mode,修改完以后在终端运行

python setup.py install

再在配置文件中进行修改即可。修改例子如下:

    elif iou_mode == "eiou":# CIoU = IoU - ( (ρ^2(b_pred,b_gt) / c^2) + (alpha x v) )# calculate enclose area (c^2)enclose_area = enclose_w**2 + enclose_h**2 + eps# calculate ρ^2(b_pred,b_gt):# euclidean distance between b_pred(bbox2) and b_gt(bbox1)# center point, because bbox format is xyxy -> left-top xy and# right-bottom xy, so need to / 4 to get center point.rho2_left_item = ((bbox2_x1 + bbox2_x2) - (bbox1_x1 + bbox1_x2))**2 / 4rho2_right_item = ((bbox2_y1 + bbox2_y2) -(bbox1_y1 + bbox1_y2))**2 / 4rho2 = rho2_left_item + rho2_right_item  # rho^2 (ρ^2)rho_w2 = ((bbox2_x2 - bbox2_x1) - (bbox1_x2 - bbox1_x1)) ** 2rho_h2 = ((bbox2_y2 - bbox2_y1) - (bbox1_y2 - bbox1_y1)) ** 2cw2 = enclose_w ** 2 + epsch2 = enclose_h ** 2 + epsious = ious - (rho2 / enclose_area + rho_w2 / cw2 + rho_h2 / ch2)

修改后的配置文件(以configs/yolov5/yolov5_s-v61_syncbn_8xb16-300e_coco.py为例)

_base_ = ['../_base_/default_runtime.py', '../_base_/det_p5_tta.py']# ========================Frequently modified parameters======================
# -----data related-----
data_root = 'data/coco/'  # Root path of data
# Path of train annotation file
train_ann_file = 'annotations/instances_train2017.json'
train_data_prefix = 'train2017/'  # Prefix of train image path
# Path of val annotation file
val_ann_file = 'annotations/instances_val2017.json'
val_data_prefix = 'val2017/'  # Prefix of val image pathnum_classes = 80  # Number of classes for classification
# Batch size of a single GPU during training
train_batch_size_per_gpu = 16
# Worker to pre-fetch data for each single GPU during training
train_num_workers = 8
# persistent_workers must be False if num_workers is 0
persistent_workers = True# -----model related-----
# Basic size of multi-scale prior box
anchors = [[(10, 13), (16, 30), (33, 23)],  # P3/8[(30, 61), (62, 45), (59, 119)],  # P4/16[(116, 90), (156, 198), (373, 326)]  # P5/32
]# -----train val related-----
# Base learning rate for optim_wrapper. Corresponding to 8xb16=128 bs
base_lr = 0.01
max_epochs = 300  # Maximum training epochsmodel_test_cfg = dict(# The config of multi-label for multi-class prediction.multi_label=True,# The number of boxes before NMSnms_pre=30000,score_thr=0.001,  # Threshold to filter out boxes.nms=dict(type='nms', iou_threshold=0.65),  # NMS type and thresholdmax_per_img=300)  # Max number of detections of each image# ========================Possible modified parameters========================
# -----data related-----
img_scale = (640, 640)  # width, height
# Dataset type, this will be used to define the dataset
dataset_type = 'YOLOv5CocoDataset'
# Batch size of a single GPU during validation
val_batch_size_per_gpu = 1
# Worker to pre-fetch data for each single GPU during validation
val_num_workers = 2# Config of batch shapes. Only on val.
# It means not used if batch_shapes_cfg is None.
batch_shapes_cfg = dict(type='BatchShapePolicy',batch_size=val_batch_size_per_gpu,img_size=img_scale[0],# The image scale of padding should be divided by pad_size_divisorsize_divisor=32,# Additional paddings for pixel scaleextra_pad_ratio=0.5)# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 0.33
# The scaling factor that controls the width of the network structure
widen_factor = 0.5
# Strides of multi-scale prior box
strides = [8, 16, 32]
num_det_layers = 3  # The number of model output scales
norm_cfg = dict(type='BN', momentum=0.03, eps=0.001)  # Normalization config# -----train val related-----
affine_scale = 0.5  # YOLOv5RandomAffine scaling ratio
loss_cls_weight = 0.5
loss_bbox_weight = 0.05
loss_obj_weight = 1.0
prior_match_thr = 4.  # Priori box matching threshold
# The obj loss weights of the three output layers
obj_level_weights = [4., 1., 0.4]
lr_factor = 0.01  # Learning rate scaling factor
weight_decay = 0.0005
# Save model checkpoint and validation intervals
save_checkpoint_intervals = 10
# The maximum checkpoints to keep.
max_keep_ckpts = 3
# Single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)# ===============================Unmodified in most cases====================
model = dict(type='YOLODetector',data_preprocessor=dict(type='mmdet.DetDataPreprocessor',mean=[0., 0., 0.],std=[255., 255., 255.],bgr_to_rgb=True),backbone=dict(##使用YOLOv8的主干网络type='YOLOv8CSPDarknet',deepen_factor=deepen_factor,widen_factor=widen_factor,norm_cfg=norm_cfg,act_cfg=dict(type='SiLU', inplace=True)),neck=dict(type='YOLOv5PAFPN',deepen_factor=deepen_factor,widen_factor=widen_factor,in_channels=[256, 512, 1024],out_channels=[256, 512, 1024],num_csp_blocks=3,norm_cfg=norm_cfg,act_cfg=dict(type='SiLU', inplace=True)),bbox_head=dict(type='YOLOv5Head',head_module=dict(type='YOLOv5HeadModule',num_classes=num_classes,in_channels=[256, 512, 1024],widen_factor=widen_factor,featmap_strides=strides,num_base_priors=3),prior_generator=dict(type='mmdet.YOLOAnchorGenerator',base_sizes=anchors,strides=strides),# scaled based on number of detection layersloss_cls=dict(type='mmdet.CrossEntropyLoss',use_sigmoid=True,reduction='mean',loss_weight=loss_cls_weight *(num_classes / 80 * 3 / num_det_layers)),# 修改此处实现IoU损失函数的替换loss_bbox=dict(type='IoULoss',iou_mode='eiou',bbox_format='xywh',eps=1e-7,reduction='mean',loss_weight=loss_bbox_weight * (3 / num_det_layers),return_iou=True),loss_obj=dict(type='mmdet.CrossEntropyLoss',use_sigmoid=True,reduction='mean',loss_weight=loss_obj_weight *((img_scale[0] / 640)**2 * 3 / num_det_layers)),prior_match_thr=prior_match_thr,obj_level_weights=obj_level_weights),test_cfg=model_test_cfg)albu_train_transforms = [dict(type='Blur', p=0.01),dict(type='MedianBlur', p=0.01),dict(type='ToGray', p=0.01),dict(type='CLAHE', p=0.01)
]pre_transform = [dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),dict(type='LoadAnnotations', with_bbox=True)
]train_pipeline = [*pre_transform,dict(type='Mosaic',img_scale=img_scale,pad_val=114.0,pre_transform=pre_transform),dict(type='YOLOv5RandomAffine',max_rotate_degree=0.0,max_shear_degree=0.0,scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),# img_scale is (width, height)border=(-img_scale[0] // 2, -img_scale[1] // 2),border_val=(114, 114, 114)),dict(type='mmdet.Albu',transforms=albu_train_transforms,bbox_params=dict(type='BboxParams',format='pascal_voc',label_fields=['gt_bboxes_labels', 'gt_ignore_flags']),keymap={'img': 'image','gt_bboxes': 'bboxes'}),dict(type='YOLOv5HSVRandomAug'),dict(type='mmdet.RandomFlip', prob=0.5),dict(type='mmdet.PackDetInputs',meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip','flip_direction'))
]train_dataloader = dict(batch_size=train_batch_size_per_gpu,num_workers=train_num_workers,persistent_workers=persistent_workers,pin_memory=True,sampler=dict(type='DefaultSampler', shuffle=True),dataset=dict(type=dataset_type,data_root=data_root,ann_file=train_ann_file,data_prefix=dict(img=train_data_prefix),filter_cfg=dict(filter_empty_gt=False, min_size=32),pipeline=train_pipeline))test_pipeline = [dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),dict(type='YOLOv5KeepRatioResize', scale=img_scale),dict(type='LetterResize',scale=img_scale,allow_scale_up=False,pad_val=dict(img=114)),dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),dict(type='mmdet.PackDetInputs',meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape','scale_factor', 'pad_param'))
]val_dataloader = dict(batch_size=val_batch_size_per_gpu,num_workers=val_num_workers,persistent_workers=persistent_workers,pin_memory=True,drop_last=False,sampler=dict(type='DefaultSampler', shuffle=False),dataset=dict(type=dataset_type,data_root=data_root,test_mode=True,data_prefix=dict(img=val_data_prefix),ann_file=val_ann_file,pipeline=test_pipeline,batch_shapes_cfg=batch_shapes_cfg))test_dataloader = val_dataloaderparam_scheduler = None
optim_wrapper = dict(type='OptimWrapper',optimizer=dict(type='SGD',lr=base_lr,momentum=0.937,weight_decay=weight_decay,nesterov=True,batch_size_per_gpu=train_batch_size_per_gpu),constructor='YOLOv5OptimizerConstructor')default_hooks = dict(param_scheduler=dict(type='YOLOv5ParamSchedulerHook',scheduler_type='linear',lr_factor=lr_factor,max_epochs=max_epochs),checkpoint=dict(type='CheckpointHook',interval=save_checkpoint_intervals,save_best='auto',max_keep_ckpts=max_keep_ckpts))custom_hooks = [dict(type='EMAHook',ema_type='ExpMomentumEMA',momentum=0.0001,update_buffers=True,strict_load=False,priority=49)
]val_evaluator = dict(type='mmdet.CocoMetric',proposal_nums=(100, 1, 10),ann_file=data_root + val_ann_file,metric='bbox')
test_evaluator = val_evaluatortrain_cfg = dict(type='EpochBasedTrainLoop',max_epochs=max_epochs,val_interval=save_checkpoint_intervals)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/5602.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[python趣味实战]----基于python代码实现浪漫爱心 დ

正文 01-效果演示 下图是代码运行之后的爱心显示结果: 下面的视频该爱心是动态效果,较为简洁,如果需要使用,可以进行完善,这里只是一个趣味实战,下面将对代码实现进行非常详细地描述: 浪漫爱心…

Java数据结构-模拟实现ArrayList

MyArrayList顺序结构: 接口和MyArrayList重写接口 接口 接口中的方法是很多类通用的,所以可以写到接口中 public interface IList {public void add(int data) ;// 在 pos 位置新增元素public void add(int pos, int data);// 判定是否包含某个元素p…

踏上R语言之旅:解锁数据世界的神秘密码(三)

多元相关与回归分析及R使用 文章目录 多元相关与回归分析及R使用一.变量间的关系分析1.两变量线性相关系数的计算2.相关系数的假设检验 二.一元线性回归分析的R计算三、回归系数的假设检验总结 一.变量间的关系分析 变量间的关系及分析方法如下: 1.两变量线性相关…

【C++程序员的自我修炼】string 库中常见的用法 (一)

唤起一天明月照我满怀冰雪浩荡百川流鲸饮未吞海 剑气已横秋 目录 string 库的简介 string 的一些小操作 构造函数的使用 拷贝构造的常规使用 指定拷贝内容的拷贝构造 拷贝字符串开始的前 n 个字符 用 n 个字符初始化 计算字符串的长度 string 的三种遍历方式 常规的for循环 op…

利用大型语言模型提升数字产品创新:提示,微调,检索增强生成和代理的应用

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

「笔试刷题」:字母收集

一、题目 描述 有一个 𝑛∗𝑚 的矩形方阵,每个格子上面写了一个小写字母。 小红站在矩形的左上角,她每次可以向右或者向下走,走到某个格子上就可以收集这个格子的字母。 小红非常喜欢 "love" 这四个字母。…

FFmpeg开发笔记(二十三)使用OBS Studio开启RTMP直播推流

OBS是一个开源的直播录制软件,英文全称叫做Open Broadcaster Software,广泛用于视频录制、实时直播等领域。OBS不但开源,而且跨平台,兼容Windows、Mac OS、Linux等操作系统。 OBS的官网是https://obsproject.com/,录制…

【报错处理】ib_write_bw执行遇到Couldn‘t listen to port 18515原因与解决办法?

要点 要点: ib默认使用18515端口 相关命令: netstat -tuln | grep 18515 ib_write_bw --help |grep port# server ib_write_bw --ib-devmlx5_1 --port88990 # client ib_write_bw --ib-devmlx5_0 1.1.1.1 --port88990现象: 根因&#xff…

为什么公共事业机构会偏爱 TiDB :TiDB 数据库在某省妇幼健康管理系统的应用

本文介绍了某省妇幼健康管理系统的建设和数据库架构优化的过程。原有的数据库架构使用了 StarRocks 作为分析层,但随着业务的发展,这套架构暴露出诸多痛点,不再适应妇幼业务的需求。为解决这些问题,该系统选择了将原有架构中的 St…

OBSERVER(观察者)-- 对象行为模式

意图: 定义对象间地一种一对多地依赖关系,当一个对象地状态发生改变时,所有对于依赖于它的对象都得到通知并被自动更新。 别名: 依赖(Dependents), 发布-订阅(Publish-Subsribe) 动机: 将一个系统分割成一系列相互协…

使用Python及R语言绘制简易数据分析报告

Pytohn实现 在python中有很多包可以实现绘制数据分析报告的功能,推荐两个较为方便的包:pandas-profiling 和 sweetviz 。 使用 pandas-profiling 包(功能全面) 这个包的个别依赖包与机器学习的 sklearn 包的依赖包存在版本冲突&a…

【C++中的模板】

和你有关,观后无感................................................................................................................. 目录 前言 一、【模板的引入和介绍】 1、泛型编程 2、【模板的介绍】 二、【 函数模板】 2.1【模函数板的介绍】 1.…

修改word文件的创作者方法有哪些?如何修改文档的作者 这两个方法你一定要知道

在数字化时代,文件创作者的信息往往嵌入在文件的元数据中,这些元数据包括创作者的姓名、创建日期以及其他相关信息。然而,有时候我们可能需要修改这些创作者信息,出于隐私保护、版权调整或者其他实际需求。那么,有没有…

【开源设计】京东慢SQL组件:sql-analysis

京东慢SQL组件:sql-analysis 一、背景二、源码简析三、总结 地址:https://github.com/jd-opensource/sql-analysis 一、背景 开发中,无疑会遇到慢SQL问题,而常见的处理思路都是等上线,然后由监控报警之后再去定位对应…

vue 前端读取Excel文件并解析

前端读取Excel文件并解析 前端如何解释Excel呢 平时项目中对于Excel的导入解析是很常见的功能,一般都是放在后端执行;但是也有特殊的情况,偶尔也有要求说前端执行解析,判空,校验等,最后组装成后端接口想要的…

【大数据】利用 Apache Ranger 管理 Amazon EMR 中的数据权限

利用 Apache Ranger 管理 Amazon EMR 中的数据权限 1.需求背景简介2.系统方案架构图3.主要服务和组件简介3.1 Amazon EMR3.2 Simple Active Directory3.3 Apache Ranger 4.部署步骤4.1 部署 Simple AD 服务4.2 部署 Apache Ranger4.3 部署 Amazon EMR4.4 在 Amazon EMR 的主节点…

【数据结构】二叉树(带图详解)

文章目录 1.树的概念1.2 树的结构孩子表示法孩子兄弟表示法 1.3 相关概念 2.二叉树的概念及结构2.1 二叉树的概念2.2 数据结构中的二叉树-五种形态2.3 特殊的二叉树2.4 二叉树的存储结构顺序存储链式存储 2.5 二叉树的性质 3. 堆3.1 堆的定义3.2 堆的实现堆的结构堆的插入向上调…

java技术栈快速复习02_前端基础知识总结

前端基础 经典三件套: html(盒子)css(样式)JavaScript(js:让盒子动起来) html & css HTML全称:Hyper Text Markup Language(超文本标记语言),不是编程语…

不科学上网使用Hugging Face的Transformers库

参考 Program Synthesis with CodeGen — ROCm Blogs (amd.com) HF-Mirror - Huggingface 镜像站 https://huggingface.co/docs/transformers/v4.40.1/zh/installation#%E7%A6%BB%E7%BA%BF%E6%A8%A1%E5%BC%8F 准备 apt show rocm-libs -a pip install transformers python …

计算机网络—数据链路层

一、数据链路层的基本概念 结点:主机、路由器 链路:网络中两个结点之间的物理通道,链路的传输介质主要有双绞线、光纤和微波。分为有线链路、无线链路 数据链路:网络中两个结点之间的逻辑通道,把实现控制数据协议的…