Pytorch:torch.utils.data.DataLoader()

如果读者正在从事深度学习的项目,通常大部分时间都花在了处理数据上,而不是神经网络上。因为数据就像是网络的燃料:它越合适,结果就越快、越准确!神经网络表现不佳的主要原因之一可能是由于数据不佳或理解不足。因此,以更直观的方式理解、预处理数据并将其加载到网络中非常重要。
参考:https://zhuanlan.zhihu.com/p/596730297

DataLoader加载和迭代数据集

Dataloader本质是一个迭代器对象,也就是可以通过for batch_idx,batch_dict in dataloader 来提取数据集,提取的数量由batch_size 参数决定,得到这一batch的数据后,就可以喂入网络开始训练或者推理了。
在迭代的过程中,dataloader会自动调用dataset中的__getitem__ 函数,以获取一帧数据(item)

from torch.utils.data import DataLoaderDataLoader(dataset,batch_size=1,shuffle=False,num_workers=0,collate_fn=None,pin_memory=False,)

以U-Net中的代码为例:
具体详见:U-Net代码复现

loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

1. 数据集

**dataset (Dataset) ** – dataset from which to load the data.
即自定义的数据集,非常重要,因为dataloader会调用dataset的一些重载函数(e.g. getitem && len )

2. 对数据进行批处理

batch_size (int, optional)how many samples per batch to load(default: 1).

3. 在 CUDA 张量上加载数据

pin_memory(bool, optional)If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elementsare a custom type, or your collate_fn returns a batch that is a custom type,see the example below.

pin_memory参数直接将数据集加载为 CUDA 张量。它是一个可选参数,接受一个布尔值;如果设置为True,会在返回张量之前张量复制到 CUDA 固定内存中。这样在GPU训练过程中,数据从内存到GPU的复制可以使用异步的方式进行,从而提高数据读取的效率。

通常情况下,当使用GPU训练模型时,数据读取会成为整个训练过程的瓶颈之一。使用pin_memory可以将数据在CPU和GPU之间进行传输时的复制时间减少,从而提高数据加载的速度,加速训练过程。

需要注意的是,使用pin_memory会占用更多的内存空间,因此在内存资源紧张的情况下,需要谨慎使用。同时,在某些情况下(例如数据集比较小的情况下),使用pin_memory并不会带来明显的加速效果。

4.允许多进程

num_workers (int, optional)how many subprocesses to use for dataloading. 0 means that the data will be loaded in the main process.(default: 0)
这也是一个很有意思的参数,按照官方的说法, num_workers 用于设置数据加载过程中使用的子进程数。其默认值为0,即在主进程中进行数据加载,而不使用额外的子进程。

以下是我看到的一个解释,原文链接:https://blog.csdn.net/vonct/article/details/130263743
下面说一下个人的理解,在初始化 dataloader对象时,会根据num_workers创建子线程用于加载数据(主线程数+子线程=num_workers)。每个worker或者说线程都有自己负责的dataset范围(下面统称worker)

每当迭代 dataloader 对象时,工人们(workers)就开始干活了:将数据从数据源(如硬盘)加载到内存(数据加载),当一个worker读取(调用__getitem__)到足够的数据(看你在dataset中怎么定义一个item了)后,会将这些数据封装成一个(即一帧),并将其放到该worker独有的内存队列中。 要注意的是,每次迭代时,worker会尽可能地读数据,直到自己的队列被填满。

当所有workers的队列都被填满时,一个名为sampler的线程将会被创建,它的作用就是收集各workers队列中队首的 ,把他们放到一个各线程共享内存的缓冲队列中,并调用 collate_fn 函数来将 batch_size 个 整合,最后返回给迭代的输出。

这时候大家肯定会有点疑惑,那当迭代到后期时,需要读取的样本都已经在队列中了,是不是意味着这时候工人们已经在休息了?根据chatgpt的回答:是的!下面以一张图来帮助大家理解

在这里插入图片描述

5.合并数据集

collate_fn (Callable, optional)merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.

整合多个样本到一个batch时需要调用的函数,当 getitem 返回的不是tensor而是字典之类时,需要进行 collate_fn的重载,同时可以进行数据的进一步处理以满足pytorch的输入要求。
以U-Net为例:

def __getitem__(self, idx):name = self.ids[idx]mask_file = list(self.mask_dir.glob(name + self.mask_suffix + '.*'))img_file = list(self.images_dir.glob(name + '.*'))assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}'assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}'mask = load_image(mask_file[0])img = load_image(img_file[0])assert img.size == mask.size, \f'Image and mask {name} should be the same size, but are {img.size} and {mask.size}'img = self.preprocess(self.mask_values, img, self.scale, is_mask=False)mask = self.preprocess(self.mask_values, mask, self.scale, is_mask=True)return {'image': torch.as_tensor(img.copy()).float().contiguous(),'mask': torch.as_tensor(mask.copy()).long().contiguous()}

getitem 返回的是一个包含image和mask的 data_dict 字典,这时候就需要调用自定义的collate_fn来进行打包(待补充。。。)

6.数据采样

sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with len implemented. If specified, shufflemust not be specified.

sampler的主要作用是控制样本的采样顺序,并提供样本的索引。在默认情况下,dataloader使用的是SequentialSampler,它按照数据集的顺序依次提取样本,但在某些情况下,我们可能需要自定义采样顺序。比如说想从队尾提取数据。

比如,当我们处理非常大的数据集时,为了提高训练效率,可能需要对数据进行分布式采样,这时候就需要使用DistributedSampler。DistributedSampler会将数据集划分成多个子集,每个子集分配给不同的进程进行采样。在这种情况下,如果使用默认的SequentialSampler,可能会导致各个进程采样到相同的数据,从而降低训练效率。

此外,还有一些自定义的sampler,比如随机采样器(RandomSampler)和加权采样器(WeightedRandomSampler),它们可以按照不同的采样策略对数据集进行采样,从而满足不同的训练需求。

因此,根据不同的训练需求,我们可能需要自定义sampler来控制数据的采样顺序。

原文链接:https://blog.csdn.net/vonct/article/details/130263743

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/180973.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

聊聊clickhouse分布式表的操作

序 本文主要研究一下clickhouse分布式表的操作 创建分布式表 CREATE TABLE [IF NOT EXISTS] [db.]table_name [ON CLUSTER cluster] (name1 [type1] [DEFAULT|MATERIALIZED|ALIAS expr1],name2 [type2] [DEFAULT|MATERIALIZED|ALIAS expr2],... ) ENGINE Distributed(clust…

接口01-Java

接口-Java 一、引入(快速入门案例)二、接口介绍1、概念2、语法 三、应用场景四、接口使用注意事项五、练习题1 一、引入(快速入门案例) usb插槽就是现实中的接口。 你可以把手机、相机、u盘都插在usb插槽上,而不用担心那个插槽是专门插哪个的,原因是做u…

解决git action发布失败报错:Error: Resource not accessible by integration

现象: 网上说的解决方法都是什么到github个人中心setting里面的action设置里面去找。 可这玩意根本就没有! 正确解决办法: 在你的仓库页面,注意是仓库页面的setting里面: Actions> General>Workflow permisss…

苹果手机如何格式化?五个步骤快速掌握!

如果手机出现异常情况,例如运行缓慢、频繁崩溃,又或者想将手机出售、转让给他人,那么将手机格式化可以有助于解决问题。苹果手机如何格式化?本文将为您介绍解决方法,只需要五个步骤就能搞定,帮助您快速掌握…

【新手解答5】深入探索 C 语言:宏中的文本、标识符和字符串 + 递归运算、条件语句、循环 + `switch-case` 与多项条件和枚举的差别

C语言的相关问题解答 写在最前面问题1编程中的一般概念1. 文本2. 标识符3. 字符串 宏中的文本、标识符和字符串例子规范 问题二的笔记梳理递归运算条件语句循环中断(提前退出)、继续循环break 语句(补充)continue 语句&#xff08…

天软高频时序数据仓库

1天软高频时序数仓方案架构 天软高频时序数据仓库是深圳天软科技开发有限公司专为金融用户提供的专业高频行情数据处理方案,集数据接入、检查、处理、存储、查询、订阅、计算于一体。 方案支持各类系统的实时行情、非实时行情接入;还支持压缩存储、分布式…

使用 DMA 在 FPGA 中的 HDL 和嵌入式 C 之间传输数据

使用 DMA 在 FPGA 中的 HDL 和嵌入式 C 之间传输数据 该项目介绍了如何在 PL 中的 HDL 与 FPGA 中的处理器上运行的嵌入式 C 之间传输数据的基本结构。 介绍 鉴于机器学习和人工智能等应用的 FPGA 设计中硬件加速的兴起,现在是剥开几层“云雾”并讨论 HDL 之间来回传…

xv6 内核空间共享

首发公号:Rand_cs 共享内核空间 我们常说,每个进程都有自己的虚拟地址空间,但其中内核部分是共享的。 这就有个问题,如何共享的? 系统启动时创建了一张内核页表,里面记录着内核地址空间与物理地址空间的…

Peter算法小课堂—高精度减法

给大家看个小视频高精度减法_哔哩哔哩_bilibili 基本思想 计算机模拟人类做竖式计算,从而得到正确答案 大家还记得小学时学的“减法竖式”吗?是不是这样 x-y问题 函数总览: 1.converts() 字符串转为高精度大数 2.le() 判断大小 3.sub() …

【技术干货】宇视IPC音频问题解决步骤

近期技术人员从宇视官网下载sdk进行二次开发时,在启动实时直播,并通过回调函数拿到流数据,发现没有音频流数据。 通过下面的数据发现,codeType此字段一直是28,代表的是H.264数据,但未没发现有音频的数据包…

【C++】define宏定义

define宏定义 define是C语言中的一个宏定义命令&#xff0c;它用来将一个标识符定义为一个字符串&#xff0c;该标识符被称为宏名&#xff0c;被定义的字符串称为替换文本&#xff1b; define <宏名> (<参数表>) <宏体>操作符 # &#xff1a;可将参数转化为…

什么是Geo Trust OV证书

一、GeoTrust OV证书的介绍 GeoTrust OV证书是由GeoTrust公司提供的SSL证书&#xff0c;它是一种支持OpenSSL的数字证书&#xff0c;具有更高的安全性和可信度。GeoTrust是全球领先的网络安全解决方案提供商&#xff0c;为各类用户提供SSL证书和信任管理服务。GeoTrust OV证书…

如何使用ArcGIS实现生态廊道模拟

生态廊道是指一种连接不同生态系统的走廊或通道&#xff0c;其建立有助于解决人类活动对野生动植物栖息地破碎化和隔离化的问题&#xff0c;提高生物多样性&#xff0c;减轻生态系统的压力。在城市化和农业开发不断扩张的背景下&#xff0c;生态廊道对于野生动植物的生存和繁衍…

重生之我是一名程序员 44 ——字符串函数(3)

哈喽啊大家晚上好&#xff01;迄今为止我已近给大家介绍了2个字符串函数&#xff0c;今天呢再给大家带来一个字符串函数——strcmp函数。 首先呢还是先带大家认识一下它。strcmp函数是C语言中的字符串函数之一&#xff0c;用于比较两个字符串是否相等。 该函数原型为&#xf…

mysql中的锁及其作用

在MySQL中&#xff0c;锁是用于控制对数据库对象的并发访问的一种机制。锁可以防止多个事务同时对同一数据进行修改或删除&#xff0c;以确保数据的完整性和一致性。 MySQL中的锁有以下几种类型&#xff1a; 共享锁&#xff08;Shared Lock&#xff09;&#xff1a;也称为读锁&…

短视频运营常用的ChatGPT通用提示词模板

短视频定位和策划&#xff1a;请帮助我明确短视频的定位和策划&#xff0c;包括目标受众、主题、风格、内容等方面的内容&#xff0c;以便我能够更好地制定短视频运营策略。 短视频制作&#xff1a;请帮助我制作高质量的短视频&#xff0c;包括脚本编写、拍摄、剪辑、特效等方…

Hive安装与配置

你需要掌握&#xff1a; 1.Hive的基本安装&#xff1b; 2.Mysql的安装与设置&#xff1b; 3.Hive 的配置。 注意&#xff1a;Hive的安装与配置建立在Hadoop已安装配置好的情况下。 hadopp安装与配置 Hive 的基本安装 从 官网 下载Hive二进制包&#xff0c;下载好放在/op…

万人拼团团购小程序源码系统+拼团设置+拼团管理 附带完整的搭建教程

随着互联网的快速发展&#xff0c;电子商务和社交电商的兴起&#xff0c;团购作为一种高效的营销策略和消费方式&#xff0c;受到了广大消费者的热烈欢迎。在此背景下&#xff0c;我们开发了一款基于微信小程序的万人拼团团购系统&#xff0c;旨在为用户提供一种更加便捷、高效…

python爬虫进阶教程之如何正确的使用cookie

文章目录 前言一、获取cookie二、程序实现三、动态获取cookie四、其他关于Python爬虫技术储备一、Python所有方向的学习路线二、Python基础学习视频三、精品Python学习书籍四、Python工具包项目源码合集①Python工具包②Python实战案例③Python小游戏源码五、面试资料六、Pytho…

lxml 总结

xm 和 lxml库 哪个更好用点 1. 性能&#xff1a; lxml 通常比 xml.etree.ElementTree 更快。lxml 使用了 C 编写的底层解析器&#xff0c;因此在处理大型 XML 文档时可能更高效。 如果性能对你的应用很重要&#xff0c;特别是在处理大型 XML 文件时&#xff0c;选择 lxml 可能…