YOLOv5改进(七)--改进损失函数EIoU、Alpha-IoU、SIoU、Focal-EIOU

文章目录

  • 1、前言
  • 2、损失函数代码实现
    • 2.1、修改metrics.py
    • 2.2、修改loss.py
  • 3、替换EIOU
  • 4、替换SIoU
  • 5、替换Alpha-IoU
  • 6、替换Focal-EIOU
  • 7、目标检测系列文章

1、前言

YOLOv5默认使用损失函数为CIoU,本文主要针对损失函数进行修改,主要将bbox_iou函数进行修改,添加 EIoU、Alpha-IoU、SIoU、Focal-IOU等边界框回归损失。

2、损失函数代码实现

2.1、修改metrics.py

(1)首先找到utils/metrics.py文件,然后找到该python文件下的bbox_iou函数,其实在yolov5源码中设置是有GIoU, DIoU, CIoU这些边界框iou损失,但是默认值都为False

在这里插入图片描述

(2)将原始的bbox_iou函数代码注释掉,替换成如下代码,这段代码是将EIoU、Alpha-IoU、SIoU、Focal-EIOU这几个功能集中在一起,如果想要使用不同的Iou计算边界框损失,只需要修改utils/loss.py下的iou方法即可。

# 优化后的代码
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, SIoU=False, EIoU=False, Focal=False, alpha=1,gamma=0.5, eps=1e-7):# Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)# Get the coordinates of bounding boxesif xywh:  # transform from xywh to xyxy(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_else:  # x1, y1, x2, y2 = box1b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)# Intersection areainter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \(b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)# Union Areaunion = w1 * h1 + w2 * h2 - inter + eps# IoU# iou = inter / union # ori iouiou = torch.pow(inter / (union + eps), alpha)  # alpha iouif CIoU or DIoU or GIoU or EIoU or SIoU:cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) widthch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex heightif CIoU or DIoU or EIoU or SIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1c2 = (cw ** 2 + ch ** 2) ** alpha + eps  # convex diagonal squaredrho2 = (((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4) ** alpha  # center dist ** 2if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)with torch.no_grad():alpha_ciou = v / (v - iou + (1 + eps))if Focal:return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)), torch.pow(inter / (union + eps),gamma)  # Focal_CIoUelse:return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha))  # CIoUelif EIoU:rho_w2 = ((b2_x2 - b2_x1) - (b1_x2 - b1_x1)) ** 2rho_h2 = ((b2_y2 - b2_y1) - (b1_y2 - b1_y1)) ** 2cw2 = torch.pow(cw ** 2 + eps, alpha)ch2 = torch.pow(ch ** 2 + eps, alpha)if Focal:return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2), torch.pow(inter / (union + eps),gamma)  # Focal_EIouelse:return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2)  # EIouelif SIoU:# SIoU Loss https://arxiv.org/pdf/2205.12740.pdfs_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + epss_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + epssigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)sin_alpha_1 = torch.abs(s_cw) / sigmasin_alpha_2 = torch.abs(s_ch) / sigmathreshold = pow(2, 0.5) / 2sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)rho_x = (s_cw / cw) ** 2rho_y = (s_ch / ch) ** 2gamma = angle_cost - 2distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)if Focal:return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha), torch.pow(inter / (union + eps), gamma)  # Focal_SIouelse:return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha)  # SIouif Focal:return iou - rho2 / c2, torch.pow(inter / (union + eps), gamma)  # Focal_DIoUelse:return iou - rho2 / c2  # DIoUc_area = cw * ch + eps  # convex areaif Focal:return iou - torch.pow((c_area - union) / c_area + eps, alpha), torch.pow(inter / (union + eps),gamma)  # Focal_GIoU https://arxiv.org/pdf/1902.09630.pdfelse:return iou - torch.pow((c_area - union) / c_area + eps, alpha)  # GIoU https://arxiv.org/pdf/1902.09630.pdfif Focal:return iou, torch.pow(inter / (union + eps), gamma)  # Focal_IoUelse:return iou  # IoU

要点

  1. gamma参数: 是Focal EloU中的gamma参数,一般就是为0.5,有需要可以自行更改。
  2. alpha参数:为Alpha-IOU中的alpha参数,默认为1,即使用原始I0U。若需要使用Alpha-IOU,只需将其设置为任意值(论文中默认设置为3)

2.2、修改loss.py

找到utils/loss.py损失函数计算文件,修改ComputeLoss类下面的__call__函数,通过修改iou = bbox_iou(pbox, tbox[i],x1y1x2y2=False, CIoU=True)里面第4个参数实现不同的损失函数。

在这里插入图片描述

将红框内容替换成如下代码:

iou = bbox_iou(pbox, tbox[i], CIoU=True)  # iou(prediction, target)
if type(iou) is tuple:lbox += (iou[1].detach().squeeze() * (1 - iou[0].squeeze())).mean()iou = iou[0].squeeze()
else:lbox += (1.0 - iou.squeeze()).mean()  # iou lossiou = iou.squeeze()

在这里插入图片描述

3、替换EIOU

如果想要使用EIOU,只需要将CIoU替换成EIOU:

iou = bbox_iou(pbox, tbox[i], EIoU=True) 

4、替换SIoU

如果想要使用SIoU,只需要将CIoU替换成SIoU:

iou = bbox_iou(pbox, tbox[i], SIoU=True) 

5、替换Alpha-IoU

如果想要使用Alpha-IoU,只需要添加alpha=3这个参数项开启Alpha,如果不设置该参数,alpha默认为1:

iou = bbox_iou(pbox, tbox[i], CIoU=True, alpha=3) 

6、替换Focal-EIOU

Focal-EIOU相对于EIOU只多了一个Focal项,这两个iou损失都是出自同一篇论文,只需要设置Focal=True即可。

iou = bbox_iou(pbox, tbox[i], EIoU=True, Focal=True) 

当然Focal项也可以用于CIoU、SIoU,至于效果需要根据不同数据集进行测试,修改如下:

Focal-CIoU

iou = bbox_iou(pbox, tbox[i], CIOU=True, Focal=True) 

Focal-SIoU

iou = bbox_iou(pbox, tbox[i], SIOU=True, Focal=True) 

7、目标检测系列文章

  1. YOLOv5s网络模型讲解(一看就会)
  2. 生活垃圾数据集(YOLO版)
  3. YOLOv5如何训练自己的数据集
  4. 双向控制舵机(树莓派版)
  5. 树莓派部署YOLOv5目标检测(详细篇)
  6. YOLO_Tracking 实践 (环境搭建 & 案例测试)
  7. 目标检测:数据集划分 & XML数据集转YOLO标签
  8. DeepSort行人车辆识别系统(实现目标检测+跟踪+统计)
  9. YOLOv5参数大全(parse_opt篇)
  10. YOLOv5改进(一)-- 轻量化YOLOv5s模型
  11. YOLOv5改进(二)-- 目标检测优化点(添加小目标头检测)
  12. YOLOv5改进(三)-- 引进Focaler-IoU损失函数
  13. YOLOv5改进(四)–轻量化模型ShuffleNetv2
  14. YOLOv5改进(五)-- 轻量化模型MobileNetv3
  15. YOLOv5改进(六)–引入YOLOv8中C2F模块

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/34998.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

云渲染农场使用指南:如何以最低成本享受最快渲染速度?

​云渲染农场怎么低成本享受快速渲染? 云渲染农场利用其分布式计算能力,为视觉艺术家提供了一种经济高效的渲染选择。它特别适用于高质量的影视动画和视觉效果制作。下面一起来看看如何以最低的成本实现快速渲染的策略。 在追求成本效益的同时&#xff…

第一百二十七节 Java面向对象设计 - Java枚举方法

Java面向对象设计 - Java枚举方法 因为枚举类型实际上是一个类类型,所以我们可以在枚举类型体中声明一切,我们可以在类体中声明它。 以下代码使用字段,构造函数和方法定义了一个级别枚举。 public enum Level {LOW(30), MEDIUM(15), HIGH(7…

2024年好用的加密工具,迅软DSE加密系统原来这么强大

加密软件具有灵活的加密方式和用户友好的操作界面,可定制个性化的安全方案,同时支持数据备份和恢复功能,确保数据的完整性和可用性,是保护数据安全、维护商业机密、防范信息泄露的重要工具。 2024好用的加密工具是哪个&#xff1f…

云计算:未来科技的基石

目录 什么是云计算? 云计算的分类 1. 基础设施即服务 (IaaS) 2. 平台即服务 (PaaS) 3. 软件即服务 (SaaS) 云计算的优势 1. 成本效益 2. 灵活性和可扩展性 3. 高可用性和可靠性 4. 创新和快速迭代 云计算的应用场景 1. 数据存储和备份 2. 大数据分析 3…

MySQL——基本的Select语句和别名使用

DQL (Data Query Language:数据查询语言) 所有的查询操作都用它 Select 简单或者复杂的查询都能做 数据库中最核心的语言,最重要的语句 使用频率最高的语言 指定查询字段 -- 查询全部的学生 SELECT 字段 FROM 表…

【SpringCloud-Seata客户端源码分析01】

文章目录 启动seata客户端1.导入依赖2.自动装配 发送请求的核心方法客户端开启事务的核心流程服务端分布式事务的处理机制 启动seata客户端 1.导入依赖 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent…

【DS Solutions】一个反欺诈产品的进化,Stripe Radar

Stripe Radar 是 Stripe 提供的一项防欺诈服务&#xff0c;它利用机器学习技术来帮助商家检测和阻止信用卡欺诈行为。这篇文章是Stripe公司关于其反欺诈解决方案Stripe Radar的构建过程的介绍。文章从Stripe的防欺诈团队工程师的角度出发&#xff0c;详细讲述了Stripe Radar的工…

车辆数据的提取、定位和融合 精确车辆定位(其三.一 共十二篇)随机复合

第一篇&#xff1a; System Introduction 第二篇&#xff1a;State of the Art 第三篇&#xff1a;localization 第四篇&#xff1a;Submapping and temporal weighting 第五篇&#xff1a;Mapping of Point-shaped landmark data 第六篇&#xff1a;Clustering of landma…

礼让,不是一昧地退让,而是表达我们的素养、品德

礼 / 让&#xff0c;发心是文明相处&#xff0c;互助互让&#xff0c;是君子之交

覆盖容器的默认设置

覆盖容器都默认设置 目录 覆盖网络端口设置环境变量限制容器资源使用试一试 运行多个 Postgres 数据库实例在受控网络中运行 Postgres 容器管理资源在 Docker Compose 中覆盖默认的 CMD 和 ENTRYPOINT使用 docker run 覆盖默认的 CMD 和 ENTRYPOINT 额外资源下一步 当 Docker…

UnityShader SDF有向距离场简单实现

UnityShader SDF有向距离场简单实现 前言项目场景布置连连看画一个圆复制一个圆计算修改shader参数 鸣谢 前言 突然看到B站的一个教程&#xff0c;还不错&#xff0c;记录一下 项目 场景布置 使用ASE连连看&#xff0c;所以先要导入Amplify Shader Editor 连连看 画一个…

面试-JMM的内存可见性

1.JAVA内存模型 分析&#xff1a; 由于JVM运行程序的实体是线程&#xff0c;而每个线程创建时&#xff0c;JVM都会 为其创建一个工作内存(栈空间),用于存储线程私有的数据。而java内存模型中规定所有变量都存储在主内存中。主内存是共享内存区域&#xff0c;所有线程都可以访问…

Python-PDF文件密码破解小工具

背景 经常从网络上下载的PDF笔记被加了密&#xff0c;在自己学习的过程中想要添加书签却因为没有密码无法添加&#xff0c;所以通过Python实现一个解密小工具&#xff0c;亲测大多数密码都可以破解。 代码 import os import tkinter as tk from tkinter import filedialog #…

你还不知道Modbus RTU???

1. 什么是Modbus RTU Modbus RTU&#xff08;Remote Terminal Unit&#xff09;是Modbus通信协议的一种变种&#xff0c;用于串行通信。它是一种常见的工业控制系统通信协议&#xff0c;通常用于采集传感器数据、控制执行器和监控设备状态。Modbus RTU采用二进制编码&#xff0…

基于ruoyi-app的手机短信登录(uniapp)

本篇用于记录h5的框架搭建 组件地址:短信验证码登陆&#xff0c;手机号&#xff0c;验证码倒计时 - DCloud 插件市场 调整后的表单组件代码: <template><view class"login-view"><!-- <input type"tel" confirm-type"确认"…

073、类的三大特征初识

&#xff08;1&#xff09;继承 类之间可以通过继承建立父子关系&#xff0c;子类可以继承父类的属性和方法&#xff0c;并可以添加自己的特定属性和方法。如下是一个简单示例&#xff1a; class Student(Person):def __init__(self, name, age, student_id):super().__init_…

Follow Carl To Grow|【LeetCode】93.复原IP地址,78.子集,90.子集II

【LeetCode】93.复原IP地址 题意&#xff1a;有效 IP 地址 正好由四个整数&#xff08;每个整数位于 0 到 255 之间组成&#xff0c;且不能含有前导 0&#xff09;&#xff0c;整数之间用 ‘.’ 分隔。 例如&#xff1a;“0.1.2.201” 和 “192.168.1.1” 是 有效 IP 地址&…

【深度学习】实现基于MNIST数据集的TensorFlow/Keras深度学习案例

基于TensorFlow/Keras的深度学习案例 实现基于MNIST数据集的TensorFlow/Keras深度学习案例0. 什么是深度学习&#xff1f;1. TensorFlow简介2. Keras简介3. 安装TensorFlow前的注意事项4. 安装Anaconda3及搭建TensorFlow环境1&#xff09; 下载安装Anaconda Navigator2&#xf…

go语言day06 数组 切片

数组 : 定长且元素类型一致,在索引逻辑上连续存储,数组的内存地址是存储的第一个元素的内存地址 几种创建方式: 仅声明 var nums [ 3 ] int 声明并赋值 var nums [ 2 ] string {"武沛齐","alex"} 声明并按下标赋值 var nums [ 3 ] int {0:88,2:1 } 省略…

ffmpeg+nginx+video实现rtsp流转hls流,web页面播放

项目场景&#xff1a; 最近调试海康摄像头需要将rtsp流在html页面播放,因为不想去折腾推拉流&#xff0c;所以我选择ffmpeg转hls流&#xff0c;nginx转发&#xff0c;html直接访问就好了 1.首先要下载nginx和ffmpeg 附上下载地址&#xff1a; nginx nginx news ffmpeg htt…