【目标检测】YOLOv5算法实现(八):模型验证

  本系列文章记录本人硕士阶段YOLO系列目标检测算法自学及其代码实现的过程。其中算法具体实现借鉴于ultralytics YOLO源码Github,删减了源码中部分内容,满足个人科研需求。
  本系列文章主要以YOLOv5为例完成算法的实现,后续修改、增加相关模块即可实现其他版本的YOLO算法。

文章地址:
YOLOv5算法实现(一):算法框架概述
YOLOv5算法实现(二):模型加载
YOLOv5算法实现(三):数据集加载
YOLOv5算法实现(四):损失计算
YOLOv5算法实现(五):预测结果后处理
YOLOv5算法实现(六):评价指标及实现
YOLOv5算法实现(七):模型训练
YOLOv5算法实现(八):模型验证
YOLOv5算法实现(九):模型预测(编辑中…)

本文目录

  • 1 引言
  • 2 模型验证(validation.py)

1 引言

  本篇文章综合之前文章中的功能,实现模型的验证。模型验证的逻辑如图1所示。
在这里插入图片描述

图1 模型验证流程

2 模型验证(validation.py)

def validation(parser_data):device = torch.device(parser_data.device if torch.cuda.is_available() else "cpu")print("Using {} device validation.".format(device.type))# read class_indictlabel_json_path = './data/object.json'assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)with open(label_json_path, 'r') as f:class_dict = json.load(f)category_index = {v: k for k, v in class_dict.items()}data_dict = parse_data_cfg(parser_data.data)test_path = data_dict["valid"]# 注意这里的collate_fn是自定义的,因为读取的数据包括image和targets,不能直接使用默认的方法合成batchbatch_size = parser_data.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using %g dataloader workers' % nw)# load validation data setval_dataset = LoadImagesAndLabels(test_path, parser_data.img_size, batch_size,hyp=parser_data.hyp,rect=False)  # 将每个batch的图像调整到合适大小,可减少运算量(并不是512x512标准尺寸)val_dataset_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=True,num_workers=nw,pin_memory=True,collate_fn=val_dataset.collate_fn)# create modelmodel = Model(parser_data.cfg, ch=3, nc=parser_data.nc)weights_dict = torch.load(parser_data.weights, map_location='cpu')weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dictmodel.load_state_dict(weights_dict, strict=False)model.to(device)# evaluate on the test dataset# 计算PR曲线和APstats = []iouv = torch.linspace(0.5, 0.95, 10, device=device)  # iou vector for mAP@0.5:0.95niou = iouv.numel()# 混淆矩阵confusion_matrix = ConfusionMatrix(nc=3, conf=0.6)model.eval()with torch.no_grad():for imgs, targets, paths, shapes, img_index in tqdm(val_dataset_loader, desc="validation..."):imgs = imgs.to(device).float() / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0nb, _, height, width = imgs.shape  # batch size, channels, height, widthtargets = targets.to(device)preds = model(imgs)[0]  # only get inference resultpreds = non_max_suppression(preds, conf_thres=0.3, iou_thres=0.6, multi_label=False)targets[:, 2:] *= torch.tensor((width, height, width, height), device=device)outputs = []for si, pred in enumerate(preds):'''labels: [clas, x, y, w, h] (训练图像上绝对坐标)pred: [x,y,x,y,obj,cls] (训练图像上绝对坐标)predn: [x,y,x,y,obj,cls] (输入图像上绝对坐标)labels: [x,y,x,y,class] (输入图像上绝对坐标)shapes[si][0]: 输入图像大小shapes[si][1]'''labels = targets[targets[:, 0] == si, 1:]  # 当前图片的标签信息nl = labels.shape[0]  # number of labels # 当前图片标签数量if pred is None:npr = 0else:npr = pred.shape[0]  # 预测结果数量correct = torch.zeros(npr, niou, dtype=torch.bool, device=device)  # 判断在不同IoU下预测是否预测正确path, shape = Path(paths[si]), shapes[si][0]  # 当前图片shape(原图大小)if npr == 0:  # 若没有预测结果if nl:  # 没有预测结果但有实际目标# 不同IoU阈值下预测准确率,目标类别置信度,预测类别,实际类别stats.append((correct, *torch.zeros((2, 0), device=device), labels[:, 0]))# 混淆矩阵计算(类别信息)confusion_matrix.process_batch(detections=None, labels=labels[:, 0])continuepredn = pred.clone()scale_boxes(imgs[si].shape[1:], predn[:, :4], shape, shapes[si][1])  # native-space predif nl:  # 有预测结果且有实际目标tbox = xywh2xyxy(labels[:, 1:5])  # target boxesscale_boxes(imgs[si].shape[1:], tbox, shape, shapes[si][1])  # native-space labelslabelsn = torch.cat((labels[:, 0:1], tbox), 1)  # native-space labelscorrect = process_batch(predn, labelsn, iouv)confusion_matrix.process_batch(predn, labelsn)stats.append((correct, pred[:, 4], pred[:, 5], labels[:, 0]))  # 预测结果在不同IoU是否预测正确, 预测置信度, 预测类别, 实际类别confusion_matrix.plot(save_dir=parser_data.save_path, names=["normal", 'defect', 'leakage'])# 图片:预测结果在不同IoU下预测结果,预测置信度,预测类别,实际类别stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)]  # to numpyif len(stats) and stats[0].any():tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, names=["normal", 'defect', 'leakage'])ap50, ap = ap[:, 0], ap.mean(1)  # AP@0.5, AP@0.5:0.95mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()print(map50)if __name__ == "__main__":import argparseparser = argparse.ArgumentParser(description=__doc__)# 使用设备类型parser.add_argument('--device', default='cuda', help='device')# 检测目标类别数parser.add_argument('--nc', type=int, default=3, help='number of classes')file = 'yolov5s'cfg = f'cfg/models/{file}.yaml'parser.add_argument('--cfg', type=str, default=cfg, help="*.cfg path")parser.add_argument('--data', type=str, default='data/my_data.data', help='*.data path')parser.add_argument('--hyp', type=str, default='cfg/hyps/hyp.scratch-med.yaml', help='hyperparameters path')parser.add_argument('--img-size', type=int, default=640, help='test size')# 训练好的权重文件weight_1 = f'./weights/{file}/{file}' + '-best_map.pt'weight_2 = f'./weights/{file}/{file}' + '.pt'weight = weight_1 if os.path.exists(weight_1) else weight_2parser.add_argument('--weights', default=weight, type=str, help='training weights')parser.add_argument('--save_path', default=f'results/{file}', type=str, help='result save path')# batch sizeparser.add_argument('--batch_size', default=2, type=int, metavar='N',help='batch size when validation.')args = parser.parse_args()validation(args)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/618302.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JavaWeb,CSS的学习

CSS,层叠样式表(Cascading Style Sheets),能够对网页中元素位置的排版进行像素级精确控制,支持几乎所有的字体字号样式,拥有网页对象和模型样式编辑的能力,简单来说,美化页面。 CSS…

c++临时对象的探讨及相关性能提升

产生临时对象的情况 我们定义一个类进行测试 class tempVal { public:int v1, v2;tempVal(int v1 0, int v2 0);tempVal(const tempVal& t) :v1(t.v1), v2(t.v2) {cout << "调用拷贝构造函数" << endl;}virtual ~tempVal() {cout << "…

【python】——turtle动态画

&#x1f383;个人专栏&#xff1a; &#x1f42c; 算法设计与分析&#xff1a;算法设计与分析_IT闫的博客-CSDN博客 &#x1f433;Java基础&#xff1a;Java基础_IT闫的博客-CSDN博客 &#x1f40b;c语言&#xff1a;c语言_IT闫的博客-CSDN博客 &#x1f41f;MySQL&#xff1a…

AR HUD全面「上新」

AR HUD赛道正在迎来新的时代。 上周&#xff0c;蔚来ET9正式发布亮相&#xff0c;新车定位为D级行政旗舰轿车&#xff0c;其中&#xff0c;在智能座舱交互层面&#xff0c;继理想L系列、长安深蓝S7之后&#xff0c;也首次取消仪表盘&#xff0c;取而代之的是业内首个全焦段AR H…

分块矩阵的定义、计算

目录 一、定义 二、分块矩阵的加减乘法 三、考点 一、定义 分块&#xff0c;顾名思义&#xff0c;将整个矩阵分成几部分&#xff0c;如下图所示 二、分块矩阵的加减乘法 三、考点 分块矩阵的考点不多&#xff0c;一般来说&#xff0c;有一种&#xff1a; 求分块矩阵的转置…

PHP如何拆分中文名字(包括少数民族名字)

/*** param string|null $name* return array|null*/ function splitName($name) {if (empty($name) || empty(trim($name))) {return null;}//该正则是用来提取$name参数里面的中文字符的。preg_match_all(/[\x{4e00}-\x{9fff}]/u, $name, $matchers);$matchersCount isset($…

2024年,谷歌云首席技术官眼中的生成AI三大支柱,来看看有啥新花样

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

App在线封装的革命性创新

随着移动互联网的蓬勃发展&#xff0c;App已经成为我们日常生活中不可或缺的一部分。从购物、交通、社交到娱乐&#xff0c;几乎每个人的智能手机都装载着数十个应用程序&#xff0c;以满足各式各样的需求。然而&#xff0c;对于许多非技术背景的企业家和小型企业而言&#xff…

java---final以及抽象类

final - 修饰变量&#xff1a;变量不能被改变 //演示final修饰变量class Aoo{final int num 5;void show(){//num 55; //编译错误&#xff0c;final的变量不能被改变}} - 修饰方法&#xff1a;方法不能被重写 //演示final修饰方法class Boo{final void show(){}}class Coo ex…

Spring整理-Spring Bean的作用域

在Spring框架中,Bean的作用域定义了Bean实例的生命周期和可见性。Spring提供了多种作用域选项,适用于不同的应用需求。 Spring中的主要Bean作用域 Singleton:默认的作用域。在Spring IoC容器中,对于每个Spring Bean配置,只创建一个实例。适用于无状态的服务,如配置、工具…

【机器学习】模型参数优化工具:Optuna使用分步指南(附XGB/LGBM调优代码)

常用的调参方式和工具包 常用的调参方式包括网格搜索(Grid Search)、**随机搜索(Random Search)和贝叶斯优化(Bayesian Optimization)**等。 工具包方面&#xff0c;Scikit-learn提供了GridSearchCV和RandomizedSearchCV等用于网格搜索和随机搜索的工具。另外&#xff0c;有一…

VS报错:error:LNK2005 _main 已经在 *.obj 中定义

应该是重定义了&#xff0c;但是又解决不了&#xff0c;看似又没有重定义啊&#xff0c;就在一个文件定义了啊&#xff1f;怎么会出现这种情况呢&#xff1f;关键是&#xff0c;编译报错&#xff0c;程序运行不了了。 这里提一下我的前期操作&#xff0c;是因为将一个头文件和…

云原生 微服务 restapi devops相关的一些概念说明(持续更新中)

云原生&#xff1a; 定义 云原生是一种构建和运行应用程序的方法&#xff0c;是一套技术体系和方法论。它是一种在云计算环境中构建、部署和管理现代应用程序的软件方法。云原生应用程序是基于微服务架构的&#xff0c;采用开源堆栈&#xff08;K8SDocker&#xff09;进行容器…

NULL是什么?

NULL是一个编程术语&#xff0c;通常用于表示一个空值或无效值。在很多编程语言中&#xff0c;NULL用于表示一个变量或指针不引用任何有效的对象或内存位置。 NULL可以看作是一个特殊的值&#xff0c;表示缺少有效的数据或引用。当一个变量被赋予NULL值时&#xff0c;它表示该变…

10年Java面试总结:Java程序员面试必备的面试技巧

作为一名资深10年Java技术专家&#xff0c;我参与了无数次的面试&#xff0c;无论是作为面试者还是面试官。在这里&#xff0c;我将分享我的一些面试经历和面试技巧&#xff0c;希望能帮助即将面临面试的Java程序员们。 本文已收录于&#xff0c;我的技术网站 ddkk.com&#x…

柳氏新论:慈不掌兵的两层含义

前几天在一个如何理解慈不掌兵的回答中&#xff0c;我提出了这句话实际上有两层含义。这个应该是我第一个提出的。所以单独摘录出来。 第一层含义&#xff0c;不能怕士兵伤亡 这一层&#xff0c;所有人都能理解。比如你是个连长&#xff0c;正在防守阵地&#xff0c;排长过来报…

CMake_02_如何编译可调试文件

软件开发过程中&#xff0c;调试是必不可少的环节之一&#xff0c;让可执行文件”明牌“执行&#xff0c;不会漏过每一行代码&#xff0c;每一个变量的信息。从而帮助开发者快速定位到问题点。 先看下没有调试信息的可执行文件是什么样子&#xff1f; rootlocalhost:~/testWo…

【面试宝典】图解ARP协议、TCP协议、UDP协议

一、ARP协议 二、TCP协议 三、UDP协议 四、TCP和UDP的区别

Linux Git打包部署JAVA项目 shell脚本

my-test-8080.jar.sh 脚本 #!/bin/bashBASE_PATH"/root/local"GIT_BASE_PATH"/root/local/publish/my-java-study"SCRIPT_NAME$(basename "$0")JAR_NAME"${SCRIPT_NAME%.sh}"BRANCH_NAME"dev"GIT_URL"gitgitee.com:xx…

LeetCode 36. 有效的数独

有效的数独 请你判断一个 9 x 9 的数独是否有效。只需要 根据以下规则 &#xff0c;验证已经填入的数字是否有效即可。 数字 1-9 在每一行只能出现一次。 数字 1-9 在每一列只能出现一次。 数字 1-9 在每一个以粗实线分隔的 3x3 宫内只能出现一次。 一次遍历法 有效数独的三个…