【进阶篇】YOLOv8实现K折交叉验证——解决数据集样本稀少和类别不平衡的难题,让你的模型评估更加稳健


在这里插入图片描述


在这里插入图片描述
YOLOv8专栏导航:点击此处跳转


K折交叉验证

K折交叉验证(K-Fold Cross-Validation)是一种常用的机器学习模型评估方法,可以帮助我们评估模型的性能,特别适用于数据集相对较小的情况。

在K折交叉验证中,将原始数据集分成K个子集,然后依次将其中一个子集作为验证集,其余K-1个子集作为训练集进行模型训练和评估。这样可以得到K个模型,每个模型都在不同的验证集上进行评估,最后将K个模型的评估结果,求平均或取最优结果作为最终评估。

优点

  • 充分利用数据集: 在K折交叉验证中,整个数据集被划分为K个互斥的折叠(Folds)。每次训练模型时,都有K-1个折叠用于训练,而剩下的一个用于验证。这样,每个样本都有机会作为验证集的一部分,从而充分利用了数据集中的所有样本。
  • 减轻过拟合风险: 由于每个样本都会在训练集和验证集中出现,模型在验证集上的性能评估更具有代表性。这有助于减轻由于数据稀疏或类别不平衡导致的过拟合问题。
  • 更稳健的模型评估: K折交叉验证计算多次模型性能的平均值,提供了更稳健的性能评估。这对于对抗数据稀疏和类别不平衡等挑战的模型评估尤为重要,因为它减少了单次评估可能引入的随机性。
  • 处理类别不平衡: 如果数据集中某些类别的样本数量较少,K折交叉验证可以确保每个折叠中都包含这些少见类别的样本。这有助于确保模型在少见类别上的性能得到充分评估,并提高对类别不平衡的鲁棒性。

YOLOv8

YOLOv8 是由 YOLOv5 的发布者 Ultralytics 发布的最新版本的 YOLO。它可用于对象检测、分割、分类任务以及大型数据集的学习,并且可以在包括 CPU 和 GPU 在内的各种硬件上执行。

YOLOv8是一种尖端的、最先进的 (SOTA) 模型,它建立在以前成功的 YOLO 版本的基础上,并引入了新的功能和改进,以进一步提高性能和灵活性。YOLOv8 旨在快速、准确且易于使用,这也使其成为对象检测、图像分割和图像分类任务的绝佳选择。具体创新包括一个新的骨干网络、一个新的 Ancher-Free 检测头和一个新的损失函数,还支持YOLO以往版本,方便不同版本切换和性能对比。

YOLOv8的K折交叉验证实现步骤

  1. 首先,将数据集划分成K个子集,可以使用现有的数据集划分函数或手动划分。
  2. 然后,使用一个循环迭代K次,每次将其中一个子集作为验证集,其余K-1个子集作为训练集。
  3. 在每次迭代中,使用训练集进行模型训练,并使用验证集进行模型评估。
  4. 可以根据需要调整模型的超参数或其他设置。例如,可以尝试不同的学习率、迭代次数等。
  5. 最后,将K次迭代中的评估结果进行平均或取最优结果作为最终的模型评估结果。

官方项目下载

git clone https://github.com/ultralytics/ultralytics.git

🚀 YOLOv8-K折交叉验证代码实现

在上述下载好的项目文件夹下新建k-fold-train.py文件,并添加下述代码:

import argparse
import datetime
from itertools import chain
import os
from pathlib import Path
import shutil
import yaml
import pandas as pd
from collections import Counter
from sklearn.model_selection import KFold
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from ultralytics import YOLONUM_THREADS = min(8, max(1, os.cpu_count() - 1))def parse_opt():parser = argparse.ArgumentParser()parser.add_argument('--data', default=r'./data')  # 数据集路径parser.add_argument('--ksplit', default=5, type=int)  # K-Fold交叉验证拆分数据集parser.add_argument('--im_suffixes', default=['jpg', 'png', 'jpeg'], help='images suffix')  # 图片后缀名return parser.parse_args()def run(func, this_iter, desc="Processing"):with ThreadPoolExecutor(max_workers=NUM_THREADS, thread_name_prefix='MyThread') as executor:results = list(tqdm(executor.map(func, this_iter), total=len(this_iter), desc=desc))return resultsdef main(opt):dataset_path, ksplit, im_suffixes = Path(opt.data), opt.ksplit, opt.im_suffixessave_path = Path(dataset_path / f'{datetime.date.today().isoformat()}_{ksplit}-Fold_Cross-Valid')save_path.mkdir(parents=True, exist_ok=True)# 获取所有图像和标签文件的列表images = sorted(list(chain(*[(dataset_path / "images").rglob(f'*.{ext}') for ext in im_suffixes])))# images = sorted(image_files)labels = sorted((dataset_path / "labels").rglob("*.txt"))root_directory = Path.cwd()print("当前文件运行根目录:", root_directory)if len(images) != len(labels):print('*' * 20)print('当前数据集和标签数量不一致!!!')print('*' * 20)# 从YAML文件加载类名classes_file = sorted(dataset_path.rglob('classes.yaml'))[0]assert classes_file.exists(), "请创建classes.yaml类别文件"if classes_file.suffix == ".txt":passelif classes_file.suffix == ".yaml":with open(classes_file, 'r', encoding="utf8") as f:classes = yaml.safe_load(f)['names']cls_idx = sorted(classes.keys())# 创建DataFrame来存储每张图像的标签计数indx = [l.stem for l in labels]  # 使用基本文件名作为ID(无扩展名)labels_df = pd.DataFrame([], columns=cls_idx, index=indx)# 计算每张图像的标签计数for label in labels:lbl_counter = Counter()with open(label, 'r') as lf:lines = lf.readlines()for l in lines:# YOLO标签使用每行的第一个位置的整数作为类别lbl_counter[int(l.split(' ')[0])] += 1labels_df.loc[label.stem] = lbl_counter# 用0.0替换NaN值labels_df = labels_df.fillna(0.0)kf = KFold(n_splits=ksplit, shuffle=True, random_state=20)  # 设置random_state以获得可重复的结果kfolds = list(kf.split(labels_df))folds = [f'split_{n}' for n in range(1, ksplit + 1)]folds_df = pd.DataFrame(index=indx, columns=folds)# 为每个折叠分配图像到训练集或验证集for idx, (train, val) in enumerate(kfolds, start=1):folds_df[f'split_{idx}'].loc[labels_df.iloc[train].index] = 'train'folds_df[f'split_{idx}'].loc[labels_df.iloc[val].index] = 'val'# 计算每个折叠的标签分布比例fold_lbl_distrb = pd.DataFrame(index=folds, columns=cls_idx)for n, (train_indices, val_indices) in enumerate(kfolds, start=1):train_totals = labels_df.iloc[train_indices].sum()val_totals = labels_df.iloc[val_indices].sum()# 为避免分母为零,向分母添加一个小值(1E-7)ratio = val_totals / (train_totals + 1E-7)fold_lbl_distrb.loc[f'split_{n}'] = ratiods_yamls = []for split in folds_df.columns:split_dir = save_path / splitsplit_dir.mkdir(parents=True, exist_ok=True)(split_dir / 'train' / 'images').mkdir(parents=True, exist_ok=True)(split_dir / 'train' / 'labels').mkdir(parents=True, exist_ok=True)(split_dir / 'val' / 'images').mkdir(parents=True, exist_ok=True)(split_dir / 'val' / 'labels').mkdir(parents=True, exist_ok=True)dataset_yaml = split_dir / f'{split}_dataset.yaml'ds_yamls.append(dataset_yaml.as_posix())split_dir = (root_directory / split_dir).as_posix()with open(dataset_yaml, 'w') as ds_y:yaml.safe_dump({'train': split_dir + '/train/images','val': split_dir + '/val/images','names': classes}, ds_y)# print(ds_yamls)with open(dataset_path / 'yaml_paths.txt', 'w') as f:for path in ds_yamls:f.write(path + '\n')args_list = [(image, save_path, folds_df) for image in images]run(split_images_labels, args_list, desc=f"Creating dataset")def split_images_labels(args):image, save_path, folds_df = argslabel = image.parents[1] / 'labels' / f'{image.stem}.txt'if label.exists():for split, k_split in folds_df.loc[image.stem].items():# 目标目录img_to_path = save_path / split / k_split / 'images'lbl_to_path = save_path / split / k_split / 'labels'shutil.copy(image, img_to_path / image.name)shutil.copy(label, lbl_to_path / label.name)if __name__ == "__main__":opt = parse_opt()main(opt)model = YOLO('yolov8n.pt', task='train')# 从文本文件中加载内容并存储到一个列表中ds_yamls = []with open(Path(opt.data) / 'yaml_paths.txt', 'r') as f:for line in f:# 去除每行末尾的换行符line = line.strip()ds_yamls.append(line)# 打印加载的文件路径列表print(ds_yamls)for k in range(opt.ksplit):dataset_yaml = ds_yamls[k]name = Path(dataset_yaml).stemmodel.train(data=dataset_yaml,batch=16,epochs=100,imgsz=640,device=0,workers=8,project="runs/train",name=name,)print("*"*40)print("K-Fold Cross Validation Completed.")print("*"*40)

命令行运行

python k-fold-train.py --data ./data --ksplit 5
  • data:数据集所在目录,数据集的分布如下:
    - data- images- labels
    
    images存放图片文件,labels存放txt标注文件。
  • ksplit:数据集拆分为ksplit个子集,该值通常取5或者10,某些情况下可取其他数值。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/235202.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

redis相关面试题

1、说一说你在项目中的redis的应用场景? 需要频繁查询的数据,分布式锁,spring session 5大value类型:string hash list set zset基本上就是缓存为的是服务无状态,延申思考,看你的项目有哪些数据结构或对象…

final的详解

在Java中,final 关键字用于表示不可改变的实体,可以应用于变量、方法、类和指令重排序。它有不同的作用,具体取决于它被应用的上下文。 1.对于变量: 如果一个变量被声明为 final,则该变量的值在一旦被赋予后就不能再被…

Starting the Docker Engine...一直转圈

出现的问题: 原因排查: 看了网上的很多篇文章,每个原因都排查了,没有发现问题。 遇到这样的情况应先看自己是否安装成功 打开运行,在空框中输入powershell并点击确定: docker version 显示版本证明安装…

微信小程序-选择和分割打开地图选择位置的信息

一、 前言 废话不多说,单刀直入。 本文要实现的功能是微信小程序中打开地图选择位置,以及将返回的位置信息分割。 例如返回的位置信息是:广东省深圳市龙岗区xxxxx小区 分割后变成: {province: "广东省",city: "深…

042.Python异常处理_异常捕获

我 的 个 人 主 页:👉👉 失心疯的个人主页 👈👈 入 门 教 程 推 荐 :👉👉 Python零基础入门教程合集 👈👈 虚 拟 环 境 搭 建 :👉&…

SpringBoot前后端分离开发项目部署时,项目打包准备工作

第一步:项目打包之前,拉前后端代码 拉完代码后,再执行下面操作(确保项目能正常启动并运行) 后端(执行如下操作) mvn clean install -T 8 -Dmaven.test.skiptrue -Dmaven.compile.forktrue 执行…

JDK17 SpringBoot3 整合常见依赖

JDK版本&#xff1a;17 SpringBoot 整合Mybatis Plus 、Redis等 依赖文件 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance&q…

【python】程序运行添加命令行参数argparse模块用法详解

Python标准库之argparse&#xff0c;详解如何创建一个ArgumentParser对象及使用 一. argparse介绍二. 使用步骤及参数介绍三. 具体使用3.1 设置必需参数3.2 传一个参数3.3 传多个参数3.4 位置参数和可选参数3.5 参数设置默认值3.6 其它用法 一. argparse介绍 很多时候&#xff…

Amazon CodeWhisperer 在 vscode 的应用

文章作者:旧花阴 CodeWhisperer 是一款可以帮助程序员更快、更安全地编写代码的工具&#xff0c;可以在他们的开发环境中实时提供代码建议和推荐。亚马逊云科技发布的这款代码生成工具 CodeWhisperer 最大的优势就是对于个人用户免费。以在 vscode 为例&#xff0c;演示安装过程…

LeetCode 1901. 寻找峰值 II:二分查找

【LetMeFly】1901.寻找峰值 II&#xff1a;二分查找 力扣题目链接&#xff1a;https://leetcode.cn/problems/find-a-peak-element-ii/ 一个 2D 网格中的 峰值 是指那些 严格大于 其相邻格子(上、下、左、右)的元素。 给你一个 从 0 开始编号 的 m x n 矩阵 mat &#xff0c…

【漏洞复现】CVE-2023-6895 IP网络对讲广播系统远程命令执行

漏洞描述 杭州海康威视数字技术有限公司IP网络对讲广播系统。 海康威视对讲广播系统3.0.3_20201113_RELEASE(HIK)存在漏洞。它已被宣布为关键。该漏洞影响文件/php/ping.php 的未知代码。使用输入 netstat -ano 操作参数 jsondata[ip] 会导致 os 命令注入。 开发语言:PHP 开…

原子学习笔记3——使用tslib库

一、tslib介绍 tslib 是专门为触摸屏设备所开发的 Linux 应用层函数库&#xff0c;并且是开源。 tslib 为触摸屏驱动和应用层之间的适配层&#xff0c;它把应用程序中读取触摸屏 struct input_event 类型数据&#xff08;这是输入设备上报给应用层的原始数据&#xff09;并进行…

2023-2024-2Java面向对象程序设计-阶段性测试2

填空题&#xff08;总分&#xff1a;10.00&#xff09; 1、Java程序中使用【 import 】关键字导入外部的包。 2、使用【 final 】关键字声明的类不能有子类。 4、JVM是【 Java Virtual Machine 】的英文简写。 5、面向对象编程思想的三个特性是【封装】、【继承】、【多态】。 …

数据分析师的职业规划与参考资料

数据分析师如何规划 参考&#xff1a;超详细的数据分析职业规划 一个产品的出现可以从业务和技术两个方向分析&#xff0c;业务需求技术支持产品的出现。 如果把职业也当成一个产品&#xff0c;也有类似的分析&#xff0c; 其中业务也就是领域&#xff0c;即这个业务领域的特点…

Power BI案例-医院数据集的仪表盘制作

数据集描述 医生数据集doctor 医生编号是唯一的&#xff0c;名称会存在重复 医疗项目数据projects 病例编号是唯一的&#xff0c;注意这个日期编号不是真正的日期。 日期数据date 这里的日期编号对应医疗项目数据中的日期编号 科室数据集Department 维度表 采购成本事实表…

知乎上高频提问:Redis到底是单线程还是多线程程序?

1.概述 这里我们先给出问题的全面回答&#xff1a;Redis到底是多线程还是单线程程序要看是针对哪个功能而言&#xff0c;对于核心业务功能部分(命令操作处理数据)&#xff0c;Redis是单线程的&#xff0c;主要是指 Redis 的网络 IO 和键值对读写是由一个线程来完成的&#xff…

海康rtsp拉流,rtmp推流,nginx部署转flv集成

海康rtsp拉流&#xff0c;rtmp推流&#xff0c;nginx部署转flv集成 项目实际使用并测试经正式使用无问题&#xff0c;有问题欢迎评论留言 核心后台java代码&#xff1a; try {// FFmpeg命令String command "ffmpeg -re -i my_video.mp4 -c copy -f flv rtmp://localho…

[学习笔记]批量迁移数据库文件

拷贝数据库文件 首先在本地运行如下SQL语句&#xff0c;查看数据库文件的磁盘位置 SELECT name, physical_name AS CurrentLocation, state_desc FROM sys.master_files默认是保存在C:\Program Files\Microsoft SQL Server\MSSQL13.MSSQLSERVER\MSSQL\DATA目录下 首先复制数据…

Ansible常用模块详解(附各模块应用实例和Ansible环境安装部署)

目录 一、ansible概述 1、简介 2、Ansible主要功能&#xff1a; 3、Ansible的另一个特点&#xff1a;所有模块都是幂等性 4、Ansible的优点&#xff1a; 5、Ansible的四大组件&#xff1a; 二、ansible环境部署&#xff1a; 1、环境&#xff1a; 2、安装ansible&#…

浅析RoPE旋转位置编码的远程衰减特性

为什么 θ i \theta_i θi​的取值会造成远程衰减性 旋转位置编码的出发点为&#xff1a;通过绝对位置编码的方式实现相对位置编码。 对词向量 q \boldsymbol{q} q添加绝对位置信息 m m m&#xff0c;希望找到一种函数 f f f&#xff0c;使得&#xff1a; < f ( q , m ) …