YOLOv8教程系列:三、K折交叉验证——让你的每一份标注数据都物尽其用(yolov8目标检测+k折交叉验证法)

YOLOv8教程系列:三、K折交叉验证——让你的每一份标注数据都物尽其用(yolov8目标检测+k折交叉验证法)

0.引言

k折交叉验证(K-Fold
Cross-Validation)是一种在机器学习中常用的模型评估技术,用于估计模型的性能和泛化能力。它的主要作用是在有限的数据集上对模型进行评估,以便更准确地了解模型在新数据上的表现。

K折交叉验证的基本思想是将原始数据集分成K个子集(折),然后依次将每个子集作为验证集,其他K-1个子集作为训练集,进行K次训练和验证。每次验证后,计算模型在验证集上的性能指标,如准确率、精确率、召回率等。最后,将K次验证的性能指标平均,作为模型在整个数据集上的性能估计。

K折交叉验证的作用包括:

  1. 模型性能评估: K折交叉验证可以更准确地评估模型在数据集上的性能,避免因数据分布不均匀而导致评估结果不准确的问题。
  2. 泛化能力估计: 通过在不同的训练集和验证集上进行多次评估,可以更好地估计模型的泛化能力,即模型在新数据上的表现。
  3. 减少过拟合: K折交叉验证可以帮助检测模型是否出现过拟合问题。如果模型在训练集上表现很好,但在验证集上表现较差,可能存在过拟合。
  4. 参数调优: 在每一轮交叉验证中,可以使用不同的参数设置来训练模型,以找到在验证集上表现最好的参数组合。
  5. 数据利用率: K折交叉验证充分利用了数据集中的所有样本,因为每个样本都会在不同的折中被用作训练和验证。

总之,K折交叉验证是一种有助于评估和改进模型性能的重要技术,尤其在数据有限的情况下,它能更准确地估计模型在新数据上的表现。
在这里插入图片描述

1.数据准备

使用交叉验证前,需要把数据准备为yolo格式,不知道如何数据准备的朋友可以看下这篇文章:YOLOv8教程系列:一、使用自定义数据集训练YOLOv8模型(详细版教程,你只看一篇->调参攻略),包含环境搭建/数据准备/模型训练/预测/验证/导出等
.
├── ./data
│ ├── ./data/Annotations
│ │ ├── ./data/Annotations/fall_0.xml
│ │ ├── ./data/Annotations/fall_1000.xml
│ │ ├── ./data/Annotations/fall_1001.xml
│ │ ├── ./data/Annotations/fall_1002.xml
│ │ ├── ./data/Annotations/fall_1003.xml
│ │ ├── ./data/Annotations/fall_1004.xml
│ │ ├── …
│ ├── ./data/images
│ │ ├── ./data/images/fall_0.jpg
│ │ ├── ./data/images/fall_1000.jpg
│ │ ├── ./data/images/fall_1001.jpg
│ │ ├── ./data/images/fall_1002.jpg
│ │ ├── ./data/images/fall_1003.jpg
│ │ ├── ./data/images/fall_1004.jpg
│ │ ├── …
│ ├── ./data/ImageSets
│ └── ./data/labels
│ │ ├── ./data/images/fall_0.txt
│ │ ├── ./data/images/fall_1000.txt
│ │ ├── ./data/images/fall_1001.txt
│ │ ├── ./data/images/fall_1002.txt
│ │ ├── ./data/images/fall_1003.txt
│ │ ├── ./data/images/fall_1004.txt
│ ├── ./data/classes.yaml
其中,特别要注意的一点是,需要新建个classes.yaml的文件,然后将自己的标签按序填写,如下所示:

names:0: your_label_11: your_label_2

2.代码准备

下面代码可以什么都不用改直接运行,前提是按我的数据格式,这个代码放在data的上层目录中

import datetime
import shutil
from pathlib import Path
from collections import Counter
import osimport yaml
import numpy as np
import pandas as pd
from ultralytics import YOLO
from sklearn.model_selection import KFold# 定义数据集路径
dataset_path = Path('./data')  # 替换成你的数据集路径# 获取所有标签文件的列表
labels = sorted(dataset_path.rglob("*labels/*.txt"))  # 所有标签文件在'labels'目录中# 获取当前文件的绝对路径
current_file_path = os.path.abspath(__file__)# 获取当前文件所在的文件夹路径(即当前文件的根目录)
root_directory = os.path.dirname(current_file_path)print("当前文件运行根目录:", root_directory)# 从YAML文件加载类名
yaml_file = 'data/classes.yaml'
with open(yaml_file, 'r', encoding="utf8") as y:classes = yaml.safe_load(y)['names']
cls_idx = sorted(classes.keys())# 创建DataFrame来存储每张图像的标签计数
indx = [l.stem for l in labels]  # 使用基本文件名作为ID(无扩展名)
labels_df = pd.DataFrame([], columns=cls_idx, index=indx)# 计算每张图像的标签计数
for label in labels:lbl_counter = Counter()with open(label, 'r') as lf:lines = lf.readlines()for l in lines:# YOLO标签使用每行的第一个位置的整数作为类别lbl_counter[int(l.split(' ')[0])] += 1labels_df.loc[label.stem] = lbl_counter# 用0.0替换NaN值
labels_df = labels_df.fillna(0.0)# 使用K-Fold交叉验证拆分数据集
ksplit = 5
kf = KFold(n_splits=ksplit, shuffle=True, random_state=20)  # 设置random_state以获得可重复的结果
kfolds = list(kf.split(labels_df))
folds = [f'split_{n}' for n in range(1, ksplit + 1)]
folds_df = pd.DataFrame(index=indx, columns=folds)# 为每个折叠分配图像到训练集或验证集
for idx, (train, val) in enumerate(kfolds, start=1):folds_df[f'split_{idx}'].loc[labels_df.iloc[train].index] = 'train'folds_df[f'split_{idx}'].loc[labels_df.iloc[val].index] = 'val'# 计算每个折叠的标签分布比例
fold_lbl_distrb = pd.DataFrame(index=folds, columns=cls_idx)
for n, (train_indices, val_indices) in enumerate(kfolds, start=1):train_totals = labels_df.iloc[train_indices].sum()val_totals = labels_df.iloc[val_indices].sum()# 为避免分母为零,向分母添加一个小值(1E-7)ratio = val_totals / (train_totals + 1E-7)fold_lbl_distrb.loc[f'split_{n}'] = ratio# 创建目录以保存分割后的数据集
save_path = Path(dataset_path / f'{datetime.date.today().isoformat()}_{ksplit}-Fold_Cross-val')
save_path.mkdir(parents=True, exist_ok=True)# 获取图像文件列表
images = sorted((dataset_path / 'images').rglob("*.jpg"))  # 更改文件扩展名以匹配你的数据
ds_yamls = []# 循环遍历每个折叠并复制图像和标签
for split in folds_df.columns:# 为每个折叠创建目录split_dir = save_path / splitsplit_dir.mkdir(parents=True, exist_ok=True)(split_dir / 'train' / 'images').mkdir(parents=True, exist_ok=True)(split_dir / 'train' / 'labels').mkdir(parents=True, exist_ok=True)(split_dir / 'val' / 'images').mkdir(parents=True, exist_ok=True)(split_dir / 'val' / 'labels').mkdir(parents=True, exist_ok=True)# 创建数据集的YAML文件dataset_yaml = split_dir / f'{split}_dataset.yaml'ds_yamls.append(dataset_yaml.as_posix())split_dir = os.path.join(root_directory, split_dir.as_posix())with open(dataset_yaml, 'w') as ds_y:yaml.safe_dump({'path': split_dir,'train': 'train','val': 'val','names': classes}, ds_y)
print(ds_yamls)# 将文件路径保存到一个txt文件中
with open('data/file_paths.txt', 'w') as f:for path in ds_yamls:f.write(path + '\n')# 为每个折叠复制图像和标签到相应的目录
for image, label in zip(images, labels):for split, k_split in folds_df.loc[image.stem].items():# 目标目录img_to_path = save_path / split / k_split / 'images'lbl_to_path = save_path / split / k_split / 'labels'# 将图像和标签文件复制到新目录中# 如果文件已存在,可能会抛出SamefileErrorshutil.copy(image, img_to_path / image.name)shutil.copy(label, lbl_to_path / label.name)

运行代码后,会在data目录下生成一个文件夹,里面有5种不同划分的数据集

3.开始训练

下面的代码放在和上面代码的同级目录中,训练参数可以根据自己情况进行调整

from ultralytics import YOLOweights_path = 'checkpoints/yolov8s.pt'
model = YOLO(weights_path, task='train')
ksplit = 5
# 从文本文件中加载内容并存储到一个列表中
ds_yamls = []
with open('data/file_paths.txt', 'r') as f:for line in f:# 去除每行末尾的换行符line = line.strip()ds_yamls.append(line)# 打印加载的文件路径列表
print(ds_yamls)results = {}
for k in range(ksplit):dataset_yaml = ds_yamls[k]model.train(data=dataset_yaml, batch=6, epochs=2, imgsz=1280, device=0, workers=8, single_cls=False, ) 

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/48529.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java详解编译型和解释型语言

在计算机的高级编程语言类型分为两种,分别是编译型和解释型,而Java既有编译型又有解释型 什么是编译型?什么是解释型? 字面上来说编译和解释都有‘翻译’的意思,而她们两个的区别是‘翻译’的时机不同,什…

Python采集电商平台泳衣数据进行可视化分析

前言 嗨喽,大家好呀~这里是爱看美女的茜茜呐 环境使用: python 3.8 解释器 pycharm 编辑器 模块使用: 第三方模块 需要安装 requests —> 发送 HTTP请求 内置模块 不需要安装 csv —> 数据处理中经常会用到的一种文件格式 第三方模块安装&#xff1a…

实验五 Linux 内核的安装与加载

【实验目的】 掌握 uboot 的使用方法,能够使用 uboot 安装和加载内核 【实验环境】 ubuntu 14.04 发行版FS4412 实验平台 【注意事项】 实验步骤中以“$”开头的命令表示在 ubuntu 环境下执行,以“#”开头的命令表 示在开发板下执行 【实验步骤】 …

计算机视觉 -- 图像分割

文章目录 1. 图像分割2. FCN2.1 语义分割– FCN (Fully Convolutional Networks)2.2 FCN--deconv2.3 Unpool2.4 拓展–DeconvNet 3. 实例分割3.1 实例分割--Mask R-CNN3.2 Mask R-CNN3.3 Faster R-CNN与 Mask R-CNN3.4 Mask R-CNN:Resnet1013…

ES搭建集群

一、创建 elasticsearch-cluster 文件夹 创建 elasticsearch-7.8.0-cluster 文件夹,在内部复制三个 elasticsearch 服务。 然后每个文件目录中每个节点的 config/elasticsearch.yml 配置文件 node-1001 节点 #节点 1 的配置信息: #集群名称&#xff0…

【数据备份、恢复、迁移与容灾】上海道宁与云祺科技为企业用户提供云数据中心容灾备份解决方案

云祺容灾备份系统支持 主流虚拟化环境下的虚拟机备份 提供对云基础设施 云架构平台以及 应用系统的全方位数据保护 云祺容灾备份系统规范功能 增强决策能力 高效恢复数据至可用状态 有效降低恢复成本 更大限度减少业务中断时间 保障业务可访问性 开发商介绍 成都云祺…

LSTM数学计算公式

LSTM(长短期记忆网络)是一种循环神经网络(RNN)的变体,常用于处理时间序列相关的任务。下面将简要介绍LSTM的数学推导和公式模型。 在训练一般神经网络模型时,通常用,其中W为权重,X为输入&#…

算法通关村第九关——中序遍历与搜索树

1 中序遍历和搜索树原理 二叉搜索树按照中序遍历正好是一个递增序列。其比较规范的定义是: 若它的左子树不为空,则左子树上所有节点的值均小于它的根节点的值;若它的右子树不为空,则右子树所有节点的值均大于它的根节点的值&…

【网络层协议】ARP攻击与欺骗常见的手段以及工作原理

个人主页:insist--个人主页​​​​​​ 本文专栏:网络基础——带你走进网络世界 本专栏会持续更新网络基础知识,希望大家多多支持,让我们一起探索这个神奇而广阔的网络世界。 目录 一、ARP攻击的常见手段 第一种:IP…

【健康医疗】Axure用药提醒小程序原型图,健康管理用药助手原型模板

作品概况 页面数量:共 20 页 兼容软件:Axure RP 9/10,不支持低版本 应用领域:健康管理,用药助手 作品申明:页面内容仅用于功能演示,无实际功能 作品特色 本作品为「用药提醒」小程序原型图…

Spring Boot 知识集锦之actuator监控端点详解

文章目录 0.前言1.参考文档2.基础介绍默认支持的端点 3.步骤3.1. 引入依赖3.2. 配置文件3.3. 核心源码 4.示例项目5.总结 0.前言 背景: 一直零散的使用着Spring Boot 的各种组件和特性,从未系统性的学习和总结,本次借着这个机会搞一波。共同学…

Android NDK JNI与Java的相互调用

一、Jni调用Java代码 jni可以调用java中的方法和java中的成员变量,因此JNIEnv定义了一系列的方法来帮助我们调用java的方法和成员变量。 以上就是jni调用java类的大部分方法,如果是静态的成员变量和静态方法,可以使用***GetStaticMethodID、CallStaticObjectMethod等***。就…

『C语言』数据在内存中的存储规则

前言 小羊近期已经将C语言初阶学习内容与铁汁们分享完成,接下来小羊会继续追更C语言进阶相关知识,小伙伴们坐好板凳,拿起笔开始上课啦~ 一、数据类型的介绍 我们目前已经学了基本的内置类型: char //字符数据类型 short …

SpeedBI数据可视化工具:浏览器上做分析

SpeedBI数据分析云是一种在浏览器上进行数据可视化分析的工具,它能够将数据以可视化的形式呈现出来,并支持多种数据源和图表类型。 所有操作,均在浏览器上进行 在浏览器中打开SpeedBI数据分析云官网,点击【免费使用】进入&#…

微服务(多级缓存)

多级缓存 1.什么是多级缓存 传统的缓存策略一般是请求到达Tomcat后,先查询Redis,如果未命中则查询数据库,如图: 存在下面的问题: 请求要经过Tomcat处理,Tomcat的性能成为整个系统的瓶颈Redis缓存失效时&…

SpringCloud学习笔记(二)_Eureka注册中心

一、Eureka简介 Eureka是一项基于REST(代表性状态转移)的服务,主要在AWS云中用于定位服务,以实现负载均衡和中间层服务器的故障转移。我们称此服务为Eureka Server。Eureka还带有一个基于Java的客户端组件Eureka Client&#xff…

发布 net 项目 到 Docker

背景 因为发布到 centOS8 使用 screen -S 可以,想开机自启 使用 nohup 命令有启动不起来。环境问题不好找,就想尝试用 docker 运行 步骤 在生成的 Dockerfile 文件里增加修改时区指令 因为我们用的都是北京时间所以 创建镜像的时候就调整好 #设置时间…

B站视频码率用户上传视频的视频码率

一般来说,B站用户可以根据自己的视频内容和需求来选择视频的码率,但以下是一些常见的视频码率范围,供用户参考: 标清(SD): 码率范围可能在500 Kbps至1.5 Mbps左右,适用于480p的分辨率…

【JavaEE基础学习打卡05】JDBC之基本入门就可以了

目录 前言一、JDBC学习前说明1.Java SE中JDBC2.JDBC版本 二、JDBC基本概念1.JDBC原理2.JDBC组件 三、JDBC基本编程步骤1.JDBC操作的数据库准备2.JDBC操作数据库表步骤 四、代码优化1.简单优化2.with-resources探讨 总结 前言 📜 本系列教程适用于JavaWeb初学者、爱好…

小程序中的页面配置和网络数据请求

页面配置文件和常用的配置项 1.在msg.json中配置window中的颜色和背景色 "navigationBarBackgroundColor": "#efefef","navigationBarTextStyle": "black" 2.可以看到home中的没有发生变化但是msg的发生变化了,这个和前面的…