浅谈torch.utils.data.TensorDataset和torch.utils.data.DataLoader

1.torch.utils.data.TensorDataset

功能定位

torch.utils.data.TensorDataset 是一个将多个张量(Tensor)数据进行简单包装整合的数据集类,它主要的作用是将相关联的数据(比如特征数据和对应的标签数据等)组合在一起,形成一个方便后续用于训练等操作的数据集对象。

例如,如果你有输入特征数据 x(形状为 [n_samples, feature_dim])和对应的标签数据 y(形状为 [n_samples]),且它们都是 torch.Tensor 类型,可以这样创建 TensorDataset

import torch
from torch.utils.data import TensorDatasetx = torch.randn(100, 10)  # 模拟100个样本,每个样本特征维度为10
y = torch.randint(0, 2, (100,))  # 模拟二分类标签dataset = TensorDataset(x, y)
特点
  • 简单包装:只是把传入的张量按照样本维度进行了对应组合,并没有对数据做复杂的预处理、采样等额外操作。

  • 索引支持:支持像普通列表那样通过索引访问其中的数据元素,例如 dataset[0] 会返回由对应索引的特征和标签组成的元组(按照传入构造函数的张量顺序)。

  • 适用于小型数据集直接使用:当数据量不大且数据格式已经整理为张量形式时,可以直接基于它来进行简单的模型训练循环等操作,不过对于批量处理等更复杂的情况支持有限,需要进一步配合其他工具。

2.torch.utils.data.DataLoader

功能定位

torch.utils.data.DataLoader 是一个用于加载数据的工具类,它围绕着给定的数据集(比如 TensorDataset 或者自定义的继承自 Dataset 的类实例等),实现了诸如批量加载数据、打乱数据顺序、并行加载数据等功能,旨在让数据能够以合适的方式、合适的批量大小等被送入到模型中进行训练、验证或测试等操作。

示例:

from torch.utils.data import DataLoaderbatch_size = 10
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)for batch_x, batch_y in dataloader:# 这里的batch_x和batch_y就是每次迭代取出的一个批量的特征和标签数据pass
特点
  • 批量处理:可以按照设定的 batch_size 参数,将数据集中的数据划分为一个个的小批量(mini-batch),方便模型以批量的方式进行梯度计算更新,有助于优化训练过程和提升效率,尤其在大数据集场景下优势明显。

  • 数据打乱:通过设置 shuffle=True 可以在每个训练轮次(epoch)开始时对数据集里面的数据顺序进行随机打乱,使得数据的输入顺序具有随机性,这有助于提升模型训练的泛化能力,避免模型因数据顺序固定而产生过拟合等问题。

  • 并行加载:支持多进程加载数据(通过设置 num_workers 参数大于 0),能够利用多核 CPU 的优势加快数据读取和预处理的速度,特别是在处理大规模数据集或者数据加载比较耗时的情况下,能显著提升整体训练效率。

  • 灵活性和通用性:它可以适配各种不同类型的数据集,只要这些数据集继承自 torch.utils.data.Dataset 抽象类并实现了必要的 __len____getitem__ 等方法,因此无论是简单的 TensorDataset 还是复杂的自定义数据集都可以用它来加载数据。

总的来说,TensorDataset 侧重于对已有张量数据进行简单的整合包装形成数据集;而 DataLoader 侧重于围绕数据集实现数据的批量加载、打乱顺序、并行化等复杂的数据加载相关功能,它们通常配合使用,先使用 TensorDataset 组织好数据,再通过 DataLoader 按照训练需求来加载和处理这些数据并送入模型中。

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoadertrain_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/65095.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Go】运行自己的第一个Go程序

运行自己的第一个Go程序 一、Go语言的安装Go环境安装查看是否安装成功配置GOPROXY(代理) 二、Goland安装三、Goland破解四、新建项目 开一篇专栏记录学习Go的过程,一门新语言从hello world开始,这篇文章详细讲解Go语言环境搭建及hello world实现 一、Go语…

计算机的错误计算(二百零一)

摘要 用两个大模型计算 ,结果保留 10位有效数字。实验表明,两个大模型的输出均只有1位正确数字;并它们几乎相同:仅最后1位数字不同。 例1. 计算 , 结果保留 10位有效数字。 下面是与一个数学解题器的对话。 以上为与一个数学解…

2024 年度时序数据库 IoTDB 论文总结

论文成果总结 2024 年度,时序数据库 IoTDB 在数据库领域 CCF-A 类国际会议上共发表论文 8 篇,包括:SIGMOD 3 篇、VLDB 3 篇、ICDE 2 篇,涵盖存储、引擎、查询、分析等方面。 2024 最后一天,我们将分类盘点 IoTDB 本年的…

ImportError: /lib/x86_64-linux-gnu/libc.so.6: version `GLIBC_2.32‘ not found

这个问题之前遇到过,没有记录,导致今天又花了2小时 原因是没有GLIBC——2.32 使用以下命令查一下有哪些版本: strings /lib/x86_64-linux-gnu/libm.so.6 | grep GLIBC_ 我已经安装好了,所有有2.32版本 原因是当前的ubuntu版本…

海南省大数据发展中心:数据资产场景化评估案例手册(第二期)

2025年1月3日,海南省数据产品超市印发《数据资产场景化评估案例手册(第二期)》(以下简称《手册》),该手册是基于真实数据要素典型应用场景进行数据资产评估操作的指导性手册,为企业在数据资产入…

python3GUI--智慧交通监控与管理系统 By:PyQt5

文章目录 一.前言二.预览三.软件组成&技术难点1.软件组成结构2.技术难点3.项目结构 四.总结 大小:35.5 M,软件安装包放在了这里! 一.前言 博主高产,本次给大家带来一款我自己使…

Linux高并发服务器开发 第八天(makefile的规则 wildcard/patsubst函数 普通变量/自动变量/其他关键字)

目录 1.makefile 1.1makefile的规则 1.2两个函数 1.3三个自动变量 1.3.1普通变量 (自定义变量) 1.3.2自动变量 1.3.3其他关键字 - ALL/all - clean 1.makefile - 作用:进行项目管理。 - 初步学习:1个规则、2个函数、3个自动变量。 - 要想使用默…

Vue动态控制disabled属性

参考:https://blog.csdn.net/guhanfengdu/article/details/126082781 在Vue中disabled:的值是受布尔值影响的,false为关闭禁用,true为开启禁用效果。 结果就是true会让按钮禁用 相反false会让按钮重新可以使用 那如果想要通过id属性值来判断是否禁用…

【DevOps】Jenkins项目发布

Jenkins项目发布 文章目录 Jenkins项目发布前言资源列表基础环境一、Jenkins发布静态网站1.1、项目介绍1.2、部署Web1.3、准备gitlab1.4、配置gitlab1.5、创建项目1.6、推送代码 二、Jenkins中创建gitlab凭据2.1、创建凭据2.2、在Jenkins中添加远程主机2.3、获取gitlab项目的UR…

Vue2/Vue3使用DataV

Vue2 注意vue2与3安装DataV命令命令是不同的Vue3 DataV - Vue3 官网地址 注意vue2与3安装DataV命令命令是不同的 vue3vite 与 Vue3webpack 对应安装也不同vue3vite npm install kjgl77/datav-vue3全局引入 // main.ts中全局引入 import { createApp } from vue import Da…

【AI学习】Transformer深入学习(二):从MHA、MQA、GQA到MLA

前面文章: 《Transformer深入学习(一):Sinusoidal位置编码的精妙》 一、MHA、MQA、GQA 为了降低KV cache,MQA、GQA作为MHA的变体,很容易理解。 多头注意力(MHA): 多头注…

trendFinder - 利用 AI 掌握社交媒体上的热门话题

1600 Stars 177 Forks 7 Issues 2 贡献者 MIT License Javascript 语言 代码: https://github.com/ericciarla/trendFinder 更多AI开源软件:AI开源 - 小众AI Trend Finder 收集并分析来自关键影响者的帖子,然后在检测到新趋势或产品发布时发送 Slack 通知…

以图像识别为例,关于卷积神经网络(CNN)的直观解释

大家读完觉得有意义记得关注和点赞!!! 作者以图像识别为例,用图文而非数学公式的方式解释了卷积神经网络的工作原理, 适合初学者和外行扫盲。 目录 1 卷积神经网络(CNN) 1.1 应用场景 1.2 起…

vim 的基础使用

目录 一:vim 介绍二:vim 特点三:vim 配置四:vim 使用1、vim 语法格式2、vim 普通模式(1)保存退出(2)光标跳转(3)文本删除(4)文本查找&…

HP 电脑开机黑屏 | 故障判断 | BIOS 恢复 | BIOS 升级

注:本文为 “HP 电脑开机黑屏 | 故障判断 | BIOS 恢复 | BIOS 升级” 相关文章合辑。 引文图片 csdn 转储异常,重传。 篇 1:Smart-Baby 回复中给出故障现象判断参考 篇 2、篇3 :HP 官方 BIOS 恢复、升级教程 开机黑屏&#xff0c…

JAVA:利用 Redis 实现每周热评的技术指南

1、简述 在现代应用中,尤其是社交媒体和内容平台,展示热门评论是常见的功能。我们可以通过 Redis 的高性能和丰富的数据结构,轻松实现每周热评功能。本文将详细介绍如何利用 Redis 实现每周热评,并列出完整的实现代码。 2、需求分…

VSCode下配置Blazor环境 断点调试Blazor项目

VSCode下使用Blazor的环境配置和插件推荐 Blazor是一种用于构建交互式Web UI的.NET框架,它可以让你使用C#、Razor和HTML进行Web开发,而不需要JavaScript。在这篇文章中,我们将介绍如何在VSCode中配置Blazor环境,并推荐一些有用的…

《Rust权威指南》学习笔记(一)

基本介绍 1.Rust使用场景 :需要运行速度、需要内存安全、更好的利用多处理器。程序员无法在安全的Rust代码中执行任何非法的内存操作。相对于C#等带有垃圾回收机制的语言来讲,Rust遵循了零开销抽象(Zero-Cost Abstraction)规则&a…

STM32-笔记26-WWDG窗口看门狗

一、简介 窗口看门狗用于监测单片机程序运行时效是否精准,主要检测软件异常,一般用于需要精准检测程序运行时间的场合。 窗口看门狗的本质是一个能产生系统复位信号和提前唤醒中断的6位计数器(有的地方说7位。其实都无所谓&#xff0…

机组的概述

计算机系统组成 硬件系统和软件系统 计算机硬件 1.冯诺依曼机基本思想 特点 1.采用“存储程序”工作方式 2.硬件系统由运算器,存储器,控制器,输入输出设备组成 3.指令和数据存在存储器中,形式无区别 4.指令和数据用二进制代…