Pytorch:torch.utils.data.DataLoader()

如果读者正在从事深度学习的项目,通常大部分时间都花在了处理数据上,而不是神经网络上。因为数据就像是网络的燃料:它越合适,结果就越快、越准确!神经网络表现不佳的主要原因之一可能是由于数据不佳或理解不足。因此,以更直观的方式理解、预处理数据并将其加载到网络中非常重要。
参考:https://zhuanlan.zhihu.com/p/596730297

DataLoader加载和迭代数据集

Dataloader本质是一个迭代器对象,也就是可以通过for batch_idx,batch_dict in dataloader 来提取数据集,提取的数量由batch_size 参数决定,得到这一batch的数据后,就可以喂入网络开始训练或者推理了。
在迭代的过程中,dataloader会自动调用dataset中的__getitem__ 函数,以获取一帧数据(item)

from torch.utils.data import DataLoaderDataLoader(dataset,batch_size=1,shuffle=False,num_workers=0,collate_fn=None,pin_memory=False,)

以U-Net中的代码为例:
具体详见:U-Net代码复现

loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

1. 数据集

**dataset (Dataset) ** – dataset from which to load the data.
即自定义的数据集,非常重要,因为dataloader会调用dataset的一些重载函数(e.g. getitem && len )

2. 对数据进行批处理

batch_size (int, optional)how many samples per batch to load(default: 1).

3. 在 CUDA 张量上加载数据

pin_memory(bool, optional)If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elementsare a custom type, or your collate_fn returns a batch that is a custom type,see the example below.

pin_memory参数直接将数据集加载为 CUDA 张量。它是一个可选参数,接受一个布尔值;如果设置为True,会在返回张量之前张量复制到 CUDA 固定内存中。这样在GPU训练过程中,数据从内存到GPU的复制可以使用异步的方式进行,从而提高数据读取的效率。

通常情况下,当使用GPU训练模型时,数据读取会成为整个训练过程的瓶颈之一。使用pin_memory可以将数据在CPU和GPU之间进行传输时的复制时间减少,从而提高数据加载的速度,加速训练过程。

需要注意的是,使用pin_memory会占用更多的内存空间,因此在内存资源紧张的情况下,需要谨慎使用。同时,在某些情况下(例如数据集比较小的情况下),使用pin_memory并不会带来明显的加速效果。

4.允许多进程

num_workers (int, optional)how many subprocesses to use for dataloading. 0 means that the data will be loaded in the main process.(default: 0)
这也是一个很有意思的参数,按照官方的说法, num_workers 用于设置数据加载过程中使用的子进程数。其默认值为0,即在主进程中进行数据加载,而不使用额外的子进程。

以下是我看到的一个解释,原文链接:https://blog.csdn.net/vonct/article/details/130263743
下面说一下个人的理解,在初始化 dataloader对象时,会根据num_workers创建子线程用于加载数据(主线程数+子线程=num_workers)。每个worker或者说线程都有自己负责的dataset范围(下面统称worker)

每当迭代 dataloader 对象时,工人们(workers)就开始干活了:将数据从数据源(如硬盘)加载到内存(数据加载),当一个worker读取(调用__getitem__)到足够的数据(看你在dataset中怎么定义一个item了)后,会将这些数据封装成一个(即一帧),并将其放到该worker独有的内存队列中。 要注意的是,每次迭代时,worker会尽可能地读数据,直到自己的队列被填满。

当所有workers的队列都被填满时,一个名为sampler的线程将会被创建,它的作用就是收集各workers队列中队首的 ,把他们放到一个各线程共享内存的缓冲队列中,并调用 collate_fn 函数来将 batch_size 个 整合,最后返回给迭代的输出。

这时候大家肯定会有点疑惑,那当迭代到后期时,需要读取的样本都已经在队列中了,是不是意味着这时候工人们已经在休息了?根据chatgpt的回答:是的!下面以一张图来帮助大家理解

在这里插入图片描述

5.合并数据集

collate_fn (Callable, optional)merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.

整合多个样本到一个batch时需要调用的函数,当 getitem 返回的不是tensor而是字典之类时,需要进行 collate_fn的重载,同时可以进行数据的进一步处理以满足pytorch的输入要求。
以U-Net为例:

def __getitem__(self, idx):name = self.ids[idx]mask_file = list(self.mask_dir.glob(name + self.mask_suffix + '.*'))img_file = list(self.images_dir.glob(name + '.*'))assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}'assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}'mask = load_image(mask_file[0])img = load_image(img_file[0])assert img.size == mask.size, \f'Image and mask {name} should be the same size, but are {img.size} and {mask.size}'img = self.preprocess(self.mask_values, img, self.scale, is_mask=False)mask = self.preprocess(self.mask_values, mask, self.scale, is_mask=True)return {'image': torch.as_tensor(img.copy()).float().contiguous(),'mask': torch.as_tensor(mask.copy()).long().contiguous()}

getitem 返回的是一个包含image和mask的 data_dict 字典,这时候就需要调用自定义的collate_fn来进行打包(待补充。。。)

6.数据采样

sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with len implemented. If specified, shufflemust not be specified.

sampler的主要作用是控制样本的采样顺序,并提供样本的索引。在默认情况下,dataloader使用的是SequentialSampler,它按照数据集的顺序依次提取样本,但在某些情况下,我们可能需要自定义采样顺序。比如说想从队尾提取数据。

比如,当我们处理非常大的数据集时,为了提高训练效率,可能需要对数据进行分布式采样,这时候就需要使用DistributedSampler。DistributedSampler会将数据集划分成多个子集,每个子集分配给不同的进程进行采样。在这种情况下,如果使用默认的SequentialSampler,可能会导致各个进程采样到相同的数据,从而降低训练效率。

此外,还有一些自定义的sampler,比如随机采样器(RandomSampler)和加权采样器(WeightedRandomSampler),它们可以按照不同的采样策略对数据集进行采样,从而满足不同的训练需求。

因此,根据不同的训练需求,我们可能需要自定义sampler来控制数据的采样顺序。

原文链接:https://blog.csdn.net/vonct/article/details/130263743

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/180973.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

接口01-Java

接口-Java 一、引入(快速入门案例)二、接口介绍1、概念2、语法 三、应用场景四、接口使用注意事项五、练习题1 一、引入(快速入门案例) usb插槽就是现实中的接口。 你可以把手机、相机、u盘都插在usb插槽上,而不用担心那个插槽是专门插哪个的,原因是做u…

解决git action发布失败报错:Error: Resource not accessible by integration

现象: 网上说的解决方法都是什么到github个人中心setting里面的action设置里面去找。 可这玩意根本就没有! 正确解决办法: 在你的仓库页面,注意是仓库页面的setting里面: Actions> General>Workflow permisss…

苹果手机如何格式化?五个步骤快速掌握!

如果手机出现异常情况,例如运行缓慢、频繁崩溃,又或者想将手机出售、转让给他人,那么将手机格式化可以有助于解决问题。苹果手机如何格式化?本文将为您介绍解决方法,只需要五个步骤就能搞定,帮助您快速掌握…

天软高频时序数据仓库

1天软高频时序数仓方案架构 天软高频时序数据仓库是深圳天软科技开发有限公司专为金融用户提供的专业高频行情数据处理方案,集数据接入、检查、处理、存储、查询、订阅、计算于一体。 方案支持各类系统的实时行情、非实时行情接入;还支持压缩存储、分布式…

使用 DMA 在 FPGA 中的 HDL 和嵌入式 C 之间传输数据

使用 DMA 在 FPGA 中的 HDL 和嵌入式 C 之间传输数据 该项目介绍了如何在 PL 中的 HDL 与 FPGA 中的处理器上运行的嵌入式 C 之间传输数据的基本结构。 介绍 鉴于机器学习和人工智能等应用的 FPGA 设计中硬件加速的兴起,现在是剥开几层“云雾”并讨论 HDL 之间来回传…

Peter算法小课堂—高精度减法

给大家看个小视频高精度减法_哔哩哔哩_bilibili 基本思想 计算机模拟人类做竖式计算,从而得到正确答案 大家还记得小学时学的“减法竖式”吗?是不是这样 x-y问题 函数总览: 1.converts() 字符串转为高精度大数 2.le() 判断大小 3.sub() …

【技术干货】宇视IPC音频问题解决步骤

近期技术人员从宇视官网下载sdk进行二次开发时,在启动实时直播,并通过回调函数拿到流数据,发现没有音频流数据。 通过下面的数据发现,codeType此字段一直是28,代表的是H.264数据,但未没发现有音频的数据包…

什么是Geo Trust OV证书

一、GeoTrust OV证书的介绍 GeoTrust OV证书是由GeoTrust公司提供的SSL证书,它是一种支持OpenSSL的数字证书,具有更高的安全性和可信度。GeoTrust是全球领先的网络安全解决方案提供商,为各类用户提供SSL证书和信任管理服务。GeoTrust OV证书…

如何使用ArcGIS实现生态廊道模拟

生态廊道是指一种连接不同生态系统的走廊或通道,其建立有助于解决人类活动对野生动植物栖息地破碎化和隔离化的问题,提高生物多样性,减轻生态系统的压力。在城市化和农业开发不断扩张的背景下,生态廊道对于野生动植物的生存和繁衍…

Hive安装与配置

你需要掌握: 1.Hive的基本安装; 2.Mysql的安装与设置; 3.Hive 的配置。 注意:Hive的安装与配置建立在Hadoop已安装配置好的情况下。 hadopp安装与配置 Hive 的基本安装 从 官网 下载Hive二进制包,下载好放在/op…

万人拼团团购小程序源码系统+拼团设置+拼团管理 附带完整的搭建教程

随着互联网的快速发展,电子商务和社交电商的兴起,团购作为一种高效的营销策略和消费方式,受到了广大消费者的热烈欢迎。在此背景下,我们开发了一款基于微信小程序的万人拼团团购系统,旨在为用户提供一种更加便捷、高效…

python爬虫进阶教程之如何正确的使用cookie

文章目录 前言一、获取cookie二、程序实现三、动态获取cookie四、其他关于Python爬虫技术储备一、Python所有方向的学习路线二、Python基础学习视频三、精品Python学习书籍四、Python工具包项目源码合集①Python工具包②Python实战案例③Python小游戏源码五、面试资料六、Pytho…

lxml 总结

xm 和 lxml库 哪个更好用点 1. 性能: lxml 通常比 xml.etree.ElementTree 更快。lxml 使用了 C 编写的底层解析器,因此在处理大型 XML 文档时可能更高效。 如果性能对你的应用很重要,特别是在处理大型 XML 文件时,选择 lxml 可能…

这款高性能分布式ID生成器,现在是你的了~

这是DDD&微服务系列的第17篇,欢迎持续关注~ 概述 在软件开发过程中,我们经常会遇到需要生成全局唯一流水号的场景,例如各种流水号和分库分表的分布式主键ID。特别是在使用MySQL数据库时,除了要求流水号具有“全局唯一”性外&…

继电保护-变压器纵联差动保护MATLAB仿真模型

微❤关注“电气仔推送”获得资料(专享优惠) 原理概述 差动保护是在两端设置的保护,通过比较两端测回来的电气量,进而看是否需要动作,纵联差动保护是变压器主保护。 纵联差动保护基本原则 双绕组变压器实现纵联差动…

泄密零容忍!迅软科技打造设计图纸安全防线,助您无忧创作!

对于建筑设计、鞋服设计、动漫设计、平面设计等设计行业而言,海量设计图纸都以电子数据的形式存在企业的终端电脑上,这些图纸蕴含着企业的核心竞争资源,一旦泄露将给企业带来巨大的经济损失。 因此,迅软科技采用了先进的数据加密技…

Ruoyi-cloud / 若依 SpringCloud服务器部署

1、redis 环境 服务器安装redis ,注意 密码 端口 2、mysql 环境 服务器安装 mysql 5.7 以上的版本 代码中的sql 文件夹中有 sql 文件 创建数据库ry-cloud并导入数据脚本ry_2021xxxx.sql(必须),quartz.sql(可选&…

同旺科技 USB 转 RS-485 适配器 -- 隔离型

内附链接 1、USB 转 RS-485 适配器 隔离版主要特性有: ● 支持USB 2.0/3.0接口,并兼容USB 1.1接口; ● 支持USB总线供电; ● 支持Windows系统驱动,包含WIN10 / WIN11 系统32 / 64位; ● 支持Windows …

使用vue-admin-template时,需要注意的问题,包括一定要去除mock.js注释

在使用vue-admin-template等前端框架时,如果你没有打算用他们的mock数据,在生产环境下一定要注释mock引用的代码,虽然它没有被调用,但是如果你不注释,就会被打包进去。 找到main.js,看如下代码&#xff1a…

八、Lua数组和迭代器

一、Lua数组 数组,就是相同数据类型的元素按一定顺序排列的集合,可以是一维数组和多维数组。 在 Lua 中,数组不是一种特定的数据类型,而是一种用来存储一组值的数据结构。 实际上,Lua 中并没有专门的数组类型&#xf…