torch.utils.data

整体架构

平时使用 pytorch 加载数据时大概是这样的:

import numpy as np
from torch.utils.data import Dataset, DataLoaderclass ExampleDataset(Dataset):def __init__(self):self.data = [1, 2, 3, 4, 5]def __getitem__(self, idx):return self.data[idx]def __len__(self):return len(self.data)def collate_fn(batch):return np.array(batch)dataset = ExampleDataset()  # create the dataset
dataloader = DataLoader(dataset=dataset,batch_size=2,shuffle=True,num_workers=4,collate_fn=collate_fn
)
for datapoint in dataloader:print(datapoint)
  1. 继承 Dataset 类,定义一个迭代器,包含两个魔法方法:__getitem__(self, idx)__len__(self),分别实现如何获取一条数据和如何设定数据长度;
  2. 定义 collate_fn 函数,设定如何组织一个 batch
  3. 实例化 Dataset,并和 collate_fn 一起传入 DataLoader,参数 batch_size 设置批大小、shuffle 设置是否打乱、num_workers 设置并行加载数据的进程数。

然而,背后到底干了什么,我们不清楚,甚至遇到 DataLoader 的如 samplerbatch_samplerworker_init_fn 的其他参数,就会懵逼。那就看一看官方文档,了解一下 torch.utils.data 是如何工作的。


上图是数据加载的整体框架图,官网说 DataLoader 组合datasetsampler,多个 workers 根据 dataset 提供的数据副本sampler 提供的 keys 并行地加载数据,并通过 collate_fn 组成 batch 供用户迭代。需要注意的有:

  1. 每个 worker 持有数据的一个副本,故占用内存主线程内存 * num_workers”;
  2. 即使用户不提供 sampler 对象 (通常不提供),DataLoader 也会根据 shuffle 参数创建一个默认的 sampler 对象;一旦提供了,其前路的 shuffle 参数不能为 True (不提供就好);
  3. 即使用户不提供 batch_sampler 对象 (通常不提供),DataLoader 也会根据 batch_sampler, drop_last 参数创建一个默认的 batch_sampler 对象;一旦提供了,其前路的 shuffle, drop_last 不能为 Truebatch_size 必须为 1 1 1sampler 必须为 None,因为创建 BatchSampler 时已经有了这些参数;

    本质上是把创建 batch_sampler 的活拉出来由用户在 DataLoader 外自定义地做了。

Dataset

分为两种:map-styleiterable-style。前者的数据可通过 [idx or key] 访问,后者的数据只能通过迭代器 next 一个个访问。所以上面架构中的采样器是对于 map-style 数据集说的iterable-style 的数据集的访问顺序由迭代器决定。

Sampler

torch.utils.data.Sampler 的子类或 Iterable,两个例子:

class AccedingSequenceLengthSampler(tu_data.Sampler[int]):def __init__(self, data: List[str]) -> None:super().__init__()self.data = datadef __len__(self) -> int:return len(self.data)def __iter__(self) -> Iterator[int]:""":return: 实现了按数据长短顺序访问数据集"""sizes = torch.tensor([len(x) for x in self.data])yield from torch.argsort(sizes).tolist()class AccedingSequenceLengthBatchSampler(tu_data.Sampler[List[int]]):def __init__(self, data: List[str], batch_size: int) -> None:super().__init__()self.data = dataself.batch_size = batch_sizedef __len__(self) -> int:return (len(self.data) + self.batch_size - 1) // self.batch_sizedef __iter__(self) -> Iterator[List[int]]:sizes = torch.tensor([len(x) for x in self.data])for batch in torch.chunk(torch.argsort(sizes), len(self)):  # 按块遍历yield batch.tolist()

Batch

batch_sampler 提供一批下标,取得一批数据后由 collate_fn 将这批数据整合:

if collate_fn is None:if self._auto_collation:collate_fn = _utils.collate.default_collateelse:  # self.batch_sampler is None: (batch_size is None) and (batch_sampler is None)collate_fn = _utils.collate.default_convert

分两种情况:

  • automatic batching is disabled:调用 default_convert 函数简单地将 NumPy arrays 转化为 PyTorch Tensor;
  • automatic batching is enabled:调用 default_collate 函数,转化会变得复杂一点:
from torch.utils import data as tu_data
import collections# %% Example with a batch of `int`s:
tu_data.default_collate([0, 1, 2, 3])
# tensor([0, 1, 2, 3])# %% Example with a batch of `str`s:
tu_data.default_collate(['a', 'b', 'c'])
# ['a', 'b', 'c']# %% Example with `Map` inside the batch:
tu_data.default_collate([{'A': 0, 'B': 1},{'A': 100, 'B': 100}
])
# {'A': tensor([0, 100]), 'B': tensor([1, 100])}, 同 key 的合并了# %% Example with `NamedTuple` inside the batch:
Point = collections.namedtuple('Point', ['x', 'y'])
tu_data.default_collate([Point(0, 0), Point(1, 1)])
# Point(x=tensor([0, 1]), y=tensor([0, 1])), 同 name 的合并了, 大概和 dict 一样吧# %% Example with `Tuple` inside the batch:
tu_data.default_collate([(0, 1), (2, 3)])
# [tensor([0, 2]), tensor([1, 3])], 对 list 内部执行 collate# %% Example with `List` inside the batch:
tu_data.default_collate([[0, 1], [2, 3]])  # [tensor([0, 2]), tensor([1, 3])], 对 list 内部执行 collate, 并没有变成二维 tensor

Multi-process Data Loading

dataset, collate_fn, and worker_init_fn are passed to each worker,大概能说明 batch 是在子进程内部合成的。

有一个需要注意的地方是内存增长问题,当 __get_item__(self, key) 访问数据时,由于 Python 对象的 refcount 机制,数据会不断地复制,从而内存爆炸。但这里说解决 number of workers * size of parent process 问题,就不追究了,反正尽量用 numpy 或 pytorch tensor 吧。
iterable-style datasets 的随机性

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/694203.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

网络入门基础

本专栏内容为:Linux学习专栏,分为系统和网络两部分。 通过本专栏的深入学习,你可以了解并掌握Linux。 💓博主csdn个人主页:小小unicorn ⏩专栏分类:网络 🚚代码仓库:小小unicorn的代…

32FLASH闪存

目录 一.FLASH简介 二.代码实现 (1)读写内部FLASH (2)读取芯片ID 一.FLASH简介 存储器地址要记得累 系统存储器是原厂写入的Bootloader程序(用于串口下载)&#xff0…

Python 写网络监控

大家好!我是爱摸鱼的小鸿,关注我,收看每期的编程干货。 网络监控是保障网络可靠性的一项重要任务。通过实时监控网络性能,我们可以及时发现异常,迅速采取措施,保障网络畅通无阻。本文将以 Python为工具&…

Windows / Linux dir 命令

Windows / Linux dir 命令 1. dir2. dir *.* > data.txt3. dir - list directory contentsReferences 1. dir 显示目录的文件和子目录的列表。 Microsoft Windows [版本 10.0.18363.900] (c) 2019 Microsoft Corporation。保留所有权利。C:\Users\cheng>dir驱动器 C 中…

线性代数:向量组的秩

目录 回顾“秩” 及 向量组线性表示 相关特性 向量组的秩 例1 例2 矩阵的“秩” 及 向量组线性表示 相关特性 向量组的秩 例1 例2

@Async引发的spring循环依赖的问题,

今天发现一个很有意思的问题,正常解决项目中产生的循环依赖,是找出今天添加的注入代码,然后一个个加lazy试过去,会涉及到类中新增的注入 但是今天修改了某个serviceimpl的方法,加入了Async方法后 就发生循环依赖了 ai…

如何实现一个K8S DevicePlugin?

什么是device plugin k8s允许限制容器对资源的使用,比如CPU和内存,并以此作为调度的依据。 当其他非官方支持的设备类型需要参与到k8s的工作流程中时,就需要实现一个device plugin。 Kubernetes提供了一个设备插件框架,你可以用…

机器视觉系统选型-为什么还要选用工业光源控制器

工业光源控制器最主要的用途是给光源供电,实现光源的正常工作。 1.开关电源启动时,电压是具有波浪的不稳定电压,其瞬间峰值电压超过了LED灯的耐压值,灯珠在多次高压冲击下严重降低了使用寿命; 2.使用专用的光源控制器&…

inBuilder低代码平台新特性推荐-第十六期

各位友友们,大家好~今天来给大家介绍一下inBuilder低代码平台社区版中的系列特性之一 —— 构件热加载! 01 概述 构件热加载指的是:构件代码修改后,无需重启应用,通过WebIDE的部署或发布工程后,即可正常调…

08-静态pod(了解即可,不重要)

我们都知道,pod是kubelet创建的,那么创建的流程是什么呐? 此时我们需要了解我们k8s中config.yaml配置文件了; 他的存放路径:【/var/lib/kubelet/config.yaml】 一、查看静态pod的路径 [rootk8s231 ~]# vim /var/lib…

华为配置直连三层组网直接转发示例

华为配置直连三层组网直接转发示例 组网图形 图1 配置直连三层组网直接转发示例组网图 业务需求组网需求数据规划配置思路配置注意事项操作步骤配置文件扩展阅读 业务需求 企业用户接入WLAN网络,以满足移动办公的最基本需求。且在覆盖区域内移动发生漫游时&#xff…

LeetCode 算法题 (数组)存在连续3个奇数的数组

问题: 输入一个数组,并输入长度,判断数组中是否存在连续3个元素都是奇数的情况,如果存在返回存在连续3个元素都是奇数的情况,不存在返回不存在连续3个元素都是奇数的情况 例一: 输入:a[1,2,3…

数论 - 博弈论(Nim游戏)

文章目录 前言一、Nim游戏1.题目描述输入格式输出格式数据范围输入样例:输出样例: 2.算法 二、台阶-Nim游戏1.题目描述输入格式输出格式数据范围输入样例:输出样例: 2.算法 三、集合-Nim游戏1.题目描述输入格式输出格式数据范围输…

React18原理: React核心对象之ReactElement对象和Fiber对象

React中的核心对象 在React应用中,有很多特定的对象或数据结构.了解这些内部的设计,可以更容易理解react运行原理列举从react启动到渲染过程出现频率较高,影响范围较大的对象,它们贯穿整个react运行时 如 ReactElement 对象如 Fi…

IO 作业 24/2/21

1、使用多线程完成两个文件的拷贝&#xff0c;第一个线程拷贝前一半&#xff0c;第二个线程拷贝后一半&#xff0c;主线程回收两个线程的资源 #include <myhead.h> //定义分支线程1 void *task1(void *arg) {int fdr-1;//只读打开被复制文件if((fdropen("./111.txt…

2024光伏展

2024年光伏展是一个专业的光伏行业展览会&#xff0c;旨在展示最新的光伏技术和产品&#xff0c;并促进光伏行业的发展和合作。 该展览会预计将吸引来自全球各地的光伏制造商、供应商、投资者和专业人士。参展的公司将有机会展示他们的最新产品和技术&#xff0c;与其他行业领导…

HTTP协议要点总结

一、什么是 HTTP 协议 1. 超文本传输协议 (HTTP &#xff0c; HyperText Transfer Protocol) 是互联网上应用广泛的一种网络协议。 是工作在 tcp/ip 协议基础上的 , 所有的 WWW 文件都遵守这个标准。 2. http1.0 短连接 http1.1 长连接 3. http 是 TCP/IP 协议的一个…

react实现转盘抽奖功能

看这个文章不错&#xff0c;借鉴 这个博主 的内容 样式是背景图片直接&#xff0c;没有设置。需要的话应该是 #bg { width: 650px; height: 600px; margin: 0 auto; background: url(turntable-bg.jpg) no-repeat; position: relative; } img[src^"pointer"] {positi…

马斯克称首位受试者可凭思维操控鼠标;字节低调推出视频模型丨 RTE 开发者日报 Vol.148

开发者朋友们大家好&#xff1a; 这里是 「RTE 开发者日报」 &#xff0c;每天和大家一起看新闻、聊八卦。我们的社区编辑团队会整理分享 RTE &#xff08;Real Time Engagement&#xff09; 领域内「有话题的 新闻 」、「有态度的 观点 」、「有意思的 数据 」、「有思考的 文…

微信小程序uniapp校园在线报修系统维修系统java+python+nodejs+php

管理员的主要功能有&#xff1a; 1.管理员输入账户登陆后台 2.个人中心&#xff1a;管理员修改密码和账户信息 3.用户管理&#xff1a;对注册的用户信息进行删除&#xff0c;查询&#xff0c;添加&#xff0c;修改 4.维修工管理&#xff1a;对维修工信息进行添加&#xff0c;修…