pytorch之torch.utils.data学习

1、概述

PyTorch 数据加载利用的核心是torch.utils.data.DataLoader类 。它表示在数据集上 Python 可迭代,支持

map-style and iterable-style datasets(地图样式和可迭代样式数据集),
customizing data loading order(自定义数据加载顺序),
automatic batching(自动批处理),
single- and multi-process data loading(单进程和多进程数据加载),
automatic memory pinning(自动内存固定)。
这些选项由 a 的构造函数参数配置 DataLoader,其签名为:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=None,pin_memory=False, drop_last=False, timeout=0,worker_init_fn=None, *, prefetch_factor=2,persistent_workers=False)

以下部分详细描述了这些选项的效果和用法。

2、Dataset Types(数据类型)

DataLoader 构造函数最重要的参数是dataset,它指示要从中加载数据的数据集对象。PyTorch 支持两种不同类型的数据集:

map-style datasets,(地图式数据集),
iterable-style datasets.(可迭代式数据集。)

map-style datasets 地图式数据集

映射式数据集是一种实现__getitem__()和 len()协议的数据集,并表示从(可能是非整数)索引/键到数据样本的映射。
例如,这样的数据集,当使用 访问时dataset[idx],可以从磁盘上的文件夹中读取第 idx-th个图像及其相应的标签。
请参阅Dataset了解更多详情。

iterable-style datasets可迭代式数据集

CLASStorch.utils.data.IterableDataset(*args, **kwds)

可迭代样式数据集是实现协议__iter__()的IterableDataset 子类的实例,并表示数据样本上的一个迭代迭代。这种类型的数据集特别适合随机读取成本昂贵甚至不可能的情况,以及批量大小取决于获取的数据的情况。

例如,这样的数据集在调用时iter(dataset)可以返回从数据库、远程服务器甚至实时生成的日志读取的数据流。
所有表示数据样本可迭代对象的数据集都应该继承它。当数据来自流时,这种形式的数据集特别有用
请参阅IterableDataset了解更多详情。

Example 1: splitting workload across all workers in iter():

class MyIterableDataset(torch.utils.data.IterableDataset):def __init__(self, start, end):super(MyIterableDataset).__init__()assert end > start, "this example code only works with end >= start"self.start = startself.end = enddef __iter__(self):worker_info = torch.utils.data.get_worker_info()if worker_info is None:  # single-process data loading, return the full iteratoriter_start = self.startiter_end = self.endelse:  # in a worker process# split workloadper_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))worker_id = worker_info.iditer_start = self.start + worker_id * per_workeriter_end = min(iter_start + per_worker, self.end)return iter(range(iter_start, iter_end))
# should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
ds = MyIterableDataset(start=3, end=7)# Single-process loading
print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[tensor([3]), tensor([4]), tensor([5]), tensor([6])# Mult-process loading with two worker processes
# Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]# With even more workers
print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

Example 2: splitting workload across all workers using worker_init_fn:

class MyIterableDataset(torch.utils.data.IterableDataset):def __init__(self, start, end):super(MyIterableDataset).__init__()assert end > start, "this example code only works with end >= start"self.start = startself.end = enddef __iter__(self):return iter(range(self.start, self.end))
# should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
ds = MyIterableDataset(start=3, end=7)# Single-process loading
print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
# Directly doing multi-process loading yields duplicate data
print(list(torch.utils.data.DataLoader(ds, num_workers=2)))# Define a `worker_init_fn` that configures each dataset copy differently
def worker_init_fn(worker_id):worker_info = torch.utils.data.get_worker_info()dataset = worker_info.dataset  # the dataset copy in this worker processoverall_start = dataset.startoverall_end = dataset.end# configure the dataset to only process the split workloadper_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))worker_id = worker_info.iddataset.start = overall_start + worker_id * per_workerdataset.end = min(dataset.start + per_worker, overall_end)# Mult-process loading with the custom `worker_init_fn`
# Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))# With even more workers
print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))

笔记

当使用多进程数据加载的IterableDataset时。在每个工作进程上复制相同的数据集对象,因此必须对副本进行不同的配置,以避免重复数据。请参阅IterableDataset文档了解如何实现这一点。

3、数据加载顺序和Sampler

对于可迭代样式的数据集,数据加载顺序完全由用户定义的可迭代控制。这允许更容易地实现块读取和动态批量大小(例如,通过每次生成批量样本)。

本节的其余部分涉及地图样式数据集的情况 。torch.utils.data.Sampler 类用于指定数据加载中使用的索引/键的序列。它们表示数据集索引上的可迭代对象。例如,在随机梯度下降 (SGD) 的常见情况下,a Sampler可以随机排列一系列索引并一次生成每个索引,或者为小批量 SGD 生成少量索引。

A sequential or shuffled sampler将根据DataLoader 的shuffle参数自动构。或者,用户可以使用sampler参数来指定一个自定义Sampler对象,该对象每次都会生成下一个要获取的索引/键。

自定义Sampler一次生成批次索引列表,可以作为batch_sampler参数传递。自动批处理也可以通过batch_size和drop_last 参数启用。有关这方面的更多详细信息,请参阅 下一节。

注意

sampler和batch_sampler都不兼容可迭代风格的数据集,因为这样的数据集没有概念。

4、加载批量和非批量数据

DataLoader支持自动将单独获取的数据样本通过参数 batch_size、drop_last、batch_sampler和 collate_fn(具有默认功能)整理为批次

自动批处理(默认)

这是最常见的情况,对应于获取小批量数据并将它们整理成批量样本,即包含一个维度为批量维度(通常是第一个维度)的张量。

当batch_size(默认1)不是None时,数据加载器将生成批量样本而不是单个样本。batch_size和 drop_last参数用于指定数据加载器如何获取批量的数据集键。对于地图样式数据集,用户也可以指定batch_sampler,这一次会生成一个键列表。

在使用来自sampler的索引获取样本列表之后,使用作为collate_fn参数传递的函数将样本列表整理成批。
在这种情况下,从地图样式的数据集加载大致相当于:

for indices in batch_sampler:yield collate_fn([dataset[i] for i in indices])

从可迭代风格的数据集加载大致相当于:

dataset_iter = iter(dataset)
for indices in batch_sampler:yield collate_fn([next(dataset_iter) for _ in indices])

禁用自动批处理

在某些情况下,用户可能希望在数据集代码中手动处理批处理,或者简单地加载单个样本。例如,直接加载批处理数据(例如,从数据库中批量读取或读取连续的内存块)可能更便宜,或者批处理大小依赖于数据,或者程序被设计为处理单个样本。在这些场景下,最好不要使用自动批处理(其中collate_fn用于整理样本),而是让数据加载器直接返回数据集object的每个成员。

当batch_size和batch_sampler都为None时(batch_sampler的默认值已经为None),自动批处理被禁用。从数据集中获得的每个样本都使用作为collate_fn参数传递的函数进行处理。

当自动批处理被禁用时,默认的collate_fn只是将NumPy数组转换为PyTorch张量,并保持其他所有内容不变。
在这种情况下,从地图样式的数据集加载大致相当于:

for index in sampler:yield collate_fn(dataset[index])

从可迭代风格的数据集加载大致相当于:

for data in iter(dataset):yield collate_fn(data)

Working with collate_fn

例如,如果每个数据样本由一个3通道图像和一个整数类标签组成,也就是说,数据集的每个元素返回一个元组(image, class_index),默认collate_fn将这样的元组列表整理成一个批处理图像张量和一个批处理类标签张量的元组。特别是,默认的collate_fn具有以下属性:

5、Single- and Multi-process Data Loading

默认情况下,DataLoader使用单进程数据加载。

torch.utils.data.get_worker_info()返回工作进程中的各种有用信息(包括工作进程id,数据集副本,初始种子等),并在主进程中返回None。用户可以在数据集代码和/或worker_init_fn中使用这个函数来单独配置每个数据集副本,并确定代码是否在工作进程中运行。例如,这在对数据集进行分片时特别有用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/220737.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

css/less/scss代码注意事项

一.命名 1.类名使用小写字母,以中划线分割;id 使用 驼峰式命名; 2.less/scss中的函数、混合采用驼峰命名; 3. class 的命名不要使用 标签名,如.p .div .img; 二.选择器 尽量使用直接子选择器,否则,有时会造成性能损耗 .content .title { .…

【Spark精讲】Spark存储原理

目录 类比HDFS的存储架构 Spark的存储架构 存储级别 RDD的持久化机制 RDD缓存的过程 Block淘汰和落盘 类比HDFS的存储架构 HDFS集群有两类节点以管理节点-工作节点模式运行,即一个NameNode(管理节点)和多个DataNode(工作节点)。 Namenode管理文件系统的命名空…

JFlash烧写单片机bin/hex文件

1,安装压 JLink_Windows_V660c,官网可下载; 2,打开刚刚安装的 J-Flash V6.60c 选择创建新工程“Create a new project”,然后点击StartJ-Flash 点击之后跳出Select device框,选择TI 选择TI后&#xff0c…

TypeScript入门实战笔记 -- 04 什么是字面量类型、类型推断、类型拓宽和类型缩小?

🍍开发环境 1:使用vscode 新建一个 04.Literal.ts 文件,运行下列示例。 2:执行 tsc 04.Literal.ts --strict --alwaysStrict false --watch 3:安装nodemon( 全局安装npm install -g nodemon ) 检测.js文件变化重启项…

谈谈数据归一化与标准化

背景: 归一化(Normalization)和标准化(Standardization)是常用的数据预处理技术,用于将不同范围或不同单位的特征值转换为统一的尺度,以便更好地进行数据分析和模型训练。一句话:消…

Go EASY游戏框架 之 RPC Guide 03

1 Overview easy解决服务端通信问题,同样使用了RPC技术。easy使用的ETCDGRPC,直接将它们打包组合在了一起。随着服务发现的成熟,稳定,简单,若是不用,甚至你也并不需要RPC来分解你的架构。 GRPC 有默认res…

银河麒麟重置密码

桌面版银河麒麟重置密码 1.选择界面按e 出现银河麒麟系统选择的页面,我们点击键盘上的“e”键,进入电脑启动项编辑页 2.编辑启动页 在启动项编辑页面,我们将光标移动到linux这一行的最后,然后输入“init/bin/bash consoletty0”…

给一个容器添加el-popover/el-tooltip内容提示框

效果&#xff1a; html: <div class"evaluate"><div class"list flex-column-center" v-for"(item, index) in evaluateList" :key"index"mouseenter"mouseenterHandler(item)" mouseleave"mouseleaveHandle…

【Vue第5章】vuex_Vue2

目录 5.1 理解vuex 5.1.1 vuex是什么 5.1.2 什么时候使用vuex 5.1.3 案例 5.1.4 vuex工作原理图 5.2 vuex核心概念和API 5.2.1 state 5.2.2 actions 5.2.3 mutations 5.2.4 getters 5.2.5 modules 5.3 笔记与代码 5.3.1 笔记 5.3.2 23_src_求和案例_纯vue版 5.3…

什么是跨站脚本攻击(XSS)?如何防止它?

聚沙成塔每天进步一点点 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 欢迎来到前端入门之旅&#xff01;感兴趣的可以订阅本专栏哦&#xff01;这个专栏是为那些对Web开发感兴趣、刚刚踏入前端领域的朋友们量身打造的。无论你是完全的新手还是有一些基础的开发…

【面试】数据库—优化—聚簇索引和非聚簇索引、回表查询

数据库—优化—聚簇索引和非聚簇索引、回表查询 1. 什么是聚簇索引什么是非聚簇索引 ? 聚集索引选取规则: 如果存在主键&#xff0c;主键索引就是聚集索引&#xff1b;如果不存在主键&#xff0c;将使用第一个唯一&#xff08;UNIQUE&#xff09;索引作为聚集索引&#xff1b…

【移动通讯】【MIMO】[P1]【科普篇】

前言&#xff1a; 前面几个月把CA 的技术总体复盘了一下,下面一段时间 主要结合各国一些MIMO 技术的文档,复盘一下MIMO. 这篇主要参考华为&#xff1a; info.support.huawei.com MIMO 技术使用多天线发送和接受信号。主要应用在WIFI 手机通讯等领域. 这种技术提高了系统容量&…

python 内存泄露

Python的内存泄漏问题主要是由于以下几个原因导致的&#xff1a; 循环引用&#xff1a;当两个或多个对象相互引用&#xff0c;并且没有其他引用指向这些对象时&#xff0c;即使这些对象不再被使用&#xff0c;Python也无法释放它们的内存空间&#xff0c;从而造成内存泄漏。大…

MySQL和Redis有什么区别?

目录 一、什么是MySQL 二、什么是Redis 三、MySQL和Redis的区别 一、什么是MySQL MySQL是一种开源的关系型数据库管理系统&#xff08;RDBMS&#xff09;&#xff0c;它是最流行的数据库之一。MySQL以其高性能、可靠性和易用性而闻名&#xff0c;广泛应用于各种Web应用程序…

ACM-MM2023 DITN详解:一个部署友好的超分Transformer

目录 1. Introduction2. Method2.1. Overview2.2. UFONE2.3 真实场景下的部署优化 3. 结果 Paper: Unfolding Once is Enough: A Deployment-Friendly Transformer Unit for Super-Resolution Code: https://github.com/yongliuy/DITN 1. Introduction CNN做超分的缺点 由于卷…

Leetcode—709.转换成小写字母【简单】

2023每日刷题&#xff08;五十八&#xff09; Leetcode—709.转换成小写字母 实现代码 char* toLowerCase(char* s) {int len strlen(s);for(int i 0; i < len; i) {if(s[i] > A && s[i] < Z) {s[i] tolower(s[i]);}}return s; }运行结果 之后我会持续更…

第20节: Vue3 计算属性

在UniApp中使用Vue3框架时&#xff0c;你可以使用计算属性来处理一些依赖其他属性的计算逻辑。计算属性会根据依赖属性的变化自动重新计算&#xff0c;并且只会在相关依赖发生改变时触发重新渲染。 下面是一个示例&#xff0c;演示了如何在UniApp中使用Vue3框架使用计算属性&a…

java全栈体系结构-架构师之路(持续更新中)

Java 全栈体系结构 数据结构与算法实战&#xff08;已更&#xff09;微服务解决方案数据结构模型(openresty/tengine)实战高并发JVM虚拟机实战性能调优并发编程实战微服务框架源码解读集合框架源码解读分布式架构解决方案分布式消息中间件原理设计模式JavaWebJavaSE新零售电商项…

(04730)半导体器件之晶体三极管

晶体三极管的结构和分类 晶体三极管具有三个区、两个PN结&#xff0c;从三个区分别引出三个电极而构成&#xff0c;其结构和符号如图2.1.13所示。 晶体三极管内部的三个区&#xff0c;分别称为发射区、基区和集电区&#xff0c;其中基区十分薄&#xff0c;一般为1um至几十um,掺…

单日30PB量级!火山引擎ByteHouse云原生的数据导入这么做

更多技术交流、求职机会&#xff0c;欢迎关注字节跳动数据平台微信公众号&#xff0c;回复【1】进入官方交流群 近期&#xff0c;火山引擎ByteHouse技术专家受邀参加DataFunCon2023&#xff08;深圳站&#xff09;活动&#xff0c;并以“火山引擎ByteHouse基于云原生架构的实时…