PyTorch DataLoader 学习

1. DataLoader的核心概念

DataLoader是PyTorch中一个重要的类,用于将数据集(dataset)和数据加载器(sampler)结合起来,以实现批量数据加载和处理。它可以高效地处理数据加载、多线程加载、批处理和数据增强等任务。

核心参数

  • dataset: 数据集对象,必须是继承自torch.utils.data.Dataset的类。
  • batch_size: 每个批次的大小。
  • shuffle: 是否在每个epoch开始时打乱数据。
  • sampler: 定义数据加载顺序的对象,通常与shuffle互斥。
  • num_workers: 使用多少个子进程加载数据。
  • collate_fn: 如何将单个样本合并为一个批次的函数。
  • pin_memory: 是否将数据加载到CUDA固定内存中。

2. 基本使用方法

定义数据集类

首先定义一个数据集类,该类需要继承自torch.utils.data.Dataset并实现__len____getitem__方法。

import torch
from torch.utils.data import Dataset, DataLoaderclass CustomDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, idx):sample = {'data': self.data[idx], 'label': self.labels[idx]}return sample# 创建一些示例数据
data = torch.randn(100, 3, 64, 64)  # 100个样本,每个样本为3x64x64的图像
labels = torch.randint(0, 2, (100,))  # 100个标签,0或1dataset = CustomDataset(data, labels)

创建DataLoader

使用自定义数据集类创建DataLoader对象。

batch_size = 4
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

迭代DataLoader

遍历DataLoader获取批量数据。

for batch in dataloader:data, labels = batch['data'], batch['label']print(data.shape, labels.shape)

3. 进阶技巧

自定义collate_fn

如果需要自定义如何将样本合并为批次,可以定义自己的collate_fn函数。

def custom_collate_fn(batch):data = [item['data'] for item in batch]labels = [item['label'] for item in batch]return {'data': torch.stack(data), 'label': torch.tensor(labels)}dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=custom_collate_fn)

使用Sampler

Sampler定义了数据加载的顺序。可以自定义一个Sampler来实现更复杂的数据加载策略。

from torch.utils.data import Samplerclass CustomSampler(Sampler):def __init__(self, data_source):self.data_source = data_sourcedef __iter__(self):return iter(range(len(self.data_source)))def __len__(self):return len(self.data_source)custom_sampler = CustomSampler(dataset)
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=custom_sampler, num_workers=2)

数据增强

在图像处理中,数据增强(Data Augmentation)是提高模型泛化能力的一种有效方法。可以使用torchvision.transforms进行数据增强。

import torchvision.transforms as transformstransform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])dataset = CustomDataset(data, labels, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

4. 实战示例:CIFAR-10数据集

以下是使用CIFAR-10数据集的完整示例代码,包括数据加载、数据增强和模型训练。

import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10# 定义数据增强和标准化
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])# 加载训练和测试数据集
trainset = CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)testset = CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)# 定义简单的卷积神经网络
import torch.nn as nn
import torch.nn.functional as Fclass SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.fc1 = nn.Linear(64 * 8 * 8, 512)self.fc2 = nn.Linear(512, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 64 * 8 * 8)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 创建模型、定义损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练模型
for epoch in range(10):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 100 == 99:print(f'Epoch {epoch + 1}, Batch {i + 1}, Loss: {running_loss / 100}')running_loss = 0.0print('Finished Training')# 测试模型
correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')

5. 数据加载加速技巧

使用多进程数据加载

通过设置num_workers参数,可以启用多进程数据加载,加速数据读取过程。

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

使用pin_memory

如果使用GPU进行训练,将pin_memory设置为True可以加速数据传输。

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

预取数据

使用prefetch_factor参数来预取数据,以减少数据加载等待时间。

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, prefetch_factor=2)

6. 处理不规则数据

在某些情况下,数据样本可能不规则,例如变长序列。可以使用自定义的collate_fn来处理这种数据。

def custom_collate_fn(batch):batch = sorted(batch, key=lambda x: len(x['data']), reverse=True)data = [item['data'] for item in batch]labels = [item['label'] for item in batch]data_padded = torch.nn.utils.rnn.pad_sequence(data, batch_first=True)labels = torch.tensor(labels)return {'data': data_padded, 'label': labels}dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=custom_collate_fn)

7. 使用中应注意的问题

数据加载效率

设置num_workers

  • 多线程数据加载: num_workers参数决定了用于数据加载的子进程数量。合理设置num_workers可以显著提升数据加载速度。一般来说,设置为CPU核心数的一半或等于核心数是一个不错的选择,但需要根据具体情况进行调整。
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

使用pin_memory

  • 固定内存: 当使用GPU进行训练时,将pin_memory设置为True可以加速数据从CPU传输到GPU的速度。固定内存使得数据可以直接从页面锁定内存复制到GPU内存。
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

预取数据

  • 预取因子: 使用prefetch_factor参数来预取数据,以减少数据加载等待时间。默认情况下,预取因子为2。
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, prefetch_factor=2)

数据集与DataLoader的兼容性

正确实现 __getitem____len__

  • 数据集类的实现: 确保自定义数据集类正确实现了__getitem____len__方法,确保DataLoader能够正确地索引和迭代数据。
class CustomDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, idx):sample = {'data': self.data[idx], 'label': self.labels[idx]}return sample

数据增强与预处理

数据增强

  • 变换操作: 在图像处理中,数据增强可以提高模型的泛化能力。可以使用torchvision.transforms进行数据增强和标准化。
import torchvision.transforms as transformstransform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])dataset = CustomDataset(data, labels, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

数据加载过程中的内存问题

避免内存泄漏

  • 防止内存泄漏: 在使用DataLoader时,尤其是多进程加载时,注意内存泄漏问题。确保在训练过程中及时释放不再使用的数据。

合理设置batch_size

  • 批次大小: 根据GPU显存和内存大小合理设置batch_size。过大可能导致内存不足,过小可能导致计算效率低。
batch_size = 64  # 根据实际情况调整
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

数据顺序与随机性

shufflesampler

  • 数据随机性: 在训练集上使用shuffle=True,可以在每个epoch开始时打乱数据,防止模型过拟合。
  • 使用Sampler: 对于特殊的数据加载顺序需求,可以自定义Sampler。
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

数据不一致性

自定义collate_fn

  • 处理变长序列:在处理变长序列或不规则数据时,自定义collate_fn函数,确保每个批次的数据能够正确合并。
def custom_collate_fn(batch):data = [item['data'] for item in batch]labels = [item['label'] for item in batch]return {'data': torch.stack(data), 'label': torch.tensor(labels)}dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=custom_collate_fn)

数据加载调试

调试与错误处理

  • 调试: 在数据加载过程中,可以打印或检查部分数据样本,确保数据预处理和加载过程正确无误。
  • 错误处理: 使用try-except块捕捉并处理数据加载中的异常,防止程序崩溃。
for i, data in enumerate(dataloader, 0):try:inputs, labels = data['data'], data['label']# 数据处理和训练代码except Exception as e:print(f"Error loading data at batch {i}: {e}")

性能优化

数据加载性能

  • Profile数据加载: 使用profiling工具(如PyTorch的torch.utils.bottleneck)分析数据加载和训练过程中的性能瓶颈,进行相应优化。
import torch.utils.bottleneck# 在命令行运行以下命令进行性能分析
# python -m torch.utils.bottleneck <script.py>

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/46212.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

udp协议模拟远程输入指令控制xshell

不了解udp协议的可以先看一下udp协议下的socket函数_udp socket函数-CSDN博客 我之前还写过模拟实现xshell的模拟实现简单的shell-CSDN博客 如今我们要模拟的是让别人连网络连到我们主机&#xff0c;他可以执行命令&#xff1a; 1.接口 我们之前是用execl系列的函数来实现的…

第三方登录、任意用户登录漏洞总结

目录 1. 第三方昵称为XSS名称 2. 分享接口存在xss漏洞 3. 退出第三方账号仍可以登入 4. 第三方登录绑定漏洞利用(账号接管) 5. 泄漏token接口,任意账号登录 6. jwt未验参/弱密钥 7. cookie可伪造 8. 凭证过早返回 9. 逻辑漏洞导致的任意用户登录 9.1 登录完全依赖数…

IEEE(常用)参考文献引用格式详解 | LaTeX参考文献规范(IEEE Trans、Conf、Arxiv)

IEEE参考文献引用格式注意事项 期刊已正式出版&#xff08;有期卷号&#xff09;录用后在线访问即Early access&#xff08;无期卷号&#xff09; Arxiv论文会议论文IEEE缩写进阶其他 IEEE论文投稿前的参考文献格式检查&#xff01;&#xff08;如果一些细节你采用别的形式&…

香橙派AIpro:体验强劲算力,运行ROS系统

文章目录 前言一、香橙派AIpro开箱及功能介绍1.1香橙派AIpro开箱1.2香橙派AIpro功能介绍 二、香橙派AIpro资料下载及环境搭建2.1资料下载2.2环境搭建2.3使用串口启动进入开发板2.4使用HDMI线接入屏幕启动 三、部署ROS系统四、香橙派AIpro的使用和体验感受 前言 本篇文章将带体…

映射器代理工厂

我们在使用Mybatis时&#xff0c;只需要写Mapper和Dao接口就行&#xff0c;使用的时候只需要调用Dao中的方法就能完成数据的增删改查。那么Dao中的方法是谁实现的呢&#xff1f;难道Mybatis自动帮我们写了一个Dao的实现类吗&#xff1f;非也&#xff01;而是使用了映射器代理工…

在 SwiftUI 中实战使用 MapKit API

文章目录 前言新 MapKit API 的引入控制初始地图位置相机位置的双向绑定总结 前言 SwiftUI 与 MapKit 的集成在今年发生了重大变化。在之前的 SwiftUI 版本中&#xff0c;我们将 MKMapView 的基本功能封装到名为 Map 的 SwiftUI 视图中。幸运的是&#xff0c;事情发生了变化&a…

ontape备份跨服务器还原的样例

1. 查看实例备份参数文件 [gbasedbtiZ2ze5s78e4tanwe5q2znxZ ~]$ onstat -g dis Your evaluation license will expire on 2025-05-26 00:00:00 On-Line -- Up 00:00:15 -- 266536 Kbytes There are 1 servers found Server : node1 Server Number : 1 Server Type :…

虚拟环境操作

1、对虚拟环境的操作 查看虚拟环境列表 conda env list 创建虚拟环境 conda create -n 虚拟环境名称 python3.x 激活虚拟环境 conda activate 虚拟环境名称 退出虚拟环境 conda deactivate 删除虚拟环境 conda remove -n 虚拟环境名称 all 2、对虚拟环境下的包的操作…

力扣题解(分割回文串II)

132. 分割回文串 II 给你一个字符串 s&#xff0c;请你将 s 分割成一些子串&#xff0c;使每个子串都是 回文串 。 返回符合要求的 最少分割次数 思路&#xff1a; 规定dp[i]是以i位置为最后一个元素&#xff0c;&#xff08;0-i&#xff09;的最少分割次数&#xff0c;此…

硅谷并购中的牙刷测试

注&#xff1a;机翻&#xff0c;未校对。 In Silicon Valley, Mergers Must Meet the Toothbrush Test By David Gelles August 17, 2014 9:22 pm Credit Liz Grauman/The New York Times MOUNTAIN VIEW, Calif. — When deciding whether Google should spend millions or …

DP讨论——简单工厂模式

学而时习之&#xff0c;温故而知新。 敌人出招&#xff08;使用场景&#xff09; 不同的业务场景下要创建不同的对象&#xff0c;但是这些对象又有共同的特点。如何复用代码呢&#xff1f;你会想到&#xff0c;这些对象可以抽象出一个基类/抽象类就行了&#xff0c;那么随着业…

docker安装nginx并配置https

参考 docker安装nginx并配置https-腾讯云开发者社区-腾讯云 (tencent.com) 证书的生成 参见&#xff1a;SpringBoot项目配置HTTPS接口的安全访问&#xff08;openssl配置&#xff09;_配置接口访问-CSDN博客 步骤 1: 拉取Nginx镜像 docker pull nginx 好使的镜像如下&#x…

【AI】目标检测算法【R-CNN:Regions with CNN features】

1. 常用目标检测算法介绍 目标检测是计算机视觉领域的一个重要分支&#xff0c;它旨在识别并定位图像中的各种对象。以下是一些流行的目标检测算法&#xff1a; 1.1 二阶段目标检测算法 R-CNN (Regions with CNN features): 通过选择性搜索算法选取候选区域&#xff0c;然后…

vue3-vite-pinia模板

模板说明 下载 git clone https://github.com/AIxiaoHanBao/vue-template.gitmodule参数 node版本 16 UI组件库 element-plus 持久化 pinia 网络请求 axios 路由 vue-router 使用说明 权限管理目录access资源目录assets组件目录components页面目录pages网络请求目录re…

【AI原理解析】—对抗学习(AL)原理

目录 一、基本原理 二、核心模型 三、对抗性损失函数 四、训练过程 五、对抗学习的优势 六、对抗学习的挑战与解决方案 七、对抗学习的应用 八、未来展望 一、基本原理 对抗学习的核心思想是通过两个模型的相互对抗&#xff0c;使得生成模型&#xff08;Generator&…

AI agents 印象

1、flowise ai 拖拉方式用llm创建AI agents Flowise - Low code LLM Apps Builder GitHub - FlowiseAI/Flowise: Drag & drop UI to build your customized LLM flow 2、coze 也是用拖拉方式&#xff0c;还可以多Agents 引入GPT4,Claude&#xff0c;Gemin最好的模型 3、…

【数学建模】——数学规划模型

目录 一、线性规划&#xff08;Linear Programming&#xff09; 1.1 线性规划的基本概念 1.2 线性规划的图解法 模型建立&#xff1a; 二、整数规划&#xff08;Integer Programming&#xff09; 2.1 整数规划的基本概念 2.2 整数规划的求解方法 三、非线性规划&#x…

LeetCode刷题笔记第3011题:判断一个数组是否可以变为有序

LeetCode刷题笔记第3011题&#xff1a;判断一个数组是否可以变为有序 题目&#xff1a; 想法&#xff1a; 使用冒泡排序进行排序&#xff0c;在判断大小条件时加入判断二进制下数位为1的数目是否相同&#xff0c;相同则可以进行互换。最后遍历数组&#xff0c;相邻两两之间是…

Java中实现一维数组逆序交换的完整解决方案

引言 ❤❤点个关注吧~~编程梦想家&#xff08;大学生版&#xff09;-CSDN博客 在日常编程中&#xff0c;处理数组时经常会遇到需要逆序交换数组元素的情况。逆序交换即是将数组的第一个元素与最后一个元素交换&#xff0c;第二个元素与倒数第二个元素交换&#xff0c;依此类推…

浏览器出现 502 Bad Gateway的原理分析以及解决方法

目录 前言1. 问题所示2. 原理分析3. 解决方法 前言 此类问题主要作为疑难杂症 1. 问题所示 2. 原理分析 502 Bad Gateway 错误表示服务器作为网关或代理时&#xff0c;从上游服务器收到了无效的响应 通常出现在充当代理或网关的网络服务器上&#xff0c;例如 Nginx、Apache…