【深度学习】Pytorch 教程(十四):PyTorch数据结构:6、数据集(Dataset)与数据加载器(DataLoader):自定义鸢尾花数据类

文章目录

  • 一、前言
  • 二、实验环境
  • 三、PyTorch数据结构
    • 1、Tensor(张量)
      • 1. 维度(Dimensions)
      • 2. 数据类型(Data Types)
      • 3. GPU加速(GPU Acceleration)
    • 2、张量的数学运算
      • 1. 向量运算
      • 2. 矩阵运算
      • 3. 向量范数、矩阵范数、与谱半径详解
      • 4. 一维卷积运算
      • 5. 二维卷积运算
      • 6. 高维张量
    • 3、张量的统计计算
    • 4、张量操作
      • 1. 张量变形
      • 2. 索引
      • 3. 切片
      • 4. 张量修改
    • 5、张量的梯度计算
    • 6、数据集(Dataset)与数据加载器(DataLoader)
      • 1. 数据集(Dataset)
      • 2. 数据加载器(DataLoader)
      • 3. 实战——鸢尾花数据集

一、前言

  本文将介绍PyTorch中数据集(Dataset)与数据加载器(DataLoader),并实现自定义鸢尾花数据类

二、实验环境

  本系列实验使用如下环境

conda create -n DL python==3.11
conda activate DL
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia

三、PyTorch数据结构

1、Tensor(张量)

  Tensor(张量)是PyTorch中用于表示多维数据的主要数据结构,类似于多维数组,可以存储和操作数字数据。

1. 维度(Dimensions)

  Tensor(张量)的维度(Dimensions)是指张量的轴数或阶数。在PyTorch中,可以使用size()方法获取张量的维度信息,使用dim()方法获取张量的轴数。

在这里插入图片描述

2. 数据类型(Data Types)

  PyTorch中的张量可以具有不同的数据类型:

  • torch.float32或torch.float:32位浮点数张量。
  • torch.float64或torch.double:64位浮点数张量。
  • torch.float16或torch.half:16位浮点数张量。
  • torch.int8:8位整数张量。
  • torch.int16或torch.short:16位整数张量。
  • torch.int32或torch.int:32位整数张量。
  • torch.int64或torch.long:64位整数张量。
  • torch.bool:布尔张量,存储True或False。

【深度学习】Pytorch 系列教程(一):PyTorch数据结构:1、Tensor(张量)及其维度(Dimensions)、数据类型(Data Types)

3. GPU加速(GPU Acceleration)

【深度学习】Pytorch 系列教程(二):PyTorch数据结构:1、Tensor(张量): GPU加速(GPU Acceleration)

2、张量的数学运算

  PyTorch提供了丰富的操作函数,用于对Tensor进行各种操作,如数学运算、统计计算、张量变形、索引和切片等。这些操作函数能够高效地利用GPU进行并行计算,加速模型训练过程。

1. 向量运算

【深度学习】Pytorch 系列教程(三):PyTorch数据结构:2、张量的数学运算(1):向量运算(加减乘除、数乘、内积、外积、范数、广播机制)

2. 矩阵运算

【深度学习】Pytorch 系列教程(四):PyTorch数据结构:2、张量的数学运算(2):矩阵运算及其数学原理(基础运算、转置、行列式、迹、伴随矩阵、逆、特征值和特征向量)

3. 向量范数、矩阵范数、与谱半径详解

【深度学习】Pytorch 系列教程(五):PyTorch数据结构:2、张量的数学运算(3):向量范数(0、1、2、p、无穷)、矩阵范数(弗罗贝尼乌斯、列和、行和、谱范数、核范数)与谱半径详解

4. 一维卷积运算

【深度学习】Pytorch 系列教程(六):PyTorch数据结构:2、张量的数学运算(4):一维卷积及其数学原理(步长stride、零填充pad;宽卷积、窄卷积、等宽卷积;卷积运算与互相关运算)

5. 二维卷积运算

【深度学习】Pytorch 系列教程(七):PyTorch数据结构:2、张量的数学运算(5):二维卷积及其数学原理

6. 高维张量

【深度学习】pytorch教程(八):PyTorch数据结构:2、张量的数学运算(6):高维张量:乘法、卷积(conv2d~ 四维张量;conv3d~五维张量)

3、张量的统计计算

【深度学习】Pytorch教程(九):PyTorch数据结构:3、张量的统计计算详解

4、张量操作

1. 张量变形

【深度学习】Pytorch教程(十):PyTorch数据结构:4、张量操作(1):张量变形操作

2. 索引

3. 切片

【深度学习】Pytorch 教程(十一):PyTorch数据结构:4、张量操作(2):索引和切片操作

4. 张量修改

【深度学习】Pytorch 教程(十二):PyTorch数据结构:4、张量操作(3):张量修改操作(拆分、拓展、修改)

5、张量的梯度计算

【深度学习】Pytorch教程(十三):PyTorch数据结构:5、张量的梯度计算:变量(Variable)、自动微分、计算图及其可视化

6、数据集(Dataset)与数据加载器(DataLoader)

  数据集(Dataset)是指存储和表示数据的类或接口。它通常用于封装数据,以便能够在机器学习任务中使用。数据集可以是任何形式的数据,比如图像、文本、音频等。数据集的主要目的是提供对数据的标准访问方法,以便可以轻松地将其用于模型训练、验证和测试。
  数据加载器(DataLoader)是一个提供批量加载数据的工具。它通过将数据集分割成小批量,并按照一定的顺序加载到内存中,以提高训练效率。数据加载器常用于训练过程中的数据预处理、批量化操作和数据并行处理等。

以下是一个具体案例,介绍如何使用PyTorch中的数据集和数据加载器:

import torch
from torch.utils.data import Dataset, DataLoader# 定义自定义数据集类
class MyDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]# 创建数据集实例
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)# 创建数据加载器实例
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)# 遍历数据加载器
for batch in dataloader:print(batch)

在上面的案例中,首先定义了一个自定义数据集类MyDataset,包含了__len____getitem__两个方法。__len__方法返回数据集的长度,__getitem__方法根据给定的索引返回数据集中的样本。

然后,创建了一个数据集实例dataset,传入了一组示例数据。再创建数据加载器实例dataloader,设置了批量大小为2,并开启了数据的随机打乱。

最后,在遍历数据加载器的过程中,每次打印出的batch是一个批量大小为2的数据。在实际应用中,可以根据具体的需求对每个批次进行进一步的处理和训练。

1. 数据集(Dataset)

  PyTorch中,Dataset(数据集)是用于存储和管理训练、验证或测试数据的抽象类。它是一个可迭代的对象,可以通过索引或迭代方式访问数据样本。
  PyTorch提供了torch.utils.data.Dataset类,可以通过继承该类来创建自定义的数据集。自定义数据集时需要实现下面两个主要的方法:

  • __len__()方法:返回数据集中样本的数量
  • __getitem__(index)方法:根据给定的索引index,返回对应位置的数据样本
import torch
from torch.utils.data import Datasetclass CustomDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, index):sample = self.data[index]# 可以继续添加对数据样本进行预处理或转换操作# 返回经过处理的数据样本return sample# 自定义数据
data = [1, 2, 3, 4, 5]
dataset = CustomDataset(data)# 访问数据集中的样本
sample = dataset[0]
print(sample)  # 输出: 1

  使用自定义数据集时,可以将其与 torch.utils.data.DataLoader结合使用,以便更方便地进行数据的批量加载和处理。

2. 数据加载器(DataLoader)

  DataLoader(数据加载器)是用于批量加载和处理数据的实用工具。它提供了对数据集的迭代器,并支持按照指定的批量大小、随机洗牌等方式加载数据

  • 批量加载数据:DataLoader可以从数据集中按照指定的批量大小加载数据。每个批次的数据可以作为一个张量或列表返回,便于进行后续的处理和训练。
  • 数据随机洗牌:通过设置shuffle=True,DataLoader可以在每个迭代周期中对数据进行随机洗牌,以减少模型对数据顺序的依赖性,提高训练效果。
  • 多线程数据加载:DataLoader支持使用多个线程来并行加载数据,加快数据加载的速度,提高训练效率。
  • 数据批次采样:除了按照批量大小加载数据外,DataLoader还支持自定义的数据批次采样方式。可以通过设置batch_sampler参数来指定自定义的批次采样器,例如按照指定的样本顺序或权重进行采样。
import torch
from torch.utils.data import Dataset, DataLoader# 自定义数据集类
class MyDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]# 自定义数据加载器类
class MyDataLoader(DataLoader):def __init__(self, dataset, batch_size=1, shuffle=False, num_workers=0):super().__init__(dataset, batch_size, shuffle, num_workers=num_workers)def collate_fn(self, batch):# 自定义的数据预处理、合并等操作# 这里只是简单地将样本转换为Tensor,并进行堆叠return torch.stack(batch)# 自定义数据集类
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)# 创建数据加载器实例
dataloader = MyDataLoader(dataset, batch_size=2, shuffle=True)# 遍历数据加载器
for batch in dataloader:# batch是一个包含多个样本的张量(或列表)# 这里可以对批次数据进行处理print(batch)

  在创建DataLoader时,指定了批量大小batch_size和是否随机洗牌shuffle。
通过DataLoader加载数据集后,使用for循环迭代加载数据批次。每个批次的数据将作为一个张量或列表返回,可以根据需要在循环中对批次数据进行处理。
在这里插入图片描述

3. 实战——鸢尾花数据集

import torch
from sklearn.datasets import load_iris
from torch.utils.data import Dataset, DataLoader# 此函数用于加载鸢尾花数据集
def load_data(shuffle=True):x = torch.tensor(load_iris().data)y = torch.tensor(load_iris().target)# 数据归一化x_min = torch.min(x, dim=0).valuesx_max = torch.max(x, dim=0).valuesx = (x - x_min) / (x_max - x_min)if shuffle:idx = torch.randperm(x.shape[0])x = x[idx]y = y[idx]return x, y# 自定义鸢尾花数据类
class IrisDataset(Dataset):def __init__(self, mode='train', num_train=120, num_dev=15):super(IrisDataset, self).__init__()x, y = load_data(shuffle=True)if mode == 'train':self.x, self.y = x[:num_train], y[:num_train]elif mode == 'dev':self.x, self.y = x[num_train:num_train + num_dev], y[num_train:num_train + num_dev]else:self.x, self.y = x[num_train + num_dev:], y[num_train + num_dev:]def __getitem__(self, idx):return self.x[idx], self.y[idx]def __len__(self):return len(self.x)batch_size = 16# 分别构建训练集、验证集和测试集
train_dataset = IrisDataset(mode='train')
dev_dataset = IrisDataset(mode='dev')
test_dataset = IrisDataset(mode='test')train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/708015.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

高性能图表组件LightningChart .NET v11.0发布——增强DPI感知能力

LightningChart完全由GPU加速,并且性能经过优化,可用于实时显示海量数据-超过10亿个数据点。 LightningChart包括广泛的2D,高级3D,Polar,Smith,3D饼/甜甜圈,地理地图和GIS图表以及适用于科学&am…

华为---RSTP(四)---RSTP的保护功能简介和示例配置

目录 1. 技术背景 2. RSTP的保护功能 3. BPDU保护机制原理和配置命令 3.1 BPDU保护机制原理 3.2 BPDU保护机制配置命令 3.3 BPDU保护机制配置步骤 4. 根保护机制原理和配置命令 4.1 根保护机制原理 4.2 根保护机制配置命令 4.3 根保护机制配置步骤 5. 环路保护机…

php基础学习之错误处理(其二)

在实际应用中,开发者当然不希望把自己开发的程序的错误暴露给用户,一方面会动摇客户对己方的信心,另一方面容易被攻击者抓住漏洞实施攻击,同时开发者本身需要及时收集错误,因此需要合理的设置错误显示与记录错误日志 一…

Linux笔记--用户与用户组

Linux系统是一个多用户多任务的操作系统,任何一个要使用系统资源的用户,都必须首先向系统管理员(root)申请一个账号,然后以这个账号的身份进入系统。 用户的账号一方面可以帮助系统管理员对使用系统的用户进行跟踪,并控制他们对系…

3D数字孪生

数字孪生(Digital Twin)是物理对象、流程或系统的虚拟复制品,用于监控、分析和优化现实世界的对应物。 这些数字孪生在制造、工程和城市规划等领域变得越来越重要,因为它们使我们能够在现实世界中实施改变之前模拟和测试不同的场景…

从业务角度出发,实现UniApp二次开发的最佳实践

UniApp作为一款跨平台的移动应用开发框架,为开发者提供了在多个平台上构建应用的便利性。在这篇文章中,我们将深入探讨UniApp的二次开发,以及如何通过定制化来满足你的独特需求。 1.了解UniApp基础 1.1项目结构和文件 熟悉UniApp的项目结构…

动态规划(题目提升)

[NOIP2012 普及组] 摆花 方法一:记忆化搜索 何为记忆化搜素:就是使用递归函数对每次得到的结果进行保存,下次遇到就直接输出即可 那么这个题目使用递归(DFS)是怎样的? 首先我们需要搞清楚几个坑点&#x…

C/C++语言文字小游戏(荒岛求生)

游戏背景 玩家在荒岛上,需要寻找食物、水源、避难所等资源来生存。 玩家需要避免危险,如野兽、植物、天气等,否则会失去血量或生命。 玩家可以在荒岛上遇到其他生存者,可以选择合作或对抗。 游戏目标是生存一定时间或找到生存的出…

Javaweb之SpringBootWeb案例之 Bean管理的第三方Bean的详细解析

2.3 第三方Bean 学习完bean的获取、bean的作用域之后,接下来我们再来学习第三方bean的配置。 之前我们所配置的bean,像controller、service,dao三层体系下编写的类,这些类都是我们在项目当中自己定义的类(自定义类)。当我们要声…

将任何网页变成桌面应用,全平台支持 | 开源日报 No.184

tw93/Pake Stars: 20.9k License: MIT Pake 是利用 Rust 轻松构建轻量级多端桌面应用的工具。 与 Electron 包大小相比几乎小了 20 倍(约 5M!)使用 Rust Tauri,Pake 比基于 JS 的框架更轻量和更快内置功能包括快捷方式传递、沉浸…

小程序中使用echarts地图

一、下载并安装echarts 1、下载echarts-for-weixin组件 echarts-for-weixin项目提供了一个小程序组件,用这种方式可以在小程序中方便地使用 ECharts。 下载ec-canvas项目(下载地址) ​​ 注意:下载的 ec-canvas 中的echarts的版本…

【Linux】协程简介

【Linux】协程简介 一、什么是协程?简介优点 二、为什么使用协程?三、协程的种类1、对称协程2、非对称协程 四、协程栈1、静态栈2、分段栈3、共享栈4、虚拟内存栈 五、协程调度1、栈式调度2、星切调度3、环切调度 六、常见协程库参考文献 一、什么是协程…

机器学习 | 模型评估和选择 各种评估指标总结——错误率精度-查准率查全率-真正例率假正例率 PR曲线ROC曲线

文章目录 1. 如何产生训练集和测试集呢?1.1 留出法1.2 K折交叉验证法1.3 自助法 2. 模型评估指标2.1 错误率和精度2.2 查准率和查全率与F12.2.1 PR曲线及其绘制 2.3 正例率和假例率2.3.1 ROC曲线图绘制及AUC 3 假设检验 1. 如何产生训练集和测试集呢? 1…

【ACW 服务端】k8s部署

k8s部署 --- apiVersion: apps/v1 kind: Deployment metadata:annotations:k8s.kuboard.cn/displayName: 【wu-smart-acw-server】后台服务端labels:k8s.kuboard.cn/layer: svck8s.kuboard.cn/name: wu-smart-acw-servername: wu-smart-acw-servernamespace: defaultresourceV…

记autodl跑模型GPU CPU利用率骤变为0问题

目录 问题 解决 问题 实验室服务器资源紧张,博主就自己在autodl上租卡跑了,autodl有一个网络共享存储,可挂载至同一地区的不同实例中,当我们在该地区创建实例开机后,将会挂载文件存储至实例的/root/autodl-fs目录…

韩国量子之梦:将量子计算纳入新增长 4.0战略

内容来源:量子前哨(ID:Qforepost) 编辑丨王珩 编译/排版丨沛贤 深度好文:1500字丨9分钟阅读 据《朝鲜邮报》报道,韩国将推出由量子计算加速的云服务,并在首尔地区启动城市空中交通的试飞&…

微信小程序订阅消息前后端示例

微信小程序的订阅消息&#xff0c; 必须是由弹框&#xff0c;弹框&#xff0c;弹框来调起了&#xff0c;单纯的在页面上调用 wx.requestSubscribeMessage是没有效果的 小程序端的代码 <view class"sub" bindtap"dinyuxiaoxi">订阅消息</view>…

Leetcoder Day27| 贪心算法part01

语言&#xff1a;Java/Go 理论 贪心的本质是选择每一阶段的局部最优&#xff0c;从而达到全局最优。 什么时候用贪心&#xff1f;可以用局部最优退出全局最优&#xff0c;并且想不到反例到情况 贪心的一般解题步骤 将问题分解为若干个子问题找出适合的贪心策略求解每一个子…

【Linux系统化学习】信号概念和信号的产生

目录 信号的概念 从生活中的例子中感知信号 前台进程和后台进程 前台进程 后台进程 操作系统如何知道用户向键盘写入数据了&#xff1f; 进程如何得知自己收到了信号&#xff1f; 信号捕捉 signal函数 Core Dump&#xff08;核心转储&#xff09; 信号产生的方式 通…

LeetCode 刷题 [C++] 第102题.二叉树的层序遍历

题目描述 给你二叉树的根节点 root &#xff0c;返回其节点值的 层序遍历 。 &#xff08;即逐层地&#xff0c;从左到右访问所有节点&#xff09;。 题目分析 题目中要求层序遍历二叉树&#xff0c;即二叉树的广度优先搜索(BFS)。BFS一般使用队列的先入先出特性实现&#…