torch 的数据加载 Datasets DataLoaders

点赞收藏关注!
如需要转载,请注明出处!

torch的模型加载有两种方式:
Datasets & DataLoaders

torch本身可以提供两数据加载函数:
torch.utils.data.DataLoader()和torch.utils.data.Dataset()

其中torch.utils.data 是PyTorch提供的一个模块,用于处理和加载数据。该模块提供了一系列工具类和函数,用于创建、操作和批量加载数据集。

加载函数后可以实现数据集代码与模型训练代码分离,以获得更好的可读性和模块化
Dataset定义了抽象的数据集类,用户可以通过继承该类来构建自己的数据集。制作自己的数据集必须要实现三个函数:

  • init()函数在实例化Dataset对象时运行一次
  • len()函数返回数据集中样本的数量
  • getitem()函数的作用是:从给定索引 index,从数据集中加载并返回一个样本并将其转换为张量。
import torch
from torch.utils.data import Datasetclass CreateDataset(Dataset):def __init__(self, data):self.data = datadef __getitem__(self, index):# 根据索引获取样本return self.data[index]def __len__(self):# 返回数据集大小return len(self.data)# 创建数据集对象
data = [[255,255,255],[255,245,235],[225,226,227]]
dataset = CreateDataset(data)# 根据索引获取样本
sample = dataset[1]
print(sample)
# [255,245,235]

数据处理模块其他的功能:

  • TensorDataset: 继承自 Dataset 类,用于将张量数据打包成数据集。它接受多个张量作为输入,并按照第一个输入张量的大小来确定数据集的大小。对 tensor 进行打包,就好像 python 中的 zip 功能。该类通过每一个 tensor 的第一个维度进行索引。因此,该类中的 tensor 第一维度必须相等。
from torch.utils.data import TensorDataset
import torch
from torch.utils.data import DataLoadera = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99], [11, 22, 33], [44, 55, 66], [77, 88, 99], [11, 22, 33], [44, 55, 66], [77, 88, 99], [11, 22, 33], [44, 55, 66], [77, 88, 99]])
b = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2])
train_ids = TensorDataset(a, b)for x_train, y_label in train_ids:print(x_train, y_label)##############################################################################################
#tensor([11, 22, 33]) tensor(0)
#tensor([44, 55, 66]) tensor(1)
#tensor([77, 88, 99]) tensor(2)
#tensor([11, 22, 33]) tensor(0)
#tensor([44, 55, 66]) tensor(1)
#tensor([77, 88, 99]) tensor(2)
#tensor([11, 22, 33]) tensor(0)
#tensor([44, 55, 66]) tensor(1)
#tensor([77, 88, 99]) tensor(2)
#tensor([11, 22, 33]) tensor(0)
#tensor([44, 55, 66]) tensor(1)
#tensor([77, 88, 99]) tensor(2)
  • DataLoader: 数据加载器类,用于批量加载数据集。它接受一个数据集对象作为输入,并提供多种数据加载和预处理的功能,如设置批量大小、多线程数据加载和数据打乱等。DataLoader中最重要的参数就是dataset,它决定了要装载的数据集。

  • Subset: 数据集的子集类,用于从数据集中选择指定的样本。定义了一个子集的索引列表indices,它可以根据需要进行调整。然后,我们使用Subset类创建了一个名为subset的子集对象,它接受两个参数:原始数据集dataset和子集的索引列表indices。

indices = [0, 2, 4]  # 子集的索引列表
subset = Subset(dataset, indices)
  • random_split: 将一个数据集随机划分为多个子集,可以指定划分的比例或指定每个子集的大小。
import torch
import torchvision
# from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import ImageFolder
# 准备数据集
from torch import nn
from torch.utils.data import DataLoader# 定义训练的设备
device = torch.device("cuda")
#读取数据
data_transform = transforms.Compose([transforms.Resize(size=(224,224)),transforms.ToTensor(),transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5, 0.5, 0.5])
])
full_dataset = ImageFolder(r'D:\PythonSpace\data\trainTest',transform = data_transform)
# length 数据集总长度
full_data_size = len(full_dataset)
print("总数据集的长度为:{}".format(full_data_size))
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size#在这里
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
#在这里train_data_size = len(train_dataset)
test_data_size = len(test_dataset)
# 如果train_data_size=10, 训练数据集的长度为:10
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))
>>>
总数据集的长度为:100
训练数据集的长度为:80
测试数据集的长度为:20
  • ConcatDataset: 将多个数据集连接在一起形成一个更大的数据集。
#链接两个数据集
dataset = torch.utils.data.ConcatDataset([celeba_dataset, digiface_dataset]) 
#导入数据集
loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=cfg.batch_size, shuffle=True, drop_last=True, num_workers=cfg.n_workers)
  • get_worker_info: 获取当前数据加载器所在的进程信息。torch.utils.data.get_worker_info() 在worker进程中返回各种有用的信息(包括worker id、dataset replica、initial seed等),在main进程中返回None。用户可以在数据集代码和/或 worker_init_fn 中使用此函数来单独配置每个数据集副本,并确定代码是否在工作进程中运行。分片数据集特别有用。

如有帮助,点赞收藏关注!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/156969.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解锁电力安全密码:迅软DSE助您保护机密无忧

电力行业信息化水平不断提高,明显提升了电力企业的生产运营能力,然而随着越来越多重要信息存储在终端计算机中,电力面临的信息安全挑战也越来越多。 作为关键基础设施的基础,电力企业各部门产生的资料文档涵盖着大量机密信息&…

1.Qt5.15及其以上的下载

Qt5.15及其以上的下载 简介: ​ Qt是一个跨平台的C库,允许开发人员创建在不同操作系统(如Windows、macOS、Linux/Unix)和设备上具有本地外观和感觉的应用程序。Qt提供了一套工具和库,用于构建图形用户界面&#xff0…

【Linux】 find命令使用

find find命令是一种通过条件匹配在指定目录下查找对应文件或者目录的工具。匹配的条件可以是文件名称、类型、大小、权限属性、时间戳等。find命令还可以配合相关命令对匹配到的文件作出后续处理。 语法 find [路径...] [表达式] [path...]为需要查找文件所指定的路径。如果…

智能座舱架构与芯片- (12) 软件篇 中

三、智能座舱操作系统 3.1 概述 车载智能计算平台自下而上可大致划分为硬件平台、系统软件(硬件抽象层OS内核中间件)、功能软件(库组件中间件)和应用算法软件等四个部分。狭义上的OS特指可直接搭载在硬件上的OS内核;…

【C++进阶之路】第十一篇:C++的IO流

文章目录 1. C语言的输入与输出2. 流是什么3. CIO流3.1 C标准IO流3.2 C文件IO流 4.stringstream的简单介绍 1. C语言的输入与输出 C语言中我们用到的最频繁的输入输出方式就是scanf ()与printf()。 scanf(): 从标准输入设备(键盘)读取数据,并将值存放在变量中。prin…

7-sqlalchemy快速使用和原生操作、对象映射类型和增删查改、增加和基于对象的跨表查询、scoped线程安全、g对象、基本增查改和高级查询

1 sqlalchemy快速使用 2 sqlalchemy原生操作 3 sqlalchemy操作表 3.1 对象映射类型 3.2 基本增删查改 4 一对多关系 4.1 关系建立 4.2 增加和基于对象的跨表查询 4.3 一对一关系,就是一对多,只不过多的一方只有一条 5 多对多关系 5.2 增加和基于对象跨…

【人工智能Ⅰ】8-回归 降维

【人工智能Ⅰ】8-回归 & 降维 8-1 模型评价指标 分类任务 准确率、精确率与召回率、F值、ROC-AUC、混淆矩阵、TPR与FPR 回归任务 MSE、MAE、RMSE 无监督任务(聚类) 兰德指数、互信息、轮廓系数 回归任务的评价指标 1:MSE均方误差…

易航网址引导系统 v1.9 源码:去除弹窗功能的易航网址引导页管理系统

易航自主开发了一款极其优雅的易航网址引导页管理系统,后台采用全新的光年 v5 模板开发。该系统完全开源,摒弃了后门风险,可以管理无数个引导页主题。数据管理采用易航原创的JsonDb数据包,无需复杂的安装解压过程即可使用。目前系…

SVN创建分支

一 从本地创建方式可指定版本号进行分支创建。 1、在本地目录右击 -----> 点击branch/tag(分支/标签) From: 源,可指定具体的版本号, To path: 可通过"..."选择分支路径 最后点击确定,交由服务器执行创建。 二 通过SVN客…

html实现计算器源码

文章目录 1.设计来源1.1 主界面1.2 计算效果界面 2.效果和源码2.1 动态效果2.2 源代码 源码下载 作者:xcLeigh 文章地址:https://blog.csdn.net/weixin_43151418/article/details/134532725 html实现计算器源码,计算器源码,简易计…

假ArrayList导致的线上事故......

假ArrayList导致的线上事故… 线上事故回顾 晚饭时,当我正沉迷于排骨煲肉质鲜嫩,汤汁浓郁时,产研沟通群内发出一条消息,显示用户存在可用劵,但进去劵列表却什么也没有,并附含了一个视频。于是我一边吃了排…

MATLAB算法实战应用案例精讲-【目标检测】多目标跟踪(MOT)

目录 算法原理 算法步骤 评价指标 数据集 SORT和DeepSORT 关键算法 应用案例

11.21序列检测,状态机比较与代码,按键消抖原理

序列检测 用一个atemp存储之前的所有状态,即之前出现的七位 含无关项检测 要检测011XXX110 对于暂时变量的高位,位数越高就是越早出现的数字,因为新的数字存储在TEMP的最低位 不重叠序列检测 ,一组一组 011100 timescale 1ns…

【算法】二分查找-20231122

这里写目录标题 一、1089. 复写零二、917. 仅仅反转字母三、88. 合并两个有序数组四、283. 移动零 一、1089. 复写零 提示 简单 266 相关企业 给你一个长度固定的整数数组 arr ,请你将该数组中出现的每个零都复写一遍,并将其余的元素向右平移。 注意&a…

智能座舱架构与芯片- (13) 软件篇 下

四、面向服务的智能座舱软件架构 4.1 面向信号的软件架构 随着汽车电子电气架构向中央计算-域控制器的方向演进,甚至向车云一体化的方向迈进,适用于汽车的软件平台也需要进行相应的进化。 在传统的观念中,座舱域即娱乐域,座舱软…

4.Gin HTML 模板渲染

4.Gin HTML 模板渲染 Gin HTML 模板渲染 1. 全部模板放在一个目录里面的配置方法 创建用于渲染的模板html templates/index.html <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</title> …

2021秋招-算法-哈希算法-哈希表

LeetCoe-03-无重复字符的最长字串 LeetCode链接:LeetCoe-03-无重复字符的最长字串 题目理解及描述 无重复字符的最长子串难度中等3747给定一个字符串&#xff0c;请你找出其中不含有重复字符的 最长子串 的长度。示例 1:输入: “abcabcbb” 输出: 3 解释: 因为无重复字符的最…

【GitHub】保姆级使用教程

一、如何流畅访问GitHub 1、网易uu加速器 输入网址&#xff0c;无脑下载网易加速器&#xff1b;https://uu.163.com/ 下载安装完毕后&#xff0c;创建账号进行登录 登录后&#xff0c;在右上角搜索框中搜索“学术资源”&#xff0c;并点击&#xff1b; 稍等一会儿就会跳…

如何在3dMax中使用Python按类型选择对象?

如何在3dMax中使用Python按类型选择对象&#xff1f; 3dMax提供了pymxs API&#xff0c;这是MAXScript的Python包装器&#xff0c;可帮助您扩展和自定义3dMax&#xff0c;并更轻松地将其集成到基于Python的管道中。 pymxs模块包含一个运行时成员&#xff0c;该成员提供对MAXSc…

电子学会C/C++编程等级考试2022年09月(一级)真题解析

C/C++等级考试(1~8级)全部真题・点这里 第1题:指定顺序输出 依次输入3个整数a、b、c,将他们以c、a、b的顺序输出。 时间限制:1000 内存限制:65536输入 一行3个整数a、b、c,以空格分隔。 0 < a,b,c < 108输出 一行3个整数c、a、b,整数之间以一个空格分隔。样例输入…