Pytorch:Dataset类和DataLoader类

文章目录

    • 一、Dataset 类
      • 1、定义
      • 2、示例
    • 二、DataLoader 类
      • 1、定义
      • 2、参数
      • 3、示例:使用 DataLoader
    • 三、总结
    • 四、实战
      • 1、load_data函数:
      • 2、IrisDataset类
      • 3、DataLoader 的使用

  在机器学习和深度学习框架中,尤其是在 PyTorch 中,DatasetDataLoader 是处理和加载数据的重要工具。这里我们详细探讨这两个类的结构、用途和如何实际使用它们。
  数据集(Dataset)是指存储和表示数据的类或接口。它通常用于封装数据,以便能够在机器学习任务中使用。数据集可以是任何形式的数据,比如图像、文本、音频等。数据集的主要目的是提供对数据的标准访问方法,以便可以轻松地将其用于模型训练、验证和测试。
  数据加载器(DataLoader)是一个提供批量加载数据的工具。它通过将数据集分割成小批量,并按照一定的顺序加载到内存中,以提高训练效率。数据加载器常用于训练过程中的数据预处理、批量化操作和数据并行处理等。

一、Dataset 类

1、定义

Dataset 是一个抽象类,用于表示一个数据集的全部内容。在 PyTorch 中,任何继承自 torch.utils.data.Dataset 的自定义数据集需要实现两个必须的方法:

  • __getitem__(self, index)
    • 这个方法应该返回一个索引处的数据点和其对应的标签。例如,在图像数据集中,这可能是一对(图像,标签)。
  • __len__(self)
    • 这个方法返回数据集中的数据点的总数,即数据集的大小。

2、示例

下面是一个简单的形象化例子,展示如何创建一个用于加载图像数据集的自定义 Dataset 类:

import torch
from torch.utils.data import Dataset
class IceCreamDataset(Dataset):def __init__(self):self.flavors = ["vanilla", "chocolate", "strawberry"]def __len__(self):return len(self.flavors)def __getitem__(self, index):return f"One scoop of {self.flavors[index]} ice cream"
ice_cream_menu = IceCreamDataset()

在这个例子中,IceCreamDataset 类定义了一个冰激凌数据。

二、DataLoader 类

1、定义

DataLoader 是一个迭代器,用于将 Dataset 封装成易于访问的数据流,支持批量加载和多进程数据加载等操作。

2、参数

  • dataset: 要加载的 Dataset 对象。
  • batch_size(可选): 每个批次加载的样本数量。即对Dataset数据集进行等分,每份(每个batch)的大小为len(dataset)/batch_size,默认为1。batch_size通常是单次训练使用的数据量。
  • shuffle(可选): 是否在每个训练周期开始时打乱数据。
  • num_workers(可选): 用于数据加载的进程数。

3、示例:使用 DataLoader

一旦定义了 Dataset,就可以使用 DataLoader 来有效地加载数据:

from torch.utils.data import DataLoader# 创建 DataLoader,每批三份不同口味的冰激凌
ice_cream_loader = DataLoader(ice_cream_menu)#等价于ice_cream_loader = DataLoader(ice_cream_menu,batch_size=1)for batch in ice_cream_loader:print(batch)

在这个例子中,data_loader 会自动管理从 dataset 中加载数据的复杂性,如批量加载、打乱顺序和多进程加载。
输出:

['One scoop of vanilla ice cream']
['One scoop of chocolate ice cream']
['One scoop of strawberry ice cream']
ice_cream_loader = DataLoader(ice_cream_menu,batch_size=2)

输出:

['One scoop of vanilla ice cream', 'One scoop of chocolate ice cream']
['One scoop of strawberry ice cream']
ice_cream_loader = DataLoader(ice_cream_menu,batch_size=3)#大于等于3的输出一样,因为就三个数据了
['One scoop of vanilla ice cream', 'One scoop of chocolate ice cream', 'One scoop of strawberry ice cream']

三、总结

通过组合使用 DatasetDataLoader,PyTorch 用户可以高效、灵活地处理大规模数据集。Dataset 提供了一个清晰的接口来访问单个数据点__getitem__),而 DataLoader 管理整个数据集的批量处理和并行加载,这两者的结合极大地简化了在训练深度学习模型时的数据处理工作。

为了简单说明,以下我们将继承Dataset类的类,说成Dataset
根据上述简单的例子,我们可以知道,Dataset可以用来导入数据集,并规定整个数据集的长度是如何计算的,并规定单个数据点的格式;而DataLoader配合Dataset使用,可以导入数据集,并规定该数据集划分的批次数量和批次大小,以及导入数据集时是否打乱数据等。


对于:

for batch in dataloader:pass

每一个batch实际上就是DataLoaderDataset划分成的一个批次,每个batch的大小就是batch_size(除非数据集不是它的整数倍,上面也有体现)。所有batch加起来才构成整个Dataset
如果是图片数据集,batch_size可以认为,一个batchbatch_size张图片(如果该数据集规定单个数据点是一张图片的话。)(因为DataLoader访问数据时,会按照Dataset规定的数据点规格访问)。

四、实战

以上是一个简单的实例,方便理解,现在我们进行实战。

import torch
from sklearn.datasets import load_iris
from torch.utils.data import Dataset, DataLoader# 此函数用于加载鸢尾花数据集
def load_data(shuffle=True):x = torch.tensor(load_iris().data)y = torch.tensor(load_iris().target)# 数据归一化x_min = torch.min(x, dim=0).valuesx_max = torch.max(x, dim=0).valuesx = (x - x_min) / (x_max - x_min)if shuffle:idx = torch.randperm(x.shape[0])x = x[idx]y = y[idx]return x, y# 自定义鸢尾花数据类
class IrisDataset(Dataset):def __init__(self, mode='train', num_train=120, num_dev=15):super(IrisDataset, self).__init__()x, y = load_data(shuffle=True)if mode == 'train':self.x, self.y = x[:num_train], y[:num_train]elif mode == 'dev':self.x, self.y = x[num_train:num_train + num_dev], y[num_train:num_train + num_dev]else:self.x, self.y = x[num_train + num_dev:], y[num_train + num_dev:]def __getitem__(self, idx):return self.x[idx], self.y[idx]def __len__(self):return len(self.x)batch_size = 16# 分别构建训练集、验证集和测试集
train_dataset = IrisDataset(mode='train')
dev_dataset = IrisDataset(mode='dev')
test_dataset = IrisDataset(mode='test')train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

这段代码涉及到使用 PyTorch 加载和处理著名的鸢尾花(Iris)数据集,并将其分成训练集、验证集和测试集。下面逐部分详细解释:

1、load_data函数:

  1. 加载数据:

    • 使用 load_iris() 函数从 scikit-learn 库中加载鸢尾花数据集。这个函数返回包含特征(data)和目标(target)的数据结构。
    • 数据转换成 PyTorch 张量,方便后续使用 PyTorch 进行操作。
  2. 归一化:

    • 对特征进行归一化处理,使得每个特征的值范围都缩放到 [0, 1] 区间。这是通过从每个特征中减去最小值,然后除以其范围(最大值 - 最小值)来实现的。
    • 归一化有助于模型训练,因为它确保了所有特征都在相同的尺度上,从而加速学习过程。
  3. 打乱数据:

    • 如果启用 shuffle,则通过生成一个随机排列的索引并重新排序数据来打乱数据集。这通常用于训练数据集,以保证每次训练的随机性和泛化能力。
    • 这里使用的方法:
      • idx = torch.randperm(x.shape[0])x.shape[0]是二维张量的行数。torch.randperm即随机打乱(生成一个 0 到样本数量减一的随机排列),得到一个随机排列。
      • x = x[idx];y = y[idx],使用的是高级索引:使用多个整数索引访问多个元素

2、IrisDataset类

  • IrisDataset 类继承自 Dataset。它用于封装鸢尾花数据,使其可以通过 PyTorch DataLoader 使用。
  • 在构造函数中,根据 mode(训练、验证或测试)来划分数据:
    • 训练集 (train): 使用数据集的前 num_train 个样本。
    • 验证集 (dev): 紧随训练集之后的 num_dev 个样本。
    • 测试集 (test): 剩余的样本。
  • 这种方式的好处是简单易实现,但在实际应用中可能需要更复杂的交叉验证策略来更好地评估模型。

3、DataLoader 的使用

  • 对于每种数据集(训练、验证、测试),通过创建 DataLoader 实例来进行封装。这允许以批量方式加载数据,可选择是否打乱。
  • 批量大小 (batch_size):
    • 对于训练数据,使用较大的批量(例如 16),有助于稳定和加速训练过程。
    • 对于验证数据,也采用同样大小的批量,以保持一致性。
    • 对于测试数据,每批只有一个样本,这常用于评估模型时逐个样本进行处理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/2510.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

小程序 rich-text 解析富文本 图片过大时如何自适应?

在微信小程序中&#xff0c;用rich-text 解析后端返回的数据&#xff0c;当图片尺寸太大时&#xff0c;会溢出屏幕&#xff0c;导致横向出现滚动 查看富文本代码 图片是用 <img 标签&#xff0c;所以写个正则匹配一下图片标签&#xff0c;手动加上样式即可 // content 为后…

Python 面向对象——5.多态

本章学习链接如下&#xff1a; Python 面向对象——1.基本概念 Python 面向对象——2.类与对象实例属性补充解释&#xff0c;self的作用等 Python 面向对象——3.实例方法&#xff0c;类方法与静态方法 Python 面向对象——4.继承 1.基本概念 多态是面向对象编程&#x…

kafka架构

kafka架构 Kafka是一种分布式流处理平台&#xff0c;由Apache软件基金会开发。它采用发布-订阅模式&#xff0c;可以持久化和高效地处理大规模数据流。 Kafka的架构主要由以下几个组成部分&#xff1a; Producer&#xff08;生产者&#xff09;&#xff1a;发送数据到Kafka集…

贪吃蛇(C语言版)

在我们学习完C语言 和单链表知识点后 我们开始写个贪吃蛇的代码 目标&#xff1a;使用C语言在Windows环境的控制台模拟实现经典小游戏贪吃蛇 贪吃蛇代码实现的基本功能&#xff1a; 地图的绘制 蛇、食物的创建 蛇的状态&#xff08;正常 撞墙 撞到自己 正常退出&#xf…

Python蜘蛛侠

目录 写在前面 蜘蛛侠 编写代码 代码分析 更多精彩 写在后面 写在前面 本期小编给大家推荐一个酷酷的Python蜘蛛侠&#xff0c;一起来看看叭~ 蜘蛛侠 蜘蛛侠&#xff08;Spider-Man&#xff09;是美国漫威漫画宇宙中的一位标志性人物&#xff0c;由传奇创作者斯坦李与艺…

探索ChatGPT在提高人脸识别与软性生物识准确性的表现与可解释性

概述 从GPT-1到GPT-3&#xff0c;OpenAI的模型不断进步&#xff0c;推动了自然语言处理技术的发展。这些模型在处理语言任务方面展现出了强大的能力&#xff0c;包括文本生成、翻译、问答等。 然而&#xff0c;当涉及到面部识别和生物特征估计等任务时&#xff0c;这些基于文…

设计模式-00 设计模式简介之几大原则

设计模式-00 设计模式简介之几大原则 本专栏主要分析自己学习设计模式相关的浅解&#xff0c;并运用modern cpp 来是实现&#xff0c;描述相关设计模式。 通过编写代码&#xff0c;深入理解设计模式精髓&#xff0c;并且很好的帮助自己掌握设计模式&#xff0c;顺便巩固自己的c…

用于车载T-BOX汽车级的RA8900CE

用于车载T-BOX等高精度计时的汽车级时钟模块RTC:RA8900CE.车载实时时钟芯片RA8900CE内置32.768Khz的晶体&#xff0c;实现年、月、日、星期、小时、分钟和秒精准计时。RA8900CE满足AEC-Q200认证&#xff0c;内置温补功能&#xff0c;保证实时时钟的稳定可靠&#xff0c;功耗低至…

【Linux】解决ubuntu20.04版本插入无线网卡没有wifi显示【无线网卡Realtek 8811cu】

ubuntu为Realtek 8811cu安装驱动&#xff0c;解决wifi连接问题 1、确认无线网卡的型号-Realtek 8810cu2、下载并配置驱动 一句话总结&#xff1a;先确定网卡的型号&#xff0c;然后根据网卡的型号区寻找对应的驱动下载&#xff0c;下载完成之后在ubuntu系统中进行编译&#xff…

LeetCode 123.买卖股票的最佳时机III 188.买卖股票的最佳时机IV

LeetCode 123.买卖股票的最佳时机III 题目链接&#xff1a; LeetCode 123.买卖股票的最佳时机III 代码&#xff1a; class Solution { public:int maxProfit(vector<int>& prices) {int size prices.size();if(size0) return 0;//dp[i][0] 不操作//dp[i][1]…

js如何模拟表单输入

jQuery时代&#xff0c;模拟表单输入很简单&#xff0c;本质上就是操作dom&#xff0c;选择对于的dom&#xff0c;给dom.value设置值即可。 到了react时代就不同了&#xff0c;虽然也可以通过js拿到dom&#xff0c;给dom.value设置&#xff0c;但是react的状态绑定下&#xff…

Java 执行 JVM Native 方法导致内存碎片

背景&#x1f69e; 由于需要调用到 C/C 的业务对外&#xff0c;使用了 Java 来封装 SDK 进行调用。 事故起因⚡&#xff1a;当 Java 使用 JNI 发生调用 JVM Native 本地方法时&#xff0c;发现内存一直飙升发生 OOM。 操作复现&#x1f50d; 使用 Jmeter 进行压测高并发环境…

C++笔记打卡第21天(map)

1.map基本概念 map中所有元素都是pairpair中第一个元素为key&#xff0c;起到索引作用&#xff0c;第二个元素为value所有元素都会根据元素的键值自动排序 本质&#xff1a; map/multimap属于关联式容器&#xff0c;底层结构是用二叉树实现 优点&#xff1a; 可以根据key值…

HTTP慢连接攻击的原理和防范措施

随着互联网的快速发展&#xff0c;网络安全问题日益凸显&#xff0c;网络攻击事件频繁发生。其中&#xff0c;HTTP慢速攻击作为一种隐蔽且高效的攻击方式&#xff0c;近年来逐渐出现的越来越多。 为了防范这些网络攻击&#xff0c;我们需要先了解这些攻击情况&#xff0c;这样…

【笔试】03

FLOPS FLOPS 是 Floating Point Operations Per Second 的缩写&#xff0c;意为每秒浮点运算次数。它是衡量计算机性能的指标&#xff0c;特别是用于衡量计算机每秒能够执行多少浮点运算。在高性能计算领域&#xff0c;FLOPS 被广泛用来评估超级计算机、CPU、GPU 和其他处理器…

2024年区块链链游即将迎来大爆发

随着区块链技术的不断发展和成熟&#xff0c;其应用领域也在不断扩展。其中&#xff0c;区块链链游&#xff08;Blockchain Games&#xff09;作为区块链技术在游戏行业中的应用&#xff0c;备受关注。2024年&#xff0c;区块链链游行业即将迎来爆发&#xff0c;这一趋势不容忽…

Windows10如何关闭Edge浏览器的Copilot

在Windows10更新后&#xff0c;打开Edge浏览器&#xff0c;无论复制什么内容&#xff0c;都会弹出Copilot人工智能插件&#xff0c;非常令人反感&#xff0c;网上搜索的关闭方法都非常麻烦&#xff0c;比如&#xff1a;组策略和注册表。自己摸索得出最简便有效的关闭方法。 1、…

【java毕业设计】 基于Spring Boot+mysql的高校心理教育辅导系统设计与实现(程序源码)-高校心理教育辅导系统

基于Spring Bootmysql的高校心理教育辅导系统设计与实现&#xff08;程序源码毕业论文&#xff09; 大家好&#xff0c;今天给大家介绍基于Spring Bootmysql的高校心理教育辅导系统设计与实现&#xff0c;本论文只截取部分文章重点&#xff0c;文章末尾附有本毕业设计完整源码及…

一致性hash

一、什么是一致性hash 普通的hash算法 (hashcode % size )&#xff0c;如果size发生变化&#xff0c;几乎所有的历史数据都需要重hash、移动&#xff0c;代价非常大&#xff0c;常见的java中的hashmap就是如此。 那如果在hash表扩容或者收缩的时候size能够保持不变&#xff0…

gitee / github 配置git, 实现免密码登录

文章目录 怎么配置公钥和私钥验证配置成功问题 怎么配置公钥和私钥 以下内容参考自 github ssh 配置&#xff0c;gitee的配置也是一样的&#xff1b; 粘贴以下文本&#xff0c;将示例中使用的电子邮件替换为 GitHub 电子邮件地址。 ssh-keygen -t ed25519 -C "your_emai…