【YOLO改进】换遍IoU损失函数之EIoU Loss(基于MMYOLO)

EIoU损失函数

设计原理

一、IoU的局限性

IoU(Intersection over Union)是一种常用于评估目标检测模型性能的指标,特别是在计算预测边界框与真实边界框之间的重叠程度时。然而,IoU存在一些局限性,尤其是当两个边界框没有任何交集时,IoU 的值为0,这使得梯度更新停滞,不利于模型的进一步学习和优化。

二、EIoU的引入

为了解决这一问题,引入了EIoU(Enhanced Intersection over Union)损失函数。EIoU 不仅考虑了边界框间的重叠区域,还引入了其他度量,如边界框中心点的距离,以及边界框的宽度和高度的相对差异。这样的设计使得即使两个边界框不重叠,损失函数仍然可以提供有效的梯度,从而促进模型的训练和收敛。

计算步骤

一、计算IoU

  • 计算两个边界框A和B的交集面积I。

  • 计算两个边界框的并集面积U。

  • IoU计算公式为:I/U

二、计算中心点距离的公式

中心点距离是预测框和真实框中心点之间的欧氏距离。设预测框中心为(x_p,y_p),真实框中心为 (x_g,y_g),则中心距离D_c计算为:

D_c = \sqrt{(x_p - x_g)^2 + (y_p - y_g)^2}

三、计算宽高比的差异

宽度差异w_{diff}和高度差异 h_{diff} 分别为预测框和真实框宽度和高度的相对差值。计算方法可以是简单的差值或者比例差等。

四、整合以上度量

EIoU将上述度量整合到一个损失函数中,通常形式为:

\text{EIoU Loss} = 1 - \text{IoU} + \lambda_1 D_c + \lambda_2 (w_{\text{diff}} + h_{\text{diff}})

其中,\lambda_1​ 和\lambda_2是调节中心距离和宽高差异影响的超参数。

使用PyTorch实现EIoU计算的源代码

import torch
import torch.nn.functional as Fdef bbox_iou(boxes1, boxes2):"""计算两组边界框的IoU。boxes1, boxes2: [N, 4] (x1, y1, x2, y2)"""area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])inter_x1 = torch.max(boxes1[:, 0], boxes2[:, 0])inter_y1 = torch.max(boxes1[:, 1], boxes2[:, 1])inter_x2 = torch.min(boxes1[:, 2], boxes2[:, 2])inter_y2 = torch.min(boxes1[:, 3], boxes2[:, 3])inter_area = torch.clamp(inter_x2 - inter_x1, min=0) * torch.clamp(inter_y2 - inter_y1, min=0)union_area = area1 + area2 - inter_areareturn inter_area / union_areadef eiou_loss(pred_boxes, target_boxes, lambda1=1, lambda2=1):"""计算EIoU损失。pred_boxes, target_boxes: [N, 4] (x1, y1, x2, y2)"""iou = bbox_iou(pred_boxes, target_boxes)# 计算中心点center_pred = (pred_boxes[:, :2] + pred_boxes[:, 2:4]) / 2center_target = (target_boxes[:, :2] + target_boxes[:, 2:4]) / 2# 计算中心点距离dc = torch.sqrt(torch.sum((center_pred - center_target) ** 2, dim=1))# 计算宽高差异wh_pred = pred_boxes[:, 2:4] - pred_boxes[:, :2]wh_target = target_boxes[:, 2:4] - target_boxes[:, :2]wh_diff = torch.abs(wh_pred - wh_target).sum(dim=1)# 计算EIoU损失loss = 1 - iou + lambda1 * dc + lambda2 * wh_diffreturn loss.mean()# 示例用法
pred_boxes = torch.tensor([[25, 25, 75, 75], [50, 50, 100, 100]], dtype=torch.float32)
target_boxes = torch.tensor([[30, 30, 70, 70], [70, 70, 120, 120]], dtype=torch.float32)loss = eiou_loss(pred_boxes, target_boxes)
print(f"EIoU Loss: {loss}")

替换EIoU损失函数(基于MMYOLO)

由于MMYOLO中没有实现EIoU损失函数,所以需要在mmyolo/models/iou_loss.py中添加EIoU的计算和对应的iou_mode,修改完以后在终端运行

python setup.py install

再在配置文件中进行修改即可。修改例子如下:

    elif iou_mode == "eiou":# CIoU = IoU - ( (ρ^2(b_pred,b_gt) / c^2) + (alpha x v) )# calculate enclose area (c^2)enclose_area = enclose_w**2 + enclose_h**2 + eps# calculate ρ^2(b_pred,b_gt):# euclidean distance between b_pred(bbox2) and b_gt(bbox1)# center point, because bbox format is xyxy -> left-top xy and# right-bottom xy, so need to / 4 to get center point.rho2_left_item = ((bbox2_x1 + bbox2_x2) - (bbox1_x1 + bbox1_x2))**2 / 4rho2_right_item = ((bbox2_y1 + bbox2_y2) -(bbox1_y1 + bbox1_y2))**2 / 4rho2 = rho2_left_item + rho2_right_item  # rho^2 (ρ^2)rho_w2 = ((bbox2_x2 - bbox2_x1) - (bbox1_x2 - bbox1_x1)) ** 2rho_h2 = ((bbox2_y2 - bbox2_y1) - (bbox1_y2 - bbox1_y1)) ** 2cw2 = enclose_w ** 2 + epsch2 = enclose_h ** 2 + epsious = ious - (rho2 / enclose_area + rho_w2 / cw2 + rho_h2 / ch2)

修改后的配置文件(以configs/yolov5/yolov5_s-v61_syncbn_8xb16-300e_coco.py为例)

_base_ = ['../_base_/default_runtime.py', '../_base_/det_p5_tta.py']# ========================Frequently modified parameters======================
# -----data related-----
data_root = 'data/coco/'  # Root path of data
# Path of train annotation file
train_ann_file = 'annotations/instances_train2017.json'
train_data_prefix = 'train2017/'  # Prefix of train image path
# Path of val annotation file
val_ann_file = 'annotations/instances_val2017.json'
val_data_prefix = 'val2017/'  # Prefix of val image pathnum_classes = 80  # Number of classes for classification
# Batch size of a single GPU during training
train_batch_size_per_gpu = 16
# Worker to pre-fetch data for each single GPU during training
train_num_workers = 8
# persistent_workers must be False if num_workers is 0
persistent_workers = True# -----model related-----
# Basic size of multi-scale prior box
anchors = [[(10, 13), (16, 30), (33, 23)],  # P3/8[(30, 61), (62, 45), (59, 119)],  # P4/16[(116, 90), (156, 198), (373, 326)]  # P5/32
]# -----train val related-----
# Base learning rate for optim_wrapper. Corresponding to 8xb16=128 bs
base_lr = 0.01
max_epochs = 300  # Maximum training epochsmodel_test_cfg = dict(# The config of multi-label for multi-class prediction.multi_label=True,# The number of boxes before NMSnms_pre=30000,score_thr=0.001,  # Threshold to filter out boxes.nms=dict(type='nms', iou_threshold=0.65),  # NMS type and thresholdmax_per_img=300)  # Max number of detections of each image# ========================Possible modified parameters========================
# -----data related-----
img_scale = (640, 640)  # width, height
# Dataset type, this will be used to define the dataset
dataset_type = 'YOLOv5CocoDataset'
# Batch size of a single GPU during validation
val_batch_size_per_gpu = 1
# Worker to pre-fetch data for each single GPU during validation
val_num_workers = 2# Config of batch shapes. Only on val.
# It means not used if batch_shapes_cfg is None.
batch_shapes_cfg = dict(type='BatchShapePolicy',batch_size=val_batch_size_per_gpu,img_size=img_scale[0],# The image scale of padding should be divided by pad_size_divisorsize_divisor=32,# Additional paddings for pixel scaleextra_pad_ratio=0.5)# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 0.33
# The scaling factor that controls the width of the network structure
widen_factor = 0.5
# Strides of multi-scale prior box
strides = [8, 16, 32]
num_det_layers = 3  # The number of model output scales
norm_cfg = dict(type='BN', momentum=0.03, eps=0.001)  # Normalization config# -----train val related-----
affine_scale = 0.5  # YOLOv5RandomAffine scaling ratio
loss_cls_weight = 0.5
loss_bbox_weight = 0.05
loss_obj_weight = 1.0
prior_match_thr = 4.  # Priori box matching threshold
# The obj loss weights of the three output layers
obj_level_weights = [4., 1., 0.4]
lr_factor = 0.01  # Learning rate scaling factor
weight_decay = 0.0005
# Save model checkpoint and validation intervals
save_checkpoint_intervals = 10
# The maximum checkpoints to keep.
max_keep_ckpts = 3
# Single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)# ===============================Unmodified in most cases====================
model = dict(type='YOLODetector',data_preprocessor=dict(type='mmdet.DetDataPreprocessor',mean=[0., 0., 0.],std=[255., 255., 255.],bgr_to_rgb=True),backbone=dict(##使用YOLOv8的主干网络type='YOLOv8CSPDarknet',deepen_factor=deepen_factor,widen_factor=widen_factor,norm_cfg=norm_cfg,act_cfg=dict(type='SiLU', inplace=True)),neck=dict(type='YOLOv5PAFPN',deepen_factor=deepen_factor,widen_factor=widen_factor,in_channels=[256, 512, 1024],out_channels=[256, 512, 1024],num_csp_blocks=3,norm_cfg=norm_cfg,act_cfg=dict(type='SiLU', inplace=True)),bbox_head=dict(type='YOLOv5Head',head_module=dict(type='YOLOv5HeadModule',num_classes=num_classes,in_channels=[256, 512, 1024],widen_factor=widen_factor,featmap_strides=strides,num_base_priors=3),prior_generator=dict(type='mmdet.YOLOAnchorGenerator',base_sizes=anchors,strides=strides),# scaled based on number of detection layersloss_cls=dict(type='mmdet.CrossEntropyLoss',use_sigmoid=True,reduction='mean',loss_weight=loss_cls_weight *(num_classes / 80 * 3 / num_det_layers)),# 修改此处实现IoU损失函数的替换loss_bbox=dict(type='IoULoss',iou_mode='eiou',bbox_format='xywh',eps=1e-7,reduction='mean',loss_weight=loss_bbox_weight * (3 / num_det_layers),return_iou=True),loss_obj=dict(type='mmdet.CrossEntropyLoss',use_sigmoid=True,reduction='mean',loss_weight=loss_obj_weight *((img_scale[0] / 640)**2 * 3 / num_det_layers)),prior_match_thr=prior_match_thr,obj_level_weights=obj_level_weights),test_cfg=model_test_cfg)albu_train_transforms = [dict(type='Blur', p=0.01),dict(type='MedianBlur', p=0.01),dict(type='ToGray', p=0.01),dict(type='CLAHE', p=0.01)
]pre_transform = [dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),dict(type='LoadAnnotations', with_bbox=True)
]train_pipeline = [*pre_transform,dict(type='Mosaic',img_scale=img_scale,pad_val=114.0,pre_transform=pre_transform),dict(type='YOLOv5RandomAffine',max_rotate_degree=0.0,max_shear_degree=0.0,scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),# img_scale is (width, height)border=(-img_scale[0] // 2, -img_scale[1] // 2),border_val=(114, 114, 114)),dict(type='mmdet.Albu',transforms=albu_train_transforms,bbox_params=dict(type='BboxParams',format='pascal_voc',label_fields=['gt_bboxes_labels', 'gt_ignore_flags']),keymap={'img': 'image','gt_bboxes': 'bboxes'}),dict(type='YOLOv5HSVRandomAug'),dict(type='mmdet.RandomFlip', prob=0.5),dict(type='mmdet.PackDetInputs',meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip','flip_direction'))
]train_dataloader = dict(batch_size=train_batch_size_per_gpu,num_workers=train_num_workers,persistent_workers=persistent_workers,pin_memory=True,sampler=dict(type='DefaultSampler', shuffle=True),dataset=dict(type=dataset_type,data_root=data_root,ann_file=train_ann_file,data_prefix=dict(img=train_data_prefix),filter_cfg=dict(filter_empty_gt=False, min_size=32),pipeline=train_pipeline))test_pipeline = [dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),dict(type='YOLOv5KeepRatioResize', scale=img_scale),dict(type='LetterResize',scale=img_scale,allow_scale_up=False,pad_val=dict(img=114)),dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),dict(type='mmdet.PackDetInputs',meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape','scale_factor', 'pad_param'))
]val_dataloader = dict(batch_size=val_batch_size_per_gpu,num_workers=val_num_workers,persistent_workers=persistent_workers,pin_memory=True,drop_last=False,sampler=dict(type='DefaultSampler', shuffle=False),dataset=dict(type=dataset_type,data_root=data_root,test_mode=True,data_prefix=dict(img=val_data_prefix),ann_file=val_ann_file,pipeline=test_pipeline,batch_shapes_cfg=batch_shapes_cfg))test_dataloader = val_dataloaderparam_scheduler = None
optim_wrapper = dict(type='OptimWrapper',optimizer=dict(type='SGD',lr=base_lr,momentum=0.937,weight_decay=weight_decay,nesterov=True,batch_size_per_gpu=train_batch_size_per_gpu),constructor='YOLOv5OptimizerConstructor')default_hooks = dict(param_scheduler=dict(type='YOLOv5ParamSchedulerHook',scheduler_type='linear',lr_factor=lr_factor,max_epochs=max_epochs),checkpoint=dict(type='CheckpointHook',interval=save_checkpoint_intervals,save_best='auto',max_keep_ckpts=max_keep_ckpts))custom_hooks = [dict(type='EMAHook',ema_type='ExpMomentumEMA',momentum=0.0001,update_buffers=True,strict_load=False,priority=49)
]val_evaluator = dict(type='mmdet.CocoMetric',proposal_nums=(100, 1, 10),ann_file=data_root + val_ann_file,metric='bbox')
test_evaluator = val_evaluatortrain_cfg = dict(type='EpochBasedTrainLoop',max_epochs=max_epochs,val_interval=save_checkpoint_intervals)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/5602.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[python趣味实战]----基于python代码实现浪漫爱心 დ

正文 01-效果演示 下图是代码运行之后的爱心显示结果: 下面的视频该爱心是动态效果,较为简洁,如果需要使用,可以进行完善,这里只是一个趣味实战,下面将对代码实现进行非常详细地描述: 浪漫爱心…

Java数据结构-模拟实现ArrayList

MyArrayList顺序结构: 接口和MyArrayList重写接口 接口 接口中的方法是很多类通用的,所以可以写到接口中 public interface IList {public void add(int data) ;// 在 pos 位置新增元素public void add(int pos, int data);// 判定是否包含某个元素p…

踏上R语言之旅:解锁数据世界的神秘密码(三)

多元相关与回归分析及R使用 文章目录 多元相关与回归分析及R使用一.变量间的关系分析1.两变量线性相关系数的计算2.相关系数的假设检验 二.一元线性回归分析的R计算三、回归系数的假设检验总结 一.变量间的关系分析 变量间的关系及分析方法如下: 1.两变量线性相关…

LeetCode 727. 菱形

输入一个奇数 n n n,输出一个由 * 构成的 n n n阶实心菱形。 输入格式 一个奇数 n n n。 输出格式 输出一个由 * 构成的 n n n阶实心菱形。 具体格式参照输出样例。 数据范围 1 ≤ n ≤ 99 1≤n≤99 1≤n≤99 输入样例: 5输出样例: *…

JAVA的多态

在Java中,多态(Polymorphism)是面向对象编程的三大特性之一,它允许一个引用变量在运行时引用不同类的对象,并根据实际对象的类型来执行对应的方法。多态的存在增加了代码的灵活性和可扩展性。 多态的实现通常依赖于以下…

一文掌握python上下文管理器(with语句)

目录 一、上下文管理协议 二、with 语句 三、自定义上下文管理器 四、生成器上下文管理器 五、几个常用例子 1、自动关闭网络连接 2、临时更改目录 3、数据库事务管理 4、计时器上下文管理器 5、日志记录上下文管理器 6、资源锁定上下文管理器 7、临时修改环境变量…

windows远程访问树莓派ubuntu22.04 桌面 - NoMachine

通过nomachine 实现 windows 安装 nomachine 下载:链接:https://pan.baidu.com/s/10rGBREs-AnwRz7D7QbLQ1A?pwd8651 提取码:8651 安装:下一步 下一步 使用: 下一步 下一步 ubuntu 安装 nomachine服务 下载&#…

Java基础知识总结(81)

JUC容器 JUC基于非阻塞算法(Lock Free 无锁编程)提供了一组高并发的List、Set、Queue、Map容器。 JUC高并发容器是基于非阻塞算法实现的容器类,无锁编程算法主要通过CAS(Compare And Swap)volatile的组合实现&#x…

【C++程序员的自我修炼】string 库中常见的用法 (一)

唤起一天明月照我满怀冰雪浩荡百川流鲸饮未吞海 剑气已横秋 目录 string 库的简介 string 的一些小操作 构造函数的使用 拷贝构造的常规使用 指定拷贝内容的拷贝构造 拷贝字符串开始的前 n 个字符 用 n 个字符初始化 计算字符串的长度 string 的三种遍历方式 常规的for循环 op…

利用大型语言模型提升数字产品创新:提示,微调,检索增强生成和代理的应用

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

Linux基础part-6

一、Shell编程理论和运用 1、程序的编程风格和执行模式 编程风格(Programming Style) 过程式编程:以指令为中心,来进行写程序,数据服务于指令。(bash shell) C 以指令为中心,程序的逻辑由一系…

「笔试刷题」:字母收集

一、题目 描述 有一个 𝑛∗𝑚 的矩形方阵,每个格子上面写了一个小写字母。 小红站在矩形的左上角,她每次可以向右或者向下走,走到某个格子上就可以收集这个格子的字母。 小红非常喜欢 "love" 这四个字母。…

FFmpeg开发笔记(二十三)使用OBS Studio开启RTMP直播推流

OBS是一个开源的直播录制软件,英文全称叫做Open Broadcaster Software,广泛用于视频录制、实时直播等领域。OBS不但开源,而且跨平台,兼容Windows、Mac OS、Linux等操作系统。 OBS的官网是https://obsproject.com/,录制…

【报错处理】ib_write_bw执行遇到Couldn‘t listen to port 18515原因与解决办法?

要点 要点: ib默认使用18515端口 相关命令: netstat -tuln | grep 18515 ib_write_bw --help |grep port# server ib_write_bw --ib-devmlx5_1 --port88990 # client ib_write_bw --ib-devmlx5_0 1.1.1.1 --port88990现象: 根因&#xff…

为什么公共事业机构会偏爱 TiDB :TiDB 数据库在某省妇幼健康管理系统的应用

本文介绍了某省妇幼健康管理系统的建设和数据库架构优化的过程。原有的数据库架构使用了 StarRocks 作为分析层,但随着业务的发展,这套架构暴露出诸多痛点,不再适应妇幼业务的需求。为解决这些问题,该系统选择了将原有架构中的 St…

OBSERVER(观察者)-- 对象行为模式

意图: 定义对象间地一种一对多地依赖关系,当一个对象地状态发生改变时,所有对于依赖于它的对象都得到通知并被自动更新。 别名: 依赖(Dependents), 发布-订阅(Publish-Subsribe) 动机: 将一个系统分割成一系列相互协…

使用Python及R语言绘制简易数据分析报告

Pytohn实现 在python中有很多包可以实现绘制数据分析报告的功能,推荐两个较为方便的包:pandas-profiling 和 sweetviz 。 使用 pandas-profiling 包(功能全面) 这个包的个别依赖包与机器学习的 sklearn 包的依赖包存在版本冲突&a…

【C++中的模板】

和你有关,观后无感................................................................................................................. 目录 前言 一、【模板的引入和介绍】 1、泛型编程 2、【模板的介绍】 二、【 函数模板】 2.1【模函数板的介绍】 1.…

修改word文件的创作者方法有哪些?如何修改文档的作者 这两个方法你一定要知道

在数字化时代,文件创作者的信息往往嵌入在文件的元数据中,这些元数据包括创作者的姓名、创建日期以及其他相关信息。然而,有时候我们可能需要修改这些创作者信息,出于隐私保护、版权调整或者其他实际需求。那么,有没有…

【开源设计】京东慢SQL组件:sql-analysis

京东慢SQL组件:sql-analysis 一、背景二、源码简析三、总结 地址:https://github.com/jd-opensource/sql-analysis 一、背景 开发中,无疑会遇到慢SQL问题,而常见的处理思路都是等上线,然后由监控报警之后再去定位对应…