YOLOV8逐步分解(6)_模型训练初始设置之image size检测batch预设及dataloder初始化

yolov8逐步分解(1)--默认参数&超参配置文件加载

yolov8逐步分解(2)_DetectionTrainer类初始化过程

yolov8逐步分解(3)_trainer训练之模型加载

YOLOV8逐步分解(4)_模型的构建过程

YOLOV8逐步分解(5)_模型训练初始设置之混合精度训练AMP

        接逐步分解(5),继续模型训练初始设置的讲解,本章将讲解image size检测、batch预设及dataloder初始代码。

1. image size代码

        # Check imgszgs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32)         # grid size (max stride)self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)

这段代码是用来处理输入图像的尺寸(imgsz)。它的作用:

1.1 gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32):

        计算出模型的最大 stride 值。

        如果模型有 stride 属性,就取它的最大值;否则默认为 32。

        这个 gs 变量代表了模型的网格大小(grid size),它是用于确定输入图像尺寸的一个重要参数。

1.2 self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1):

        代码调用了 check_imgsz() 函数,用于检查和设置输入图像的尺寸。

        self.args.imgsz 是用户传入的期望输入图像尺寸。

        check_imgsz() 函数会根据模型的 stride 值和其他参数,对 self.args.imgsz 进行调整和验证。

        具体来说:

                stride=gs: 使用计算出的网格大小作为 stride 参数。

                floor=gs: 确保输入图像尺寸是网格大小的倍数。

                max_dim=1: 限制输入图像的最大维度为 1。

        这段代码的目的是确保输入图像的尺寸与模型的特性(如 stride)匹配,以确保模型能够正确地处理输入数据。这有助于提高模型的性能和稳定性。

2. batch size设置

        # Batch sizeif self.batch_size == -1: #表示批量大小需要自动估计if RANK == -1:  # single-GPU only, estimate best batch sizeself.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)#估计最佳批量大小else:SyntaxError('batch=-1 to use AutoBatch is only available in Single-GPU training. ''Please pass a valid batch size value for Multi-GPU DDP training, i.e. batch=16')

处理训练时的批量大小(batch size):

2.1 if self.batch_size == -1::

        这个条件检查是否需要自动估计批量大小。

        如果 self.batch_size 为 -1,表示需要自动估计最佳批量大小。

2.2 if RANK == -1:

        这个条件检查当前是否处于单 GPU 训练模式。

        如果是单 GPU 训练,才能使用自动估计批量大小的功能。

2.3 self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp):

        在单 GPU 训练模式下,代码调用 check_train_batch_size() 函数来估计最佳批量大小。

        这个函数会根据模型、输入图像尺寸和是否启用混合精度训练,来计算出最佳的批量大小。

        计算结果会赋值给 self.args.batch 和 self.batch_size。

2.4 如果不是单 GPU 训练模式(即分布式训练),就会抛出一个 SyntaxError 异常。

        异常信息提示用户,在分布式训练时不能使用自动批量大小估计功能,需要手动设置一个有效的批量大小。

        这段代码的目的是尽可能自动地估计出最佳的批量大小,以提高训练的效率和性能。但这个功能只在单 GPU 训练模式下可用,在分布式训练中需要手动设置批量大小。

3. dataloader 初始化

        # Dataloadersbatch_size = self.batch_size // max(world_size, 1)self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')#获取训练集if RANK in (-1, 0):self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val') #获取测试集self.validator = self.get_validator() #创建验证器(validator),用于评估模型在验证数据集上的性能。metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))  # TODO: init metrics for plot_results()?self.ema = ModelEMA(self.model)if self.args.plots and not self.args.v5loader: #如果 self.args.plots 为真且 self.args.v5loader 为假self.plot_training_labels() #绘制训练标签的图表

设置和获取训练集和测试集的数据加载器(dataloader):

3.1 batch_size = self.batch_size // max(world_size, 1):

        这一行代码计算出每个进程(process)使用的批量大小。

        它将 self.batch_size 除以 world_size (分布式训练时的进程数)或 1(单机训练时)。

        这样做是为了确保在分布式训练时,每个进程使用的批量大小是合理的。

3.2 self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train'):

        这一行代码通过调用 self.get_dataloader() 函数,获取训练数据集的数据加载器。

        self.trainset 是训练数据集,batch_size 是计算出的批量大小,rank 是当前进程的序号(在分布式训练时使用)。

        mode='train' 表示这是用于训练的数据加载器。

3.3 if RANK in (-1, 0)::

        这个条件检查当前是否处于单机训练模式或主进程(rank 为 0)。

        只有在这些情况下,才会执行以下操作。

3.4 self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val'):

        这一行代码获取验证数据集的数据加载器。

        self.testset 是验证数据集,批量大小是训练批量大小的两倍,rank 设置为 -1 表示不参与分布式训练。

        mode='val' 表示这是用于验证的数据加载器。

3.5 self.validator = self.get_validator():

        这一行代码创建了一个验证器(validator)对象,用于评估模型在验证数据集上的性能。

3.6 metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val'):

        这一行代码获取验证阶段需要的所有度量指标的键(key)。

3.7 self.validator.metrics.keys 是验证器定义的度量指标,self.label_loss_items(prefix='val') 是验证阶段的标签损失项。

3.8 self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))):

        这一行代码初始化了所有度量指标的值为 0。

3.9 self.ema = ModelEMA(self.model):

        这一行代码创建了一个指数移动平均(EMA)模型,用于在训练过程中保存模型的滚动平均值。

3.10 if self.args.plots and not self.args.v5loader: self.plot_training_labels():

        如果需要绘制训练标签的图表,并且不使用 v5 格式的数据加载器,就会调用 self.plot_training_labels() 函数。

        这段代码的主要目的是设置训练集和验证集的数据加载器,创建验证器,初始化度量指标,以及设置 EMA 模型等。这些都是训练模型时的常见操作。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/17015.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OpenLayers6入门,OpenLayers实现在地图上拖拽编辑修改绘制图形

专栏目录: OpenLayers6入门教程汇总目录 前言 在前面一章中,我们已经学会了如何绘制基础的三种图形线段、圆形和多边形:《OpenLayers6入门,OpenLayers图形绘制功能,OpenLayers实现在地图上绘制线段、圆形和多边形》,那么本章将在此基础上实现图形的拖拽编辑功能,方便我…

使用Java 读取PDF表格数据并保存到TXT或Excel

目录 导入相关Java库 Java读取PDF表格数据并保存到TXT Java读取PDF表格数据并保存到Excel 在日常工作中,我们经常需要处理来自各种来源的数据。其中,PDF 文件是常见的数据来源之一。这类文件通常包含丰富的信息,其中可能包含重要的表格数据…

FreeRtos进阶——栈保存现场的几种场景

MCU架构 在认识栈的结构前,我们先来认识以下单片机的简单架构。在我们的CPU中有着很重要的一个模块——寄存器(R0-R15),其中R13,R14,R15的别称分别为SP栈顶指针、LR返回地址、PC当前指令地址。外部RAM是单片…

css中min-height

在CSS中&#xff0c;min-height 属性用于设置元素的最小高度。这意味着&#xff0c;即使内容没有达到指定的最小高度&#xff0c;元素也会尝试占据至少指定的最小高度。 例如&#xff0c;如果你有一个 <div> 元素&#xff0c;并希望它至少有200px的高度&#xff0c;即使…

Android Gradle plugin 版本和Gradle 版本

1.当看到这两个版本时&#xff0c;确实有点迷糊。但是他们是独立的&#xff0c;没有太大关联。 就是说在Android studio中看到的两个版本信息&#xff0c;并无太大关联&#xff0c;是相互独立的。Gradle插件版本决定了你的项目是如何构建的&#xff0c;而Gradle版本是执行构建…

对竞品分析的理解

一、竞品分析是什么 竞品分析即对竞争对手进行分析&#xff0c;是市场研究中的一项重要工作&#xff0c;它可以帮助企业了解竞争对手的产品、策略、市场表现等信息&#xff0c;通过竞品分析可以为自己的产品制定更加精准的策略。 二、为什么要做竞品分析 1.了解市场情况 了解…

002 访问修饰符 package

访问修饰符 在Java中&#xff0c;protected、private、public 和包级别访问权限&#xff08;有时称为default或package-private&#xff09;是用于控制类、变量、方法和构造器的可见性和可访问性的修饰符。下面是这些修饰符的主要区别&#xff1a; public&#xff1a; 可见性…

vue/core源码中ref源码的js化

起源&#xff1a; 当看见reactivity文件中的ref.ts文件长达五百多的ts代码后&#xff0c;突发奇想想看下转化成js有多少行。 进行转化&#xff1a; let shouldTrack true; // Define shouldTrack variable let activeEffect null; // Define activeEffect variable// 定义…

M2m中的采样

采样的完整代码 import torch import numpy as np from torchvision import datasets, transforms from torch.utils.data import DataLoader, WeightedRandomSampler, SubsetRandomSamplerdef get_oversampled_data(dataset, num_sample_per_class):""" Gener…

C语言从头学12——流程控制(一)

C语言程序的执行顺序是从前到后依次序执行的。如果想要控制程序执行的流程&#xff0c;就必须使用 流程控制的语法结构&#xff0c;分为条件执行和循环执行。 1、if语句 if 语句在前面的举例中曾经出现过&#xff0c;这里做详细介绍。该语句用于条件判断&#xff0c;满…

Upstream最新发布2024年汽车网络安全报告-百度网盘下载

Upstream最新发布2024年汽车网络安全报告-百度网盘下载 2024年2月7日&#xff0c;Upstream Security发布了2024年Upstream《GLOBAL AUTOMOTIVE CYBERSECURITY REPORT》。这份报告的第六版着重介绍了汽车网络安全的拐点&#xff1a;从实验性的黑客攻击发展到规模庞大的攻击&…

fpga系列 HDL 00 : 可编程逻辑器件原理

一次性可编程器件&#xff08;融保险丝实现&#xff09; 一次性可编程器件&#xff08;One-Time Programmable Device&#xff0c;简称 OTP&#xff09;是一种在制造后仅能编程一次的存储设备。OTP器件在编程后数据不可更改。这些器件在很多应用场景中具有独特的优势和用途。 …

【软件设计师】——10.面向对象技术

目录 10.1 基本概念 10.2设计原则 10.3 设计模式的概念与分类 10.4 创建型模式 10.4.1 Singleton 单例模式 10.4.2 Builder 构建器模式 10.4.3 Abstract Factory 抽象工厂模式 10.4.4 Prototype原型模式 10.4.5 Factory Method工厂方法模式 10.5 结构型模式 10.5.1 A…

【LeetCode算法】第83题:删除排序链表中的重复元素

目录 一、题目描述 二、初次解答 三、官方解法 四、总结 一、题目描述 二、初次解答 1. 思路&#xff1a;双指针法&#xff0c;只需遍历一遍。使用low指向前面的元素&#xff0c;high用于查找low后面与low不同内容的节点。将具有不同内容的节点链接在low后面&#xff0c;实…

【c++】菱形虚拟继承的虚函数表如何继承

请看如下代码 #include <iostream>// 基类 class Base { public:virtual void foo() { std::cout << "Base::foo()" << std::endl; }virtual void bar() { std::cout << "Base::bar()" << std::endl; } };// 虚拟继承的中间…

全栈:session用户会话信息,用户浏览记录实例

PHP中的session是一种存储机制&#xff0c;它允许您存储和跟踪用户在访问Web应用程序时的信息。会话通常用于存储用户特定的数据&#xff0c;如用户ID、购物车内容、用户偏好设置等&#xff0c;这些数据需要在多个页面请求之间保持不变。 session详解 1. 会话是如何工作的 会…

西门子S7-1200加入MRP 环网用法

MRP&#xff08;介质冗余&#xff09;功能概述 SIMATIC 设备采用标准的冗余机制为 MRP&#xff08;介质冗余协议&#xff09;&#xff0c;符合 IEC62439-2 标准&#xff0c;典型重新组态时间为 200ms&#xff0c;每个环网最多支持 50个设备。​博途TIA/WINCC社区VX群 ​博途T…

Linux 批量网络远程PXE

一、搭建PXE远程安装服务器 1、yum -y install tftp-server xinetd #安装tftp服务 2、修改vim /etc/xinetd.d/tftpTFTP服务的配置文件 systemctl start tftp systemctl start xinetd 3、yum -y install dhcp #---安装服务 cp /usr/share/doc/dhc…

c 语言 ---- 结构体

什么是结构体 自定义的数据类型 结构体的声明定义 //1.先声明再定义 struct point{int x;int y; };struct point p1,p2;//2.声明的同时定义 struct point{int x;int y; }p1,p2;typedef定义别名 关键字typedef用于为系统固有的或者程序员自定义的数据类型定义一个别名。数据类…

利用Python队列生产者消费者模式构建高效爬虫

目录 一、引言 二、生产者消费者模式概述 三、Python中的队列实现 四、生产者消费者模式在爬虫中的应用 五、实例分析 生产者类&#xff08;Producer&#xff09; 消费者类&#xff08;Consumer&#xff09; 主程序 六、总结 一、引言 随着互联网的发展&#xff0c;信…