pytorch基础2-数据集与归一化

专题链接:https://blog.csdn.net/qq_33345365/category_12591348.html

本教程翻译自微软教程:https://learn.microsoft.com/en-us/training/paths/pytorch-fundamentals/

初次编辑:2024/3/2;最后编辑:2024/3/2


本教程第一篇介绍pytorch基础和张量操作

这是本教程的第二篇,介绍以下内容:

  1. 数据集与数据加载器
  2. 归一化

另外本人还有pytorch CV相关的教程,见专题:

https://blog.csdn.net/qq_33345365/category_12578430.html


数据集与数据加载器 Datasets and Dataloaders

处理数据样本的代码可能会变得复杂且难以维护。通常希望数据集代码与模型训练代码解耦,以获得更好的可读性和模块化性。PyTorch提供了两个数据原语:torch.utils.data.DataLoadertorch.utils.data.Dataset,它们使您能够使用预加载的数据集以及您自己的数据。Dataset存储样本及其对应的标签,而DataLoader则在Dataset周围包装了一个可迭代对象,以便轻松访问样本。

PyTorch库提供了许多预加载的样本数据集(例如FashionMNIST),它们是torch.utils.data.Dataset的子类,并实现了特定于特殊数据的功能。用于模型原型设计(prototyping)和基准测试(benchmark)的样本包括:

  1. 图像数据集
  2. 文本数据集
  3. 音频数据集

加载数据集

此处使用TorchVision来加载Fashion-MNIST数据集。Fashion-MNIST是Zalando的服装图像数据集,包含60,000个训练样本和10,000个测试样本。每个样本包括一个28×28的灰度图像和一个来自10个类别中的关联标签。

  • 每个图像高度为28个像素,宽度为28个像素,总共有784个像素。
  • 这10个类别告诉图像是什么类型,例如:T恤/上衣、裤子、套头衫、连衣裙、包、短靴等。
  • 灰度像素的值介于0到255之间,用于衡量黑白图像的强度。强度值从白色到黑色递增。例如:白色的颜色值为0,而黑色的颜色值为255。

我们使用以下参数加载FashionMNIST数据集:

  • root 是存储训练/测试数据的路径。
  • train 指定训练或测试数据集。
  • download=True 如果数据在root中不可用,则从互联网下载数据。
  • transformtarget_transform 指定特征和标签的转换。

加载训练和测试数据集的代码如下所示:

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)

数据集的迭代与可视化

可以像列表一样手动索引Datasets: training_data[index]。此处使用matplotlib对训练数据中的一些样本进行可视化。代码如下所示:

labels_map = {0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):sample_idx = torch.randint(len(training_data), size=(1,)).item()img, label = training_data[sample_idx]figure.add_subplot(rows, cols, i)plt.title(labels_map[label])plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()

得到如下图片:

在这里插入图片描述

使用数据加载器(DataLoaders)准备自定义数据

Dataset 逐个样本检索数据集的特征和标签。在训练模型时,通常希望以“minibatch”的形式传递样本,在每个周期重混洗(reshuffle)数据以减少模型过拟合,并使用Python的多进程加速数据检索。

在机器学习中,需要指定数据集中的特征和标签。**特征(Feature)**是输入,**标签(label)**是输出。我们训练特征,然后训练模型以预测标签。

  • 特征是图像像素中的模式。
  • 标签是10个类别类型:T恤,凉鞋,连衣裙等。

DataLoader是一个可迭代对象,用简单的API抽象了这种复杂性。为了使用DataLoader,我们需要设置以下参数:

  • data 用于训练模型的训练数据,以及用于评估模型的测试数据。
  • batch size 每个批次要处理的记录数。
  • shuffle 按索引随机抽取数据样本。

使用DataLoader处理两个数据集的代码如下所示:

from torch.utils.data import DataLoadertrain_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

使用DataLoader进行迭代

我们已将数据集加载到 DataLoader 中,现在可以根据需要迭代数据集。 下面的每次迭代都会返回一个批次的 train_featurestrain_labels(分别包含 batch_size=64 个特征和标签)。由于我们指定了 shuffle=True,在遍历完所有批次后,数据将被混洗(shuffle),以更精细地控制数据加载顺序。

展示图片和标签的代码如下所示:

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
label_name = list(labels_map.values())[label]
print(f"Label: {label_name}")

得到图片:

在这里插入图片描述

归一化

标准化是一种常见的数据预处理技术,用于对数据进行缩放或转换,以确保每个特征都有相等的学习贡献。例如,灰度图像中的每个像素的值都介于0和255之间,这些是特征。如果一个像素值为17,另一个像素为197,则像素重要性的分布将不均匀,因为较高的像素值会偏离学习。标准化改变了数据的范围,而不会扭曲其特征之间的区别。这种预处理是为了避免:

  • 减少预测精度
  • 模型学习困难
  • 特征数据范围的不利分布

Transforms

数据并不总是以最终处理过的形式呈现,这种形式适合训练机器学习算法。可以使用transforms来操作数据,使其适合训练。

所有 TorchVision 数据集都有两个参数(transform用于修改特征,target_transform用于修改标签),这些参数接受包含转换逻辑的可调用对象。torchvision.transforms 模块提供了几种常用的转换。

FashionMNIST 的特征是 PIL 图像格式,标签是整数。 对于训练,需要将特征转换为归一化的张量,将标签转换为 one-hot 编码的张量。 为了进行这些转换,将使用 ToTensorLambda

from torchvision import datasets
from torchvision.transforms import ToTensor, Lambdads = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

解释上述代码:

ToTensor()

ToTensor 将PIL图像或NumPy ndarray转换为FloatTensor,并将图像的像素强度值缩放到范围 [0., 1.]内。

Lambda transforms

Lambda transforms 应用任何用户定义的lambda函数。在这里,我们定义了一个函数来将整数转换为一个one-hot编码的张量。它首先创建一个大小为10的零张量(数据集中标签的数量),然后调用scatter,根据标签y的索引为其分配值为1。也可以使用torch.nn.functional.one_hot的方式来实现这一点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/714681.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

迪杰斯特拉算法的具体应用

fill与memset的区别介绍 例一 #include <iostream> #include <algorithm> using namespace std; const int maxn500; const int INF1000000000; bool isin[maxn]{false}; int G[maxn][maxn]; int path[maxn],rescue[maxn],num[maxn]; int weight[maxn]; int cityn…

探索数据宇宙:深入解析大数据分析与管理技术

✨✨ 欢迎大家来访Srlua的博文&#xff08;づ&#xffe3;3&#xffe3;&#xff09;づ╭❤&#xff5e;✨✨ &#x1f31f;&#x1f31f; 欢迎各位亲爱的读者&#xff0c;感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua&#xff0c;在这里我会分享我的知识和经验。&#x…

第六课:NIO简介

一、传统BIO的缺点 BIO属于同步阻塞行IO,在服务器的实现模型为&#xff0c;每一个连接都要对应一个线程。当客户端有连接请求的时候&#xff0c;服务器端需要启动一个新的线程与之对应处理&#xff0c;这个模型有很多缺陷。当客户端不做出进一步IO请求的时候&#xff0c;服务器…

《Spring Security 简易速速上手小册》第4章 授权与角色管理(2024 最新版)

文章目录 4.1 理解授权4.1.1 基础知识详解授权的核心授权策略方法级安全动态权限检查 4.1.2 主要案例&#xff1a;基于角色的页面访问控制案例 Demo 4.1.3 拓展案例 1&#xff1a;自定义投票策略案例 Demo测试自定义投票策略 4.1.4 拓展案例 2&#xff1a;使用方法级安全进行细…

c语言数据结构(5)——栈

欢迎来到博主的专栏——C语言数据结构 博主id&#xff1a;代码小豪 文章目录 栈栈的顺序存储结构栈的插入空栈的初始化栈的删除判断空栈读取栈顶元素数据 实现顺序栈的所有代码栈的链式存储结构链式栈的初始化链式栈的入栈操作链式栈的出栈操作 实现链式栈的所有代码 栈 栈是…

学习网络编程No.11【传输层协议之UDP】

引言&#xff1a; 北京时间&#xff1a;2023/11/20/9:17&#xff0c;昨天成功更文&#xff0c;上周实现了更文两篇&#xff0c;所以这周再接再厉。当然做题任在继续&#xff0c;而目前做题给我的感觉以套路和技巧偏多&#xff0c;还是那句话很多东西不经历你就是不懂&#xff…

【Python】2. 基础语法

常量和表达式 我们可以把 Python 当成一个计算器, 来进行一些算术运算. 注意: print 是一个 Python 内置的 函数, 这个稍后详细介绍. 可以使用 - * / ( ) 等运算符进行算术运算. 先算乘除, 后算加减. 运算符和数字之间, 可以没有空格, 也可以有多个空格. 但是一般习惯上写一…

LDR6328芯片:智能家居时代的小家电充电革新者

在当今的智能家居时代&#xff0c;小家电的供电方式正变得越来越智能化和高效化。 利用PD&#xff08;Power Delivery&#xff09;芯片进行诱骗取电&#xff0c;为后端小家电提供稳定电压的技术&#xff0c;正逐渐成为行业的新宠。在这一领域&#xff0c;LDR6328芯片以其出色的…

Qt下使用modbus-c库实现PLC线圈/保持寄存器的读写

系列文章目录 提示&#xff1a;这里是该系列文章的所有文章的目录 第一章&#xff1a;Qt下使用ModbusTcp通信协议进行PLC线圈/保持寄存器的读写&#xff08;32位有符号数&#xff09; 第二章&#xff1a;Qt下使用modbus-c库实现PLC线圈/保持寄存器的读写 文章目录 系列文章目录…

前端Vue3项目如何打包成Docker镜像运行

将前端Vue3项目打包成Docker镜像并运行包括几个主要步骤&#xff1a;项目打包、编写Dockerfile、构建镜像和运行容器。下面是一个基本的流程&#xff1a; 1. 项目打包 首先&#xff0c;确保你的Vue3项目可以正常运行和打包。在项目根目录下执行以下命令来打包你的Vue3项目&am…

nest.js使用nest-winston日志一

nest-winston文档 nest-winston - npm 参考&#xff1a;nestjs中winston日志模块使用 - 浮的blog - SegmentFault 思否 安装 cnpm install --save nest-winston winstoncnpm install winston-daily-rotate-file 在main.ts中 import { NestFactory } from nestjs/core; im…

【5G 接口协议】GTP-U协议介绍

博主未授权任何人或组织机构转载博主任何原创文章&#xff0c;感谢各位对原创的支持&#xff01; 博主链接 本人就职于国际知名终端厂商&#xff0c;负责modem芯片研发。 在5G早期负责终端数据业务层、核心网相关的开发工作&#xff0c;目前牵头6G算力网络技术标准研究。 博客…

基础小白快速入门Python------>模块的作用和意义

模块&#xff0c; 这个词听起来是如此的高大威猛&#xff0c;以至于萌新小白见了瑟瑟发抖&#xff0c;本草履虫见了都直摇头&#xff0c;好像听上去很难的样子&#xff0c;但是但是&#xff0c;年轻人&#xff0c;请听本少年细细讲述&#xff0c;他只是看起来很难&#xff0c;实…

GO-接口

1. 接口 在Go语言中接口&#xff08;interface&#xff09;是一种类型&#xff0c;一种抽象的类型。 interface是一组method的集合&#xff0c;接口做的事情就像是定义一个协议&#xff08;规则&#xff09;&#xff0c;只要一台机器有洗衣服和甩干的功能&#xff0c;我就称它…

【go语言开发】swagger安装和使用

本文主要介绍go-swagger的安装和使用&#xff0c;首先介绍如何安装swagger&#xff0c;测试是否成功&#xff1b;然后列出常用的注释和给出使用例子&#xff1b;最后生成接口文档&#xff0c;并在浏览器上测试 文章目录 安装注释说明常用注释参考例子 文档生成格式化文档生成do…

大模型生成,Open API调用

大模型是怎么生成结果的 通俗原理 其实&#xff0c;它只是根据上文&#xff0c;猜下一个词&#xff08;的概率&#xff09;…… OpenAI 的接口名就叫【completion】&#xff0c;也证明了其只会【生成】的本质。 下面用程序演示【生成下一个字】。你可以自己修改 prompt 试试…

【C++】类的转换函数

使用场景 C中当你创建了一个类&#xff0c;你想把这个类对象转换成基本类型的函数。类对象->基本类型对象 原理 如下实例&#xff0c;设计一个分数类&#xff0c;实现分数转换成double 浮点数的转换函数。并在mian函数隐式调用。 #include<iostream> class Fractio…

Python爬虫Cookies 池的搭建

Cookies 池的搭建 很多时候&#xff0c;在爬取没有登录的情况下&#xff0c;我们也可以访问一部分页面或请求一些接口&#xff0c;因为毕竟网站本身需要做 SEO&#xff0c;不会对所有页面都设置登录限制。 但是&#xff0c;不登录直接爬取会有一些弊端&#xff0c;弊端主要有…

南京师范大学计电院数据结构课设——排序算法

1 排序算法 1.1 题目要求 编程实现希尔、快速、堆排序、归并排序算法。要求首先随机产生10000个数据存入磁盘文件&#xff0c;然后读入数据文件&#xff0c;分别采用不同的排序方法进行排序并将结果存入文件中。 1.2 算法思想描述 1.2.1 随机数生成 当需要生成一系列随机数…

windows 11 前后端项目部署

目录 1.准备环境&#xff1a; 2.安装jdk 测试&#xff1a;winr 输入cmd 3.安装tomcat 4.安装mysql 远程导入数据&#xff1a; 外部后台访问&#xff1a;192.168.232.1:8080/crm/sys/loginAction.action?usernamezs&password123 5.安装nginx 前后端部署&#xff1…