【PyTorch简介】3.Loading and normalizing datasets 加载和规范化数据集

Loading and normalizing datasets 加载和规范化数据集

文章目录

  • Loading and normalizing datasets 加载和规范化数据集
  • Datasets & DataLoaders 数据集和数据加载器
  • Loading a Dataset 加载数据集
  • Iterating and Visualizing the Dataset 迭代和可视化数据集
  • Creating a Custom Dataset for your files 为您的文件创建自定义数据集
    • \__init__
    • \__len__
    • \__getitem__
  • Preparing your data for training with DataLoaders 使用 DataLoaders 准备数据以进行训练
  • Iterate through the DataLoader 遍历 DataLoader
  • Normalization 正则化
    • Transforms 转换
    • ToTensor()
    • Lambda transforms
  • 知识检查
  • Further Reading 进一步阅读
  • References 参考文献
  • Github

Datasets & DataLoaders 数据集和数据加载器

用于处理数据样本的代码可能会变得混乱且难以维护;理想情况下,我们希望数据集代码与模型训练代码分离,以获得更好的可读性和模块化性。PyTorch 提供了两个数据原语:torch.utils.data.DataLoadertorch.utils.data.Dataset ,允许您使用预加载的数据集以及您自己的数据。 Dataset存储样本及其相应的标签,并DataLoader围绕 Dataset进行迭代,以方便访问样本。

PyTorch 域库提供了许多预加载的数据集(例如 FashionMNIST)。这些数据集是torch.utils.data.Dataset的子类。并且,对于特定数据,实现特定的函数。它们可用于对您的模型进行原型设计和基准测试。您可以在这里找到它们:图像数据集、 文本数据集和 音频数据集

Loading a Dataset 加载数据集

以下是如何从 TorchVision 加载Fashion-MNIST数据集的示例。Fashion-MNIST 是 Zalando 论文的图像数据集。这个数据集由 60,000 个训练样本和 10,000 个测试样本组成。每个样本包含一个 28×28 灰度图像和来自 10 个类别之一的关联标签。

  • 每张图像的高度为 28 像素,宽度为 28 像素,总共 784 像素。
  • 这 10 个类别表示图像的类型,例如:T 恤/上衣、裤子、套头衫、连衣裙、包、踝靴等.
  • 灰度像素的值介于 0 到 255 之间,用于测量黑白图像的强度。强度值从白色增加到黑色。例如:白色为 0,黑色为 255。

我们使用以下参数,来加载FashionMNIST Dataset:

  • root 是存储训练/测试数据的路径,

  • train 指定训练或测试数据集,

  • download=True 如果root 上没有数据,则从 Internet 下载数据。

  • transformtarget_transform 指定特征和标签的转换。

%matplotlib inline
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)

Iterating and Visualizing the Dataset 迭代和可视化数据集

我们可以像列表一样手动索引Datasetstraining_data[index]。我们用matplotlib来可视化训练数据中的一些样本。

labels_map = {0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):sample_idx = torch.randint(len(training_data), size=(1,)).item()img, label = training_data[sample_idx]figure.add_subplot(rows, cols, i)plt.title(labels_map[label])plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()

Out:
在这里插入图片描述

Creating a Custom Dataset for your files 为您的文件创建自定义数据集

自定义 Dataset 类必须实现三个函数:__init____len____getitem__。看看这个实现:FashionMNIST 图像存储在目录img_dir中,它们的标签单独存储在CSV 文件annotations_file中。

在接下来的部分中,我们将详细介绍每个函数中实现的功能。

import os
import pandas as pd
from torchvision.io import read_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

_init_

__init__ 函数在实例化 Dataset 对象时运行一次。我们初始化包含图像、注释文件和两种转换的目录(下一节将更详细地介绍)。

tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transform

_len_

__len__ 函数返回数据集中的样本数。

例子:

def __len__(self):return len(self.img_labels)

_getitem_

__getitem__ 函数从数据集加载并返回给定索引idx的的样本。基于索引,它识别图像在磁盘上的位置,使用read_image 将其转换为张量,从 self.img_labels中的 csv 数据中检索相应的标签,调用它们的转换函数(如果适用),并以元组方式返回张量图像和相应的标签。

def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

Preparing your data for training with DataLoaders 使用 DataLoaders 准备数据以进行训练

Dataset检索我们的数据集中一个样本的特征和标签。在训练模型时,我们通常希望以 “minibatches”方式传递样本,在每个epcoch重新整理数据以减少模型过度拟合,并使用 Python的multiprocessing来加速数据检索。

在机器学习中,您需要指定数据集中的特征和标签。输入特征,输出标签。我们训练特征,然后训练模型来预测标签。

  • 特征是图像像素中的图案
  • 标签是我们的 10 类类型:T 恤、凉鞋、连衣裙等

DataLoader是一个可迭代对象,它通过一个简单的 API 为我们抽象了这种复杂性。要使用 Dataloader,我们需要设置以下参数:

  • data 将用于训练模型的训练数据,以及评估模型的测试数据
  • batch size 每批中要处理的记录数
  • shuffle 按索引随机抽取数据样本
from torch.utils.data import DataLoadertrain_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

Iterate through the DataLoader 遍历 DataLoader

我们已将该数据集加载到 DataLoader 中,并且可以根据需要迭代数据集。下面的每次迭代都会返回一批train_featurestrain_labelsbatch_size=64分别包含特征和标签)。因为我们指定了shuffle=True,所以在迭代所有批次后,数据将被打乱(为了更细粒度地控制数据加载顺序,请查看Samplers)。

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
label_name = list(labels_map.values())[label]
print(f"Label: {label_name}")

在这里插入图片描述

Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: Ankle Boot

Normalization 正则化

正则化是一种常见的数据预处理技术,用于缩放或转换数据,以确保每个特征的学习贡献相等。例如,灰度图像中的每个像素的值在0到255之间,这是特征。如果一个像素值为17,另一个像素为197。就会出现像素重要性分布不均匀的情况,因为较高的像素量会使学习发生偏差。正则化会改变数据的范围,而不会扭曲其特征之间的区别。进行这种预处理是为了避免:

  • 预测精度降低
  • 模型学习困难
  • 特征数据范围的不利分布

Transforms 转换

数据并不总是以训练机器学习算法所需的最终处理形式出现。我们使用transforms来操作数据并使其适合训练。

所有 TorchVision 数据集都有两个参数(transform 用于修改特征,target_transform 用于修改标签),它们接受包含转换逻辑的可调用对象。 torchvision.transforms 模块提供了几种开箱即用的常用转换。

FashionMNIST特征采用PIL图像格式,标签为整数。对于训练,我们需要将特征作为归一化张量,将标签作为单热编码张量。为了进行这些转换,我们将使用 ToTensorLambda

from torchvision import datasets
from torchvision.transforms import ToTensor, Lambdads = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

ToTensor()

ToTensor 将 PIL 图像或 NumPy ndarray 转换为 FloatTensor 并将图像的像素强度值缩放到 [0., 1.]范围。

Lambda transforms

Lambda transforms 应用任何用户定义的 lambda 函数。在这里,我们定义一个函数将整数转换为 one-hot 编码张量。它首先创建一个大小为 10(数据集中的标签数量)的零张量,并调用 scatter,它在标签 y 给定的索引上分配 value=1。您还可以使用 torch.nn.function.one_hot 作为另一个选项来执行此操作。

知识检查

1.PyTorch DataSet 和 PyTorch DataLoader 之间有什么区别

DataSet 按设计用于检索单个数据项,而 DataLoader 按设计用于处理批量数据。

2.PyTorch 中的转换旨在:

对数据执行某些操作,使其适用于训练。

Further Reading 进一步阅读

  • torch.utils.data API

References 参考文献

使用 PyTorch 进行机器学习的简介 - Training | Microsoft Learn

使用 PyTorch 进行机器学习的简介 - Training | Microsoft Learn

Github

storm-ice/PyTorch_Fundamentals

storm-ice/PyTorch_Fundamentals

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/622572.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Docker篇】从0到1搭建自己的镜像仓库并且推送镜像到自己的仓库中

文章目录 🔎docker私有仓库🍔具体步骤 🔎docker私有仓库 Docker私有仓库的存在为用户提供了更高的灵活性、控制和安全性。与使用公共镜像仓库相比,私有仓库使用户能够完全掌握自己的镜像生命周期。 首先,私有仓库允许…

力扣-盛最多水的容器

11.盛最多水的容器 给定一个长度为 n 的整数数组 height 。有 n 条垂线,第 i 条线的两个端点是 (i, 0) 和 (i, height[i]) 。找出其中的两条线,使得它们与 x 轴共同构成的容器可以容纳最多的水。返回容器可以储存的最大水量。 说明:你不能倾斜…

C语言程序设计——程序流程控制方法(二)

循环结构 while语句 while(表达式){代码块; }do{代码块; }while(表达式)while语句分为do-while和while两种,区别在于循环之前是不是先执行一次循环的内容,可以类似于i和i的关系,本质上来讲是相同的。当表达式为真时,则会执行一次…

超详细的 pytest 钩子函数 之初始钩子和引导钩子来啦

前几篇文章介绍了 pytest 点的基本使用,学完前面几篇的内容基本上就可以满足工作中编写用例和进行自动化测试的需求。从这篇文章开始会陆续给大家介绍 pytest 中的钩子函数,插件开发等等。 仔细去看过 pytest 文档的小伙伴,应该都有发现 pyte…

【数据结构 | 希尔排序法】

希尔排序法 思路ShellSort 思路 希尔排序法又称缩小增量法。希尔排序法的基本思想是:先选定一个整数,把待排序文件中所有记录分成个组,所有距离为的记录分在同一组内,并对每一组内的记录进行排序。然后,取&#xff0c…

ospf-gre隧道小练习

全网可达,R5路由表没有其他路由器的路由条目 注:每个路由器都添加了自己的环回,如R1就是1.1.1.1 R1可以分别ping通与R2,R3,R4之间的隧道 R1路由表上有所有路由器环回的路由条目 R5路由表上没有其他路由器的路由条目 实现代码: 首先将各个接口IP配好 边上3个路由器:[R6][R7][R…

ES API 批量操作 Bulk API

bulk 是 elasticsearch 提供的一种批量增删改的操作API。 bulk 对 JSON串 有着严格的要求。每个JSON串 不能换行 ,只能放在同一行,同时, 相邻的JSON串之间必须要有换行 (Linux下是\n;Window下是\r\n)。bul…

【谭浩强C程序设计 学习辅导第3章】最简单的C程序设计——顺序程序设计(含详细源码)

文章目录 一、 顺序程序设计题的解题思路及注意事项解题思路注意事项 二、源码讲解第3章源码文件构成:main.c 文件内容说明chap3.c源码实现chap3.h声明头文件测试结果展示源码链接 说明:本学习辅导题适用于谭浩强教辅第四版。 一、 顺序程序设计题的解题…

Programming Abstractions in C阅读笔记:p246-p247

《Programming Abstractions in C》学习第68天,p246-p247总结,总计2页。 一、技术总结 本章通过“the game of nim(尼姆游戏)”,这类以现实生活中事物作为例子进行讲解的情况,往往对学习者要求比较高,需要学习者具备…

<软考高项备考>《论文专题 - 65 质量管理(4) 》

4 过程3-管理质量 4.1 问题 4W1H过程做什么为了评估绩效,确保项目输出完整、正确且满足客户期望,而监督和记录质量管理活动执行结果的过程作用:①核实项目可交付成果和工作已经达到主要干系人的质量要求,可供最终验收;②确定项目…

C# 静态代码织入AOP组件之肉夹馍

写在前面 关于肉夹馍组件的官方介绍说明: Rougamo是一个静态代码织入的AOP组件,同为AOP组件较为常用的有Castle、Autofac、AspectCore等,与这些组件不同的是,这些组件基本都是通过动态代理IoC的方式实现AOP,是运行时…

【Web】CTFSHOW PHP特性刷题记录(全)

知其然知其所以然,尽量把每种特性都详细讲明白。 目录 web89 web90 web91 web92 web93 web94 web95 web96 web97 web98 web99 web100 web101 web102 web103 web104 web105 web106 web107 web108 web109 web110 web111 web112 web113 web…

高级分布式系统-第12讲 分布式控制经典理论

控制器基础 分布式控制系统的设计,是指在给定系统性能指标的条件下,设计出控制器的控制规律和相应的数字控制算法。 PID控制器 根据偏差的比例(Proportional)、积分(Integral)、微分(Derivati…

工作压力测试

每个职场人都会遇到工作压力,在企业人力资源管理的角度来看,没有工作压力是人力资源的低效,适当的工作压力可以促使员工不断进取,然而每个人的抗压能力是不同的,同样的工作量和工作难度,不同的人在面对相同…

漏洞复现--GitLab 任意用户密码重置漏洞(CVE-2023-7028)

免责声明: 文章中涉及的漏洞均已修复,敏感信息均已做打码处理,文章仅做经验分享用途,切勿当真,未授权的攻击属于非法行为!文章中敏感信息均已做多层打马处理。传播、利用本文章所提供的信息而造成的任何直…

class_4:car类

#include <iostream> using namespace std; class Car{ public://成员数据string color; //颜色string brand; //品牌string type; //车型int year; //年限//其实也是成员数据&#xff0c;指针变量&#xff0c;指向函数的变量&#xff0c;并非真正的成员函数void (*…

Win10(CPU)+ Anaconda3 + python3.9安装pytorch

1. 安装Anaconda3 1.1 下载Anaconda3 可以在官网下载Anaconda3-2022.05-Windows-x86_64.exe&#xff0c;这个版本对应的是python3.9。 1.2 安装Anaconda3 此步骤比较简单&#xff0c;双击.exe文件&#xff0c;一步一步执行即可&#xff0c;有不确定的可以自行百度&#xff…

C++ 对象模型 | 关于对象

一、C 对象模型 1、对象内存布局 在C中&#xff0c;有两种数据成员&#xff1a;static和nonstatic&#xff0c;以及三种成员方法static、nonstatic、virtual&#xff0c;下面从虚函数、非虚函数、静态成员变量、非静态成员变量等维度来分析&#xff0c;类对象的内存布局。例如…

聚道云软件连接器助力知名企业,提升合同管理效率

一、客户介绍 某服饰股份有限公司是一家集服装设计、生产、销售及品牌建设于一体的企业。该公司的产品线涵盖男装、女装、童装等多个领域&#xff0c;设计风格时尚、简约、大方&#xff0c;深受消费者喜爱。公司注重产品研发&#xff0c;不断推陈出新&#xff0c;紧跟时尚潮流…

【linux笔记】vim

【linux笔记】vim 启动和退出 启动 vi退出 q强制退出 q&#xff01;编辑模式 vi foo.txt创建一个文件&#xff0c;启动后&#xff0c;是命令模式&#xff0c;是不能编辑的&#xff0c;键盘上的按键对应不同的命令。 插入模式 按键盘上的i&#xff0c;进入插入模式 保…