😎😎😎物体检测-系列教程 总目录
有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码
18、训练配置文件
位置yolov5/data文件夹/hyp.scratch.yaml文件
文件内容:
lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
momentum: 0.937 # SGD momentum/Adam beta1
weight_decay: 0.0005 # optimizer weight decay 5e-4
giou: 0.05 # GIoU loss gain
cls: 0.5 # cls loss gain
cls_pw: 1.0 # cls BCELoss positive_weight
obj: 1.0 # obj loss gain (scale with pixels)
obj_pw: 1.0 # obj BCELoss positive_weight
iou_t: 0.20 # IoU training threshold
anchor_t: 4.0 # anchor-multiple threshold
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
hsv_v: 0.4 # image HSV-Value augmentation (fraction)
degrees: 0.0 # image rotation (+/- deg)
translate: 0.5 # image translation (+/- fraction)
scale: 0.5 # image scale (+/- gain)
shear: 0.0 # image shear (+/- deg)
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
flipud: 0.0 # image flip up-down (probability)
fliplr: 0.5 # image flip left-right (probability)
mixup: 0.0 # image mixup (probability)
解释部分参数含义
- lr0,初始学习率
- lrf,余弦退火,使用余弦函数动态降低学习率
- warmup_epochs,刚开始的几个epoch用小一点的学习率,让模型预热
- GIoU(Generalized Intersection over Union)损失用于改善模型对于边界框位置的预测准确性。这个参数控制着GIoU损失在总损失中的比重
- cls,一个超参数,控制分类损失在总损失中的权重
- cls_pw,对于二元交叉熵损失(BCELoss),这个参数给出正样本的权重,用于平衡正负样本之间的不均衡
- fl_gamma (焦点损失伽马): 焦点损失用于缓解正负样本不平衡问题,通过减少易分类样本的损失影响,这个参数控制了焦点损失的强度
- 从hsv_h到最后的mixup都是图像增强的参数
19、训练脚本
在train.py中的配置参数已经在yolov5系列的第一篇博客有过讲解
train.py写了500多行代码,但是实际上能用到的很少,因为v5是一个非常工程化的版本,适应了很多版本很多场景
训练脚本可先简单分为3个部分,然后再细分
- 导包部分,大约30行代码,这部分不做解读
- 训练函数def train(),300多行代码
- main函数,100多行代码
由于实在太长,不可能逐行解释,我会将训练函数分成多个部分来解读,这是导包部分的代码:
import argparse
import math
import os
import random
import time
import logging
from pathlib import Pathimport numpy as np
import torch.distributed as dist
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data
import yaml
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdmimport test # import test.py to get mAP after each epoch
from models.yolo import Model
from utils.datasets import create_dataloader
from utils.general import (torch_distributed_zero_first, labels_to_class_weights, plot_labels, check_anchors, labels_to_image_weights,compute_loss, plot_images, fitness, strip_optimizer, plot_results, get_latest_run, check_dataset, check_file,check_git_status, check_img_size, increment_dir, print_mutation, plot_evolution, set_logging)
from utils.google_utils import attempt_download
from utils.torch_utils import init_seeds, ModelEMA, select_device, intersect_dicts
第一部分主要是导入了一些常用python与深度学习任务的辅助工具
第二部分主要是torch相关的工具
第三部分主要是导入yolov5项目的各个模块的辅助函数与类的实例化,一共有26个,也就是说光是辅助函数的调用就有26个,确实太复杂了