PyTorch中的 Dataset、DataLoader 和 enumerate()

PyTorch:关于Dataset,DataLoader 和 enumerate()

本博文主要参考了 Pytorch中DataLoader的使用方法详解 和 pytorch:关于enumerate,Dataset和Dataloader 两篇文章进行总结和归纳。

DataLoader 隶属 PyTorch 中 torch.utils.data 下的一个类,任何继承 torch.utils.data.Data 类的子类均需要重载__getitem__()及__len__()两个函数,且子类在__init__()函数产生的数据路径,将作为 DataLoader 参数 DataSets 的实参。该类将自定义的 Dataset 根据 batch size 大小、是否 shuffle 等封装成一个 Batch Size 大小的 Tensor,用于后面的训练。

Dataset 类构建

在构建数据集类时,除了__init__(self),还要有__len__(self)与__getitem__(self,item)两个方法,这三个是必不可少的,至于其它用于数据处理的函数,可以任意定义。这里的 Dateset 可以指整个数据集,也可以是训练集,测试集等。

class Dataset:def __init__(self,...):...def __len__(self,...):return ndef __getitem__(self,item):return data[item]

正常情况下,该数据集是要继承 Pytorch 中 Dataset 类的,但实际操作中,即使不继承,数据集类构建后仍可以用 Dataloader() 加载的。

在dataset类中,len(self)返回数据集中数据的总个数,getitem(self,item)表示每次返回第 item 条(个)数据。
①__init__:传入数据,或者像下面一样直接在函数里加载数据
②__len__:返回这个数据集一共有多少个 item
③__getitem__:返回一条(个)训练样本的数据,并将其转换成 tensor

在 dataset 实例化时一般要传入数据集的路径,一般在__init__() 函数中指定数据集路径等相关信息(可以通过相关路径读取包含图像名称、标签等相关信息的 json 或者 csv 等类型的文件);通过__getitem__(self,item) 得到对应的图像并将进行 transform 转换(缩放、裁剪、转换成 tensor 等操作),最终以 tensor 的形式返回。

DataLoader 使用

在构建 Dataset 类后,即可使用 DataLoader 加载。DataLoader 中常用参数如下:

  1. dataset:需要载入的数据集,如前面构造的 dataset 类。
  2. batch_size:批大小,在神经网络训练时我们很少逐条数据训练,而是几条数据作为一个 batch 进行训练。
  3. shuffle:是否在打乱数据集样本顺序。True 为打乱,False 反之。
  4. num_workers:这个参数决定了有几个进程来处理 data loading。0 意味着所有的数据都会被 load 进主进程。(默认num_workers=0,在 Windows 系统下需要设置为 0
  5. drop_last:是否舍去最后一个batch的数据(很多情况下数据总数 N 与 batch size 不整除,导致最后一个 batch 不为 batch size)。True 为舍去,False 反之。

注意:使用 DataLoader 读取数据时,为了加快效率,所以使用了多个线程,即 num_workers 不为0,在 windows 系统下报如下的错误。
RuntimeError: Couldn’t open shared file mapping: <torch_16716_3565374679>, error code: <1455>

DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 

参照 DataLoader windows平台下 多线程读数据报错 | BrokenPipeError: [Errno 32] Broken pipe | freeze_support() 教程中提到,在 https://github.com/pytorch/pytorch/pull/5585 中给出了一些官方解释,应该是 Windows下的一些线程文件读写的问题。
在 Windows 上,FileMapping 对象应必须在所有相关进程都关闭后,才能释放。启用多线程处理时,子进程将创建 FileMapping,然后主进程将打开它。 之后当子进程将尝试释放它的时候,因为父进程还在引用,所以它的引用计数不为零,无法释放。 但是当前代码没有提供在可能的情况下再次关闭它的机会。这个版本官方说 num_workers=1 是可以用的,更多的线程还在解决,不过现在即便是用 2 个子进程也已经可以了。

加载数据的过程

pytorch 中加载数据的顺序是:

  1. 创建一个 dataset 对象
  2. 创建一个 dataloader 对象
  3. 循环 dataloader 对象,将 data, label 拿到模型中去训练

enumerate() 函数

在对 Dataloader 进行读取时,通常使用 enumerate() 函数,enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标。调用 enumerate(dataloader) 时每次都会读出一个 batch_size 大小的数据。例如,数据集中总共包含 245 张图像,train_loader = dataloader(dataset, batch_size=32, drop_last=True) 被实例化时,经过以下代码后输出的 count 为 224(正好等于32*7),而多出来的 245-224=21 张图像不够一个 batch 因此被 drop 掉了。下面展示了如何从 dataloader 中通过 enumerate() 返回一个batch_size的数据。

for k, images, target in enumerate(dataloader):

其中,k代表下标值,images, target 代表可遍历的数据对象。因为 enumerate(dataloader) 一次会返回一个 batch 的数据,所以返回的 images 为 batch_size 长度的list,target 也为 batch_size 长度的 list。

通常,dataloader 里包含很多个数据对象,那么我们应该怎么保证 batch 就是我们所需要的数据呢?通过 Dataset 的定义可以实现我们需要的数据。Dataset 是用来定义数据从哪里读取,以及如何读取的问题,通过重写 Dataset 抽象类的__getitem__()函数。enumerate(dataloader) 得到的数据就是 __getitem__() 函数返回的数据,只不过 enumerate(dataloader) 一次会得到 batch_size 个不同 item 的数据组成的 list。

def __getitem__(self, item):images = self.data[item]target = self.label[item]return images, target

返回 item 对应的数据,就是 enumerate(dataloader) 得到的数据的一部分。

def __len__(self):return len(self.data)

返回 dataset 中总的数据个数,用于控制返回多少个 batch 的数据,enumerate(dataloader) 一次会返回 batch_size 大小的 list。

Reference

Pytorch中DataLoader的使用方法详解
pytorch:关于enumerate,Dataset和Dataloader
DataLoader windows平台下 多线程读数据报错 | BrokenPipeError: [Errno 32] Broken pipe | freeze_support()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/584861.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据结构之树 --- 二叉树

目录 定义二叉树的结构体 二叉树的遍历 递归遍历 非递归遍历 链式二叉树的实现 二叉树的功能接口 先序遍历创建二叉树 后序遍历销毁二叉树 先序遍历查找树中值为x的节点 层序遍历 上篇我们对二叉树的顺序存储堆进行了讲述&#xff0c;本文我们来看链式二叉树。 定…

SpringCloud(H版alibaba)框架开发教程之nacos做配置中心——附源码(2)

上篇主要讲了使用eureka&#xff0c;zk&#xff0c;nacos当注册中心 这篇内容是nacos配置中心 代码改动部分mysql驱动更新到8.0&#xff0c;数据库版本升级到了8.0&#xff0c;nacos版本更新到了2.x nacos2.x链接 链接&#xff1a;https://pan.baidu.com/s/11nObzgTjWisAfOp…

探秘交互设计:深入了解五大核心维度!

交互式设计是用户体验&#xff08;UX&#xff09;设计的重要组成部分。本文将解释什么是交互设计&#xff0c;并分享一些有用的交互设计模型&#xff0c;并简要描述交互设计师通常做什么。 如何解释交互设计 交互式设计可以用一个简单的术语来理解&#xff1a;它是用户和产品…

探索深度学习在自然语言处理中的应用

摘要&#xff1a; 随着人工智能技术的不断发展&#xff0c;深度学习在自然语言处理领域的应用越来越广泛。本文将探讨深度学习在自然语言处理中的各种应用&#xff0c;包括文本分类、情感分析、机器翻译等&#xff0c;并分析其优缺点。 一、引言 自然语言处理&#xff08;NLP…

借贷协议 Tonka Finance:铭文资产流动性的新破局者

“Tonka Finance 是铭文赛道中首个借贷协议&#xff0c;它正在为铭文资产赋予捕获流动性的能力&#xff0c;并为其构建全新的金融场景。” 在 2023 年的 1 月&#xff0c;比特币 Ordinals 协议被推出后&#xff0c;包括 BRC20&#xff0c;Ordinals 等在内的系列铭文资产在包括比…

nginx源码分析-3

这一章内容讲述nginx中的事件是如何一步步添加到epoll实例中的。 在初始化http连接的函数ngx_http_init_connection中&#xff0c;nginx为http连接初始化了处理请求的回调函数&#xff0c;之后调用ngx_handle_read_event函数对可读数据进行处理。这里只为连接设置read而没有设…

Ubuntu22.04 安装教程

系统下载 Ubuntu官网下载 清华源镜像 安装流程 1. 选择安装语言 2. 选择是否在安装时更新 为了系统安装速度一般选择安装时不更新&#xff0c;安装后自行更新 3. 选择系统语言和键盘布局 4. 选择安装模式 5. 配置网络信息 6. 设置静态IP 7. 配置代理信息 8. 配置Ubuntu镜像…

非常好用的ocr图片文字识别技术,识别图片中的文字

目录 一.配置环境 二.应用 2.1常见图片识别 2.2排版简单的印刷体截图图片识别 2.3竖排文字识别 2.4英文识别 2.5繁体中文识别 2.6单行文字的图片识别 三.参考 一.配置环境 pip3 install cnocr -i https://pypi.tuna.tsinghua.edu.cn/simple pip3 install onnxruntime…

在电脑上免费分区的 5 个有效磁盘分区软件工具

磁盘分区可能是一个脆弱而复杂的过程&#xff0c;磁盘崩溃或用户设备受到病毒攻击的风险很高。因此&#xff0c;它们很难由用户单独或手动管理。本文详细介绍了可以帮助简化磁盘分区过程的不同软件工具、它们的功能和优点。那么让我们开始吧。 什么是磁盘分区工具&#xff1f; …

在STM32中集成TSL2561光强传感器的开发和调试

在STM32中集成TSL2561光强传感器的开发和调试是一个常见的应用场景。TSL2561是一款数字光传感器&#xff0c;能够测量可见光和红外光的光强&#xff0c;并通过I2C接口将数据传输给微控制器。下面将为您介绍在STM32中集成TSL2561传感器的开发步骤&#xff0c;并附上相应的代码示…

Web常用的编码和解码技术

文章目录 一、URI的编码与解码1.1 URI介绍1.2 什么是encodeURI1.3 什么是encodeURIComponent1.4 应用场景1.5 URI解码1.6 扩展&#xff1a;内置对象URL 二、字符串的Base64编码与解码2.1 ASCII字符编解码2.2 非ASCII字符编解码 一、URI的编码与解码 1.1 URI介绍 URI指的是统一…

【音视频 ffmpeg 学习】 RTMP推流 mp4文件

1.RTMP(实时消息传输协议)是Adobe 公司开发的一个基于TCP的应用层协议。 2.RTMP协议中基本的数据单元称为消息&#xff08;Message&#xff09;。 3.当RTMP协议在互联网中传输数据的时候&#xff0c;消息会被拆分成更小的单元&#xff0c;称为消息块&#xff08;Chunk&#xff…

Linux系统下隧道代理HTTP

在Linux系统下配置隧道代理HTTP是一个涉及网络技术的话题&#xff0c;主要目的是在客户端和服务器之间建立一个安全的通信通道。下面将详细解释如何进行配置。 一、了解基本概念 在开始之前&#xff0c;需要了解几个关键概念&#xff1a;代理服务器、隧道代理和HTTP协议。代理…

VsCode的介绍和入门详细讲解

VS Code&#xff08;Visual Studio Code&#xff09;是由 Microsoft 开发的一款轻量级开源编辑器&#xff0c;支持多种语言和框架的编写、调试和测试。它拥有丰富的扩展生态系统&#xff0c;可以满足不同开发者的需求。 下面是 VS Code 的入门详细讲解&#xff1a; 下载和安装…

使用Python绘制各种图表

1、折线图&#xff08;Line Chart&#xff09; import matplotlib.pyplot as plt # 数据 x [1, 2, 3, 4, 5] y [2, 4, 1, 3, 7] # 绘制折线图 plt.plot(x, y) plt.title(折线图示例) plt.xlabel(X轴) plt.ylabel(Y轴) plt.show() 2、柱状图&#xff08;Bar…

算法专题四:前缀和

前缀和 一.一维前缀和(模板)&#xff1a;1.思路一&#xff1a;暴力解法2.思路二&#xff1a;前缀和思路 二. 二维前缀和(模板)&#xff1a;1.思路一&#xff1a;构造前缀和数组 三.寻找数组的中心下标&#xff1a;1.思路一&#xff1a;前缀和 四.除自身以外数组的乘积&#xff…

3、Git分支操作与团队协作

Git分支操作 1.什么是分支2. 分支的好处3. 分支的操作3.1 查看分支3.2 创建分支3.3 切换分支3.4 修改分支3.5 合并分支3.6 产生和解决冲突 4. 创建分支和切换分支图解5. Git团队协作机制团队内协作跨团队协作 均在git bash中进行操作。事先建好本地工作库 1.什么是分支 在版本…

GLTF 编辑器实现逼真3D动物毛发效果

在线工具推荐&#xff1a; 3D数字孪生场景编辑器 - GLTF/GLB材质纹理编辑器 - 3D模型在线转换 - Three.js AI自动纹理开发包 - YOLO 虚幻合成数据生成器 - 三维模型预览图生成器 - 3D模型语义搜索引擎 要实现逼真的3D动物毛发效果&#xff0c;可以采用以下技术和方法&…

接口和抽象类

在Java编程语言中&#xff0c;抽象类和接口都是用于定义抽象概念的重要工具。它们都提供了一种方式来创建可重用的代码&#xff0c;并且都可以被其他类继承或实现。然而&#xff0c;尽管它们有一些相似之处&#xff0c;但也存在一些显著的区别。本文将探讨抽象类和接口的相同点…

Vue学习day_03

普通组件的注册 局部注册: 创建一个components的文件夹 在里面写上对应的.vue文件 在对应的vue里面写上对应的3部分 template写上对应的核心代码 盒子等 style 写上对应的css修饰 在App.vue里面进行引用 import 导包 格式是 import 起个名字 from 位置 在写一个component…