`DataLoader` 类是 PyTorch 中用于构建数据加载器的一个重要工具,它可以对数据集进行批处理、洗牌和并行加载,以便于训练神经网络模型。
### 输入参数:
- **dataset**:数据集对象,通常是 `torch.utils.data.Dataset` 类的子类对象,用于包装需要加载的数据。
- **batch_size**:每个批次中包含的样本数量。
- **shuffle**:一个布尔值,表示是否在每个 epoch 前洗牌数据。
- **num_workers**:用于数据加载的子进程数量。
- **collate_fn**:用于自定义批处理方式的函数,通常在需要对每个批次进行一些自定义处理时使用。
- **drop_last**:一个布尔值,表示是否丢弃最后一个不完整的批次,当数据总数不能被 batch_size 整除时使用。
### 输出:
`DataLoader` 对象,可以通过迭代器的方式逐批次地加载数据,每个批次的数据以字典或元组的形式返回。下面是一个简单的示例:
```python
import torch
from torch.utils.data import Dataset, DataLoader# 定义自定义数据集类
class MyDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx]# 创建数据集对象
data = [i for i in range(100)]
dataset = MyDataset(data)# 创建 DataLoader 对象
batch_size = 10
shuffle = True
num_workers = 2
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)# 迭代加载数据
for step, batch in enumerate(train_dataloader):
```
在这个示例中,`batch` 是一个由数据组成的张量,它的大小为 `[batch_size]`。根据需要,你可以对 `collate_fn` 进行自定义来改变输出的形式。