【PyTorch攻略(2/7)】 加载数据集

一、说明

        PyTorch提供了两个数据原语:torch.utils.data.DataLoadertorch.utils.data.Dataset,允许您使用预加载的数据集以及您自己的数据。数据集存储样本及其相应的标签,DataLoader 围绕数据集包装一个可迭代对象,以便轻松访问样本。

        PyTorch域库提供了许多示例预加载数据集,例如FashionMNIST,它子类torch.utils.data.Dataset并实现特定于特定数据的函数。可以在此处找到它们并用作原型设计和基准测试模型的示例:

  • 图像数据集
  • 文本数据集
  • 音频数据集

二、加载数据集

        我们将从TorchVision加载FashionMNIST数据集。FashionMNIST是Zalando文章图像的数据集,由60,000个训练示例和10,000个测试示例组成。每个示例包含一个 28x28 灰度图像和一个来自 10 个类之一的相关标签。

  • 每张图片高 28 像素,宽 28 像素,共 784 像素。
  • 这 10 个类告诉它是什么类型的图像。例如,T型短裤/上衣,裤子,套头衫,连衣裙,包,踝靴等。
  • 灰度是介于 0 到 255 之间的值,用于测量黑白图像的强度。强度值从白色增加到黑色。例如,白色为 0,黑色为 255。

        我们使用以下参数加载 FashionMNIST 数据集:

  • 是存储训练/测试数据的路径。
  • 训练指定训练或测试数据集。
  • download = 如果数据在根目录中不可用,则 True 从互联网下载数据。
  • 转换指定特征和标注转换
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)

三、迭代和可视化数据集

我们可以像列表一样手动索引数据集:training_data[index]。我们使用 matplotlib 来可视化训练数据中的一些样本。

labels_map = {0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):sample_idx = torch.randint(len(training_data), size=(1,)).item()img, label = training_data[sample_idx] # Iterate training datafigure.add_subplot(rows, cols, i)plt.title(labels_map[label])plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()

四、准备数据以使用数据加载程序进行训练

        数据集检索数据集的特征并一次标记一个样本。在训练模型时,我们通常希望以“小批量”方式传递样本,在每个时期重新洗牌数据以减少模型过度拟合,并使用 Python 的多处理来加速数据检索。

在机器学习中,需要指定数据集中的特征和标签。要素是输入,标注是输出。我们训练特征并训练模型来预测标签。

        DataLoader 是一个迭代对象,它在一个简单的 API 中为我们抽象了这种复杂性。要使用数据加载器,我们需要设置以下参数:

  • 数据是将用于训练模型的训练数据;以及用于评估模型的测试数据。
  • 批大小是每个批中要处理的记录数。
  • 随机播放是按索引随机抽取的数据。
from torch.utils.data import DataLoadertrain_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

五、遍历数据加载器

        我们已将数据集加载到数据加载器中,并可以根据需要循环访问数据集。下面的每次迭代都会返回一批train_featurestrain_labels(分别包含 batch_size = 64 个要素和标注)。由于我们指定了 shuffle = True,因此在我们遍历所有批次后,数据将被洗牌。

# Display image and label
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

        NOrmalization是一种常见的数据预处理技术,用于缩放或转换数据,以确保每个特征的学习贡献相等。例如,灰度图像中的每个像素都有一个介于 0 到 255 之间的值,这些值是特征。如果一个像素值为 17,另一个像素值为 197。像素重要性的分布将不均匀,因为较高的像素体积会偏离学习。归一化会更改数据的范围,而不会扭曲其在我们的功能之间的区别。进行此预处理是为了避免:

  • 预测精度降低
  • 模型学习的难度
  • 特征数据范围的不利分布

六、变换

数据并不总是以训练机器学习算法所需的最终处理形式出现。我们使用 transform 对数据执行一些操作,并使其适合训练。

所有 TorchVision 数据集都有两个参数:用于修改特征的转换和用于修改接受包含转换逻辑的可调用对象的标签target_transformtorchvision.transform模块提供了几种开箱即用的常用变换。

FashionMNIST 功能采用 PIL 图像格式,标签为整数。对于训练,我们需要特征作为规范化张量,标签作为独热编码张量。为了进行这些转换,我们使用ToTensorLamda

from torchvision import datasets
from torchvision.transforms import ToTensor, Lambdads = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

七、ToTensor()

ToTensor 将 PIL 图像或 NumPy ndarray 转换为 FloatTensor,并将图像的像素强度值缩放在 [0., 1.] 范围内。

八、Lambda()

        Lambda 应用任何用户定义的 lambda 函数。在这里,我们定义了一个函数来将整数转换为独热编码张量。它首先创建一个大小为 10(我们数据集中的标签数量)的张量并调用 scatter,它在索引上分配一个 value=1,如标签 y 给出的那样。您也可以将torch.nn.functional.one_hot用作其他选项。

target_transform = Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

下一>> PyTorch 简介 (3/7)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/82061.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring注解家族介绍: @RequestMapping

前言: 今天我们来介绍RequestMapping这个注解,这个注解的内容相对来讲比较少,篇幅会比较短。 目录 前言: RequestMapping 应用场景: 总结: RequestMapping RequestMapping 是一个用于映射 HTTP 请求…

【ELFK】之zookeeper

一、Zookeeper是什么? zooleeper是一个分布式服务管理框架。存储业务服务节点元数据及信息,并复制;通知客户端在zookeeper上注册的服务节点状态,通过文件系统通知机制 1、Zookeeper工作机制 Zookeeper从设计模式角度来理解 是…

Vue3 封装 element-plus 图标选择器

一、实现效果 二、实现步骤 2.1. 全局注册 icon 组件 // main.ts import App from ./App.vue; import { createApp } from vue; import * as ElementPlusIconsVue from element-plus/icons-vueconst app createApp(App);// 全局挂载和注册 element-plus 的所有 icon app.con…

Python Web开发:构建动态Web应用

💂 个人网站:【工具大全】【游戏大全】【神级源码资源网】🤟 前端学习课程:👉【28个案例趣学前端】【400个JS面试题】💅 寻找学习交流、摸鱼划水的小伙伴,请点击【摸鱼学习交流群】 Python已经成为一门流行…

FPGA纯verilog实现8路视频拼接显示,提供工程源码和技术支持

目录 1、前言版本更新说明免责声明 2、我已有的FPGA视频拼接叠加融合方案3、设计思路框架视频源选择OV5640摄像头配置及采集静态彩条视频拼接算法图像缓存视频输出 4、vivado工程详解5、工程移植说明vivado版本不一致处理FPGA型号不一致处理其他注意事项 6、上板调试验证并演示…

PHP8的类与对象的基本操作之成员方法-PHP8知识详解

成员方法是指在类中声明的函数。 在类中可以声明多个函数,所以对象中可以存在多个成员方法。类的成员方法可以通过关键字进行修饰,从而控制成员方法的商用权限。 函数和成员方法唯一的区别就是,函数实现的是某个独立的功能,而成…

Java8实战-总结30

Java8实战-总结30 并行数据处理与性能并行流正确使用并行流高效使用并行流 小结 并行数据处理与性能 并行流 正确使用并行流 错用并行流而产生错误的首要原因,就是使用的算法改变了某些共享状态。下面是另一种实现对前n个自然数求和的方法,但这会改变…

RocketMQ 发送顺序消息

文章目录 顺序消息应用场景消息组(MessageGroup)顺序性生产的顺序性MQ 存储的顺序性消费的顺序性 rocketmq-client-java 示例(gRPC 协议)1. 创建 FIFO 主题生产者代码消费者代码解决办法解决后执行结果 rocketmq-client 示例&…

Hbase工作原理

Hbase:HBase 底层原理详解(深度好文,建议收藏) - 腾讯云开发者社区-腾讯云 Hbase架构图 同一个列族如果有多个store,那么这些store在不同的region Hbase写流程(读比写慢) MemStore Flush Hbas…

Python运算符、函数与模块和程序控制结构

给我家憨憨写的python教程 ——雁丘 Python运算符、函数与模块和程序控制结构 关于本专栏一 运算符1.1 位运算符1.1.1 按位取反1.1.2 按位与1.1.3 按位或1.1.4 按位异或1.1.5 左移位 1.2 关系运算符1.3 运算顺序1.4 运算方向 二 函数与模块2.1 内建函数2.2 库函数2.2.1 标准库…

【pytest】 pytest拓展功能 PermissionError问题

目录 1. pytest-html 1.1 PermissionError: [Errno 13] Permission denied: D:\\software\\python3\\anconda3\\Lib\\site-packages\\pytest_html\\__pycache__\\tmp_ttoasov 1.2错误原因 2. 失败用例重试 3. 用例并行执行 pytest-parallel 1. pytest-html 管理员打开 A…

「聊设计模式」之建造者模式(Builder)

🏆本文收录于《聊设计模式》专栏,专门攻坚指数级提升,助你一臂之力,带你早日登顶🚀,欢迎持续关注&&收藏&&订阅! 前言 设计模式是众多优秀软件开发实践的总结和提炼,…

前端VUE---JS实现数据的模糊搜索

实现背景 因为后端实现人员列表返回&#xff0c;每次返回的数据量在100以内&#xff0c;要求前端自己进行模糊搜索 页面实现 因为是实时更新数据的&#xff0c;就不需要搜索和重置按钮了 代码 HTML <el-dialogtitle"团队人员详情":visible.sync"centerDi…

C#通过重写Panel改变边框颜色与宽度的方法

在C#中,Panel控件是一个容器控件,用于在窗体或用户控件中创建一个可用于容纳其他控件的面板。Panel提供了一种将相关控件组合在一起并进行布局的方式。以下是Panel控件的详细使用方法: 在窗体上放置 Panel 控件: 在 Visual Studio 的窗体设计器中,从工具箱中拖动并放置一…

WebGL 视图矩阵、模型视图矩阵

目录 立方体由三角形构成 视点和视线 视点、观察目标点和上方向 视点&#xff1a; 观察目标点&#xff1a; 上方向&#xff1a; 在WebGL中&#xff0c;观察者的默认状态应该是这样的&#xff1a; 视图矩阵程序&#xff08;LookAtTriangles.js&#xff09; 实际上&…

Matlab论文插图绘制模板第114期—带图形标记的图

之前的文章中&#xff0c;分享了Matlab带线标记的图&#xff1a; 带阴影标记的图&#xff1a; 带箭头标记的图&#xff1a; 进一步&#xff0c;分享一下带图形标记的图&#xff0c;先来看一下成品效果&#xff1a; 特别提示&#xff1a;本期内容『数据代码』已上传资源群中&…

flutter开发实战-自定义长按TextField输入框剪切、复制、选择全部菜单AdaptiveTextSelectionToolba样式UI效果

flutter开发实战-自定义长按TextField输入框剪切、复制、选择全部菜单样式UI效果 在开发过程中&#xff0c;需要长按TextField输入框cut、copy设置为中文“复制、粘贴”&#xff0c;我首先查看了TextField中的源码&#xff0c;看到了ToolbarOptions、AdaptiveTextSelectionToo…

深度学习中安装了包但是依然导入(import)失败这一问题,例如pytorch环境下已经安装了scikit-learn但是import不了

在跑深度学习模型的时候我们要先搭建pytorch环境&#xff0c;这个环境跟windows环境是不同的&#xff0c;我们默认在windows中安装的包在当前的虚拟环境中读取不到&#xff0c;所以导致我们明明安装了包但是依然在实际的导入中(import)报错。解决办法就是我们去虚拟环境中安装包…

linux驱动开发day6--(epoll实现IO多路复用、信号驱动IO、设备树以及节点和属性解析相关API使用)

一、IO多路复用--epoll实现 1.核心&#xff1a; 红黑树、一张表以及三个接口、 2.实现过程及API 1&#xff09;创建epoll句柄/创建红黑树根节点 int epfdepoll_create(int size--无意义&#xff0c;>0即可)----------成功&#xff1a;返回根节点对应文件描述符&#xff…

构建无缝的服务网格体验:分享在生产环境中构建和管理服务网格的最佳实践

&#x1f337;&#x1f341; 博主猫头虎 带您 Go to New World.✨&#x1f341; &#x1f984; 博客首页——猫头虎的博客&#x1f390; &#x1f433;《面试题大全专栏》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33a; &a…