PyTorch处理数据--Dataset和DataLoader

       在 PyTorch 中,Dataset 和 DataLoader 是处理数据的核心工具。它们的作用是将数据高效地加载到模型中,支持批量处理、多线程加速和数据增强等功能。

一、Dataset:数据集的抽象

Dataset 是一个抽象类,用于表示数据集的接口。你需要继承 torch.utils.data.Dataset 并实现以下两个方法:

  • __len__(): 返回数据集的总样本数。
  • __getitem__(idx): 根据索引 idx 返回一个样本(数据和标签)。
示例:自定义 Dataset
import torch
from torch.utils.data import Datasetclass CustomDataset(Dataset):def __init__(self, data, labels, transform=None):self.data = dataself.labels = labelsself.transform = transform  # 数据预处理/增强函数def __len__(self):return len(self.data)def __getitem__(self, idx):sample = {"data": self.data[idx], "label": self.labels[idx]}if self.transform:sample = self.transform(sample)return sample
使用场景
  • 加载图像、文本、表格数据等。
  • 支持数据预处理(如归一化、裁剪)和数据增强(如随机翻转)。

二、 DataLoader:高效加载数据

DataLoader 负责将 Dataset 包装成一个可迭代对象,支持批量加载、多线程加速和数据打乱。

基本用法
from torch.utils.data import DataLoader# 假设 dataset 是你的 CustomDataset 实例
data_loader = DataLoader(dataset,batch_size=32,       # 批量大小shuffle=True,        # 是否打乱数据(训练时建议开启)num_workers=4,       # 多线程加载数据的进程数drop_last=False      # 是否丢弃最后不足一个 batch 的数据
)

 ‌遍历 DataLoader

for batch in data_loader:data = batch["data"]    # 形状:[batch_size, ...]labels = batch["label"] # 形状:[batch_size]# 将数据送入模型训练...

、pytorch内置数据集

PyTorch 提供了一系列内置数据集,这些数据集可以直接用于训练模型。这些数据集涵盖了多种领域,如图像、文本、音频等。以下是一些常用的PyTorch内置数据集:

图像数据集
  1. MNIST: 手写数字数据集,包含0到9的手写数字图片。

    from torchvision import datasets
    mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
  2. CIFAR10/CIFAR100: 包含彩色图片的数据集,CIFAR10有60000张32x32的彩色图片,分为10个类别;CIFAR100类似但有100个类别。

    cifar10_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
  3. ImageNet: 包含超过1400万张图片的非常庞大的数据集,常用于图像识别和分类任务。

    import torchvision.datasets as datasets
    imagenet_train = datasets.ImageNet(root='./data', split='train', download=True)
  4. STL10: 一个用于计算机视觉研究的小型图像数据集,包含96x96的彩色图片。

    stl10_train = datasets.STL10(root='./data', split='train', download=True)
  5. SVHN: 包含数字图片的数据集,与MNIST类似但包含更多实际场景的图片。

    svhn_train = datasets.SVHN(root='./data', split='train', download=True, transform=transform)
文本数据集

    1.Text8: 一个用于自然语言处理的小型文本数据集。

from torchtext.datasets import Text8
text8_train = Text8(split=('train',))

    2. AG_NEWS: 包含新闻文章的文本数据集,分为4个类别。

from torchtext.datasets import AG_NEWS
ag_news_train = AG_NEWS(split=('train',))

音频数据集  

  1. Speech Commands: 一个用于语音识别的数据集,包含约65,000个单词发音的音频文件。 

from torchaudio.datasets import SPEECHCOMMANDS
speech_commands = SPEECHCOMMANDS(root="./data", download=True)

 使用方法
要使用这些数据集,首先需要导入torchvision(对于图像数据集)、torchtext(对于文本数据集)或torchaudio(对于音频数据集),然后使用其提供的类来加载数据。通常还包括一些数据预处理步骤,例如转换(transforms)。

import torchvision.transforms as transforms
from torchvision import datasetstransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

四、完整代码示例

步骤 1:创建数据集
import numpy as np
from torch.utils.data import Dataset, DataLoader# 生成示例数据(假设是 10 个样本,每个样本是长度为 5 的向量)
data = np.random.randn(10, 5)
labels = np.random.randint(0, 2, size=(10,))  # 二分类标签class MyDataset(Dataset):def __init__(self, data, labels):self.data = torch.tensor(data, dtype=torch.float32)self.labels = torch.tensor(labels, dtype=torch.long)def __len__(self):return len(self.data)def __getitem__(self, idx):return {"data": self.data[idx],"label": self.labels[idx]}dataset = MyDataset(data, labels)
步骤 2:创建 DataLoader
data_loader = DataLoader(dataset,batch_size=2,shuffle=True,num_workers=2
)

 ‌步骤 3:使用 DataLoader 训练模型

model = ...  # 你的模型
optimizer = torch.optim.Adam(model.parameters())
loss_fn = torch.nn.CrossEntropyLoss()for epoch in range(10):for batch in data_loader:x = batch["data"]y = batch["label"]# 前向传播outputs = model(x)loss = loss_fn(outputs, y)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()

五、常见问题解决

(1)数据格式不匹配
  • 问题‌:DataLoader 返回的数据形状与模型输入不匹配。
  • 解决‌:检查 Dataset 的 __getitem__ 返回的数据类型和形状,确保与模型输入一致。
(2)多线程加载卡顿
  • 问题‌:设置 num_workers>0 时程序卡死或报错。
  • 解决‌:在 Windows 系统中,多线程可能需要将代码放在 if __name__ == "__main__": 块中运行。
(3)数据增强
  • 使用 torchvision.transforms 中的工具(如 RandomCropRandomHorizontalFlip)对图像数据进行增强:
    from torchvision import transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5]),
    ])
    
(4)内存不足
  • 对于大型数据集,使用 torch.utils.data.DataLoader 的 persistent_workers=True(PyTorch 1.7+)或优化数据加载逻辑。

六、高级功能

  • 分布式训练‌:使用 torch.utils.data.distributed.DistributedSampler 配合多 GPU。
  • 预加载数据‌:使用 torch.utils.data.TensorDataset 直接加载 Tensor 数据。
  • 自定义采样器‌:通过 sampler 参数控制数据采样顺序(如平衡类别采样)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/73713.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android 蓝牙/Wi-Fi通信协议之:经典蓝牙(BT 2.1/3.0+)介绍

在 Android 开发中,经典蓝牙(BT 2.1/3.0)支持多种协议,其中 RFCOMM/SPP(串口通信)、A2DP(音频流传输)和 HFP(免提通话)是最常用的。以下是它们在 Android 中的…

R002-云计算

1 概念 英文名:Cloud Computing 核心:云计算的核心概念就是以互联网为中心,在网站上提供快速且安全的云计算服务与数据存储,让每一个使用互联网的人都可以使用网络上的庞大计算资源与数据中心 2.分类 基础设施即服务(IaaS)它向…

降维(DimensionalityReduction)基础知识2

文章目录 五、基于局部结构保持的降维1、Laplacian Eigenmaps(拉普拉斯特征映射)(1)邻接矩阵(2)图论基础(3)Laplace算子1、散度(Divergence)2、拉普拉斯算子3…

物联网中的物模型是什么意思,在嵌入式软件开发中如何体现?

1. 物模型的概念 物模型(Thing Model)是物联网中对物理设备或虚拟设备的抽象描述,定义了设备的属性、事件和服务。它是设备与云平台或其他设备之间交互的基础,用于统一描述设备的能力和行为。 1.1 物模型的组成 属性&#xff0…

【蓝桥杯】单片机设计与开发,PWM

一、PWM概述 用来输出特定的模拟电压。 二、PWM的输出 三、例程一:单片机P34引脚输出1kHZ的频率 void Timer0Init(void);unsigned char PWMtt 0;void main(void) {P20XA0;P00X00;P20X80;P00XFF;Timer0Init();EA1;ET01;ET11;while(1);}void Timer0Init(void) //1…

C#中,什么是委托,什么是事件及它们之间的关系

1. 委托(Delegate) 定义与作用 ‌委托‌是类型安全的函数指针,用于封装方法,支持多播(链式调用)。‌核心能力‌:将方法作为参数传递或异步回调。 使用场景 回调机制(如异步操作完…

从替代到超越,禅道国产化替代解决方案2.0发布!

3月22日,由禅道携手上海惠艾信息科技、麦哲思科技共同举办的禅道・中国行北京站活动圆满落下帷幕。 除深入探究AI赋能研发项目管理外,禅道在活动现场正式发布了《禅道国产化替代解决方案2.0》,助力企业全方位构建自主可控的研发项目管理新体…

【VirtualBox 安装 Ubuntu 22.04】

网上教程良莠不齐,有一个CSDN的教程虽然很全面,但是截图冗余,看蒙了给我,这里记录一个整洁的教程链接。以备后患。 下载安装全流程 UP还在记录生活,看的我好羡慕,呜呜。 [VirtualBox网络配置超全详解]&am…

2025美国网络专线国内服务商推荐

在海外业务竞争加剧的背景下,稳定高效的美国网络专线已成为外贸企业、跨国电商及跨国企业的刚需。面对复杂的国际网络环境和严苛的业务要求,国内服务商Ogcloud凭借其创新的SD-WAN技术架构与全球化网络布局,正成为企业拓展北美市场的优选合作伙…

2.2.2 引入配置文件和定义配置类

本实战通过三种方式实现Spring Boot中的配置加载与管理。首先,通过PropertySource加载自定义配置文件,结合ConfigurationProperties注解将配置文件中的属性绑定到Java类中,实现配置的灵活管理。其次,利用ImportResource加载XML配置…

Django:构建高性能Web应用

引言:为何选择Django? 在当今快速发展的互联网时代,Web应用的开发效率与可维护性成为开发者关注的核心。Django作为一款基于Python的高级Web框架,以其"开箱即用"的特性、强大的ORM系统、优雅的URL路由设计,…

【银河麒麟高级服务器操作系统 】虚拟机运行数据库存储异常现象分析及处理全流程

更多银河麒麟操作系统产品及技术讨论,欢迎加入银河麒麟操作系统官方论坛 https://forum.kylinos.cn 了解更多银河麒麟操作系统全新产品,请点击访问 麒麟软件产品专区:https://product.kylinos.cn 开发者专区:https://developer…

《2核2G阿里云神操作!Ubuntu+Ollama低成本部署Deepseek模型实战》

简介: “本文为AI开发者揭秘如何在阿里云2核2G轻量级ECS服务器上,通过Ubuntu系统与Ollama框架实现Deepseek模型的高效部署。无需昂贵硬件,手把手教程涵盖环境配置、资源优化及避坑指南,助力初学者用极低成本在云端跑通行业领先的大…

【bug解决】NameError: name ‘fused_act_ext‘ is not defined

问题 使用basicsr库做超分的时候发现NameError: name fused_act_ext is not defined这个问题,一直不断重复的使用pip uninstall basicsr 和 BASICSR_EXTTrue pip install basicsr 发现一直没有执行编译过程,导致一直推理失败 原因 之前已经安装过basi…

Anaconda开始菜单里添加JupyterLab快捷方式

Anaconda开始菜单里添加JupyterLab快捷方式 在 Windows 系统安装 Anaconda 后,发现开始菜单只有 Jupyter Notebook,却找不到Jupyter Lab入口。其实这是因为最新版 Anaconda 默认未预装 Lab 组件,本篇介绍一种添加 Jupyter Lab入口到开始菜单…

【Qt】modbus客户端笔记

Qt 中基于 Modbus 协议的通用客户端学习笔记 一、概述 本客户端利用 Qt 的 QModbusTcpClient 实现与 Modbus 服务器的通信,具备连接、读写寄存器、心跳检测、自动重连等功能,旨在提供一个可靠且易用的 Modbus 客户端框架,方便在不同项目中集…

解决Vmware 运行虚拟机Ubuntu22.04卡顿、终端打字延迟问题

亲测可用 打开虚拟机设置,关闭加速3D图形 (应该是显卡驱动的问题,不知道那个版本的驱动不会出现这个问题,所以干脆把加速关了)

【网络】Socket套接字

目录 一、端口号 二、初识TCP/UDP协议 三、网络字节序 3.1 概念 3.2 常用API 四、Socket套接字 4.1 概念 4.2 常用API (1)socket (2)bind sockaddr结构 (3)listen (4)a…

内联函数/函数重载/函数参数缺省

一、内联函数 为了减少函数调用的开销 在函数定义前加“inline”关键字,即可定义内联函数 二、函数重载 1.名字相同 2.参数个数或者参数类型不同 编译器根据调用语句实参的个数和类型判断应该调用哪个函数 三、函数的缺省参数 定义函数的时候可以让最右边的连…

基于神经网络的文本分类的设计与实现

标题:基于神经网络的文本分类的设计与实现 内容:1.摘要 在信息爆炸的时代,大量文本数据的分类处理变得至关重要。本文旨在设计并实现一种基于神经网络的文本分类系统。通过构建合适的神经网络模型,采用公开的文本数据集进行训练和测试。在实验中&#x…