YOLOv5代码解读[01] train.py

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
import argparse
import math
import os
import random
import sys
import time
from copy import deepcopy
from datetime import datetime
from pathlib import Path
import numpy as np
import yaml
from tqdm import tqdmimport torch
import torch.distributed as dist
import torch.nn as nn
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD, Adam, AdamW, lr_scheduler# 解析成绝对路径
FILE = Path(__file__).resolve()
# YOLOv5 root directory
ROOT = FILE.parents[0]  
# add ROOT to PATH
if str(ROOT) not in sys.path:sys.path.append(str(ROOT)) 
# 用os.path.relpath把绝对路径转换为相对路径relative
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  import val  
from models.experimental import attempt_load
from models.yolo import Model
# from models2.yolo import Model
from utils.autoanchor import check_anchors
from utils.autobatch import check_train_batch_size
from utils.callbacks import Callbacks
from utils.datasets import create_dataloader
from utils.downloads import attempt_download
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle,print_args, print_mutation, strip_optimizer)
from utils.loggers import Loggers
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss, ComputeLossOTA
from utils.metrics import fitness
from utils.plots import plot_lr_scheduler, plot_evolve, plot_labels
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first# DDP模式
# pytorch中的有两种分布式训练方式,一种是常用的DataParallel(DP),另外一种是DistributedDataParallel(DDP),两者都可以用来实现数据并行方式的分布式训练。
# DP是单进程多线程的实现方式,DDP是采用多进程的方式,DDP相比于DP训练速度要快。
# (1) 使用 torch.distributed.init_process_group 初始化进程组
# (2) 使用 torch.nn.parallel.DistributedDataParallel 创建分布式模型
# (3) 使用 torch.utils.data.distributed.DistributedSampler 创建 DataLoader
# (4) 调整其他必要的地方(tensor放到指定device上, S/L checkpoint,指标计算等)
# (5) 使用 torch.distributed.launch / torch.multiprocessing 或 slurm 开始训练
# 设置DDP模式的参数,world_size:表示全局进程个数,global_rank:进程编号,总共有多少个GPU。
# 只有多机多卡的时候,才会有'WORLD_SIZE'和'RANK'环境变量。因此来说,参数1和-1是单机多卡,单机单卡,CPU的情形# 设置DDP模式变量
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))  # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
# world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
# global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
# rank = global_rankdef train(hyp, opt, device, callbacks):version = opt.versionsave_dir = Path(opt.save_dir)epochs = opt.epochsbatch_size = opt.batch_sizeweights = opt.weightssingle_cls = opt.single_clsevolve = opt.evolvedata = opt.datacfg = opt.cfgresume = opt.resumenoval = opt.noval nosave = opt.nosaveworkers = opt.workersfreeze = opt.freeze#-------------------------------------------------------------------------------------------#"                                     训练权重保存路径                                       "#-------------------------------------------------------------------------------------------## 训练权重保存路径w = save_dir / 'weights'  (w.parent if evolve else w).mkdir(parents=True, exist_ok=True)  # make dir# 最后的模型last, 最好的模型bestlast = w / 'last.pt'best = w / 'best.pt'#-------------------------------------------------------------------------------------------#"                                   超参数Hyperparameters加载                                "#-------------------------------------------------------------------------------------------## 超参数Hyperparametersif isinstance(hyp, str):with open(hyp, errors='ignore') as f:# 加载hyps字典hyp = yaml.safe_load(f)  LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))# 保存opt.yaml和hyp.yaml配置if not evolve:with open(save_dir / 'hyp.yaml', 'w') as f:yaml.safe_dump(hyp, f, sort_keys=False)with open(save_dir / 'opt.yaml', 'w') as f:yaml.safe_dump(vars(opt), f, sort_keys=False)# 日志打印Loggersdata_dict = Noneif RANK in [-1, 0]:loggers = Loggers(save_dir, weights, opt, hyp, LOGGER)  if loggers.wandb:data_dict = loggers.wandb.data_dictif resume:weights, epochs, hyp, batch_size = opt.weights, opt.epochs, opt.hyp, opt.batch_size# Register actionsfor k in methods(loggers):callbacks.register_action(k, callback=getattr(loggers, k))# 是否绘制曲线plots = not evolve  cuda = device.type != 'cpu'init_seeds(1 + RANK)#-------------------------------------------------------------------------------------------#"                                      dataset数据集加载                                     "#-------------------------------------------------------------------------------------------## data: /home/easyits/road-risk-identification/yolo/YOLOV5/data/coco.yamlwith torch_distributed_zero_first(LOCAL_RANK):data_dict = data_dict or check_dataset(data)  # data_dict: # {'path': './datasets/roadrisk', #  'train': '/home/easyits/road-risk-identification/yolo/YOLOV5/datasets/roadrisk/train2017.txt', #  'val': '/home/easyits/road-risk-identification/yolo/YOLOV5/datasets/roadrisk/val2017.txt', #  'test': '/home/easyits/road-risk-identification/yolo/YOLOV5/datasets/roadrisk/val2017.txt', #  'nc': 6, 'names': ['build', 'person', 'sink', 'garbage', 'pit', 'trouble']} train_path = data_dict['train'] val_path = data_dict['val']# 类别数目class numbersnc = 1 if single_cls else int(data_dict['nc']) # 类别名字class namesnames = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names']  # 判断类别设置是否一致assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}'  # COCO datasetis_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') #-------------------------------------------------------------------------------------------#"                                        Model模型加载                                      "#-------------------------------------------------------------------------------------------## 加载预训练权重if version == 1:# 检查weights的后缀check_suffix(weights, '.pt')  pretrained = weights.endswith('.pt')if pretrained:# 如果在本地没有找到的话,就尝试下载# torch_distributed_zero_first(LOCAL_RANK): 用于同步不同进程对数据读取的上下文管理器with torch_distributed_zero_first(LOCAL_RANK):# google云盘下载weights = attempt_download(weights)  # 加载保存的checkpoint# load checkpoint to CPU to avoid CUDA memory leak# torch.load()会同时保存和加载模型的参数和结构信息ckpt = torch.load(weights, map_location='cpu')  # 定义模型结构model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # 排除exclude keys# 从零开始训练,由于anchor需要重新聚类,anchor要排除;断点续练,直接接在anchor;exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] csd = ckpt['model'].float().state_dict()  # 交集intersectcsd = intersect_dicts(csd, model.state_dict(), exclude=exclude)  # 模型加载状态字典# 设置strict参数为False,可以忽略那些没有匹配到的keys。model.load_state_dict(csd, strict=False)  LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}')  # 构建模型else:model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) elif version == 2:# 检查weights的后缀check_suffix(weights, ['.pt', '.pth'])  pretrained = weights.endswith('.pt') or weights.endswith('.pth')if pretrained:ckpt = torch.load(weights, map_location=device)model = Model(cfg).to(device)if resume:# 排除exclude keys# 从零开始训练,由于anchor需要重新聚类,anchor要排除;断点续练,直接接在anchor;exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] csd = ckpt['model'].float().state_dict()  # 交集intersectcsd = intersect_dicts(csd, model.state_dict(), exclude=exclude)  # 模型加载状态字典# 设置strict参数为False,可以忽略那些没有匹配到的keys。model.load_state_dict(csd, strict=False)  else:# 删除有关分类类别的权重for k in list(ckpt.keys()):if "head" in k:del ckpt[k]if 'model' in ckpt.keys():csd = ckpt['model']  else:csd = ckpt # 模型加载状态字典# 设置strict参数为False,可以忽略那些没有匹配到的keys。model.load_state_dict(csd, strict=False)miss_key, unexpected_key = model.backbone.load_state_dict(csd, strict=False)print("预训练权重加载结果: \n") # LOGGER.info(f'miss_key:{miss_key}')# LOGGER.info(f'unexpected_key:{unexpected_key}')else:model = Model(cfg).to(device)  else:pass# Freeze是否冻住某些网络层或全部层# freeze: [0]freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))]  # freeze: []# model.state_dict(): 字典的遍历默认是遍历key, 例如:conv1.weight, conv1.bias。# model.parameters()# model.named_parameters(): 字典的遍历是一个元组tuple ,元组的第一个元素是参数所对应的名称,第二个元素就是对应的参数值。for k, v in model.named_parameters():# 训练所有的层train all layersv.requires_grad = True # 冻住某些层# k: model.2.cv2.bn.weight, model.2.m.0.cv1.conv.weight

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/691755.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

文件包含+文件上传漏洞(图片马绕过)

目录 一.文件包含二.文件上传三.图片马四.题目 一.文件包含 将已有的代码以文件形式包含到某个指定的代码中,从而使用其中的代码或者数据,一般是为了方便直接调用所需文件,文件包含的存在使得开发变得更加灵活和方便(若对用户输入…

网络安全-pikachu之文件上传漏洞1

文件上传漏洞是危害极大的,一旦成功,可以获取服务器的最高权限。 pikachu介绍: 文件上传功能在web应用系统很常见,比如很多网站注册的时候需要上传头像、上传附件等等。当用户点击上传按钮后,后台会对上传的文件进行判…

☀️将大华摄像头画面接入Unity 【1】配置硬件和初始化摄像头

一、硬件准备 目前的设想是后期采用网口供电的形式把画面传出来,所以这边我除了大华摄像头还准备了POE供电交换机,为了方便索性都用大华的了,然后全都连接电脑主机即可。 二、软件准备 这边初始化摄像头需要用到大华的Configtool软件&#…

【知识点】CNN中concat与add的区别

cat操作经常用于将特征联合,多个卷积特征提取框架提取的特征融合或者是将输出层的信息进行融合;而add层更像是信息之间的叠加。 add是在一个特征上增加其语义信息,对最终的图像的分类是有益;cat导致的结果改进可能是由于cat操作通…

工业自动化部署选择主板的关键因素

**在构建任何计算机时,选择合适的主板至关重要。**对于游戏台式机,您需要选择能够支持您玩的游戏类型而不会出现任何问题的最新和最佳规格。当涉及工业应用时,影响您决策的变量变得更加重要。作为任何基于计算的应用中最关键的组件之一&#…

搜维尔科技:分析OptiTrack光学动作捕捉应用领域!

虚拟制作 当今虚拟制作阶段低延迟、超精确摄像机跟踪的事实上的标准。 用于运动科学的 OptiTrack OptiTrack 系统提供世界领先的测量精度和简单易用的工作流程,为研究人员和生物力学师的研究提供理想的 3D 跟踪数据。对所有主要数字测力台、EMG 和模拟设备的本机即…

trojan 突然无法上网

[ERROR] 2024/02/19 18:14:45 github.com/p4gefau1t/trojan-go/tunnel/tls.(*Server).acceptLoop.func1:server.go:140 tls handshake failed | remote error: tls: bad certificate 报证书问题,更新证书发现无法解决 最后突然客户端有一个配置 验证证书&#xf…

淘宝、1688以图搜图api使用示例

识图?当我们不知道图片内的信息时,可以通过以图识图的方式,找到对应的图片,以及对该图片的介绍。 识图工具是通过AI技术实现的,但其实识图并不需安装任何软件,在搜索引擎中就可以完成。“以图搜图”也可以…

组态软件行业分析:预计2025年市场空间可达数千亿元

组态软件可以对从控制系统得到的以及自身产生的数据进行记录存储。在系统发生事故和故障的时候,利用记录的运行工况数据和历史数据,可以对系统故障原因等进行分析定位,责任追查等。通过对数据的质量统计分析,还可以提高自动化系统…

[office] Excel中DCOUNT函数在复杂的数据中统计应用图解教程 #职场发展#其他#媒体

Excel中DCOUNT函数在复杂的数据中统计应用图解教程 Excel中DCOUNT函数返回数据库或数据区域的列中满足指定条件并且包含数字的单元格的个数。 在Excel中使用DCOUNT函数可以轻松地从数据库或数据区域中查找符合指定条件并且是数字的单元格的数量。 Excel中DCOUNT函数在复杂的…

佳能2580的下载手册

凡是和电子产品有关的产品其内部都开始不断地进行内卷,在不断地内卷背后,意味着科技更新和换代,自己也入手了一台佳能2580的打印机,一台相对比较老式的打印机,以此不断地自己想要进行打印的需要。 下载的基础步骤&…

Eureka注册中心:实现微服务架构下的服务发现与治理的艺术(一)

本系列文章简介: 在本系列文章中,我们将深入探讨Eureka注册中心在微服务架构中的应用和实践。我们将介绍Eureka的基本原理、关键特性以及配置和优化方法。同时,我们还将分享如何通过监控和日志分析来保障Eureka注册中心的稳定运行。希望通过本…

【ansible】认识ansible,了解常用的模块

目录 一、ansible是什么? 二、ansible的特点? 三、ansible与其他运维工具的对比 四、ansible的环境部署 第一步:配置主机清单 第二步:完成密钥对免密登录 五、ansible基于命令行完成常用的模块学习 模块1:comma…

【附代码】Python Excel合并单元格(OpenPyXL) Pandas.DataFrame groupby样式保存xlsx

文章目录 相关文献Excel合并单元格并居中Pandas.DataFrame groupby样式保存Excel 作者:小猪快跑 基础数学&计算数学,从事优化领域5年,主要研究方向:MIP求解器、整数规划、随机规划、智能优化算法 如有错误,欢迎指…

解释 C++ 中的静态成员变量和静态成员函数。

解释 C 中的静态成员变量和静态成员函数。 在C中,静态成员变量和静态成员函数都与类本身相关联,而不是与类的各个实例相关联。这意味着无论创建了多少个类的对象,静态成员变量和静态成员函数的存储空间只分配一次,并且它们可以被…

Bert基础(二)--多头注意力

多头注意力 顾名思义,多头注意力是指我们可以使用多个注意力头,而不是只用一个。也就是说,我们可以应用在上篇中学习的计算注意力矩阵Z的方法,来求得多个注意力矩阵。让我们通过一个例子来理解多头注意力层的作用。以All is well…

在本地计算机上运行Python程序

在本地计算机上运行Python程序的详细步骤: 第一步:安装Python解释器 Python解释器是运行Python程序所必需的。你可以从Python的官方网站(https://www.python.org/downloads/)下载最新版本的Python解释器。下载完成后&#xff0c…

linux监控系统资源命令

当前CPU内核版本 [rootVM-12-12-centos ~]# cat /proc/version Linux version 3.10.0-1160.11.1.el7.x86_64 (mockbuildkbuilder.bsys.centos.org) (gcc version 4.8.5 20150623 (Red Hat 4.8.5-44) (GCC) ) #1 SMP Fri Dec 18 16:34:56 UTC 2020 当前系统版本 [rootVM-12-1…

Python六级考试笔记

Python六级考试笔记【源源老师】 六级标准 一、 掌握文件操作及数据格式化。 二、 掌握数据可视化操作。 三、 理解类与对象的概念,初步掌握类与对象的使用。 四、 掌握SQLite数据库基础编程。 五、 掌握简单的使用tkinter的GUI设计。 ​ 1. 文件操作 &#xff0…

3、安装插件

以下插件请按需安装 Mask Passwords 使用此插件可以将在console中出现的password加密,以防止密码泄露。 Job Import Plugin 支持从其他的Jenkins上远程导入job Extended E-mail Notification 在job构建后发送邮件 Python Adds the ability to execute python scrip…