PyTorch概述(四)---DataLoader

  • torch.utils.data.DataLoader是PyTorch数据加载工具的核心;
  • 表示一个Python可迭代数据集;

DataLoader支持的数据集类型

  • map-style 和 iterable-style 的数据集;
  • 可定制的数据加载顺序;
  • 自动批量数据集;
  • 单进程和多进程数据加载;
  • 自动内存固定;

DataLoader构造函数

DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None,generator=None,*prefetch_factor=2,persistent_workers=False,pin_memory_device='')

Dataset 类型

  • DataLoader构造函数中的最重要的参数是dataset;
  • dataset指示了加载数据的数据集对象;
  • PyTorch支持两种不同类型数据集;
  • map-style数据集;
  • 可迭代类型数据集;

Map-Style数据集

  • 实现了__getitem__()函数和__len__()函数;
  • 表示一个从指数/键值到数据样本的映射;
  • dataset[idx],可以从磁盘的文件夹中读取第idx序列的图像和相应的标签;

可迭代类型数据集

  • 是IterableDataset子类的一个实例;
  • 实现了__iter__()函数原型;
  • 表示了一个可迭代的数据样本集;
  • 该类型的数据集特别适合代价较高的随机读取或者不可随机读取的场合;
  • 批量数据的大小依赖于获取的数据;
  • 调用iter(dataset),可以返回一个数据流、远程的服务器或者实时生成的日志;

数据加载顺序和采样器

  • 对于可迭代数据集,数据加载的顺序取决于用户的定义;
  • 上述特性允许简单的实现块读取和动态批量大小数据读取;
  • 对于map-style类型的数据集:
  • torch.utils.data.Sampler类用于指定数据加载中的指数/键值序列;
  • 他们表示数据集的指数可迭代对象;
  • 在随机梯度下降(SGD)的情况下:
  • Sampler可以随机的排序指数列表且可即时生成一个指数列表;
  • 或者可以生成一个mini-batch SGD的小量值的指数序列;
  • 基于shuffle参数可以自动的构建一个序列或者被洗牌的采样器给到DataLoader;
  • 相反的,用户可用采样器参数指定一个定制的采样器对象,生成下一个要获取的指标/键值;
  • 定制的采样器可生成批量指数的列表并以batch_sampler参数传递给DataLoader;
  • 通过batch_size和drop_last参数可激活自动批量模式;

加载批量的和非批量的数据

  • 通过参数batch_size,drop_last,batch_sampler,和collate_fn;
  • DataLoader支持自动整理获取的数据样本到批量集合中;

自动批量(默认)

  • 最常见的情况;
  • 对应于获取一个小量数据,且整理他们到一个批次样本中;
  • 也就是包含张量的一个维度作为批量的维度(通常是第一维);
  • 当batch_size(默认1)非空时,数据加载器生成批量的样本;
  • batch_size和drop_last参数被用于指定数据加载器如何获得批量的数据集键值;
  • 对于映射数据集,用户可以指定batch_sampler,一次生成一个键值的列表;

失能自动批量

  • 某些情况下,用户可能想要手动处理批量数据集,或简单的加载几个样本;
  • 参数batch_size和batch_sampler都为None时,自动批量失能;
  • 每一个从数据集获取的样本被传递给collate_fn作为参数的函数所处理;
  • 自动批量失能时,默认collate_fn简单转换Numpy数组为PyTorch张量;
  • 保持一切不受影响;

单一进程和多进程数据加载

  • DataLoader默认使用单一进行加载数据;
  • 在一个python进程中,GIL(Global Interpreter Lock)阻止跨线程真全并行python代码运行;
  • 为了避免数据加载的计算代码;
  • PyTorch提供通过设置num_worker参数为正整数的简单设置切换执行多线程数据加载;

单一进程数据加载(默认)

  • 该模式下,在DataLoader初始化的同一进程中获取数据;
  • 因此数据加载可能会阻塞计算;
  • 在跨进程共享数据的资源(共享的内存和文件描述符)有限时,该模式被优先考虑;
  • 或者当整个数据集较小且可以被整体加载进内存时,该模式被优先考虑;
  • 另外,单进程加载通常显式更多的可读性错误追踪信息,更有利与调试;

多进程数据加载

  • 设置num_worker为正整数将会开启多进程数据加载;
  • 多进程的数量为num_worker的数量;
  • 一些次数的迭代后,加载器工作进程将会同父进程消耗相同的CPU内存;
  • 这在数据集含大量数据(比如在数据集构建时加载大量的文件名列表)时可能有问题;
  • 或者用户使用了多个进程(总的内存消耗=number_of_workers*size_of_parent_process)时可能会有问题;
  • 最简单的应变方式为使用非参考计数的表示(比如,Pandas,Numpy或者PyArrow对象)替换python对象;
  • 该模式下,在一个DataLoader迭代器被创建时,num_workers数量的进程被创建;
  • dataset,collate_fn和worker_init_fn被传递到每一个进程;
  • 上述三者被用于进程初始化和数据的获取;
  • 这意味着在工作进程的运行中内部IO,转换操作同数据获取被一同处理;
  • torch.utils.data.get_worker_info()在一个工作进程中返回多种有用信息(包括:进程id,数据集副本,初始化速度等);
  • 且在主进程中返回None;
  • 用户可以在数据集代码中使用torch.utils.data.get_worker_info()函数;
  • worker_init_fn独立配置每一个数据集副本以确定是否在工作进程中运行代码;
  • 对映射类型的数据集,主进程使用采样器生成索引并传送到工作进程中;
  • 任何随机洗牌在主程序中执行,主程序通过分配索引确定数据加载顺序;
  • 对可迭代类型数据集,每一个工作进程获取一个数据集对象的副本;
  • 原始的多进程加载将导致数据的复制;
  • 使用torch.utils.data.get_worker_info()和worker_init_fn,用户可以独立配置每一个副本;
  • 多线程加载中,drop_last参数去掉每一个进程中的可迭代数据集副本的不完整批量数据;
  • 当迭代的最后一位被达到时进程被关闭;
  • 基于多进程中使用CUDA和共享CUDA张量的细节原因在多进程加载中不推荐返回CUDA张量;
  • 推荐使用自动内存固定(设置pin_memory=True),能够更快传递数据到CUDA使能的GPU中;

基于平台的行为

  • 由于工作进程依赖于Python multiprocessing,进程启动行为windows和Unix是有区别的;
  • UNix上,fork() 是默认的multiprocessing启动方法;
  • 使用fork(),直接通过克隆的地址空间,子工作进程可获取dataset和Python参数;
  • windows或者MacOS上,spawn()是默认的multiprocessing启动方法;
  • 使用spawn(),另一个解释器被启动,运行用户的主要脚本;
  • 以及通过pickle序列化,接收数据集的内部工作进程函数、collate_fn和其他参数;
  • 以上单独的序列化意味着应该采取两个步骤以确保在使用多进程数据加载时与windows兼容;
  • 打包大部分主要的脚本代码在if __name__=='__main__':程序块中;
  • 确保当每一个工作进程被启动时,if __name__=='__main__':不再次启动;
  • 你可以在if __name__=='__main__':程序块中放置数据集和DataLoader·实例创建逻辑,因为在工作进程中其不需要被再次执行;
  • 确保collate_fn,worker_init_fn或者dataset代码声明在顶级的__main__检查之外的定义中;
  • 这就保证了上述代码声明在工作进程中是可用的;

多进程数据加载的随机性

  • 默认情况下,每一个工作进程将具有自己的PyTorch种子,设置为base_seed+worker_id;
  • 这里base_seed是一个由主进程使用他的RNG或者一个指定的生成器生成的长周期数据;
  • 然而,用于其他库的种子可以通过初始化工作进程被复制;
  • 导致每一个工作进程返回一致的随机数;
  • 在worker_init_fn中,你可以获取PyTroch种子集用于每一个工作进程,使用
  • torch.utils.data.get_worker_info().seed或者torch.initial_seed();
  • 也可以使用上述两个种子为其他库在数据加载之前设置种子;

内存锁定

  • 主机到GPU的拷贝更快,当数据来自锁定内存时;
  • 对数据加载,DataLoader的pin_memory=True时,自动将获取的数据张量放到锁定内存中;
  • 默认的内存锁定逻辑仅识别张量和映射以及包含张量的可迭代对象;
  • 默认情况下,锁定逻辑观察到一个批量定制数据类型(当有一个collate_fn返回一个定制批量类型时);
  • 或者批量数据中的每一个单元都是定制类型时;
  • 锁定逻辑不能识别他们,将返回不在锁定内存中的批量数据(或者单元);
  • 为了使能内存锁定用于定制批量数据或者数据类型,定义一个pin_memory()方法在你的定制类型中;

内存锁定实例

import torch
from torch.utils.data import DataLoader
class SimpleCustomBatch:def __init__(self,data):transposed_data=list(zip(*data))self.inp=torch.stack(transposed_data[0],0)self.tgt=torch.stack(transposed_data[1],0)def pin_memory(self):self.inp=self.inp.pin_memory()self.tgt=self.tgt.pin_memory()return selfdef collate_wrapper(batch):return SimpleCustomBatch(batch)inps=torch.arange(10*5,dtype=torch.float32).view(10,5)
tgts=torch.arange(10*5,dtype=torch.float32).view(10,5)
dataset=TensorDataset(inps,tgts)loader=DataLoader(dataset,batch_size=2,collate_fn=collate_wrapper,pin_memory=True)for batch_ndx,sample in enumerate(loader):print(sample.inp.is_pinned())print(sample.tgt.is_pinned())

DataLoader参数解析

  • DataLoader合并一个数据和一个采样器,提供一个可迭代的采样器;
  • DataLoader在单线程或多线程模式下,支持映射类型数据集和可迭代类型数据集;
  • DataLoader支持定制的加载顺序和可优化的自动批量整理和内存锁定;
  • dataset(Dataset)---加载数据的数据集;
  • batch_size(Int,optional)---每一批次加载多少数据样本(默认为1);
  • shuffle(bool,optional)---设置为True,每一代都进行数据洗牌(默认False);
  • sampler(Sampler or iterable,optional)---定义从数据集抽取样本的策略,可以是实现__len__功能的任意Iterable对象,如果指定的话,shuffle必须不指定;
  • batch_sampler(Sampler or iterable,optional)---类似与sampler,但是一次返回一个索引批次,同batch_size,shuffle,sampler和drop_last相互排斥;
  • num_workers(int,optional)---多少子进程用于数据加载,0意味着将在主进程加载数据,默认0;
  • collate_fn(Callable,optional)---合并一个样本列表为一个张量mini-batch的型式,当从映射数据集使用批次加载时使用;
  • pin_memory(bool,optional)---如果设置为True,在返回数据之前数据加载器将拷贝张量到设备/CUDA的锁定内存区.如果你的数据单元是一个定制类型,或者你的collate_fn返回一个批次定制类型,参考文档中的实例;
  • drop_last(bool,optional)---设置为true丢弃最后不完整的批次,如果数据集的大小不能被批量大小整除的话。如果设置为False,且数据集的大小不被批次大小整除,最后的批次将很小(默认False);
  • timeout(numeric,optional)---如果为正数,超时值用于收集一个工作进程的批次,应当总是非正数(默认0);
  • worker_init_fn(Callable,optional)---如果非空,该函数将会被在每一个工作子进程被调用,以工作进程id(一个正数,在【0,num_workers-1】范围内)作为输入;
  • multiprocessing_context(str or multiprocessing.context.BaseContext,optional)---如果为空,操作系统的默认多进程上下文会被使用(默认为空);
  • generator(torch.Generator,optional)---如果非空,随机采样器将使用RNG生成随机指标并多进程申城base_seed用于工作进程(默认为空);
  • prefetch_factor(int,optional,keywork-only arg)---每一个工作进程提前加载的批次数量,2意味着对于所有的工作进程将有总数为2*num_worker批次的预取数据(默认值依赖于参数num_worker的值,如果num_worker=0,默认值为空,否则,num_worker>0,默认值为2);
  • persistent_workers(bool,optional)---如果为True,数据加载器在一个数据集被消耗一次之后将不关闭工作进程,这允许保持工作进程数据集实例为活动状态(默认为False);
  • pin_memory_device(str,optional)---如果为True,设备锁定内存运行.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/702042.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ChatGPT的回答从哪里来?

ChatGPT回答问题时通常比问题本身更长,这是因为它需要通过补充额外的信息来提供完整的答案。它的回答来源于对现有信息的抽取和整合,那么具体是怎么进行抽取和整合的呢,下面我们带着这个疑问来详细讨论一下它的工作原理。 首先,英…

五种多目标优化算法(MOBA、NSWOA、MOJS、MOAHA、MOPSO)性能对比(提供MATLAB代码)

一、5种多目标优化算法简介 多目标优化算法是用于解决具有多个目标函数的优化问题的一类算法。其求解流程通常包括以下几个步骤: 1. 定义问题:首先需要明确问题的目标函数和约束条件。多目标优化问题通常涉及多个目标函数,这些目标函数可能存在冲突,需要在不同目标之间进…

1_怎么看原理图之GPIO和门电路笔记

一、GPIO类 如下图:芯片输出高电平/3.3V,LED亮;当芯片输出低电平,则LED暗 如下图:输入引脚,当开关闭合,则输入为低电平/0V,当开关打开,则输入为高电平/3.3V 现在的引脚都…

【VIP专属】Python应用案例——基于Keras, OpenCV和MobileNet口罩佩戴识别

目录 1、导入所需库 2、加载人脸口罩检测数据集 3、对标签进行独热编码

Stable Diffusion 3 发布及其重大改进

1. 引言 就在 OpenAI 发布可以生成令人瞠目的视频的 Sora 和谷歌披露支持多达 150 万个Token上下文的 Gemini 1.5 的几天后,Stability AI 最近展示了 Stable Diffusion 3 的预览版。 闲话少说,我们快来看看吧! 2. 什么是Stable Diffusion…

微信小程序 uniapp+vue餐厅美食就餐推荐系统

本论文根据系统的开发流程以及一般论文的结构分为三个部分,第一个部分为摘要、外文翻译、目录;第二个部分为正文;第三个部分为致谢和参考文献。其中正文部分包括: (1)绪论,对课题背景、意义、目…

网络编程-NIO案例 与 AIO 案例

案例说明:一个简单的群聊实现,支持重复上下线。 NIO 服务端 public class NIOServer {public static void main(String[] args) throws IOException {ServerSocketChannel serverChannel ServerSocketChannel.open();// 初始化服务器serverChannel.b…

token的有状态和无状态

在身份验证和授权领域,"有状态"(stateful)和"无状态"(stateless)通常用来描述系统处理用户认证信息的方式。 有状态(Stateful): 有状态的认证系统在服务器端会维…

uvloop,一个强大的 Python 异步IO编程库!

前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站零基础入门的AI学习网站~。 目录 ​编辑 前言 什么是uvloop库? 安装uvloop库 使用uvloop库 uvloop库的功能特性 1. 更…

DPDK常用API合集二

网络数据包缓冲管理(librte_mbuf) 1.1 rte_pktmbuf_alloc 是 DPDK(数据平面开发工具包)中的一个函数,用于在内存池中分配一个新的 mbuf(内存缓冲区) struct rte_mbuf *rte_pktmbuf_alloc(stru…

Spring ReflectionUtils 反射工具介绍和使用

一、ReflectionUtils 在 Java 中,反射(Reflection)是一种强大的机制,允许程序在运行时动态地检查类、获取类的信息、调用类的方法、访问或修改类的属性等。Java 的反射机制提供了一组类和接口,位于 java.lang.reflect…

WebLogic Server JNDI注入漏洞复现(CVE-2024-20931)

0x01 产品简介 Oracle WebLogic Server 是一个Java应用服务器,它全面实现了J2EE 1.5规范、最新的Web服务标准和最高级的互操作标准。WebLogic Server内核以可执行、可扩展和可靠的方式提供统一的安全、事务和管理服务。Oracle Fusion Middleware(Oracle融合中间件)和Oracle…

【二分查找】【浮点数的二分查找】【二分答案查找】

文章目录 前言一、二分查找(Binary Search)二、浮点数的二分查找三、二分答案总结 前言 今天记录一下基础算法之二分查找 一、二分查找(Binary Search) 二分查找(Binary Search)是一种在有序数组中查找目…

Nodejs+vue图书阅读评分个性化推荐系统

此系统设计主要采用的是nodejs语言来进行开发,采用 vue框架技术,对于各个模块设计制作有一定的安全性;数据库方面主要采用的是MySQL来进行开发,其特点是稳定性好,数据库存储容量大,处理能力快等优势&#x…

效率系列(九) macOS入门各式快捷操作

大家好,我是半虹,这篇文章来讲 macOS 中的各式快捷操作 零、序言 快捷操作这种东西,看得再多,不如实际用起来,用习惯之后,真的会感受到效率提高的 所以这篇文章主要是想总结下常用的触控板手势和键盘快捷…

数字热潮:iGaming 能否推动加密货币的普及?

过去十年,iGaming(互联网游戏)世界有了显著增长,每月有超过一百万的新用户加入。那么,这一主流的秘密是什么?让我们在本文中探讨一下。 领先一步:市场 数字时代正在重新定义娱乐,iG…

MySQL运维实战(7.2) MySQL复制server_id相关问题

作者:俊达 主库server_id没有设置 主库没有设置server_id Got fatal error 1236 from master when reading data from binary log: Misconfigured master - server_id was not set主库查看server_id mysql> show variables like server_id; ----------------…

如何在本地电脑部署HadSky论坛并发布至公网可远程访问【内网穿透】

文章目录 前言1. 网站搭建1.1 网页下载和安装1.2 网页测试1.3 cpolar的安装和注册 2. 本地网页发布2.1 Cpolar临时数据隧道2.2 Cpolar稳定隧道(云端设置)2.3 Cpolar稳定隧道(本地设置)2.4 公网访问测试 总结 前言 经过多年的基础…

哈希表在Java中的使用和面试常见问题

当谈到哈希表在Java中的使用和面试常见问题时,以下是一些重要的点和常见问题: 哈希表在Java中的使用 HashMap 和 HashTable 的区别: HashMap 和 HashTable 都实现了 Map 接口,但它们有一些重要的区别: HashMap 是非线…

Repeater:创建大量类似项

Repeater 类型用于创建大量类似项。与其它视图类型一样,Repeater有一个model和一个delegate。 首次创建Repeater时,会创建其所有delegate项。若存在大量delegate项,并且并非所有项都必须同时可见,则可能会降低效率。 有2种方式可…