pytorch升级打怪(三)

数据集合数据加载器

  • 简介
  • 加载数据集
  • 迭代和可视化数据集
  • 为您的文件创建自定义数据集
    • ```__init__```
    • ```__len__```
    • ```__getitem__```
  • 准备您的数据以使用DataLoaders进行训练
  • 通过DataLoader进行遍载

简介

处理数据样本的代码可能会变得混乱且难以维护;理想情况下,我们希望我们的数据集代码与模型训练代码解耦,以提高可读性和模块化。PyTorch提供了两个数据原语:torch.utils.data.DataLoader和torch.utils.data.Dataset,允许您使用预加载的数据集以及您自己的数据。Dataset存储样本及其相应的标签,DataLoader在Dataset周围包装一个可以可以方便地访问样本。

PyTorch域库提供一些预加载的数据集(如FashionMNIST),该子类为torch.utils.data.Dataset,并实现特定于特定数据的功能。它们可用于原型和基准测试您的模型。您可以在这里找到它们:图像数据集、文本数据集和音频数据集

加载数据集

以下是如何从TorchVision加载Fashion-MNIST数据集的示例。Fashion-MNIST是Zalando文章图像的数据集,包括60,000个训练示例和10,000个测试示例。每个示例都包括一个28×28的灰度图像和来自10个班级之一的相关标签。

我们用以下参数加载FashionMNIST数据集:

  • root是存储火车/测试数据的路径,
  • train指定训练或测试数据集,
  • download=True如果root上没有数据,则从互联网上下载数据。
  • transform和target_transform指定功能和标签转换

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz0%|          | 0/26421880 [00:00<?, ?it/s]0%|          | 65536/26421880 [00:00<01:12, 363720.69it/s]1%|          | 229376/26421880 [00:00<00:38, 682917.83it/s]3%|3         | 917504/26421880 [00:00<00:12, 2109774.93it/s]12%|#2        | 3211264/26421880 [00:00<00:03, 6286038.17it/s]28%|##8       | 7438336/26421880 [00:00<00:01, 14838321.45it/s]41%|####      | 10747904/26421880 [00:00<00:00, 16477772.21it/s]57%|#####7    | 15138816/26421880 [00:01<00:00, 22904288.96it/s]71%|#######   | 18644992/26421880 [00:01<00:00, 21979092.87it/s]92%|#########2| 24346624/26421880 [00:01<00:00, 30077676.52it/s]
100%|##########| 26421880/26421880 [00:01<00:00, 18141478.99it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/rawDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz0%|          | 0/29515 [00:00<?, ?it/s]
100%|##########| 29515/29515 [00:00<00:00, 327742.46it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/rawDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz0%|          | 0/4422102 [00:00<?, ?it/s]1%|1         | 65536/4422102 [00:00<00:11, 363330.31it/s]5%|5         | 229376/4422102 [00:00<00:06, 684189.84it/s]21%|##1       | 950272/4422102 [00:00<00:01, 2195763.19it/s]87%|########6 | 3833856/4422102 [00:00<00:00, 7634326.84it/s]
100%|##########| 4422102/4422102 [00:00<00:00, 6105857.14it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/rawDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz0%|          | 0/5148 [00:00<?, ?it/s]
100%|##########| 5148/5148 [00:00<00:00, 37228063.78it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

迭代和可视化数据集

我们可以像列表一样手动索引Datasets:training_data[index]。我们使用matplotlib在训练数据中可视化一些样本。


labels_map = {0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):sample_idx = torch.randint(len(training_data), size=(1,)).item()img, label = training_data[sample_idx]figure.add_subplot(rows, cols, i)plt.title(labels_map[label])plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()

在这里插入图片描述

为您的文件创建自定义数据集

自定义数据集类必须实现三个函数:

__init__、__len__和__getitem__

。看看这个实现;FashionMNIST图像存储在目录img_dir中,其标签单独存储在CSV文件annotations_file。

在接下来的章节中,我们将分解每个函数中发生的事情。


import os
import pandas as pd
from torchvision.io import read_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

__init__

实例化数据集对象时,__init__函数运行一次。我们初始化包含图像、注释文件和两个转换的目录(下一节将更详细地介绍)。


def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transform

__len__

__len__函数返回我们数据集中的样本数。


def __len__(self):return len(self.img_labels)

__getitem__

__getitem__函数加载并返回给定索引idx的数据集的样本。基于索引,它识别图像在磁盘上的位置,使用read_image将其转换为张量,从self.img_labels中的csv数据中检索相应的标签,调用其上的转换函数(如果适用),并在元组中返回张量图像和相应标签。


def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

准备您的数据以使用DataLoaders进行训练

Dataset检索我们数据集的功能,并一次标记一个样本。在训练模型时,我们通常希望以“迷你批次”传递样本,在每个时代重新洗牌数据以减少模型过拟合,并使用Pythonmultiprocessing来加快数据检索速度。

DataLoader是一个可以在一个简单的API中为我们抽象这种复杂性的可以进行的。

from torch.utils.data import DataLoadertrain_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

通过DataLoader进行遍载

我们已经将该数据集加载到DataLoader,可以根据需要迭代数据集。下面的每个迭代都会返回一批train_features和train_labels(分别包含batch_size=64特征和标签)。因为我们指定了shuffle=True,在我们遍复所有批次后,数据被洗牌(为了更精细地控制数据加载顺序,请查看采样器)。


# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

在这里插入图片描述

Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 5

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/748413.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

N7977A 先进电源系统:动态直流电源,160 V,12.5 A,2000 W

N7977A 先进电源系统&#xff1a;动态直流电源 160 V&#xff0c;12.5 A&#xff0c;2000 W Keysight N7977A 专为需要高速动态电源和测量功能的自动化测试设备&#xff08;ATE&#xff09;应用而设计。 简述 支持 BenchVue 软件。 无需编程&#xff0c;即可结合使用您的计…

NCP1271D65R2G中文资料规格书PDF数据手册引脚图参数图片价格功能特性描述

产品描述&#xff1a; NCP1271 是成功的 7 引脚电流模式 NCP12XX 系列的新一代引脚-引脚兼容新产品。该控制器通过使用可调节 Soft Skip 模式和集成的高电压启动 FET&#xff0c;实现了卓越的待机功耗。此专属 Soft Skip 还大大降低了噪音的风险。 因此可以在箝位网络中使用不…

模型蒸馏--一起学习吧之人工智能

一、定义 模型蒸馏&#xff08;Model Distillation&#xff09;是一种模型压缩技术&#xff0c;旨在将一个复杂的大型模型&#xff08;通常称为教师模型&#xff09;的知识转移到另一个更小、更简单的模型&#xff08;通常称为学生模型&#xff09;中。这种技术通过训练学生模…

如何在代理的IP被封后立刻换下一个IP继续任务

目录 前言 1. IP池准备 2. 使用代理IP进行网络请求 3. 处理IP被封的情况 4. 完整代码示例 总结 前言 当进行某些网络操作时&#xff0c;使用代理服务器可以帮助我们隐藏真实IP地址以保护隐私&#xff0c;或者绕过一些限制。然而&#xff0c;经常遇到的问题是代理的IP可能…

BlenderGIS 快捷键E 报错问题 Report: Error

最新版的Blender4.0 对于 BlenderGIS2.28版本的插件不兼容&#xff0c;BlenderGIS2.28兼容Blender3.6.9及之前的版本&#xff0c;应该是BlenderGIS插件很久没更新了导致的。

C#构建类库

类库程序集能将类型组合成易于部署的单元&#xff08;DLL文件&#xff09;&#xff0c;为了使编写的代码能够跨多个项目重用&#xff0c;应该将他们放在类库程序集中。 一、创建类库 在C#中&#xff0c;构建类库是指创建一个包含多个类的项目&#xff0c;这些类可以被其他应用…

热流道融合3D打印技术正在成为模具制造新利器

在模具领域中&#xff0c;3D打印技术与热流道技术联手&#xff0c;能迸发出更耀眼的光芒。两种技术虽然各有特点&#xff0c;但两者结合将形成互补作用&#xff0c;从而实现11&#xff1e;2”的跨越式提升。 将增材制造的灵活思维融入传统模具设计时&#xff0c;不仅能够突破传…

王勇:硬科技的下一站 | 演讲嘉宾公布

一、智能耳机与可穿戴专题论坛 智能耳机与可穿戴专题论坛将于3月27日同期举办&#xff01; 智能耳机、可穿戴设备已经逐渐融入我们的生活&#xff0c;它们不仅带来了便捷与舒适&#xff0c;更在悄然改变着我们的生活方式和工作模式。在这里&#xff0c;我们将分享最新的研究成果…

别再手动拼接 SQL 了,MyBatis 动态 SQL 写法应有尽有,建议收藏!

一、MyBatis动态 sql 是什么 动态 SQL 是 MyBatis 的强大特性之一。在 JDBC 或其它类似的框架中&#xff0c;开发人员通常需要手动拼接 SQL 语句。根据不同的条件拼接 SQL 语句是一件极其痛苦的工作。 例如&#xff0c;拼接时要确保添加了必要的空格&#xff0c;还要注意去掉…

[SaaS] 淘宝设AI

“淘宝设计AI” 让国际大牌造世界双11超级品牌 超级发布https://mp.weixin.qq.com/s/xFVDARQHxlweKAYG91DtYw下面是一个完整的品牌营销海报设计流程&#xff0c;AIGC起到了巨大作用&#xff0c;但是仍然很难去一步解决这个问题&#xff0c;还是逐步修改的一个过程。 Midjouner…

分布式与集群,二者区别是什么?

&#x1f413;分布式 分布式系统是由多个独立的计算机节点组成的系统&#xff0c;这些节点通过网络协作完成任务。每个节点都有自己的独立计算能力和存储能力&#xff0c;可以独立运行。分布式系统的目标是提高系统的可靠性、可扩展性和性能。 分布式服务包含的技术和理论 负…

LabVIEW多表位数字温湿度计图像识别系统

LabVIEW多表位数字温湿度计图像识别系统 解决数字温湿度计校准过程中存在的大量需求和长时间校准问题&#xff0c;通过LabVIEW开发平台设计了一套适用于20多个表位的数字温度计图像识别系统。该系统能够通过图像采集、提取和处理&#xff0c;进行字符训练&#xff0c;从而实现…

中小企业的智能化,不能再拖了!

在当今时代&#xff0c;新质生产力已然成为了国内最热门的话题。它代表着先进生产力的涌现和发展&#xff0c;正逐渐成为推动国家经济社会持续发展的核心力量。今年的两会更是将“新质生产力”写入政府工作报告&#xff0c;并将其列为2024年政府工作十大任务之首&#xff0c;足…

【JS】parseInt与Math.floor的区别

获取两数区间随机整数的函数如下 function getRandom(min,max){return Math.floor(Math.random() * (max - min) min) }这个函数中&#xff0c;只可以使用Math.random&#xff0c;parseInt会出问题&#xff0c;二者虽然都是取整&#xff0c;但又有一些区别。 parseInt是「向…

力扣大厂热门面试算法题 30-32

30. 串联所有单词的子串&#xff0c;31. 下一个排列 &#xff0c;32. 最长有效括号&#xff0c;每题做详细思路梳理&#xff0c;配套Python&Java双语代码&#xff0c; 2024.03.15 可通过leetcode所有测试用例。 目录 30. 串联所有单词的子串 解题思路 完整代码 Java …

算法笔记 连载中。。。

HashMap&#xff08;会根据key值自动排序&#xff09; HashMap<String, Integer> hash new HashMap<>() hash.put(15,18) hash.getOrDefault(ts, -1) //如果ts(key)存在&#xff0c;返回对应的value 否则返回-1 hashMap1.get(words1[i])1会报错&#xff0c;因…

AcWing 848. 有向图的拓扑序列

#include<iostream> #include<cmath> #include<queue> #include<cstring> #include<cstdlib> #include<algorithm> using namespace std; const int N1e510; int n,m,a,b; int e[N],ne[N],h[N],idx; int d[N],top[N],cnt1;//top是拓扑排序…

Linux学习笔记:什么是文件描述符

什么是文件描述符 C语言的文件接口文件的系统调用什么是文件描述符,文件描述符为什么是int类型?为什么新打开的文件的文件描述符不是从0开始? 文件描述符 fd (file descriptor) C语言的文件接口 当时学习C语言的时候,学习了文件接口 具体可以查看之前的文章: 链接:C语言的文…

flask库

文章目录 flask库1. 基本使用2. 路由路径和路由参数3. 请求跳转和请求参数4. 模板渲染1. 模板变量2. 过滤器3. 测试器 5. 钩子函数与响应对象 flask库 flask是python编写的轻量级框架&#xff0c;提供Werkzeug&#xff08;WSGI工具集&#xff09;和jinjia2&#xff08;渲染模板…

【PyTorch】基础学习:在Pycharm等IDE中打印或查看Pytorch版本信息

【PyTorch】基础学习&#xff1a;在Pycharm等IDE中打印或查看Pytorch版本信息 &#x1f308; 个人主页&#xff1a;高斯小哥 &#x1f525; 高质量专栏&#xff1a;Matplotlib之旅&#xff1a;零基础精通数据可视化、Python基础【高质量合集】、PyTorch零基础入门教程&#x1…