【PyTorch简介】3.Loading and normalizing datasets 加载和规范化数据集

Loading and normalizing datasets 加载和规范化数据集

文章目录

  • Loading and normalizing datasets 加载和规范化数据集
  • Datasets & DataLoaders 数据集和数据加载器
  • Loading a Dataset 加载数据集
  • Iterating and Visualizing the Dataset 迭代和可视化数据集
  • Creating a Custom Dataset for your files 为您的文件创建自定义数据集
    • \__init__
    • \__len__
    • \__getitem__
  • Preparing your data for training with DataLoaders 使用 DataLoaders 准备数据以进行训练
  • Iterate through the DataLoader 遍历 DataLoader
  • Normalization 正则化
    • Transforms 转换
    • ToTensor()
    • Lambda transforms
  • 知识检查
  • Further Reading 进一步阅读
  • References 参考文献
  • Github

Datasets & DataLoaders 数据集和数据加载器

用于处理数据样本的代码可能会变得混乱且难以维护;理想情况下,我们希望数据集代码与模型训练代码分离,以获得更好的可读性和模块化性。PyTorch 提供了两个数据原语:torch.utils.data.DataLoadertorch.utils.data.Dataset ,允许您使用预加载的数据集以及您自己的数据。 Dataset存储样本及其相应的标签,并DataLoader围绕 Dataset进行迭代,以方便访问样本。

PyTorch 域库提供了许多预加载的数据集(例如 FashionMNIST)。这些数据集是torch.utils.data.Dataset的子类。并且,对于特定数据,实现特定的函数。它们可用于对您的模型进行原型设计和基准测试。您可以在这里找到它们:图像数据集、 文本数据集和 音频数据集

Loading a Dataset 加载数据集

以下是如何从 TorchVision 加载Fashion-MNIST数据集的示例。Fashion-MNIST 是 Zalando 论文的图像数据集。这个数据集由 60,000 个训练样本和 10,000 个测试样本组成。每个样本包含一个 28×28 灰度图像和来自 10 个类别之一的关联标签。

  • 每张图像的高度为 28 像素,宽度为 28 像素,总共 784 像素。
  • 这 10 个类别表示图像的类型,例如:T 恤/上衣、裤子、套头衫、连衣裙、包、踝靴等.
  • 灰度像素的值介于 0 到 255 之间,用于测量黑白图像的强度。强度值从白色增加到黑色。例如:白色为 0,黑色为 255。

我们使用以下参数,来加载FashionMNIST Dataset:

  • root 是存储训练/测试数据的路径,

  • train 指定训练或测试数据集,

  • download=True 如果root 上没有数据,则从 Internet 下载数据。

  • transformtarget_transform 指定特征和标签的转换。

%matplotlib inline
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)

Iterating and Visualizing the Dataset 迭代和可视化数据集

我们可以像列表一样手动索引Datasetstraining_data[index]。我们用matplotlib来可视化训练数据中的一些样本。

labels_map = {0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):sample_idx = torch.randint(len(training_data), size=(1,)).item()img, label = training_data[sample_idx]figure.add_subplot(rows, cols, i)plt.title(labels_map[label])plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()

Out:
在这里插入图片描述

Creating a Custom Dataset for your files 为您的文件创建自定义数据集

自定义 Dataset 类必须实现三个函数:__init____len____getitem__。看看这个实现:FashionMNIST 图像存储在目录img_dir中,它们的标签单独存储在CSV 文件annotations_file中。

在接下来的部分中,我们将详细介绍每个函数中实现的功能。

import os
import pandas as pd
from torchvision.io import read_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

_init_

__init__ 函数在实例化 Dataset 对象时运行一次。我们初始化包含图像、注释文件和两种转换的目录(下一节将更详细地介绍)。

tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transform

_len_

__len__ 函数返回数据集中的样本数。

例子:

def __len__(self):return len(self.img_labels)

_getitem_

__getitem__ 函数从数据集加载并返回给定索引idx的的样本。基于索引,它识别图像在磁盘上的位置,使用read_image 将其转换为张量,从 self.img_labels中的 csv 数据中检索相应的标签,调用它们的转换函数(如果适用),并以元组方式返回张量图像和相应的标签。

def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

Preparing your data for training with DataLoaders 使用 DataLoaders 准备数据以进行训练

Dataset检索我们的数据集中一个样本的特征和标签。在训练模型时,我们通常希望以 “minibatches”方式传递样本,在每个epcoch重新整理数据以减少模型过度拟合,并使用 Python的multiprocessing来加速数据检索。

在机器学习中,您需要指定数据集中的特征和标签。输入特征,输出标签。我们训练特征,然后训练模型来预测标签。

  • 特征是图像像素中的图案
  • 标签是我们的 10 类类型:T 恤、凉鞋、连衣裙等

DataLoader是一个可迭代对象,它通过一个简单的 API 为我们抽象了这种复杂性。要使用 Dataloader,我们需要设置以下参数:

  • data 将用于训练模型的训练数据,以及评估模型的测试数据
  • batch size 每批中要处理的记录数
  • shuffle 按索引随机抽取数据样本
from torch.utils.data import DataLoadertrain_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

Iterate through the DataLoader 遍历 DataLoader

我们已将该数据集加载到 DataLoader 中,并且可以根据需要迭代数据集。下面的每次迭代都会返回一批train_featurestrain_labelsbatch_size=64分别包含特征和标签)。因为我们指定了shuffle=True,所以在迭代所有批次后,数据将被打乱(为了更细粒度地控制数据加载顺序,请查看Samplers)。

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
label_name = list(labels_map.values())[label]
print(f"Label: {label_name}")

在这里插入图片描述

Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: Ankle Boot

Normalization 正则化

正则化是一种常见的数据预处理技术,用于缩放或转换数据,以确保每个特征的学习贡献相等。例如,灰度图像中的每个像素的值在0到255之间,这是特征。如果一个像素值为17,另一个像素为197。就会出现像素重要性分布不均匀的情况,因为较高的像素量会使学习发生偏差。正则化会改变数据的范围,而不会扭曲其特征之间的区别。进行这种预处理是为了避免:

  • 预测精度降低
  • 模型学习困难
  • 特征数据范围的不利分布

Transforms 转换

数据并不总是以训练机器学习算法所需的最终处理形式出现。我们使用transforms来操作数据并使其适合训练。

所有 TorchVision 数据集都有两个参数(transform 用于修改特征,target_transform 用于修改标签),它们接受包含转换逻辑的可调用对象。 torchvision.transforms 模块提供了几种开箱即用的常用转换。

FashionMNIST特征采用PIL图像格式,标签为整数。对于训练,我们需要将特征作为归一化张量,将标签作为单热编码张量。为了进行这些转换,我们将使用 ToTensorLambda

from torchvision import datasets
from torchvision.transforms import ToTensor, Lambdads = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

ToTensor()

ToTensor 将 PIL 图像或 NumPy ndarray 转换为 FloatTensor 并将图像的像素强度值缩放到 [0., 1.]范围。

Lambda transforms

Lambda transforms 应用任何用户定义的 lambda 函数。在这里,我们定义一个函数将整数转换为 one-hot 编码张量。它首先创建一个大小为 10(数据集中的标签数量)的零张量,并调用 scatter,它在标签 y 给定的索引上分配 value=1。您还可以使用 torch.nn.function.one_hot 作为另一个选项来执行此操作。

知识检查

1.PyTorch DataSet 和 PyTorch DataLoader 之间有什么区别

DataSet 按设计用于检索单个数据项,而 DataLoader 按设计用于处理批量数据。

2.PyTorch 中的转换旨在:

对数据执行某些操作,使其适用于训练。

Further Reading 进一步阅读

  • torch.utils.data API

References 参考文献

使用 PyTorch 进行机器学习的简介 - Training | Microsoft Learn

使用 PyTorch 进行机器学习的简介 - Training | Microsoft Learn

Github

storm-ice/PyTorch_Fundamentals

storm-ice/PyTorch_Fundamentals

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/622572.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Docker篇】从0到1搭建自己的镜像仓库并且推送镜像到自己的仓库中

文章目录 🔎docker私有仓库🍔具体步骤 🔎docker私有仓库 Docker私有仓库的存在为用户提供了更高的灵活性、控制和安全性。与使用公共镜像仓库相比,私有仓库使用户能够完全掌握自己的镜像生命周期。 首先,私有仓库允许…

力扣-盛最多水的容器

11.盛最多水的容器 给定一个长度为 n 的整数数组 height 。有 n 条垂线,第 i 条线的两个端点是 (i, 0) 和 (i, height[i]) 。找出其中的两条线,使得它们与 x 轴共同构成的容器可以容纳最多的水。返回容器可以储存的最大水量。 说明:你不能倾斜…

C //练习 4-10 另一种方法是通过getline函数读入整个输入行,这种情况下可以不使用getch与ungetch函数。请运用这一方法修改计算器程序。

C程序设计语言 (第二版) 练习 4-10 练习 4-10 另一种方法是通过getline函数读入整个输入行,这种情况下可以不使用getch与ungetch函数。请运用这一方法修改计算器程序。 注意:代码在win32控制台运行,在不同的IDE环境下…

C语言程序设计——程序流程控制方法(二)

循环结构 while语句 while(表达式){代码块; }do{代码块; }while(表达式)while语句分为do-while和while两种,区别在于循环之前是不是先执行一次循环的内容,可以类似于i和i的关系,本质上来讲是相同的。当表达式为真时,则会执行一次…

【Java 干货教程】Java方法引用详解

导言 Java方法引用是Java 8引入的一项重要特性,它提供了一种简洁、可读性高的方式来直接引用已经存在的方法。方法引用使得代码更加简洁、易懂,同时提高了代码的可维护性和重用性。本文将详细介绍Java方法引用的概念、语法和使用方法,并提供…

超详细的 pytest 钩子函数 之初始钩子和引导钩子来啦

前几篇文章介绍了 pytest 点的基本使用,学完前面几篇的内容基本上就可以满足工作中编写用例和进行自动化测试的需求。从这篇文章开始会陆续给大家介绍 pytest 中的钩子函数,插件开发等等。 仔细去看过 pytest 文档的小伙伴,应该都有发现 pyte…

【数据结构 | 希尔排序法】

希尔排序法 思路ShellSort 思路 希尔排序法又称缩小增量法。希尔排序法的基本思想是:先选定一个整数,把待排序文件中所有记录分成个组,所有距离为的记录分在同一组内,并对每一组内的记录进行排序。然后,取&#xff0c…

关于Golang闭包

关于Golang闭包 1、能不用闭包的地方就不要用闭包,没必要为了炫技,写一段增加团队小伙伴负担的代码 2、for 循环代码,统一在内部用临时变量再存一下 for _, val : range values {val : val }在线代码演示:https://go.dev/play/p…

docker 支持 gpu

需求: 原先在宿主机里运行的服务需要迁移到docker 里 进程: docker 支持 gpu 需要装toolkit ,但是正常情况下没有对应的源,所以先引入源文件 distribution$(. /etc/os-release;echo $ID$VERSION_ID) \ && curl -fsSL …

ospf-gre隧道小练习

全网可达,R5路由表没有其他路由器的路由条目 注:每个路由器都添加了自己的环回,如R1就是1.1.1.1 R1可以分别ping通与R2,R3,R4之间的隧道 R1路由表上有所有路由器环回的路由条目 R5路由表上没有其他路由器的路由条目 实现代码: 首先将各个接口IP配好 边上3个路由器:[R6][R7][R…

ES API 批量操作 Bulk API

bulk 是 elasticsearch 提供的一种批量增删改的操作API。 bulk 对 JSON串 有着严格的要求。每个JSON串 不能换行 ,只能放在同一行,同时, 相邻的JSON串之间必须要有换行 (Linux下是\n;Window下是\r\n)。bul…

【谭浩强C程序设计 学习辅导第3章】最简单的C程序设计——顺序程序设计(含详细源码)

文章目录 一、 顺序程序设计题的解题思路及注意事项解题思路注意事项 二、源码讲解第3章源码文件构成:main.c 文件内容说明chap3.c源码实现chap3.h声明头文件测试结果展示源码链接 说明:本学习辅导题适用于谭浩强教辅第四版。 一、 顺序程序设计题的解题…

学习记录————

1月 1月10号 习惯这件事很重要,一个长期坚守的习惯不一定是最好的,但是是能一直坚守下去的。所以习惯不能变来变去 长期坚守的习惯是什么?①10点 && (视频后 || 聊完天后)两个小时学习。②上床不玩手机。③周末:10-12点…

Programming Abstractions in C阅读笔记:p246-p247

《Programming Abstractions in C》学习第68天,p246-p247总结,总计2页。 一、技术总结 本章通过“the game of nim(尼姆游戏)”,这类以现实生活中事物作为例子进行讲解的情况,往往对学习者要求比较高,需要学习者具备…

<软考高项备考>《论文专题 - 65 质量管理(4) 》

4 过程3-管理质量 4.1 问题 4W1H过程做什么为了评估绩效,确保项目输出完整、正确且满足客户期望,而监督和记录质量管理活动执行结果的过程作用:①核实项目可交付成果和工作已经达到主要干系人的质量要求,可供最终验收;②确定项目…

C# 静态代码织入AOP组件之肉夹馍

写在前面 关于肉夹馍组件的官方介绍说明: Rougamo是一个静态代码织入的AOP组件,同为AOP组件较为常用的有Castle、Autofac、AspectCore等,与这些组件不同的是,这些组件基本都是通过动态代理IoC的方式实现AOP,是运行时…

linux系统中线程(Thread)解读以及对IO性能的影响

线程是操作系统调度的基本单位,是进程中能够独立执行指令流的子任务。在线程模型中,多个线程共享同一进程的地址空间和其他资源,使得它们可以直接访问相同的内存区域,这样大大简化了数据共享和通信的复杂性。线程有以下几个关键特…

【Web】CTFSHOW PHP特性刷题记录(全)

知其然知其所以然,尽量把每种特性都详细讲明白。 目录 web89 web90 web91 web92 web93 web94 web95 web96 web97 web98 web99 web100 web101 web102 web103 web104 web105 web106 web107 web108 web109 web110 web111 web112 web113 web…

高级分布式系统-第12讲 分布式控制经典理论

控制器基础 分布式控制系统的设计,是指在给定系统性能指标的条件下,设计出控制器的控制规律和相应的数字控制算法。 PID控制器 根据偏差的比例(Proportional)、积分(Integral)、微分(Derivati…

SQL_DCL_管理用户

DCL英文全称 Data Control Language(数据控制语言,用来管理数据库用户,控制数据库的访问权限。 1.查询用户 USE MY SQL; SELECT * FROM USER; 2.创建用户 CREATE USER 用户名主机名 IDENTIFIED BY密码; 3.修改用户密码 ALTER USER 用户名 主机名 …