1. 定义最简单的Dataset
import torch
from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):def __init__(self, data):self.data = data # 假设data是一个列表,如[10, 20, 30, 40]def __len__(self):return len(self.data) # 返回数据总量def __getitem__(self, idx):return self.data[idx] # 返回单个数据样本# 示例数据
my_data = [10, 20, 30, 40]
dataset = MyDataset(my_data)
2. 创建DataLoader
loader = DataLoader(dataset, batch_size=2, # 每批2个样本shuffle=True) # 打乱数据顺序
3. 遍历DataLoader时的内部操作
当执行以下代码时:
for batch in loader:print(batch)
实际发生的步骤:
- DataLoader自动调用
dataset.__len__()
获取数据总量(这里是4) - 根据
batch_size=2
生成索引序列(如[1,3]
、[0,2]
,因shuffle=True而随机)- 索引生成逻辑:
- PyTorch通过以下设计保证索引不重复:
- 采样器隔离:每个epoch生成独立的随机排列。
- 批次切割:按固定步长切分排列,避免交叉。
- 全局控制:
Sampler
严格管理索引分配。
- 对每个索引调用
dataset.__getitem__(idx)
:- 第一次取
idx=1
和idx=3
→ 返回20
和40
- 自动堆叠为张量
tensor([20, 40])
- 第一次取
- 输出结果示例:
tensor([20, 40]) # 第一批 tensor([10, 30]) # 第二批
4. 关键点图解
数据集: [10, 20, 30, 40]│ │ │ │
索引: 0 1 2 3DataLoader操作:
1. 随机选索引(如[1,3]) → 取数据20和40 → 堆叠为tensor([20, 40])
2. 随机选索引(如[0,2]) → 取数据10和30 → 堆叠为tensor([10, 30])
5. 如果数据是元组
假设每个样本是(用户ID, 物品ID)
:
class PairDataset(Dataset):def __init__(self):self.pairs = [(1,101), (2,102), (3,103)] # (用户, 物品)def __len__(self):return len(self.pairs) # 必须实现:返回数据总量def __getitem__(self, idx):return self.pairs[idx] # 返回一个元组loader = DataLoader(PairDataset(), batch_size=2)
for batch in loader:print(batch)
输出:
# 每个元组字段自动堆叠
[tensor([1, 2]), tensor([101, 102])] # 第一批
[tensor([3]), tensor([103])] # 第二批(最后不足batch_size)
总结
- Dataset:定义数据存储和单个样本获取方式(必须实现
__len__
和__getitem__
) - DataLoader:
- 根据
batch_size
生成索引 - 自动调用
__getitem__
获取数据 - 将样本堆叠成批次张量
- 根据
- 核心特性:
- 支持多进程加速(
num_workers
参数) - 自动打乱数据(
shuffle=True
) - 灵活处理各种数据结构(标量、元组、字典等)
- 支持多进程加速(
这就是PyTorch数据加载的核心机制!其他复杂功能都是基于这个简单流程的扩展。