YOLOv8专栏导航:点击此处跳转
K折交叉验证
K折交叉验证(K-Fold Cross-Validation)是一种常用的机器学习模型评估方法,可以帮助我们评估模型的性能,特别适用于数据集相对较小的情况。
在K折交叉验证中,将原始数据集分成K个子集,然后依次将其中一个子集作为验证集,其余K-1个子集作为训练集进行模型训练和评估。这样可以得到K个模型,每个模型都在不同的验证集上进行评估,最后将K个模型的评估结果,求平均或取最优结果作为最终评估。
优点:
- 充分利用数据集: 在K折交叉验证中,整个数据集被划分为K个互斥的折叠(Folds)。每次训练模型时,都有K-1个折叠用于训练,而剩下的一个用于验证。这样,每个样本都有机会作为验证集的一部分,从而充分利用了数据集中的所有样本。
- 减轻过拟合风险: 由于每个样本都会在训练集和验证集中出现,模型在验证集上的性能评估更具有代表性。这有助于减轻由于数据稀疏或类别不平衡导致的过拟合问题。
- 更稳健的模型评估: K折交叉验证计算多次模型性能的平均值,提供了更稳健的性能评估。这对于对抗数据稀疏和类别不平衡等挑战的模型评估尤为重要,因为它减少了单次评估可能引入的随机性。
- 处理类别不平衡: 如果数据集中某些类别的样本数量较少,K折交叉验证可以确保每个折叠中都包含这些少见类别的样本。这有助于确保模型在少见类别上的性能得到充分评估,并提高对类别不平衡的鲁棒性。
YOLOv8
YOLOv8 是由 YOLOv5 的发布者 Ultralytics 发布的最新版本的 YOLO。它可用于对象检测、分割、分类任务以及大型数据集的学习,并且可以在包括 CPU 和 GPU 在内的各种硬件上执行。
YOLOv8是一种尖端的、最先进的 (SOTA) 模型,它建立在以前成功的 YOLO 版本的基础上,并引入了新的功能和改进,以进一步提高性能和灵活性。YOLOv8 旨在快速、准确且易于使用,这也使其成为对象检测、图像分割和图像分类任务的绝佳选择。具体创新包括一个新的骨干网络、一个新的 Ancher-Free 检测头和一个新的损失函数,还支持YOLO以往版本,方便不同版本切换和性能对比。
YOLOv8的K折交叉验证实现步骤:
- 首先,将数据集划分成K个子集,可以使用现有的数据集划分函数或手动划分。
- 然后,使用一个循环迭代K次,每次将其中一个子集作为验证集,其余K-1个子集作为训练集。
- 在每次迭代中,使用训练集进行模型训练,并使用验证集进行模型评估。
- 可以根据需要调整模型的超参数或其他设置。例如,可以尝试不同的学习率、迭代次数等。
- 最后,将K次迭代中的评估结果进行平均或取最优结果作为最终的模型评估结果。
官方项目下载
git clone https://github.com/ultralytics/ultralytics.git
🚀 YOLOv8-K折交叉验证代码实现
在上述下载好的项目文件夹下新建k-fold-train.py
文件,并添加下述代码:
import argparse
import datetime
from itertools import chain
import os
from pathlib import Path
import shutil
import yaml
import pandas as pd
from collections import Counter
from sklearn.model_selection import KFold
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from ultralytics import YOLONUM_THREADS = min(8, max(1, os.cpu_count() - 1))def parse_opt():parser = argparse.ArgumentParser()parser.add_argument('--data', default=r'./data') # 数据集路径parser.add_argument('--ksplit', default=5, type=int) # K-Fold交叉验证拆分数据集parser.add_argument('--im_suffixes', default=['jpg', 'png', 'jpeg'], help='images suffix') # 图片后缀名return parser.parse_args()def run(func, this_iter, desc="Processing"):with ThreadPoolExecutor(max_workers=NUM_THREADS, thread_name_prefix='MyThread') as executor:results = list(tqdm(executor.map(func, this_iter), total=len(this_iter), desc=desc))return resultsdef main(opt):dataset_path, ksplit, im_suffixes = Path(opt.data), opt.ksplit, opt.im_suffixessave_path = Path(dataset_path / f'{datetime.date.today().isoformat()}_{ksplit}-Fold_Cross-Valid')save_path.mkdir(parents=True, exist_ok=True)# 获取所有图像和标签文件的列表images = sorted(list(chain(*[(dataset_path / "images").rglob(f'*.{ext}') for ext in im_suffixes])))# images = sorted(image_files)labels = sorted((dataset_path / "labels").rglob("*.txt"))root_directory = Path.cwd()print("当前文件运行根目录:", root_directory)if len(images) != len(labels):print('*' * 20)print('当前数据集和标签数量不一致!!!')print('*' * 20)# 从YAML文件加载类名classes_file = sorted(dataset_path.rglob('classes.yaml'))[0]assert classes_file.exists(), "请创建classes.yaml类别文件"if classes_file.suffix == ".txt":passelif classes_file.suffix == ".yaml":with open(classes_file, 'r', encoding="utf8") as f:classes = yaml.safe_load(f)['names']cls_idx = sorted(classes.keys())# 创建DataFrame来存储每张图像的标签计数indx = [l.stem for l in labels] # 使用基本文件名作为ID(无扩展名)labels_df = pd.DataFrame([], columns=cls_idx, index=indx)# 计算每张图像的标签计数for label in labels:lbl_counter = Counter()with open(label, 'r') as lf:lines = lf.readlines()for l in lines:# YOLO标签使用每行的第一个位置的整数作为类别lbl_counter[int(l.split(' ')[0])] += 1labels_df.loc[label.stem] = lbl_counter# 用0.0替换NaN值labels_df = labels_df.fillna(0.0)kf = KFold(n_splits=ksplit, shuffle=True, random_state=20) # 设置random_state以获得可重复的结果kfolds = list(kf.split(labels_df))folds = [f'split_{n}' for n in range(1, ksplit + 1)]folds_df = pd.DataFrame(index=indx, columns=folds)# 为每个折叠分配图像到训练集或验证集for idx, (train, val) in enumerate(kfolds, start=1):folds_df[f'split_{idx}'].loc[labels_df.iloc[train].index] = 'train'folds_df[f'split_{idx}'].loc[labels_df.iloc[val].index] = 'val'# 计算每个折叠的标签分布比例fold_lbl_distrb = pd.DataFrame(index=folds, columns=cls_idx)for n, (train_indices, val_indices) in enumerate(kfolds, start=1):train_totals = labels_df.iloc[train_indices].sum()val_totals = labels_df.iloc[val_indices].sum()# 为避免分母为零,向分母添加一个小值(1E-7)ratio = val_totals / (train_totals + 1E-7)fold_lbl_distrb.loc[f'split_{n}'] = ratiods_yamls = []for split in folds_df.columns:split_dir = save_path / splitsplit_dir.mkdir(parents=True, exist_ok=True)(split_dir / 'train' / 'images').mkdir(parents=True, exist_ok=True)(split_dir / 'train' / 'labels').mkdir(parents=True, exist_ok=True)(split_dir / 'val' / 'images').mkdir(parents=True, exist_ok=True)(split_dir / 'val' / 'labels').mkdir(parents=True, exist_ok=True)dataset_yaml = split_dir / f'{split}_dataset.yaml'ds_yamls.append(dataset_yaml.as_posix())split_dir = (root_directory / split_dir).as_posix()with open(dataset_yaml, 'w') as ds_y:yaml.safe_dump({'train': split_dir + '/train/images','val': split_dir + '/val/images','names': classes}, ds_y)# print(ds_yamls)with open(dataset_path / 'yaml_paths.txt', 'w') as f:for path in ds_yamls:f.write(path + '\n')args_list = [(image, save_path, folds_df) for image in images]run(split_images_labels, args_list, desc=f"Creating dataset")def split_images_labels(args):image, save_path, folds_df = argslabel = image.parents[1] / 'labels' / f'{image.stem}.txt'if label.exists():for split, k_split in folds_df.loc[image.stem].items():# 目标目录img_to_path = save_path / split / k_split / 'images'lbl_to_path = save_path / split / k_split / 'labels'shutil.copy(image, img_to_path / image.name)shutil.copy(label, lbl_to_path / label.name)if __name__ == "__main__":opt = parse_opt()main(opt)model = YOLO('yolov8n.pt', task='train')# 从文本文件中加载内容并存储到一个列表中ds_yamls = []with open(Path(opt.data) / 'yaml_paths.txt', 'r') as f:for line in f:# 去除每行末尾的换行符line = line.strip()ds_yamls.append(line)# 打印加载的文件路径列表print(ds_yamls)for k in range(opt.ksplit):dataset_yaml = ds_yamls[k]name = Path(dataset_yaml).stemmodel.train(data=dataset_yaml,batch=16,epochs=100,imgsz=640,device=0,workers=8,project="runs/train",name=name,)print("*"*40)print("K-Fold Cross Validation Completed.")print("*"*40)
命令行运行
python k-fold-train.py --data ./data --ksplit 5
data
:数据集所在目录,数据集的分布如下:- data- images- labels
images
存放图片文件,labels
存放txt
标注文件。ksplit
:数据集拆分为ksplit
个子集,该值通常取5
或者10
,某些情况下可取其他数值。