Pytorch的torch.utils.data中Dataset以及DataLoader等详解

在我们进行深度学习的过程中,不免要用到数据集,那么数据集是如何加载到我们的模型中进行训练的呢?以往我们大多数初学者肯定都是拿网上的代码直接用,但是它底层的原理到底是什么还是不太清楚。所以今天就从内置的Dataset函数和自定义的Dataset函数做一个详细的解析。

文章目录

  • 前言
  • 1、自定义Dataset类
  • 2、torchvision.datasets
  • 3、DataLoader
  • 4、torchvision.transforms

前言

torch.utils.dataPyTorch提供的一个模块,用于处理和加载数据。该模块提供了一系列工具类和函数,用于创建、操作和批量加载数据集。

下面是 torch.utils.data 模块中一些常用的类和函数:

  • Dataset: 定义了抽象的数据集类,用户可以通过继承该类来构建自己的数据集。Dataset 类提供了两个必须实现的方法:__getitem__ 用于访问单个样本,__len__ 用于返回数据集的大小。
  • TensorDataset: 继承自 Dataset 类,用于将张量数据打包成数据集。它接受多个张量作为输入,并按照第一个输入张量的大小来确定数据集的大小。
  • DataLoader: 数据加载器类,用于批量加载数据集。它接受一个数据集对象作为输入,并提供多种数据加载和预处理的功能,如设置批量大小、多线程数据加载和数据打乱等。
  • Subset: 数据集的子集类,用于从数据集中选择指定的样本。
  • random_split: 将一个数据集随机划分为多个子集,可以指定划分的比例或指定每个子集的大小。
  • ConcatDataset: 将多个数据集连接在一起形成一个更大的数据集。
  • get_worker_info: 获取当前数据加载器所在的进程信息。

除了上述的类和函数之外,torch.utils.data 还提供了一些常用的数据预处理的工具,如随机裁剪、随机旋转、标准化等。

通过 torch.utils.data 模块提供的类和函数,可以方便地加载、处理和批量加载数据,为模型训练和验证提供了便利。但是,我们最常用的两个类还是DatasetDataLoader类。

1、自定义Dataset类

torch.utils.data.Dataset是 PyTorch 中用于表示数据集的抽象类,用于定义数据集的访问方式和样本数量。

Dataset 类是一个基类,我们可以通过继承该类并实现下面两个方法来创建自定义的数据集类:

getitem(self, index): 根据给定的索引 index,返回对应的样本数据。索引可以是一个整数,表示按顺序获取样本,也可以是其他方式,如通过文件名获取样本等。
len(self): 返回数据集中样本的数量。

import torch
from torch.utils.data import Datasetclass MyDataset(Dataset):def __init__(self, data):self.data = datadef __getitem__(self, index):# 根据索引获取样本return self.data[index]def __len__(self):# 返回数据集大小return len(self.data)# 创建数据集对象
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)# 根据索引获取样本
sample = dataset[2]
print(sample)
# 3

上面的代码样例主要实现的是一个自定义Dataset数据集类的方法,这一般都是在我们需要训练自己的数据时候需要定义的。但是一般我们作为深度学习初学者来讲,使用的都是MNIST、CIFAR-10等内置数据集,这时候就不需要再自己定义Dataset类了。至于为什么,我们下面进行详解。

2、torchvision.datasets

如果要使用PyTorch中的内置数据集,通常是通过torchvision.datasets模块来实现。torchvision.datasets模块提供了许多常用的计算机视觉数据集,如MNIST、CIFAR10、ImageNet等。

下面是使用内置数据集的示例代码:

import torch
from torchvision import datasets, transforms# 定义数据转换
transform = transforms.Compose([transforms.ToTensor(),  # 将图像转换为张量transforms.Normalize((0.5,), (0.5,))  # 标准化图像
])# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

在上述代码中,我们实现的便是一个内置MNIST(手写数字)数据集的加载和使用。可以看到,我们在这里面并未用到上面所提到的torch.utils.data.Dataset类,这是为什么呢?

这是因为在 torchvision.datasets 模块中,内置的数据集类已经实现了torch.utils.data.Dataset 接口,并直接返回一个可用的数据集对象。因此,在使用内置数据集时,我们可以直接实例化内置数据集类,而不需要显式地继承 torch.utils.data.Dataset 类。

内置数据集类(如 torchvision.datasets.MNIST)的实现已经包含了对 __getitem____len__ 方法的定义,这使得我们可以直接从内置数据集对象中获取样本和确定数据集的大小。这样,我们在使用内置数据集时可以直接将内置数据集对象传递给 torch.utils.data.DataLoader 进行数据加载和批量处理。

在内置数据集的背后,它们仍然是基于 torch.utils.data.Dataset 类进行实现,只是为了方便使用和提供更多功能,PyTorch 将这些常用数据集封装成了内置的数据集类。

为此,我专门到pytorch官网去查看了该内置数据集的加载代码,如下图所示:
在这里插入图片描述
可以看出,确实以及内置了Dataset数据集类。

3、DataLoader

torch.utils.data.DataLoader 是 PyTorch 中用于批量加载数据的工具类。它接受一个数据集对象(如 torch.utils.data.Dataset 的子类)并提供多种功能,如数据加载、批量处理、数据打乱等。

以下是 torch.utils.data.DataLoader 的常用参数和功能:

  • dataset: 数据集对象,可以是 torch.utils.data.Dataset 的子类对象。
  • batch_size: 每个批次的样本数量,默认为 1。
  • shuffle: 是否对数据进行打乱,默认为 False。在每个 epoch 时会重新打乱数据。
  • num_workers: 使用多少个子进程加载数据,默认为 0,表示在主进程中加载数据。其实在Windows系统里面都设置为0,但是在Linux中可以设置成大于0的数。
  • collate_fn: 在返回批次数据之前,对每个样本进行处理的函数。如果为 None,默认使用 torch.utils.data._utils.collate.default_collate 函数进行处理。
  • drop_last: 是否丢弃最后一个样本数量不足一个批次的数据,默认为 False
  • pin_memory: 是否将加载的数据存放在 CUDA 对应的固定内存中,默认为 False
  • prefetch_factor: 预取因子,用于预取数据到设备,默认为 2。
  • persistent_workers: 如果为 True,则在每个 epoch 中使用持久的子进程进行数据加载,默认为 False

示例代码如下:

import torch
from torchvision import datasets, transforms# 定义数据转换
transform = transforms.Compose([transforms.ToTensor(),  # 将图像转换为张量transforms.Normalize((0.5,), (0.5,))  # 标准化图像
])# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)# 使用数据加载器迭代样本
for images, labels in train_loader:# 训练模型的代码...

4、torchvision.transforms

torchvision.transforms模块是PyTorch中用于图像数据预处理的功能模块。它提供了一系列的转换函数,用于在加载、训练或推断图像数据时进行各种常见的数据变换和增强操作。下面是一些常用的转换函数的详细解释:

  1. Resize:调整图像大小

    • Resize(size):将图像调整为给定的尺寸。可以接受一个整数作为较短边的大小,也可以接受一个元组或列表作为图像的目标大小。
  2. ToTensor:将图像转换为张量

    • ToTensor():将图像转换为张量,像素值范围从0-255映射到0-1。适用于将图像数据传递给深度学习模型。
  3. Normalize:标准化图像数据

    • Normalize(mean, std):对图像数据进行标准化处理。传入的mean和std是用于像素值归一化的均值和标准差。需要注意的是,mean和std需要与之前使用的数据集相对应。
  4. RandomHorizontalFlip:随机水平翻转图像

    • RandomHorizontalFlip(p=0.5):以给定的概率对图像进行随机水平翻转。概率p控制翻转的概率,默认为0.5。
  5. RandomCrop:随机裁剪图像

    • RandomCrop(size, padding=None):随机裁剪图像为给定的尺寸。可以提供一个元组或整数作为目标尺寸,并可选地提供填充值。
  6. ColorJitter:颜色调整

    • ColorJitter(brightness=0, contrast=0, saturation=0, hue=0):随机调整图像的亮度、对比度、饱和度和色调。可以通过设置不同的参数来调整图像的样貌。

在使用的时候,我们常常通过transforms.Compose来对这些数据处理操作进行一个组合,使用的时候,直接调用该组合即可。

示例代码如下:

from torchvision import transforms# 定义图像预处理操作
transform = transforms.Compose([transforms.Resize((256, 256)),  # 缩放图像大小为 (256, 256)transforms.RandomCrop((224, 224)),  # 随机裁剪图像为 (224, 224)transforms.RandomHorizontalFlip(),  # 随机水平翻转图像transforms.ToTensor(),  # 将图像转换为张量transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化图像
])# 对图像进行预处理
image = transform(image)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/44433.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

c++通过gsop调用基于https的webservice接口总结

ww哦步骤: 第一步:生成头文件 webservice接口一般会有一个对外接口文档。比如:http://www.webxml.com.cn/WebServices/WeatherWebService.asmx?WSDL 问号后面的参数表示WSDL文档,是一个XML文档,看不懂配置没关系&a…

Kotlin实战之获取本地配置文件、远程Apollo配置失败问题排查

背景 Kotlin作为一门JVM脚本语言,收到很多Java开发者的青睐。 项目采用JavaKotlin混合编程。Spring Boot应用开发,不会发生变动的配置放在本地配置文件,可能会变化的配置放在远程Apollo Server。 问题 因为业务需要,需要增加一…

Java日志框架-JUL

JUL全称Java util logging 入门案例 先来看着入门案例,直接创建logger对象,然后传入日志级别和打印的信息,就能在控制台输出信息。 可以看出只输出了部分的信息,其实默认的日志控制器是有一个默认的日志级别的,默认就…

智慧水利利用4G物联网技术实现远程监测、控制、管理

智慧水利工业路由器是集合数据采集、实时监控、远程管理的4G物联网通讯设备,能够让传统水利系统实现智能化的实时监控和远程管理。工业路由器利用4G无线网络技术,能够实时传输数据和终端信息,为水利系统的运维提供有效的支持。 智慧水利系统是…

树莓派和windows之间文件传输

方案一:FileZilla 在电脑上下载FileZilla软件并打开,输入配置信息,用户名/密码、树莓派的IP地址,点击“快速连接” 方案二:samba 树莓派安装 samba 软件 sudo apt-get install samba samba-common-bin 修改配置文件 / etc /samba…

unity 之 GetComponent 获取游戏对象上组件实例方法

GetComponent 简单介绍 GetComponent 是Unity引擎中用于获取游戏对象上组件实例的方法。它允许您从游戏对象中获取特定类型的组件&#xff0c;以便在脚本中进行操作和交互。 GetComponent< ComponentType >(): 这是一个泛型方法&#xff0c;用于从当前游戏对象上获取指定…

HTML详解连载(8)

HTML详解连载&#xff08;8&#xff09; 专栏链接 [link](http://t.csdn.cn/xF0H3)下面进行专栏介绍 开始喽浮动-产品区域布局场景 解决方法清除浮动方法一&#xff1a;额外标签发方法二&#xff1a;单伪元素法方法三&#xff1a;双伪元素法方法四&#xff1a;overflow浮动-总结…

小程序商品如何指定打印机

有些商家&#xff0c;可能有多个仓库。不同的仓库&#xff0c;存放不同的商品。当客户下单时&#xff0c;小程序如何自动按照仓库拆分订单&#xff0c;如何让打印机自动打印对应仓库的订单呢&#xff1f;下面就来介绍一下吧。 1. 设置订单分发模式。进入管理员后台&#xff0c…

不花一分钱,利用免费电脑软件将视频MV变成歌曲音频MP3

教程 1.点击下载电脑软件下载地址&#xff0c;点击下载&#xff0c;安装。&#xff08;没有利益关系&#xff0c;没有打广告&#xff0c;只是单纯教学&#xff09; 2.安装完成后&#xff0c;点击格式工厂 3.然后如图所示依次&#xff0c;点击【音频】->【-MP3】 3.然后点击…

[机缘参悟-100] :今早的感悟:儒释道代表了不同的人生观、思维模式决定了人的行为模式、创业到处是陷阱、梦想与欺骗其实很容易辨认

目录 一、关于儒释道 二、关于成长性思维与固定性思维 三、关于创业 四、关于梦想与忽悠 一、关于儒释道 儒&#xff1a;逆势而为&#xff0c;修身齐家治国平天下 佛&#xff1a;万法皆空&#xff0c;众生皆苦&#xff0c;普度众生。 道&#xff1a;顺势无为&#xff0c;天…

低代码系列——初步认识低代码

低代码系列目录 一、初步认识低代码 二、低代码是什么 三、低代码平台的概念和分类 01.无代码开发平台 02.低代码应用平台(LCAP) 03.多重体验开发平台(MXDP) 04.智能业务流程管理套件(iBPMS) 四、低代码的能力指标 五、低代码平台jnpf 表单 报表 流程 权限 一、初步认识低代码 …

【剖析STL】vector

vector的介绍及使用 1.1 vector的介绍 cplusplus.com/reference/vector/vector/ vector是表示可变大小数组的序列容器。就像数组一样&#xff0c;vector也采用的连续存储空间来存储元素。也就是意味着可以采用下标对vector的元素 进行访问&#xff0c;和数组一样高效。但是…

STM32开关输入控制220V灯泡亮灭源代码(附带PROTEUSd电路图)

//main.c文件 /* USER CODE BEGIN Header */ /********************************************************************************* file : main.c* brief : Main program body************************************************************************…

md文本学习

这里写自定义目录标题 欢迎使用Markdown编辑器新的改变功能快捷键合理的创建标题&#xff0c;有助于目录的生成如何改变文本的样式插入链接与图片如何插入一段漂亮的代码片生成一个适合你的列表创建一个表格设定内容居中、居左、居右SmartyPants 创建一个自定义列表如何创建一个…

中间件的介绍

1.1 什么是中间件 中间件是介于应用系统和系统软件之间的一类软件&#xff0c;他使用系统软件所提供的基础服务&#xff0c;衔接网络上应用系统的各个部分或不同的应用&#xff0c;能够达到资源共享、功能共享的目的。 例如MySQL就可以看作是具备中间件特性的一种技术&#x…

【Ubuntu】简洁高效企业级日志平台后起之秀Graylog

简介 Graylog 是一个用于集中式日志管理的开源平台。在现代数据驱动的环境中&#xff0c;我们需要处理来自各种设备、应用程序和操作系统的大量数据。Graylog提供了一种方法来聚合、组织和理解所有这些数据。它的核心功能包括流式标记、实时搜索、仪表板可视化、告警触发、内容…

IDEA开发项目时一直出现http404错误的解决方法

系列文章目录 安装cv2库时出现错误的一般解决方法_cv2库安装失败 SQL&#xff1e; conn sys/root as sysdbaERROR:ORA-12560: TNS: 协议适配器错误的解决方案 虚拟机启动时出现“已启用侧通道缓解”的解决方法 Hypervisor launch failed&#xff1b; Processor does not pr…

Component name “Home“ should always be multi-word

错误 解决方案 在根目录找到eslintrc.js文件&#xff0c;配置关闭名称的校验&#xff0c;在该文件中&#xff0c;找到rules进行配置&#xff0c;如下代码&#xff1a; rules: {vue/multi-word-component-names: off, // 关闭名称校验}

VScode替换cmd powershell为git bash 终端,并设置为默认

效果图 步骤 1. 解决VScode缺少git bash的问题_failed to start bash - is git-bash.exe on the syst_Rudon滨海渔村的博客-CSDN博客效果解决步骤找到git安装目录下的/bin/bash.exe&#xff0c;复制其绝对路径&#xff0c;例如D:\Program Files\Git\bin\bash.exe把路径的右斜…

.netcore grpc身份验证和授权

一、鉴权和授权&#xff08;grpc专栏结束后会开启鉴权授权专栏欢迎大家关注&#xff09; 权限认证这里使用IdentityServer4配合JWT进行认证通过AddAuthentication和AddAuthorization方法进行鉴权授权注入&#xff1b;通过UseAuthentication和UseAuthorization启用鉴权授权增加…