PyTorch数据加载流程解析

1. 定义最简单的Dataset
import torch
from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):def __init__(self, data):self.data = data  # 假设data是一个列表,如[10, 20, 30, 40]def __len__(self):return len(self.data)  # 返回数据总量def __getitem__(self, idx):return self.data[idx]  # 返回单个数据样本# 示例数据
my_data = [10, 20, 30, 40]
dataset = MyDataset(my_data)
2. 创建DataLoader
loader = DataLoader(dataset, batch_size=2,  # 每批2个样本shuffle=True)  # 打乱数据顺序
3. 遍历DataLoader时的内部操作

当执行以下代码时:

for batch in loader:print(batch)

实际发生的步骤

  1. DataLoader自动调用dataset.__len__()获取数据总量(这里是4)
  2. 根据batch_size=2生成索引序列(如[1,3][0,2],因shuffle=True而随机)
    • 索引生成逻辑
    • PyTorch通过以下设计保证索引不重复:
      • 采样器隔离:每个epoch生成独立的随机排列。
      • 批次切割:按固定步长切分排列,避免交叉。
      • 全局控制Sampler严格管理索引分配。
  3. 对每个索引调用dataset.__getitem__(idx)
    • 第一次取idx=1idx=3 → 返回2040
    • 自动堆叠为张量tensor([20, 40])
  4. 输出结果示例:
    tensor([20, 40])  # 第一批
    tensor([10, 30])  # 第二批
    
4. 关键点图解
数据集: [10, 20, 30, 40]│   │   │   │
索引:     0   1   2   3DataLoader操作:
1. 随机选索引(如[1,3]) → 取数据2040 → 堆叠为tensor([20, 40])
2. 随机选索引(如[0,2]) → 取数据1030 → 堆叠为tensor([10, 30])
5. 如果数据是元组

假设每个样本是(用户ID, 物品ID)

class PairDataset(Dataset):def __init__(self):self.pairs = [(1,101), (2,102), (3,103)]  # (用户, 物品)def __len__(self):return len(self.pairs)  # 必须实现:返回数据总量def __getitem__(self, idx):return self.pairs[idx]  # 返回一个元组loader = DataLoader(PairDataset(), batch_size=2)
for batch in loader:print(batch)

输出:

# 每个元组字段自动堆叠
[tensor([1, 2]), tensor([101, 102])]  # 第一批
[tensor([3]), tensor([103])]          # 第二批(最后不足batch_size)

总结

  1. Dataset:定义数据存储和单个样本获取方式(必须实现__len____getitem__
  2. DataLoader
    • 根据batch_size生成索引
    • 自动调用__getitem__获取数据
    • 将样本堆叠成批次张量
  3. 核心特性
    • 支持多进程加速(num_workers参数)
    • 自动打乱数据(shuffle=True
    • 灵活处理各种数据结构(标量、元组、字典等)

这就是PyTorch数据加载的核心机制!其他复杂功能都是基于这个简单流程的扩展。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/75084.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

xsync脚本是一个基于rsync的工具

xsync脚本是一个基于rsync的工具,用于在集群间高效同步文件或目录。以下是xsync脚本的详细使用方法和配置步骤: 一、xsync脚本的作用 xsync脚本可以快速将文件或目录分发到集群中的多个节点,避免了手动逐台复制文件的繁琐操作。它利用rsync…

250408_解决加载大量数据集速度过慢,耗时过长的问题

250408_解决加载Cifar10等大量数据集速度过慢,耗时过长的问题(加载数据时多线程的坑) 在做Cifar10图像分类任务时,发现每个step时间过长,且在资源管理器中查看显卡资源调用异常,主要表现为,显卡…

Ansible的使用2

#### 一、Ansible变量 ##### facts变量 > facts组件是Ansible用于采集被控节点机器的设备信息,比如IP地址、操作系统、以太网设备、mac 地址、时间/日期相关数据,硬件信息等 - setup模块 - 用于获取所有facts信息 shell ## 常用参数 filter…

多模态大语言模型arxiv论文略读(六)

FashionLOGO: Prompting Multimodal Large Language Models for Fashion Logo Embeddings ➡️ 论文标题:FashionLOGO: Prompting Multimodal Large Language Models for Fashion Logo Embeddings ➡️ 论文作者:Zhen Wang, Da Li, Yulin Su, Min Yang,…

MySQL深入

体系结构 连接层:主要处理客户端的连接进行授权认证、校验权限等相关操作 服务层:如sql的接口、解析、优化在这里完成,所有跨存储引擎的操作在这里完成 引擎层:索引是在存储引擎层实现的,所以不同的存储引擎他的索引…

智能 SQL 优化工具 PawSQL 月度更新 | 2025年3月

📌 更新速览 本月更新包含 21项功能增强 和 9项问题修复,重点提升SQL解析精度与优化建议覆盖率。 一、SQL解析能力扩展 ✨ 新增SQL语法解析支持 SELECT...INTO TABLE 语法解析(3/26) ALTER INDEX RENAME/VISIBLE 语句解析&#…

数组划分使元素总和最接近

0划分 - 蓝桥云课 将一个数组划分为两个元素总和最接近的两个数组 要使得两组权值的乘积最大,根据数学原理,当两组权值越接近时,它们的乘积就越大。因此,可以将这个问题转化为一个 0 - 1 背包问题,把所有数的总和的一…

多线程代码案例(线程池)- 4

目录 引入 标准库中的线程池 -- ThreadPoolExecutor 研究一下这个方法的几个参数 1. int corePoolSize 2. int maximumPoolSize 3. long keepAliveTime 4. TimeUnit unit 5. BolckingQueue workQueue 6. ThreadFactory threadFactory 7. RejectedExecutionHandler h…

C,C++,C#

C、C 和 C# 是三种不同的编程语言,虽然它们名称相似,但在设计目标、语法特性、运行环境和应用场景上有显著区别。以下是它们的核心区别: 1. 设计目标和历史 语言诞生时间设计目标特点C1972(贝尔实验室)面向过程&#…

nginx 代理 https 接口

代码中需要真实访问的接口是:https://sdk2.028lk.com/application-localizationdev.yml文件中配置: url: http:/111.34.80.138:18100/sdk2.028lk.com/该服务器111.34.80.138上 18100端口监听,配置信息为: location /sdk2.028lk.c…

数据结构实验3.1:顺序栈的基本操作与进制转换

文章目录 一,问题描述二,基本要求三,算法分析四,示例代码五,实验操作六,运行效果 一,问题描述 在数据处理中,常常会遇到需要对链接存储的线性表进行操作的情况。本次任务聚焦于将链…

经典频域分析法(Bode图、Nyquist判据) —— 理论、案例与交互式 GUI 实现

目录 经典频域分析法(Bode图、Nyquist判据) —— 理论、案例与交互式 GUI 实现一、引言二、经典频域分析方法的基本原理2.1 Bode 图分析2.2 Nyquist 判据三、数学建模与公式推导3.1 一阶系统的频域响应3.2 多极系统的 Bode 图绘制3.3 Nyquist 判据的数学描述四、经典频域分析…

Vue知识点(5)-- 动画

CSS 动画是 Vue3 中实现组件动画效果的高效方式,主要通过 CSS transitions 和 keyframes 动画 CSS Keyframes(关键帧动画) 用来创建复杂的动画序列,可以精确控制动画的各个阶段。 核心语法: keyframes animationNa…

小型园区网实验

划分VLAN SW3 [sw3]vlan batch 2 3 20 30 [sw3]interface GigabitEthernet 0/0/1 [sw3-GigabitEthernet0/0/1]port link-type access [sw3-GigabitEthernet0/0/1]port default vlan 2 [sw3-GigabitEthernet0/0/1]int g0/0/2 [sw3-GigabitEthernet0/0/2]port link-type acces…

使用LangChain Agents构建Gradio及Gradio Tools(6)——创建自己的GradioTool

使用LangChain Agents构建Gradio及Gradio Tools(6)——创建自己的GradioTool 本篇摘要16. 使用LangChain Agents构建Gradio及Gradio Tool16.6 创建自己的GradioTool16.6.1 创建步骤16.6.2 创建示例StableDiffusionTool参考文献本章目录如下: 《使用LangChain Agents构建Grad…

SDL显示YUV视频

文章目录 1. **宏定义和初始化**2. **全局变量**3. **refresh_video_timer 函数**4. **WinMain 函数**主要功能及工作流程:总结: 1. 宏定义和初始化 #define REFRESH_EVENT (SDL_USEREVENT 1) // 请求画面刷新事件 #define QUIT_EVENT (SDL…

AnimateCC基础教学:随机抽取花名册,不能重复

一.核心代码: this.btnStartObj.addEventListener("click", switchBtn); this.btnOkObj.addEventListener("click", oKBtn); createjs.Ticker.addEventListener("tick", updateRandom); var _this this; var nameArr ["张三", &quo…

软考 系统架构设计师系列知识点 —— 设计模式之抽象工厂模式

本文内容参考: 软考 系统架构设计师系列知识点之设计模式(2)_系统架构设计师中考设计模式吗-CSDN博客 https://baike.baidu.com/item/%E6%8A%BD%E8%B1%A1%E5%B7%A5%E5%8E%82%E6%A8%A1%E5%BC%8F/2361182 特此致谢! Abstract Fac…

P2040 打开所有的灯

题目背景 pmshz在玩一个益(ruo)智(zhi)的小游戏,目的是打开九盏灯所有的灯,这样的游戏难倒了pmshz。。。 题目描述 这个灯很奇(fan)怪(ren),点一下就会将这个灯和其周围四盏灯的开关状态全部改变。现在你的任务就是就是告诉pmshz要全部打开…

汉得企业级 PaaS 平台 H-ZERO 1.12.0 发布!四大维度升级,构建企业数字化新底座

汉得企业级 PaaS 平台(以下简称"H-ZERO")是一款基于微服务架构的企业级数字化 PaaS 平台,可支持企业各类系统搭建、产品研发,帮助企业快速构架技术中台。 H-ZERO于2025年3月底正式发布 V1.12.0 ,此次发布聚…