【目标检测】模型验证:K-Fold 交叉验证

K-Fold 交叉验证

  • 1、引言
    • 1.1 K 折交叉验证概述
  • 2、配置
    • 2.1 数据集
    • 2.2 安装包
  • 3、 实战
    • 3.1 生成物体检测数据集的特征向量
    • 3.2 K 折数据集拆分
    • 3.3 保存记录
    • 3.4 使用 K 折数据分割训练YOLO
  • 4、总结

1、引言

我们将利用YOLO 检测格式和关键的Python 库(如 sklearn、pandas 和 PyYaml),完成必要的设置、生成特征向量的过程以及 K-Fold 数据集拆分的执行。

1.1 K 折交叉验证概述

无论你的项目涉及水果检测数据集还是自定义数据源,都可以使用 K 折交叉验证,
以提高项目的可靠性和稳健性。

书说简短,闲言少叙,咱进入正题
在这里插入图片描述

2、配置

2.1 数据集

该数据集共包含 8479 幅图像。
它包括 6 个类别标签,每个标签的实例总数如下:

类别计数
苹果7049
葡萄7202
菠萝1613
橙色15549
香蕉3536
西瓜1976

2.2 安装包

必要的Python 软件包包括

  • ultralytics
  • sklearn
  • pandas
  • pyyaml

这次实例中,我们使用 k=5 折叠次数

3、 实战

3.1 生成物体检测数据集的特征向量

具体步骤如下:

  • 1、首先创建一个新的 demo.py Python 文件来执行下面的步骤。

  • 2、继续检索数据集的所有标签文件。

from pathlib import Pathdataset_path = Path("./Fruit-detection")  # replace with 'path/to/dataset' for your custom data
labels = sorted(dataset_path.rglob("*labels/*.txt"))  # all data in 'labels'
  • 3、现在,读取数据集 YAML 文件的内容并提取类标签的索引。
yaml_file = "path/to/data.yaml"  # your data YAML with data directories and names dictionary
with open(yaml_file, "r", encoding="utf8") as y:classes = yaml.safe_load(y)["names"]
cls_idx = sorted(classes.keys())
  • 4、初始化一个空的 pandas DataFrame.
import pandas as pdindex = [label.stem for label in labels]  # uses base filename as ID (no extension)
labels_df = pd.DataFrame([], columns=cls_idx, index=index)
  • 5、计算注释文件中每个类别标签的实例数。
from collections import Counterfor label in labels:lbl_counter = Counter()with open(label, "r") as lf:lines = lf.readlines()for line in lines:# classes for YOLO label uses integer at first position of each linelbl_counter[int(line.split(" ")[0])] += 1labels_df.loc[label.stem] = lbl_counterlabels_df = labels_df.fillna(0.0)  # replace `nan` values with `0.0`
  • 6、以下是已填充 DataFrame 的示例视图:
                                                       0    1    2    3    4    5
'0000a16e4b057580_jpg.rf.00ab48988370f64f5ca8ea4...'  0.0  0.0  0.0  0.0  0.0  7.0
'0000a16e4b057580_jpg.rf.7e6dce029fb67f01eb19aa7...'  0.0  0.0  0.0  0.0  0.0  7.0
'0000a16e4b057580_jpg.rf.bc4d31cdcbe229dd022957a...'  0.0  0.0  0.0  0.0  0.0  7.0
'00020ebf74c4881c_jpg.rf.508192a0a97aa6c4a3b6882...'  0.0  0.0  0.0  1.0  0.0  0.0
'00020ebf74c4881c_jpg.rf.5af192a2254c8ecc4188a25...'  0.0  0.0  0.0  1.0  0.0  0.0...                                                  ...  ...  ...  ...  ...  ...
'ff4cd45896de38be_jpg.rf.c4b5e967ca10c7ced3b9e97...'  0.0  0.0  0.0  0.0  0.0  2.0
'ff4cd45896de38be_jpg.rf.ea4c1d37d2884b3e3cbce08...'  0.0  0.0  0.0  0.0  0.0  2.0
'ff5fd9c3c624b7dc_jpg.rf.bb519feaa36fc4bf630a033...'  1.0  0.0  0.0  0.0  0.0  0.0
'ff5fd9c3c624b7dc_jpg.rf.f0751c9c3aa4519ea3c9d6a...'  1.0  0.0  0.0  0.0  0.0  0.0
'fffe28b31f2a70d4_jpg.rf.7ea16bd637ba0711c53b540...'  0.0  6.0  0.0  0.0  0.0  0.0

解析

  • 行是标签文件的索引,每个标签文件对应数据集中的一幅图像,列则对应类标签索引。
  • 每一行代表一个伪特征向量,其中包含数据集中每个类标签的计数。
  • 这种数据结构可以将 K 折交叉验证应用于对象检测数据集。

3.2 K 折数据集拆分

  • 1、使用 KFold 从 sklearn.model_selection 以产生 k 对数据集进行分割。

    • 敲黑板:
      • 设置 shuffle=True 确保了分班中班级的随机分布。
      • 通过设置 random_state=M 其中 M 是一个选定的整数,这样就可以得到可重复的结果。
from sklearn.model_selection import KFoldksplit = 5
kf = KFold(n_splits=ksplit, shuffle=True, random_state=20)  # setting random_state for repeatable resultskfolds = list(kf.split(labels_df))
  • 2、数据集现已分为 k 折叠,每个折叠都有一个 train 和 val 指数。我们将构建一个 DataFrame 来更清晰地显示这些结果。
folds = [f"split_{n}" for n in range(1, ksplit + 1)]
folds_df = pd.DataFrame(index=index, columns=folds)for i, (train, val) in enumerate(kfolds, start=1):folds_df[f"split_{i}"].loc[labels_df.iloc[train].index] = "train"folds_df[f"split_{i}"].loc[labels_df.iloc[val].index] = "val"
  • 3、将计算每个褶皱的类别标签分布,并将其作为褶皱中出现的类别的比率。
fold_lbl_distrb = pd.DataFrame(index=folds, columns=cls_idx)for n, (train_indices, val_indices) in enumerate(kfolds, start=1):train_totals = labels_df.iloc[train_indices].sum()val_totals = labels_df.iloc[val_indices].sum()# To avoid division by zero, we add a small value (1E-7) to the denominatorratio = val_totals / (train_totals + 1e-7)fold_lbl_distrb.loc[f"split_{n}"] = ratio
最理想的情况是,每次分割和不同类别的所有类别比率都相当相似。不过,这取决于数据集的具体情况。
  • 4、为每个分割创建目录和数据集 YAML 文件。
import datetimesupported_extensions = [".jpg", ".jpeg", ".png"]# Initialize an empty list to store image file paths
images = []# Loop through supported extensions and gather image files
for ext in supported_extensions:images.extend(sorted((dataset_path / "images").rglob(f"*{ext}")))# Create the necessary directories and dataset YAML files (unchanged)
save_path = Path(dataset_path / f"{datetime.date.today().isoformat()}_{ksplit}-Fold_Cross-val")
save_path.mkdir(parents=True, exist_ok=True)
ds_yamls = []for split in folds_df.columns:# Create directoriessplit_dir = save_path / splitsplit_dir.mkdir(parents=True, exist_ok=True)(split_dir / "train" / "images").mkdir(parents=True, exist_ok=True)(split_dir / "train" / "labels").mkdir(parents=True, exist_ok=True)(split_dir / "val" / "images").mkdir(parents=True, exist_ok=True)(split_dir / "val" / "labels").mkdir(parents=True, exist_ok=True)# Create dataset YAML filesdataset_yaml = split_dir / f"{split}_dataset.yaml"ds_yamls.append(dataset_yaml)with open(dataset_yaml, "w") as ds_y:yaml.safe_dump({"path": split_dir.as_posix(),"train": "train","val": "val","names": classes,},ds_y,)
  • 5、最后,将图像和标签复制到每个分割的相应目录("train "或 “val”)中。
import shutilfor image, label in zip(images, labels):for split, k_split in folds_df.loc[image.stem].items():# Destination directoryimg_to_path = save_path / split / k_split / "images"lbl_to_path = save_path / split / k_split / "labels"# Copy image and label files to new directory (SamefileError if file already exists)shutil.copy(image, img_to_path / image.name)shutil.copy(label, lbl_to_path / label.name)

3.3 保存记录

将 K 折分割和标签分布数据框的记录保存为 CSV 文件。

folds_df.to_csv(save_path / "kfold_datasplit.csv")
fold_lbl_distrb.to_csv(save_path / "kfold_label_distribution.csv")

3.4 使用 K 折数据分割训练YOLO

  • 首先,加载YOLO 模型。
from ultralytics import YOLOweights_path = "path/to/weights.pt"
model = YOLO(weights_path, task="detect")
  • 其次,遍历数据集 YAML 文件以运行训练。结果将保存到由 project 和 name 参数。默认情况下,该目录为 “exp/runs#”,其中 # 为整数索引。
results = {}# Define your additional arguments here
batch = 16
project = "kfold_demo"
epochs = 100for k in range(ksplit):dataset_yaml = ds_yamls[k]model = YOLO(weights_path, task="detect")model.train(data=dataset_yaml, epochs=epochs, batch=batch, project=project)  # include any train argumentsresults[k] = model.metrics  # save output metrics for further analysis

4、总结

这篇小鱼使用了 K 折交叉验证来训练YOLO 物体检测模型的过程。

还创建报告 DataFrames 的程序,以可视化数据拆分和标签在这些拆分中的分布,清楚地了解训练集和验证集的结构。

此外,还保存了记录,这在大型项目或排除模型性能故障时尤为有用。

最后,在一个循环中使用每个拆分来执行实际的模型训练,保存训练结果,以便进一步分析和比较。

这种 K 折交叉验证技术是充分利用可用数据的一种稳健方法,有助于确保模型在不同数据子集中的性能是可靠和一致的。这将产生一个更具通用性和可靠性的模型,从而减少对特定数据模式的过度拟合。

我是小鱼

  • CSDN 博客专家
  • 阿里云 专家博主
  • 51CTO博客专家
  • 企业认证金牌面试官
  • 多个名企认证&特邀讲师等
  • 名企签约职场面试培训、职场规划师
  • 多个国内主流技术社区的认证专家博主
  • 多款主流产品(阿里云等)评测一等奖获得者

关注小鱼,学习【人工智能&大模型】/【深度学习&机器学习】领域最新最全的知识。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/68102.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android studio ternimal 中gradle 指令失效(gradle环境变量未配置)

默认gradle路径:C:\Users\ylwj.gradle\wrapper\dists\gradle-8.10.2-bin\a04bxjujx95o3nb99gddekhwo\gradle-8.10.2\bin 环境变量-系统环境变量-双击path-配置上即可-注意重启studio才会生效

Axure大屏可视化动态交互设计:解锁数据魅力,引领决策新风尚

可视化组件/模板预览:https://8dge09.axshare.com 一、大屏可视化技术概览 在数据驱动决策的时代,大屏可视化技术凭借直观、动态的展示方式,已成为众多行业提升管理效率和优化决策过程的关键工具。它能够将复杂的数据转化为易于理解的图形和…

Resnet 改进:尝试在不同位置加入Transform模块

目录 1. TransformerBlock 2. resnet 3. 替换部分卷积层 4. 在特定位置插入Transformer模块 5. 使用Transformer全局特征提取器 6. 其他 Tips:融入模块后的网络经过测试,可以直接使用,设置好输入和输出的图片维度即可 1. TransformerBlock TransformerBlock是Transfo…

MySQL调优02 - SQL语句的优化

SQL语句的优化 文章目录 SQL语句的优化一:SQL优化的小技巧1:编写SQL时的注意点1.1:查询时尽量不要使用*1.2:连表查询时尽量不要关联太多表1.3:多表查询时一定要以小驱大1.4:like不要使用左模糊或者全模糊1.…

langchain教程-12.Agent/工具定义/Agent调用工具/Agentic RAG

前言 该系列教程的代码: https://github.com/shar-pen/Langchain-MiniTutorial 我主要参考 langchain 官方教程, 有选择性的记录了一下学习内容 这是教程清单 1.初试langchain2.prompt3.OutputParser/输出解析4.model/vllm模型部署和langchain调用5.DocumentLoader/多种文档…

大模型中提到的超参数是什么

在大模型中提到的超参数是指在模型训练之前需要手动设置的参数,这些参数决定了模型的训练过程和最终性能。超参数与模型内部通过训练获得的参数(如权重和偏置)不同,它们通常不会通过训练自动学习,而是需要开发者根据任…

位运算及常用技巧

涉及位运算的运算符如下表所示: 位运算的运算律: 负数的位运算 首先,我们要知道,在计算机中,运算是使用的二进制补码,而正数的补码是它本身,负数的补码则是符号位不变,其余按位取反…

hot100(8)

71.10. 正则表达式匹配 - 力扣(LeetCode) 动态规划 题解:10. 正则表达式匹配题解 - 力扣(LeetCode) 72.5. 最长回文子串 - 力扣(LeetCode) 动态规划 1.dp数组及下标含义 dp[i][j] : 下标i到…

二进制/源码编译安装httpd 2.4,提供系统服务管理脚本并测试

方法一:使用 systemd 服务文件 安装所需依赖 yum install gcc make apr-devel apr-util-devel pcre-devel 1.下载源码包 wget http://archive.apache.org/dist/httpd/httpd-2.4.62.tar.gz 2.解压源码 tar -xf httpd-2.4.62.tar.gz cd httpd-2.4.62 3.编译安装 指定…

Java 中 LinkedList 的底层源码

在 Java 的集合框架中,LinkedList是一个独特且常用的成员。它基于双向链表实现,与数组结构的集合类如ArrayList有着显著差异。深入探究LinkedList的底层源码,有助于我们更好地理解其工作原理和性能特点,以便在实际开发中做出更合适…

金蝶云星空k3cloud webapi报“java.lang.Class cannot be cast to java.lang.String”的错误

最近在对接金蝶云星空k3cloud webapi时,报一个莫名其妙的转换异常,具体如下: 同步部门异常! ERP接口登录异常:java.lang.Class cannot be cast to java.lang.String at com.jkwms.k3cloudSyn.service.basics.DeptK3CloudService.…

【Android】jni开发之导入opencv和libyuv来进行图像处理

做视频图像处理时需要对其进行水印的添加,放在应用层调用工具性能方面不太满意,于是当下采用opencvlibyuv方法进行处理。 对于Android的jni开发不是很懂,我的需求是导入opencv方便在cpp中调用,但目前找到的教程都是把opencv作为模…

【MySQL】centos 7 忘记数据库密码

vim /etc/my.cnf文件; 在[mysqld]后添加skip-grant-tables(登录时跳过权限检查) 重启MySQL服务:sudo systemctl restart mysqld 登录mysql,输入mysql –uroot –p;直接回车(Enter) 输…

国产编辑器EverEdit - 自定义标记使用详解

1 自定义标记使用详解 1.1 应用场景 当阅读日志等文件,用于调试或者检查问题时,往往日志中会有很多关键性的单词,比如:ERROR, FATAL等,但由于文本模式对这些关键词并没有突出显示,造成检查问题时&#xff…

Golang 并发机制-6:掌握优雅的错误处理艺术

并发编程可能是提高软件系统效率和响应能力的一种强有力的技术。它允许多个工作负载同时运行,充分利用现代多核cpu。然而,巨大的能力带来巨大的责任,良好的错误管理是并发编程的主要任务之一。 并发代码的复杂性 并发编程增加了顺序程序所不…

JVM 四虚拟机栈

虚拟机栈出现的背景 由于跨平台性的设计,Java的指令都是根据栈来设计的。不同平台CPU架构不同,所以不能设计为基于寄存器的。优点是跨平台,指令集小,编译器容易实现,缺点是性能下降,实现同样的功能需要更多…

鼠标拖尾特效

文章目录 鼠标拖尾特效一、引言二、实现原理1、监听鼠标移动事件2、生成拖尾元素3、控制元素生命周期 三、代码实现四、使用示例五、总结 鼠标拖尾特效 一、引言 鼠标拖尾特效是一种非常酷炫的前端交互效果,能够为网页增添独特的视觉体验。它通常通过JavaScript和C…

6-图像金字塔与轮廓检测

文章目录 6.图像金字塔与轮廓检测(1)图像金字塔定义(2)金字塔制作方法(3)轮廓检测方法(4)轮廓特征与近似(5)模板匹配方法6.图像金字塔与轮廓检测 (1)图像金字塔定义 高斯金字塔拉普拉斯金字塔 高斯金字塔:向下采样方法(缩小) 高斯金字塔:向上采样方法(放大)…

RNN/LSTM/GRU 学习笔记

文章目录 RNN/LSTM/GRU一、RNN1、为何引入RNN?2、RNN的基本结构3、各种形式的RNN及其应用4、RNN的缺陷5、如何应对RNN的缺陷?6、BPTT和BP的区别 二、LSTM1、LSTM 简介2、LSTM如何缓解梯度消失与梯度爆炸? 三、GRU四、参考文献 RNN/LSTM/GRU …

qt-Quick3D笔记之官方例程Runtimeloader Example运行笔记

qt-Quick3D笔记之官方例程Runtimeloader Example运行笔记 文章目录 qt-Quick3D笔记之官方例程Runtimeloader Example运行笔记1.例程运行效果2.例程缩略图3.项目文件列表4.main.qml5.main.cpp6.CMakeLists.txt 1.例程运行效果 运行该项目需要自己准备一个模型文件 2.例程缩略图…