YOLOv9训练损失、精度、mAP绘图功能 | 支持多结果对比,多结果绘在一个图片(消融实验、科研必备)

一、本文介绍

本文给大家带来的是YOLOv9系列的绘图功能,我将向大家介绍YOLO系列的绘图功能。我们在进行实验时,经常需要比较多个结果,针对这一问题,我写了点代码来解决这个问题,它可以根据训练结果绘制损失(loss)和mAP(平均精度均值)的对比图。这个工具不仅支持多个文件的对比分析,还允许大家在现有代码的基础上进行修,从而达到数据可视化的功能,大家也可以将对比图放在论文中进行对比也是非常不错的选择。

先展示一下效果图-> 

专栏地址:YOLOv9有效涨点专栏-持续复现各种顶会内容-有效涨点-全网改进最全的专栏 

损失对比图片->

目录

一、本文介绍

二、绘图工具核心代码 

三、使用讲解 

四、本文总结


二、绘图工具核心代码 

import os
import pandas as pd
import matplotlib.pyplot as pltdef plot_metrics_and_loss(experiment_names, metrics_info, loss_info, metrics_subplot_layout, loss_subplot_layout,metrics_figure_size=(15, 10), loss_figure_size=(15, 10), base_directory='runs/train'):# Plot metricsplt.figure(figsize=metrics_figure_size)for i, (metric_name, title) in enumerate(metrics_info):plt.subplot(*metrics_subplot_layout, i + 1)for name in experiment_names:file_path = os.path.join(base_directory, name, 'results.csv')data = pd.read_csv(file_path)column_name = [col for col in data.columns if col.strip() == metric_name][0]plt.plot(data[column_name], label=name)plt.xlabel('Epoch')plt.title(title)plt.legend()plt.tight_layout()metrics_filename = 'metrics_curves.png'plt.savefig(metrics_filename)plt.show()# Plot lossplt.figure(figsize=loss_figure_size)for i, (loss_name, title) in enumerate(loss_info):plt.subplot(*loss_subplot_layout, i + 1)for name in experiment_names:file_path = os.path.join(base_directory, name, 'results.csv')data = pd.read_csv(file_path)column_name = [col for col in data.columns if col.strip() == loss_name][0]plt.plot(data[column_name], label=name)plt.xlabel('Epoch')plt.title(title)plt.legend()plt.tight_layout()loss_filename = 'loss_curves.png'plt.savefig(loss_filename)plt.show()return metrics_filename, loss_filename# Metrics to plot
metrics_info = [('metrics/precision', 'Precision'),('metrics/recall', 'Recall'),('metrics/mAP_0.5', 'mAP at IoU=0.5'),('metrics/mAP_0.5:0.95', 'mAP for IoU Range 0.5-0.95')
]# Loss to plot
loss_info = [('train/box_loss', 'Training Box Loss'),('train/cls_loss', 'Training Classification Loss'),('train/obj_loss', 'Training OBJ Loss'),('val/box_loss', 'Validation Box Loss'),('val/cls_loss', 'Validation Classification Loss'),('val/obj_loss', 'Validation obj Loss')
]# Plot the metrics and loss from multiple experiments
metrics_filename, loss_filename = plot_metrics_and_loss(experiment_names=['exp40', 'exp38'],metrics_info=metrics_info,loss_info=loss_info,metrics_subplot_layout=(2, 2),loss_subplot_layout=(2, 3)
)


三、使用讲解 

使用方式非常简单,我们首先创建一个文件,将核心代码粘贴进去,其中experiment_names这个参数就代表我们的每个训练结果的名字, 我们只需要修改这个即可,我这里就是五个结果进行对比,修改完成之后大家运行该文件即可。

五、热力图代码 

使用方式我会单独更一篇,这个热力图代码的进阶版,这里只是先放一下。 

import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
import torch, yaml, cv2, os, shutil
import numpy as np
np.random.seed(0)
import matplotlib.pyplot as plt
from tqdm import trange
from PIL import Image
from ultralytics.nn.tasks import DetectionModel as Model
from ultralytics.utils.torch_utils import intersect_dicts
from ultralytics.utils.ops import xywh2xyxy
from pytorch_grad_cam import GradCAMPlusPlus, GradCAM, XGradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradientsdef letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):# Resize and pad image while meeting stride-multiple constraintsshape = im.shape[:2]  # current shape [height, width]if isinstance(new_shape, int):new_shape = (new_shape, new_shape)# Scale ratio (new / old)r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])if not scaleup:  # only scale down, do not scale up (for better val mAP)r = min(r, 1.0)# Compute paddingratio = r, r  # width, height ratiosnew_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh paddingif auto:  # minimum rectangledw, dh = np.mod(dw, stride), np.mod(dh, stride)  # wh paddingelif scaleFill:  # stretchdw, dh = 0.0, 0.0new_unpad = (new_shape[1], new_shape[0])ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratiosdw /= 2  # divide padding into 2 sidesdh /= 2if shape[::-1] != new_unpad:  # resizeim = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))left, right = int(round(dw - 0.1)), int(round(dw + 0.1))im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add borderreturn im, ratio, (dw, dh)class yolov8_heatmap:def __init__(self, weight, cfg, device, method, layer, backward_type, conf_threshold, ratio):device = torch.device(device)ckpt = torch.load(weight)model_names = ckpt['model'].namescsd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32model = Model(cfg, ch=3, nc=len(model_names)).to(device)csd = intersect_dicts(csd, model.state_dict(), exclude=['anchor'])  # intersectmodel.load_state_dict(csd, strict=False)  # loadmodel.eval()print(f'Transferred {len(csd)}/{len(model.state_dict())} items')target_layers = [eval(layer)]method = eval(method)colors = np.random.uniform(0, 255, size=(len(model_names), 3)).astype(np.int)self.__dict__.update(locals())def post_process(self, result):logits_ = result[:, 4:]boxes_ = result[:, :4]sorted, indices = torch.sort(logits_.max(1)[0], descending=True)return torch.transpose(logits_[0], dim0=0, dim1=1)[indices[0]], torch.transpose(boxes_[0], dim0=0, dim1=1)[indices[0]], xywh2xyxy(torch.transpose(boxes_[0], dim0=0, dim1=1)[indices[0]]).cpu().detach().numpy()def draw_detections(self, box, color, name, img):xmin, ymin, xmax, ymax = list(map(int, list(box)))cv2.rectangle(img, (xmin, ymin), (xmax, ymax), tuple(int(x) for x in color), 2)cv2.putText(img, str(name), (xmin, ymin - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, tuple(int(x) for x in color), 2, lineType=cv2.LINE_AA)return imgdef __call__(self, img_path, save_path):# remove dir if existif os.path.exists(save_path):shutil.rmtree(save_path)# make dir if not existos.makedirs(save_path, exist_ok=True)# img processimg = cv2.imread(img_path)img = letterbox(img)[0]img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = np.float32(img) / 255.0tensor = torch.from_numpy(np.transpose(img, axes=[2, 0, 1])).unsqueeze(0).to(self.device)# init ActivationsAndGradientsgrads = ActivationsAndGradients(self.model, self.target_layers, reshape_transform=None)# get ActivationsAndResultresult = grads(tensor)activations = grads.activations[0].cpu().detach().numpy()# postprocess to yolo outputpost_result, pre_post_boxes, post_boxes = self.post_process(result[0])for i in trange(int(post_result.size(0) * self.ratio)):if float(post_result[i].max()) < self.conf_threshold:breakself.model.zero_grad()# get max probability for this predictionif self.backward_type == 'class' or self.backward_type == 'all':score = post_result[i].max()score.backward(retain_graph=True)if self.backward_type == 'box' or self.backward_type == 'all':for j in range(4):score = pre_post_boxes[i, j]score.backward(retain_graph=True)# process heatmapif self.backward_type == 'class':gradients = grads.gradients[0]elif self.backward_type == 'box':gradients = grads.gradients[0] + grads.gradients[1] + grads.gradients[2] + grads.gradients[3]else:gradients = grads.gradients[0] + grads.gradients[1] + grads.gradients[2] + grads.gradients[3] + grads.gradients[4]b, k, u, v = gradients.size()weights = self.method.get_cam_weights(self.method, None, None, None, activations, gradients.detach().numpy())weights = weights.reshape((b, k, 1, 1))saliency_map = np.sum(weights * activations, axis=1)saliency_map = np.squeeze(np.maximum(saliency_map, 0))saliency_map = cv2.resize(saliency_map, (tensor.size(3), tensor.size(2)))saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()if (saliency_map_max - saliency_map_min) == 0:continuesaliency_map = (saliency_map - saliency_map_min) / (saliency_map_max - saliency_map_min)# add heatmap and box to imagecam_image = show_cam_on_image(img.copy(), saliency_map, use_rgb=True)cam_image = self.draw_detections(post_boxes[i], self.colors[int(post_result[i, :].argmax())], f'{self.model_names[int(post_result[i, :].argmax())]} {float(post_result[i].max()):.2f}', cam_image)cam_image = Image.fromarray(cam_image)cam_image.save(f'{save_path}/{i}.png')def get_params():params = {'weight': 'yolov8n.pt','cfg': 'ultralytics/cfg/models/v8/yolov8n.yaml','device': 'cuda:0','method': 'GradCAM', # GradCAMPlusPlus, GradCAM, XGradCAM'layer': 'model.model[9]','backward_type': 'all', # class, box, all'conf_threshold': 0.6, # 0.6'ratio': 0.02 # 0.02-0.1}return paramsif __name__ == '__main__':model = yolov8_heatmap(**get_params())model(r'ultralytics/assets/bus.jpg', 'result')


四、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv9改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,目前本专栏免费阅读(暂时,大家尽早关注不迷路~),如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

专栏地址:YOLOv9有效涨点专栏-持续复现各种顶会内容-有效涨点-全网改进最全的专栏 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/2720.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java的八大基本数据类型和 println 的介绍

前言 如果你有C语言的基础&#xff0c;这部分内容就会很简单&#xff0c;但是会有所不同~~ 这是我将要提到的八大基本数据类型&#xff1a; 注意&#xff0c;Java的数据类型是有符号的&#xff01;&#xff01;&#xff01;和C语言不同&#xff0c;Java不存在无符号的数据。 整…

Day:动态规划 LeedCode 123.买卖股票的最佳时机III 188.买卖股票的最佳时机IV

123. 买卖股票的最佳时机 III 给定一个数组&#xff0c;它的第 i 个元素是一支给定的股票在第 i 天的价格。 设计一个算法来计算你所能获取的最大利润。你最多可以完成 两笔 交易。 注意&#xff1a;你不能同时参与多笔交易&#xff08;你必须在再次购买前出售掉之前的股票&a…

安全开发实战(2)---域名反查IP

目录 安全开发专栏 前言 域名与ip的关系 域名反查ip的作用 1.2.1 One 1.2.2 Two 1.2.3 批量监测 ​总结 安全开发专栏 安全开发实战http://t.csdnimg.cn/25N7H 这步是比较关键的一步,一般进行cdn监测后,获取到真实ip地址后,或是域名时,然后进行域名反查IP地址,进行进…

基于Springboot的职称评审管理系统

基于SpringbootVue的职称评审管理系统的设计与实现 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringbootMybatis工具&#xff1a;IDEA、Maven、Navicat 系统展示 用户登录 首页 评审条件 论坛信息 系统公告 后台登录页面 用户管理 评审员管理 省份…

再谈C语言——理解指针(四)

assert断⾔ assert.h 头⽂件定义了宏 assert() &#xff0c;⽤于在运⾏时确保程序符合指定条件&#xff0c;如果不符合&#xff0c;就报错终⽌运⾏。这个宏常常被称为“断⾔”。 assert(p ! NULL); 上⾯代码在程序运⾏到这⼀⾏语句时&#xff0c;验证变量 p 是否等于 NULL 。…

​LeetCode解法汇总2385. 感染二叉树需要的总时间

目录链接&#xff1a; 力扣编程题-解法汇总_分享记录-CSDN博客 GitHub同步刷题项目&#xff1a; https://github.com/September26/java-algorithms 原题链接&#xff1a;. - 力扣&#xff08;LeetCode&#xff09; 描述&#xff1a; 给你一棵二叉树的根节点 root &#xff0…

创建型设计模式

七大原则 1. 开闭原则&#xff08;Open-Closed Principle, OCP&#xff09; 详解&#xff1a;软件实体&#xff08;类、模块、函数等&#xff09;应该易于扩展&#xff0c;但是不易于修改。换句话说&#xff0c;当软件需求变化时&#xff0c;应该通过添加新代码来实现变化&am…

销冠必备:高效跟进客户的四个技巧

作为一名销售&#xff0c;高效而精准地跟进客户是取得成功的关键。今天&#xff0c;我将分享四个技巧&#xff0c;让你也能够高效的跟进客户。 1、善于发问 通过多询问客户&#xff0c;你可以更好地了解客户的需求和痛点。在与客户交流时&#xff0c;不要只是简单地回答问题&…

业务复习知识点Oracle查询

业务数据查询-1 单表查询 数据准备 自来水收费系统建表语句.sql 简单条件查询 精确查询 需求 &#xff1a;查询水表编号为 30408 的业主记录 查询语句 &#xff1a; select * from t_owners where watermeter 30408; 查询结果 &#xff1a; 模糊查询 需求 &#xff1a;查询业…

毕业设计注意事项(2024届更新中)

1.开题 根据学院发的开题报告模板完成&#xff0c;其中大纲部分可参考资料 2.毕设 根据资料中的毕设评价标准&#xff0c;对照工作量 3.论文 3.1 格式问题 非常重要&#xff0c;认真对比资料中我发的模板&#xff0c;格式有问题&#xff0c;答辩输一半&#xff01; 以word…

W801学习笔记十四:掌机系统——菜单——尝试打造自己的UI

未来将会有诸多应用&#xff0c;这些应用将通过菜单进行有序组织和管理。因此&#xff0c;我们需要率先打造好菜单。 LCD 驱动通常是直接写屏的&#xff0c;虽然速度较快&#xff0c;但用于界面制作则不太适宜。所以&#xff0c;最好能拥有一套 UI 框架。如前所述&#xff0c;…

【linux】编译器使用

目录 1. gcc &#xff0c;g 编译器使用 a. 有关gcc的指令&#xff08;g同理&#xff09; 2. .o 文件和库的链接方式 a. 链接方式 b. 动态库 和 静态库 优缺点对比 c. debug 版本 和 release 版本 1. gcc &#xff0c;g 编译器使用 a. 有关gcc的指令&#xff08;g同理&…

设计模式-创建型-抽象工厂模式-Abstract Factory

UML类图 工厂接口类 public interface ProductFactory {Phone phoneProduct();//生产手机Router routerProduct();//生产路由器 } 小米工厂实现类 public class XiaomiFactoryImpl implements ProductFactory {Overridepublic Phone phoneProduct() {return new XiaomiPhone…

Node.js -- fs模块

文章目录 1. 写入文件1.1 写入文件1.2 同步和异步1.3 文件追加写入1.4 流式写入1.5 文件写入的场景 2. 读取文件2.1 异步和同步读取2.2 读取文件应用场景2.3 流式读取2.4 fs 练习 -- 文件复制 3. 文件重命名和移动4. 文件删除5. 文件夹操作5.1 创建文件夹5.2 读取文件夹5.3 删除…

crossover和wine哪个好 wine和crossover有什么本质区别 苹果电脑运行Windows crossover24

CrossOver是Wine的延伸产品&#xff0c;CrossOver可以简单的理解为类虚拟机&#xff0c;那么wine是什么&#xff0c;许多小伙伴就可能有些一知半解。CrossOver和wine哪个好&#xff0c;wine和CrossOver有什么本质区别呢&#xff1f;下文将围绕着这两个问题展开。 一、CrossOve…

tcp inflight 守恒算法的几何解释

接上文&#xff1a;tcp inflight 守恒算法背后的哲学 在 tcp inflight 守恒算法正确性 中&#xff0c;E bw / srtt 的公平最优解是算出来的&#xff0c;如果自然可以用数学描述&#xff0c;那能算出来的东西反过来也一定能通过直感看出来&#xff0c;我倾向于用几何和力学描述…

力扣HOT100 - 199. 二叉树的右视图

解题思路&#xff1a; 相当于层序遍历&#xff0c;然后取每一层的最后一个节点。 class Solution {public List<Integer> rightSideView(TreeNode root) {if (root null) return new ArrayList<Integer>();Queue<TreeNode> queue new LinkedList<>…

Pushmall智能AI数字名片— —寻求商机合作的营销推广平台

Pushmall智能AI数字名片— —寻求商机合作的营销推广平台 开发计划 2024年2月开发计划&#xff1a; 1、优化名片注册、信息完善业务流程&#xff1b; 2、重构商机信息&#xff1a;供应信息、需求信息发布。 3、会员名片服务优化 4、企业名片&#xff1a;员工管理优化 5、CRM客…

【计算机网络】网络模型

OSI七层网络模型 七层模型如图所示 每层的概念和功能 物理层 职责&#xff1a;将数据以比特为单位&#xff0c;通过不同的传输介质将数据传输出去。 主要协议&#xff1a;物理媒介相关的协议&#xff0c;如RS232&#xff0c;V.35&#xff0c;以太网等。 数据链路层 职责&…

【WSL报错】执行:wsl --list --online;错误:0x80072ee7

【WSL报错】执行:wsl --list --online&#xff1b;错误:0x80072ee7 问题情况解决方法详细过程 问题情况 C:\Users\17569>wsl --list --online 错误: 0x80072ee7 解决方法 开系统代理&#xff0c;到外网即可修复&#xff01;&#xff01;&#xff01;&#xff01;&#x…