YOLOv5改进(七)--改进损失函数EIoU、Alpha-IoU、SIoU、Focal-EIOU

文章目录

  • 1、前言
  • 2、损失函数代码实现
    • 2.1、修改metrics.py
    • 2.2、修改loss.py
  • 3、替换EIOU
  • 4、替换SIoU
  • 5、替换Alpha-IoU
  • 6、替换Focal-EIOU
  • 7、目标检测系列文章

1、前言

YOLOv5默认使用损失函数为CIoU,本文主要针对损失函数进行修改,主要将bbox_iou函数进行修改,添加 EIoU、Alpha-IoU、SIoU、Focal-IOU等边界框回归损失。

2、损失函数代码实现

2.1、修改metrics.py

(1)首先找到utils/metrics.py文件,然后找到该python文件下的bbox_iou函数,其实在yolov5源码中设置是有GIoU, DIoU, CIoU这些边界框iou损失,但是默认值都为False

在这里插入图片描述

(2)将原始的bbox_iou函数代码注释掉,替换成如下代码,这段代码是将EIoU、Alpha-IoU、SIoU、Focal-EIOU这几个功能集中在一起,如果想要使用不同的Iou计算边界框损失,只需要修改utils/loss.py下的iou方法即可。

# 优化后的代码
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, SIoU=False, EIoU=False, Focal=False, alpha=1,gamma=0.5, eps=1e-7):# Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)# Get the coordinates of bounding boxesif xywh:  # transform from xywh to xyxy(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_else:  # x1, y1, x2, y2 = box1b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)# Intersection areainter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \(b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)# Union Areaunion = w1 * h1 + w2 * h2 - inter + eps# IoU# iou = inter / union # ori iouiou = torch.pow(inter / (union + eps), alpha)  # alpha iouif CIoU or DIoU or GIoU or EIoU or SIoU:cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) widthch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex heightif CIoU or DIoU or EIoU or SIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1c2 = (cw ** 2 + ch ** 2) ** alpha + eps  # convex diagonal squaredrho2 = (((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4) ** alpha  # center dist ** 2if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)with torch.no_grad():alpha_ciou = v / (v - iou + (1 + eps))if Focal:return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)), torch.pow(inter / (union + eps),gamma)  # Focal_CIoUelse:return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha))  # CIoUelif EIoU:rho_w2 = ((b2_x2 - b2_x1) - (b1_x2 - b1_x1)) ** 2rho_h2 = ((b2_y2 - b2_y1) - (b1_y2 - b1_y1)) ** 2cw2 = torch.pow(cw ** 2 + eps, alpha)ch2 = torch.pow(ch ** 2 + eps, alpha)if Focal:return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2), torch.pow(inter / (union + eps),gamma)  # Focal_EIouelse:return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2)  # EIouelif SIoU:# SIoU Loss https://arxiv.org/pdf/2205.12740.pdfs_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + epss_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + epssigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)sin_alpha_1 = torch.abs(s_cw) / sigmasin_alpha_2 = torch.abs(s_ch) / sigmathreshold = pow(2, 0.5) / 2sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)rho_x = (s_cw / cw) ** 2rho_y = (s_ch / ch) ** 2gamma = angle_cost - 2distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)if Focal:return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha), torch.pow(inter / (union + eps), gamma)  # Focal_SIouelse:return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha)  # SIouif Focal:return iou - rho2 / c2, torch.pow(inter / (union + eps), gamma)  # Focal_DIoUelse:return iou - rho2 / c2  # DIoUc_area = cw * ch + eps  # convex areaif Focal:return iou - torch.pow((c_area - union) / c_area + eps, alpha), torch.pow(inter / (union + eps),gamma)  # Focal_GIoU https://arxiv.org/pdf/1902.09630.pdfelse:return iou - torch.pow((c_area - union) / c_area + eps, alpha)  # GIoU https://arxiv.org/pdf/1902.09630.pdfif Focal:return iou, torch.pow(inter / (union + eps), gamma)  # Focal_IoUelse:return iou  # IoU

要点

  1. gamma参数: 是Focal EloU中的gamma参数,一般就是为0.5,有需要可以自行更改。
  2. alpha参数:为Alpha-IOU中的alpha参数,默认为1,即使用原始I0U。若需要使用Alpha-IOU,只需将其设置为任意值(论文中默认设置为3)

2.2、修改loss.py

找到utils/loss.py损失函数计算文件,修改ComputeLoss类下面的__call__函数,通过修改iou = bbox_iou(pbox, tbox[i],x1y1x2y2=False, CIoU=True)里面第4个参数实现不同的损失函数。

在这里插入图片描述

将红框内容替换成如下代码:

iou = bbox_iou(pbox, tbox[i], CIoU=True)  # iou(prediction, target)
if type(iou) is tuple:lbox += (iou[1].detach().squeeze() * (1 - iou[0].squeeze())).mean()iou = iou[0].squeeze()
else:lbox += (1.0 - iou.squeeze()).mean()  # iou lossiou = iou.squeeze()

在这里插入图片描述

3、替换EIOU

如果想要使用EIOU,只需要将CIoU替换成EIOU:

iou = bbox_iou(pbox, tbox[i], EIoU=True) 

4、替换SIoU

如果想要使用SIoU,只需要将CIoU替换成SIoU:

iou = bbox_iou(pbox, tbox[i], SIoU=True) 

5、替换Alpha-IoU

如果想要使用Alpha-IoU,只需要添加alpha=3这个参数项开启Alpha,如果不设置该参数,alpha默认为1:

iou = bbox_iou(pbox, tbox[i], CIoU=True, alpha=3) 

6、替换Focal-EIOU

Focal-EIOU相对于EIOU只多了一个Focal项,这两个iou损失都是出自同一篇论文,只需要设置Focal=True即可。

iou = bbox_iou(pbox, tbox[i], EIoU=True, Focal=True) 

当然Focal项也可以用于CIoU、SIoU,至于效果需要根据不同数据集进行测试,修改如下:

Focal-CIoU

iou = bbox_iou(pbox, tbox[i], CIOU=True, Focal=True) 

Focal-SIoU

iou = bbox_iou(pbox, tbox[i], SIOU=True, Focal=True) 

7、目标检测系列文章

  1. YOLOv5s网络模型讲解(一看就会)
  2. 生活垃圾数据集(YOLO版)
  3. YOLOv5如何训练自己的数据集
  4. 双向控制舵机(树莓派版)
  5. 树莓派部署YOLOv5目标检测(详细篇)
  6. YOLO_Tracking 实践 (环境搭建 & 案例测试)
  7. 目标检测:数据集划分 & XML数据集转YOLO标签
  8. DeepSort行人车辆识别系统(实现目标检测+跟踪+统计)
  9. YOLOv5参数大全(parse_opt篇)
  10. YOLOv5改进(一)-- 轻量化YOLOv5s模型
  11. YOLOv5改进(二)-- 目标检测优化点(添加小目标头检测)
  12. YOLOv5改进(三)-- 引进Focaler-IoU损失函数
  13. YOLOv5改进(四)–轻量化模型ShuffleNetv2
  14. YOLOv5改进(五)-- 轻量化模型MobileNetv3
  15. YOLOv5改进(六)–引入YOLOv8中C2F模块

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/34998.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

云渲染农场使用指南:如何以最低成本享受最快渲染速度?

​云渲染农场怎么低成本享受快速渲染? 云渲染农场利用其分布式计算能力,为视觉艺术家提供了一种经济高效的渲染选择。它特别适用于高质量的影视动画和视觉效果制作。下面一起来看看如何以最低的成本实现快速渲染的策略。 在追求成本效益的同时&#xff…

第一百二十七节 Java面向对象设计 - Java枚举方法

Java面向对象设计 - Java枚举方法 因为枚举类型实际上是一个类类型,所以我们可以在枚举类型体中声明一切,我们可以在类体中声明它。 以下代码使用字段,构造函数和方法定义了一个级别枚举。 public enum Level {LOW(30), MEDIUM(15), HIGH(7…

2024年好用的加密工具,迅软DSE加密系统原来这么强大

加密软件具有灵活的加密方式和用户友好的操作界面,可定制个性化的安全方案,同时支持数据备份和恢复功能,确保数据的完整性和可用性,是保护数据安全、维护商业机密、防范信息泄露的重要工具。 2024好用的加密工具是哪个&#xff1f…

【SpringCloud-Seata客户端源码分析01】

文章目录 启动seata客户端1.导入依赖2.自动装配 发送请求的核心方法客户端开启事务的核心流程服务端分布式事务的处理机制 启动seata客户端 1.导入依赖 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent…

【DS Solutions】一个反欺诈产品的进化,Stripe Radar

Stripe Radar 是 Stripe 提供的一项防欺诈服务&#xff0c;它利用机器学习技术来帮助商家检测和阻止信用卡欺诈行为。这篇文章是Stripe公司关于其反欺诈解决方案Stripe Radar的构建过程的介绍。文章从Stripe的防欺诈团队工程师的角度出发&#xff0c;详细讲述了Stripe Radar的工…

车辆数据的提取、定位和融合 精确车辆定位(其三.一 共十二篇)随机复合

第一篇&#xff1a; System Introduction 第二篇&#xff1a;State of the Art 第三篇&#xff1a;localization 第四篇&#xff1a;Submapping and temporal weighting 第五篇&#xff1a;Mapping of Point-shaped landmark data 第六篇&#xff1a;Clustering of landma…

礼让,不是一昧地退让,而是表达我们的素养、品德

礼 / 让&#xff0c;发心是文明相处&#xff0c;互助互让&#xff0c;是君子之交

UnityShader SDF有向距离场简单实现

UnityShader SDF有向距离场简单实现 前言项目场景布置连连看画一个圆复制一个圆计算修改shader参数 鸣谢 前言 突然看到B站的一个教程&#xff0c;还不错&#xff0c;记录一下 项目 场景布置 使用ASE连连看&#xff0c;所以先要导入Amplify Shader Editor 连连看 画一个…

面试-JMM的内存可见性

1.JAVA内存模型 分析&#xff1a; 由于JVM运行程序的实体是线程&#xff0c;而每个线程创建时&#xff0c;JVM都会 为其创建一个工作内存(栈空间),用于存储线程私有的数据。而java内存模型中规定所有变量都存储在主内存中。主内存是共享内存区域&#xff0c;所有线程都可以访问…

Python-PDF文件密码破解小工具

背景 经常从网络上下载的PDF笔记被加了密&#xff0c;在自己学习的过程中想要添加书签却因为没有密码无法添加&#xff0c;所以通过Python实现一个解密小工具&#xff0c;亲测大多数密码都可以破解。 代码 import os import tkinter as tk from tkinter import filedialog #…

你还不知道Modbus RTU???

1. 什么是Modbus RTU Modbus RTU&#xff08;Remote Terminal Unit&#xff09;是Modbus通信协议的一种变种&#xff0c;用于串行通信。它是一种常见的工业控制系统通信协议&#xff0c;通常用于采集传感器数据、控制执行器和监控设备状态。Modbus RTU采用二进制编码&#xff0…

【深度学习】实现基于MNIST数据集的TensorFlow/Keras深度学习案例

基于TensorFlow/Keras的深度学习案例 实现基于MNIST数据集的TensorFlow/Keras深度学习案例0. 什么是深度学习&#xff1f;1. TensorFlow简介2. Keras简介3. 安装TensorFlow前的注意事项4. 安装Anaconda3及搭建TensorFlow环境1&#xff09; 下载安装Anaconda Navigator2&#xf…

go语言day06 数组 切片

数组 : 定长且元素类型一致,在索引逻辑上连续存储,数组的内存地址是存储的第一个元素的内存地址 几种创建方式: 仅声明 var nums [ 3 ] int 声明并赋值 var nums [ 2 ] string {"武沛齐","alex"} 声明并按下标赋值 var nums [ 3 ] int {0:88,2:1 } 省略…

ffmpeg+nginx+video实现rtsp流转hls流,web页面播放

项目场景&#xff1a; 最近调试海康摄像头需要将rtsp流在html页面播放,因为不想去折腾推拉流&#xff0c;所以我选择ffmpeg转hls流&#xff0c;nginx转发&#xff0c;html直接访问就好了 1.首先要下载nginx和ffmpeg 附上下载地址&#xff1a; nginx nginx news ffmpeg htt…

HttpServletRequest・getContentLeng・getContentType区别

getContentLength()&#xff1a; 获取客户端发送到服务器的HTTP请求主体内容的字节数&#xff08;长度&#xff09; 如果请求没有正文内容&#xff08;如GET&#xff09;&#xff0c;或者请求头中没有包含Content-Length字段&#xff0c;则该方法返回 -1 getContentType()&am…

eclipse中svn从分支合并到主干及冲突解决

1、将分支先commit&#xff0c;然后再update&#xff0c;然后再clean一下&#xff0c;将项目多余的target都清理掉。 2、将branches切换到trunk 3、工程上右键-》Team-》合并&#xff08;或Merge&#xff09; 4、默认选项&#xff0c;点击Next 5、有未提交的改动&#xff0c;…

文献阅读:通过双线性建模来破译神经元类型连接的遗传密码

文献介绍 文献题目 Deciphering the genetic code of neuronal type connectivity through bilinear modeling 研究团队 Mu Qiao&#xff08;美国加州理工学院&#xff09; 发表时间 2024-06-10 发表期刊 eLife 影响因子 7.7 DOI 10.7554/eLife.91532.3 摘要 了解不同神经元…

打造爆款秘籍:阿里巴巴国际站测评补单优势全攻略

在阿里巴巴国际站&#xff0c;买家复购率和其他销售指标是衡量产品市场潜力和销售成功与否的关键指标。当系统评估出产品具有巨大的市场潜力时&#xff0c;它会相应地增加对产品的流量支持&#xff1b;反之&#xff0c;如果潜力不足&#xff0c;产品的排名将会受到影响&#xf…

使用 Spring Boot 3.x 与图形学技术,添加电子印章防伪特征

使用 Spring Boot 3.x 与图形学技术,添加电子印章防伪特征 在电子办公和无纸化办公日益普及的今天,电子印章的使用越来越广泛。然而,如何确保电子印章的安全性和防伪能力成为了一个亟待解决的问题。本文将通过 Spring Boot 3.x 和图形学技术,深入探讨如何为电子印章添加防…

Redis-实战篇-什么是缓存-添加redis缓存

文章目录 1、什么是缓存2、添加商户缓存3、前端接口4、ShopController.java5、ShopServiceImpl.java6、RedisConstants.java7、查看Redis Desktop Manager 1、什么是缓存 缓存就是数据交换的缓冲区&#xff08;称为Cache&#xff09;&#xff0c;是存贮数据的临时地方&#xff…