[杂记]mmdetection3.x中的数据流与基本流程详解(数据集读取, 数据增强, 训练)


之前跑了一下mmdetection 3.x自带的一些算法, 但是具体的代码细节总是看了就忘, 所以想做一些笔记, 方便初学者参考. 其实比较不能忍的是, 官网的文档还是空的…

在这里插入图片描述

这次想写其中的数据流是如何运作的, 包括从读取数据集的样本与真值, 到数据增强, 再到模型的forward当中.


0. MMDetection整体组成部分

让我们首先回顾一下C++的标准模板库(STL)是怎样设计的. STL的三个核心组件是容器, 算法与迭代器. 容器, 例如vector, queue等等, 他们是负责存储数据的, 算法是负责进行一些操作, 例如排序, 查找等等. 而迭代器是容器与算法之间的桥梁, 也就是算法可以通过迭代器去访问容器, 使得算法可以独立于容器的类型进行操作. 三个部分相辅相成, 就达到了泛型编程的理念.

再让我们回顾一下一套深度学习的代码包含什么部分. 从大的方面来说, 需要有数据的读取与增强(DataLoader), 模型的定义, 损失函数的计算, 负责梯度传播的优化器, 在验证(测试)集上的评估等. 同理, MMDetection也是按照这种方式来的, 并且每个部分接口相通, 就可以实现更广义的模型定义和训练方式.

mmengine/registry/__init__.py中, 我们可以看到, MMEngine(或者说MMDetection)总体有这些类型的模块:

from .root import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS, HOOKS,INFERENCERS, LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS,MODELS, OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS,TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS,WEIGHT_INITIALIZERS)

那么以上这么多模块可以分成几类, 分别负责什么呢? 按照我个人的理解, MMDetection的整体组成部分可以表示为下图:

在这里插入图片描述

为了节省空间, 优化器相关并未画出

1. 认识config文件

mmdetection设计的核心思想是通过字典来配置整个的训练过程和模型定义, 这些字典放在一个.py的config文件中. 一般来说,config文件最重要的就是数据加载(train_dataloader, val_dataloader和test_dataloader), 模型定义(model)和训练与测试过程(train_pipeline, test_pipeline). 除此之外, 还有一些训练, 测试配置(train_cfg, test_cfg)等等. 具体config的例子可以参照官网Learn about configs.

需要注意的是, mmdetection中字典定义class的方式, 往往是键type表示类的名字, 之后的其他键都是类初始化需要的参数. 例如, 如果我想自定义一个模型, 叫做MyModel, 定义在当前目录下的./models/my_model.py中, 定义方式如下:


from mmdet.registry import MODELS  # 自定义模型, 需要在模型库中"注册", 初始化时才能找到定义
from mmdet.models.mot.base import BaseMOTModel  # 一个模型基类@MODELS.register_module()  # 装饰器 在模型库中"注册"
class MyModel(BaseMOTModel):def __init__(self, arg1=..., arg2=..., arg3=...):...def loss(self, inputs, data_samples):  # 前向传播, inputs是输入tensor, data_samples是包含标签的列表...

如果按上述方式定义了模型, 那么在我们的配置文件中, 就是这个样子:


# 必须将自定义类的py文件导入 这样可以自动register自定义模型 否则模型初始化时找不到custom_imports = dict(imports=['models.my_model'],allow_failed_imports=False)# 现在就可以愉快的传参了
models=dict(type='MyModel', arg1=1, arg2=[16, 128], arg3=dict(channel=256), ...
)

同样, 我们可以自定义DataLoader, Loss, 等等.

此外, dict是可以嵌套的, 例如mmdetection将检测模型分成了backbone, neck和head三部分, 那么如果我们又自定义了一个Head, 叫MyHead:


from mmdet.registry import MODELS  # 自定义模型, 需要在模型库中"注册", 初始化时才能找到定义
from mmengine.model import BaseModule  # 一个模型基类@MODELS.register_module()  # 装饰器 在模型库中"注册"
class MyHead(BaseModule):def __init__(self, arg4=...):...

这样, 如果MyModel的前向传播过程中需要一个head, 则代码大致是这个样子:


from mmdet.registry import MODELS  # 自定义模型, 需要在模型库中"注册", 初始化时才能找到定义
from mmdet.models.mot.base import BaseMOTModel  # 一个模型基类@MODELS.register_module()  # 装饰器 在模型库中"注册"
class MyModel(BaseMOTModel):def __init__(self, arg1=..., arg2=..., arg3=...,head=...):self.head = MODELS.build(head)  # 建立Head的模型, 类型是nn.Module...def loss(self, inputs, data_samples):  # 前向传播, inputs是输入tensor, data_samples是包含标签的列表...  # 一些其他过程ret = self.head(inputs)  # forward...  # 后处理

配置文件中对应更改为:

如果按上述方式定义了模型, 那么在我们的配置文件中, 就是这个样子:


custom_imports = dict(imports=['models.my_model', '自定义HEAD所在的py文件'],allow_failed_imports=False)models=dict(type='MyModel', arg1=1, arg2=[16, 128], arg3=dict(channel=256), head=dict(  # 定义headtype='MyHead',arg4=256,...)...
)

篇幅所限, 自定义损失函数, 数据增强之类的就不一一列举了.

2. 数据流

我们接下来以检测与跟踪任务为例, 看看数据到底是如何被读入的. 我们以训练过程说明.

在训练过程中, 我们会初始化一个RUNNER类, 其读入我们的config文件并依次完成各种(模型, 数据加载, 优化器, 钩子等等)的初始化. 我们以官方提供的train.py为例:

runner = Runner.from_cfg(cfg)

from_cfg()是一个类方法(classmethod), 在其中我们实例化了Runner类.

随后, 我们调用Runnertrain()方法进行训练. 首先, 我们实例化训练循环:

        self._train_loop = self.build_train_loop(self._train_loop)  # type: ignore

训练循环就属于LOOP类型.

在这里, 我们以最常用的EpochBasedTrainLoop为例. 在EpochBasedTrainLoop的初始化函数中, 根据config文件中的train_dataloader字典实例化出torchDataLoader类():
在这里插入图片描述

        data_loader = DataLoader(dataset=dataset,sampler=sampler if batch_sampler is None else None,batch_sampler=batch_sampler,collate_fn=collate_fn,worker_init_fn=init_fn,**dataloader_cfg)return data_loader

当然, 我们知道torch的DataLoader类在调用的时候, 会调用到dataset(类别是torch.utils.data.Dataset)的__getitem__方法. 因此, 我们从__getitem__入手来探索数据流.

在MMDetection的设计中, 数据集的类都是继承于MMengine中的BaseDataset, 其中的__getitem__是这样写的:
在这里插入图片描述

    def __getitem__(self, idx: int) -> dict:if not self._fully_initialized:print_log('Please call `full_init()` method manually to accelerate ''the speed.',logger='current',level=logging.WARNING)self.full_init()if self.test_mode:data = self.prepare_data(idx)if data is None:raise Exception('Test time pipline should not get `None` ''data_sample')return datafor _ in range(self.max_refetch + 1):data = self.prepare_data(idx)# Broken images or random augmentations may cause the returned data# to be Noneif data is None:idx = self._rand_another()continuereturn dataraise Exception(f'Cannot find valid image after {self.max_refetch}! ''Please check your image path and pipeline')

我们可以看到, 在__getitem__中最核心的是self.prepare_data(idx). 按照这种思路一级一级向上查找, 我们就可以总结出如下图的数据读取流程:

在这里插入图片描述
其中, 数据增强pipeline是一系列类型为TRANSFORMS类的列表, 再每经过一次数据增强时, 字典都会被更新.

我们以较为常用的随机便宜(RandomShift)来说, 其是这样定义的:


@TRANSFORMS.register_module()
class RandomShift(BaseTransform):def __init__(self,...@autocast_box_type()def transform(self, results: dict) -> dict:  # transform方法, 更新字典, 图像与对应的边界框等都需要被更新"""Transform function to random shift images, bounding boxes.Args:results (dict): Result dict from loading pipeline.Returns:dict: Shift results."""if self._random_prob() < self.prob:img_shape = results['img'].shape[:2]random_shift_x = random.randint(-self.max_shift_px,self.max_shift_px)random_shift_y = random.randint(-self.max_shift_px,self.max_shift_px)new_x = max(0, random_shift_x)ori_x = max(0, -random_shift_x)new_y = max(0, random_shift_y)ori_y = max(0, -random_shift_y)# TODO: support mask and semantic segmentation maps.bboxes = results['gt_bboxes'].clone()bboxes.translate_([random_shift_x, random_shift_y])# clip borderbboxes.clip_(img_shape)# remove invalid bboxesvalid_inds = (bboxes.widths > self.filter_thr_px).numpy() & (bboxes.heights > self.filter_thr_px).numpy()# If the shift does not contain any gt-bbox area, skip this# image.if not valid_inds.any():return resultsbboxes = bboxes[valid_inds]results['gt_bboxes'] = bboxesresults['gt_bboxes_labels'] = results['gt_bboxes_labels'][valid_inds]if results.get('gt_ignore_flags', None) is not None:results['gt_ignore_flags'] = \results['gt_ignore_flags'][valid_inds]# shift imgimg = results['img']new_img = np.zeros_like(img)img_h, img_w = img.shape[:2]new_h = img_h - np.abs(random_shift_y)new_w = img_w - np.abs(random_shift_x)new_img[new_y:new_y + new_h, new_x:new_x + new_w] \= img[ori_y:ori_y + new_h, ori_x:ori_x + new_w]results['img'] = new_imgreturn results

需要注意的是, 经过pipeline后, 字典最终会被更新成如下形式:

dict = {'inputs': torch.Tensor, 'data_samples': DetDataSample或TrackDataSample等}

其中'inputs'键对应的值就是转换为tensor的图片, 而'data_samples'键对应的值是表示样本的类, 在检测任务中, 是DetDataSample, 跟踪任务中, 是TrackDataSample. DetDataSample类有许多成员, 包括该样本(图片)的目标的边界框真值, 分割真值等:

在这里插入图片描述

class DetDataSample(BaseDataElement):"""A data structure interface of MMDetection. They are used as interfacesbetween different components.The attributes in ``DetDataSample`` are divided into several parts:- ``proposals``(InstanceData): Region proposals used in two-stagedetectors.- ``gt_instances``(InstanceData): Ground truth of instance annotations.- ``pred_instances``(InstanceData): Instances of detection predictions.- ``pred_track_instances``(InstanceData): Instances of trackingpredictions.- ``ignored_instances``(InstanceData): Instances to be ignored duringtraining/testing.- ``gt_panoptic_seg``(PixelData): Ground truth of panopticsegmentation.- ``pred_panoptic_seg``(PixelData): Prediction of panopticsegmentation.- ``gt_sem_seg``(PixelData): Ground truth of semantic segmentation.- ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.

以上过程可以借用MMEngine文档里的一个图说明:

在这里插入图片描述

最终, 模型的forward, loss, predict等方法都是接收inputs: torch.Tensordata_samples作为输入, 例如:

在这里插入图片描述

def loss(self, inputs: Tensor, data_samples: TrackSampleList,**kwargs) -> Union[dict, tuple]:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/690887.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AcuAutomate:一款基于Acunetix的大规模自动化渗透测试与漏洞扫描工具

关于AcuAutomate AcuAutomate是一款基于Acunetix的大规模自动化渗透测试与漏洞扫描工具&#xff0c;该工具旨在辅助研究人员执行大规模的渗透测试任务。 在大规模的安全测试活动中&#xff0c;AcuAutomate可以帮助我们同时启动或停止多个Acunetix扫描任务。除此之外&#xff…

【鸿蒙系统学习笔记】状态管理

一、介绍 资料来自官网&#xff1a;文档中心 在声明式UI编程框架中&#xff0c;UI是程序状态的运行结果&#xff0c;用户构建了一个UI模型&#xff0c;其中应用的运行时的状态是参数。当参数改变时&#xff0c;UI作为返回结果&#xff0c;也将进行对应的改变。这些运行时的状…

[AudioRecorder]iPhone苹果通话录音汉化破解版-使用巨魔安装-ios17绕道目前还不支持

首先你必须有巨魔才能使用&#xff01;&#xff01; 不会安装的&#xff0c;还没安装的移步这里&#xff0c;ios17 以上目前装不了&#xff0c;别看了&#xff1a;永久签名 | 网址分类目录 | 路灯iOS导航-苹果签名实用知识网址导航-各种iOS技巧-后厂村路灯 视频教程 【Audio…

科技云报道:云原生是大模型“降本增效”的解药吗?

科技云报道原创。 在过去一两年里&#xff0c;以GPT和Diffusion model为代表的大语言模型和生成式AI&#xff0c;将人们对AI的期待推向了一个新高峰&#xff0c;并吸引了千行百业尝试在业务中利用大模型。 国内各家大厂在大模型领域展开了激烈的军备竞赛&#xff0c;如&#…

Python set函数

在Python编程中&#xff0c;set()函数是一个重要且常用的内置函数&#xff0c;用于创建一个新的集合对象。集合是一种无序且不重复的数据类型&#xff0c;它可以用于存储唯一的元素。本文将深入探讨Python中的set()函数&#xff0c;包括基本用法、集合操作、实际应用场景&#…

PyCharm - Project Interpreter (项目解释器)

PyCharm - Project Interpreter [项目解释器] References File -> Settings… -> Project: -> Project Interpreter References [1] Yongqiang Cheng, https://yongqiang.blog.csdn.net/

基于数字双输入的超宽带Doherty功率放大器设计-从理论到ADS版图

基于数字双输入的超宽带Doherty功率放大器设计-从理论到ADS版图 参考论文: 高效连续型射频功率放大器研究 假期就要倒计时啦&#xff0c;估计是寒假假期的最后一个博客&#xff0c;希望各位龙年工作顺利&#xff0c;学业有成。 全部工程下载&#xff1a;基于数字双输入的超宽…

遇到问题(二) 中文乱码

例如这样&#xff1a; 原本是这样&#xff1a; 解决方法&#xff1a;点击扳手工具设置——Editor——Encoding——选chinese GB2312&#xff08;有的是UTF-8&#xff09;

LabVIEW高速信号测量与存储

LabVIEW高速信号测量与存储 介绍了LabVIEW开发的高速信号测量与存储系统&#xff0c;解决实验研究中信号捕获的速度和准确性问题。通过高效的数据处理和存储解决方案&#xff0c;本系统为用户提供了一种快速、可靠的信号测量方法。 项目背景 在科学研究和工业应用中&#xf…

Ubuntu18.04有线连接后,无法设置ip地址以及显示网口设置

前提&#xff1a;首先测试过网线是完全没问题的 桌面端找不到设置网口 终端输入&#xff1a; ifconfig 没有找到网口设置和对应IP 然后查询网口驱动是否正常安装&#xff0c;输入&#xff1a; lspci | grep Ethernet 有输出说明网口驱动正常安装 然后查询电脑的ip地址&am…

物流EDI:Verizon EDI 需求分析

作为物流行业的企业&#xff0c;Verizon与其供应商之间通过EDI来传输业务单据。在与Verizon建立EDI连接时&#xff0c;需要参考EDI 指南、采购订单条款和条件以及运输路线指南这三个文档。 点击此链接&#xff0c;获取上述的三个文档 Verizon供应商可以通过上述链接找到用于处…

ubuntu22.04@laptop OpenCV Get Started: 015_deep_learning_with_opencv_dnn_module

ubuntu22.04laptop OpenCV Get Started: 015_deep_learning_with_opencv_dnn_module 1. 源由2. 应用Demo2.1 C应用Demo2.2 Python应用Demo 3. 使用 OpenCV DNN 模块进行图像分类3.1 导入模块并加载类名文本文件3.2 从磁盘加载预训练 DenseNet121 模型3.3 读取图像并准备为模型输…

解决npm淘宝镜像到期问题

1 背景 由于node安装插件是从国外服务器下载&#xff0c;如果没有“特殊手法”&#xff0c;就可能会遇到下载速度慢、或其它异常问题。 所以如果npm的服务器在中国就好了&#xff0c;于是我们乐于分享的淘宝团队干了这事。你可以用此只读的淘宝服务代替官方版本&#xff0c;且…

ARM体系在linux中的中断抢占

上一篇说到系统调用等异常通过向量el1_sync做处理&#xff0c;中断通过向量el1_irq做处理&#xff0c;然后gic的工作都是为中断处理服务&#xff0c;在rtos中&#xff0c;我们一般都会有中断嵌套和优先级反转的概念&#xff0c;但是在linux中&#xff0c;中断是否会被其他中断抢…

js_三种方法实现深拷贝

深拷贝&#xff08; 递归 &#xff09; 适用于需要完全独立于原始对象的场景&#xff0c;特别是当对象内部有引用类型时&#xff0c;为了避免修改拷贝后的对象影响到原始对象&#xff0c;就需要使用深拷贝。 // 原始对象 const obj { uname: Lily,age: 19,hobby: [乒乓球, 篮球…

力扣 188. 买卖股票的最佳时机 IV

题目来源&#xff1a;https://leetcode.cn/problems/best-time-to-buy-and-sell-stock-iv/description/ C题解&#xff1a;动态规划 思路同力扣 123. 买卖股票的最佳时机 III-CSDN博客&#xff0c;只是把最高2次换成k次。如果思路不清晰&#xff0c;可以将k从0写到4等找找规律…

Vue | (三)使用Vue脚手架(上) | 尚硅谷Vue2.0+Vue3.0全套教程

文章目录 &#x1f4da;初始化脚手架&#x1f407;创建初体验&#x1f407;分析脚手架结构&#x1f407;关于render&#x1f407;查看默认配置 &#x1f4da;ref与props&#x1f407;ref属性&#x1f407;props配置项 &#x1f4da;混入&#x1f4da;插件&#x1f4da;scoped样…

Linux网络编程——序列反序列化

文章目录 0. 前言1. 认识协议2. 序列号与反序列化3. 自定义协议——网络计算器4. json 本章Gitee仓库&#xff1a;序列反序列化 0. 前言 tcp是面向字节流的&#xff0c;但是如何保证读取的数据是一个完整的报文呢&#xff1f; 管道也是面向字节流&#xff0c;写端写了一大堆的…

Sora:新一代实时音视频通信框架

一、Sora简介 Sora是一个开源的实时音视频通信框架&#xff0c;旨在提供高效、稳定、可扩展的音视频通信解决方案。它基于WebRTC技术&#xff0c;支持跨平台、跨浏览器的实时音视频通信&#xff0c;并且具备低延迟、高并发、易集成等特点。 --点击进入Sora(一定要科学哦&#x…

机器学习基础(一)理解机器学习的本质

导读&#xff1a;在本文中&#xff0c;将深入探索机器学习的根本原理&#xff0c;包括基本概念、分类及如何通过构建预测模型来应用这些理论。 目录 机器学习 机器学习概念 相关概念 机器学习根本&#xff1a;模型 数据的语言&#xff1a;特征与标签 训练与测试&#xf…