PyTorch中定义自己的数据集

文章目录

    • 1. 简介
    • 2. 查看PyTorch自带的数据集(可视化)
    • 3. 准备材料
      • 3.1 图片数据
      • 3.2 标签数据
    • 4. 方法

1. 简介

尽管PyTorch提供了许多自带的数据集,如MNIST、CIFAR-10、ImageNet等,但它们对于没有经验的用户来说,理解数据加载器的工作原理以及如何正确地配置数据加载器可能会有一定难度。 用户需要了解所使用的数据集,包括数据集的内容、结构、标签等信息。对于一些复杂的数据集,用户可能需要理解数据集的结构和标签的含义。通过定义自己的数据集类,您可以更好地控制数据的加载和处理过程,提高代码的灵活性、可读性和可维护性,同时更好地满足模型训练的需求。

2. 查看PyTorch自带的数据集(可视化)

为了更好的定义自己的数据集,我们首先查看PyTorch自带的数据集的内容,代码如下

# 导入所需的库
import matplotlib.pyplot as plt  # 导入Matplotlib库,用于可视化
import torch  # 导入PyTorch库
from torchvision.datasets import MNIST  # 从torchvision中导入MNIST数据集
from torchvision import transforms  # 导入transforms模块,用于数据预处理
import numpy as np  # 导入NumPy库# 加载MNIST数据集
train_mnist_data = MNIST(root='./data',  # 数据集存储路径train=True,  # 加载训练集transform=transforms.Compose([transforms.Resize(size=(28, 28)), transforms.ToTensor()]),  # 数据预处理操作download=True)  # 如果数据集不存在,则自动下载# 设置要显示的样本数量
num_samples = 10# 创建包含多个子图的大图窗口
fig, axes = plt.subplots(1, num_samples, figsize=(10, 6))# 遍历选择要显示的样本
for i in range(num_samples):# 从数据集中获取图像数据和标签image, label = train_mnist_data[i]# 在子图中显示图像axes[i].imshow(image.squeeze().numpy(), cmap='gray')  # 使用imshow函数显示图像,将张量转换为NumPy数组axes[i].set_title(f"Label: {label}")  # 设置子图标题,显示图像对应的标签axes[i].axis('off')  # 关闭坐标轴显示# 将图像保存为PNG格式的图片文件,文件名以图像的标签命名plt.imsave(f"./data/mnist_images/{label}.png", image.squeeze().numpy(), cmap='gray')# 显示图形窗口
plt.show()

这里,我们使用MNIST类加载MNIST数据集。在加载数据集时,通过transform参数指定了数据预处理操作,包括将图像大小调整为28x28像素,并将图像转换为张量。train=True表示加载训练集,download=True表示如果数据集不存在则自动下载到指定的路径。

接下来,我们选择一些样本进行可视化。我们在一个子图中显示了10个样本,每个样本对应一个数字图像和其对应的标签。通过循环遍历这些样本,从数据集中获取图像数据和标签,并使用Matplotlib的imshow()函数将图像显示在子图中。
在这里插入图片描述

同时,使用imsave()函数将每个图像保存为PNG格式的图片文件,文件名以标签命名。最后,使用plt.show()显示图形窗口,显示图像的同时也会将图像保存到指定的路径中。这段代码的执行结果是显示10张MNIST数据集中的数字图像,并将这些图像保存到指定路径下。保存的图片如下所示

在这里插入图片描述

通过上面程序可以看到,数据集主要是由图片数据和对应的标签构成,那么我们就可以用这两个主要构成成分来构建自己的数据集。

3. 准备材料

3.1 图片数据

这里我们就用刚才保存的十张图片,即

在这里插入图片描述

当然,你也可以准备其它的图片,并给图片分别命名为“0.png, 1.png, …”。

这里,十张图片的相对路径为

imgs_path = "./data/mnist_images"

注:你们要根据自己存储的路径来给定。

3.2 标签数据

创建一个txt文件,为每一幅图片指定标签数据,如下所示

在这里插入图片描述

这里,txt文件的相对路径为

labels_path = "labels.txt"

4. 方法

在PyTorch中,您可以通过创建一个自定义的数据集类来定义自己的数据集。这个自定义类需要继承自torch.utils.data.Dataset类,并且实现两个主要的方法:__len____getitem____len__方法应该返回数据集的长度,而__getitem__方法则根据给定的索引返回数据集中的样本。

下面我们展示如何创建一个自定义的数据集类:

import os  # 导入os模块,用于操作文件路径
from PIL import Image  # 导入PIL库中的Image模块,用于图像处理
import torch  # 导入PyTorch库
from torch.utils.data import Dataset  # 从torch.utils.data模块导入Dataset类,用于定义自定义数据集
from torchvision import transforms  # 导入transforms模块,用于数据预处理
import numpy as np  # 导入NumPy库,用于数值处理
import matplotlib.pyplot as plt  # 导入Matplotlib库,用于可视化class CustomDataset(Dataset):def __init__(self, image_dir, label_file, transform=None):super().__init__()  # 调用父类的构造函数self.image_dir = image_dir  # 图像数据的路径self.label_file = label_file  # 标签文本的路径self.transform = transform  # 数据预处理操作self.samples = self._load_samples()  # 加载数据集样本信息def _load_samples(self):samples = []  # 存储样本信息的列表with open(self.label_file, 'r') as f:  # 打开标签文本文件for line in f:  # 逐行读取标签文本文件中的内容image_name, label = line.strip().split(',')  # 根据逗号分隔每行内容,获取图像文件名和标签image_path = os.path.join(self.image_dir, image_name)  # 拼接图像文件的完整路径samples.append((image_path, int(label)))  # 将图像路径和标签组成元组,加入样本列表return samples  # 返回样本列表def __len__(self):return len(self.samples)  # 返回数据集样本的数量def __getitem__(self, index):image_path, label = self.samples[index]  # 获取指定索引处的图像路径和标签image = Image.open(image_path).convert('L')  # 打开图像文件并将其转换为灰度图像if self.transform:  # 如果定义了数据预处理操作image = self.transform(image)  # 对图像进行预处理操作return image, label  # 返回预处理后的图像和标签# 设置图片数据路径和标签文本路径
image_dir = './data/mnist_images'  # 图像数据的路径
label_file = 'labels.txt'  # 标签文本的路径# 定义数据预处理操作,根据需要添加其他预处理操作
transform = transforms.Compose([transforms.Resize((28, 28)),  # 调整图像大小transforms.ToTensor(),  # 将图像转换为张量
])# 创建自定义数据集实例
custom_dataset = CustomDataset(image_dir, label_file, transform=transform)# 创建数据加载器
data_loader = torch.utils.data.DataLoader(custom_dataset, batch_size=1, shuffle=False)# 遍历数据加载器中的每个批次数据
for batch_images, batch_labels in data_loader:# 使用squeeze()函数去除图像张量中的单维度,将图像数据转换为NumPy数组,并存储在变量image中image = batch_images.squeeze().numpy()# 使用imshow()函数显示图像,cmap='gray'指定使用灰度色彩映射plt.imshow(image, cmap='gray')# 设置图像标题,显示图像对应的标签,使用f-string格式化字符串,将batch_labels转换为Python标量并获取其值plt.title(f"Label: {batch_labels.item()}")# 关闭坐标轴显示,即不显示坐标轴plt.axis('off')# 显示图形窗口plt.show()

这段代码实现了加载自定义数据集,并使用 PyTorch 的 DataLoader 将数据加载成批次,然后逐批次地展示图像。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/12414.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【启程Golang之旅】环境设置、工具安装与代码实践

欢迎来到Golang的世界!在当今快节奏的软件开发领域,选择一种高效、简洁的编程语言至关重要。而在这方面,Golang(又称Go)无疑是一个备受瞩目的选择。在本文中,带领您探索Golang的世界,一步步地了…

【Web后端】MVC模式

1、简介 MVC模式,全称Model-View-Controller(模型-视图-控制器)模式,是一种软件设计典范,它将应用程序的用户界面(视图)和业务逻辑(模型)分离,同时提供了一个…

K8S内容

K8S介绍 1、故障迁移:当某一个node节点关机或挂掉后,node节点上的服务会自动转移到另一个node节点上,这个过程所有服务不中断。这是docker或普通云主机是不能做到的 2、资源调度:当node节点上的cpu、内存不够用的时候,可以扩充node节点&…

​​​【收录 Hello 算法】6.2 哈希冲突

目录 6.2 哈希冲突 6.2.1 链式地址 6.2.2 开放寻址 1. 线性探测 2. 平方探测 3. 多次哈希 6.2.3 编程语言的选择 6.2 哈希冲突 上一节提到,通常情况下哈希函数的输入空间远大于输出空间,因此理论上哈希冲突是不可避免的。比如&a…

LeetCode题练习与总结:不同的二叉搜索树--96

一、题目描述 给你一个整数 n ,求恰由 n 个节点组成且节点值从 1 到 n 互不相同的 二叉搜索树 有多少种?返回满足题意的二叉搜索树的种数。 示例 1: 输入:n 3 输出:5示例 2: 输入:n 1 输出&…

从需求角度介绍PasteSpider(K8S平替部署工具)

你是否被K8S的强大而吸引,我相信一部分人是被那复杂的配置和各种专业知识而劝退,应该还有一部分人是因为K8S太吃资源而放手! PasteSpider是一款使用c#编写的linux容器部署工具,简单易上手,非常节省资源,支持…

shell脚本实现linux系统自动化配置免密互信

目录 背景脚本功能脚本内容及使用方法 1.背景 进行linux自动化运维时需要先配置免密,但某些特定场景下,做了互信的节点需要取消免密,若集群庞大节点数量多时,节点两两之间做互信操作非常麻烦,比如有五个节点&#x…

C++——动态规划

公共子序列问题 ~待补充 最长公共子序列 对于两个字符串A和B,A的前i位和B的前j位的最大公共子序列必然是所求解的一部分,设dp[i][j]为串A前i位和B串前j位的最长公共子序列的长度,则所求答案为dp[n][m],其中n,m分别为…

微信小程序主体变更的操作教程

小程序迁移变更主体有什么作用?进行小程序主体迁移变更,那可是益处多多呀!比方说,能够解锁更多权限功能;在公司变更或注销时,还能保障账号的正常使用;此外,收购账号后,也…

详解xlsxwriter 操作Excel的常用API

我们知道可以通过pandas 对excel 中的数据进行处理分析,但是pandas本身对格式化数据方面提供了很少的支持,如果我们想对pandas进行数据分析后的数据进行格式化相关操作,我们可以使用xlsxwriter,本文就对xlsxwriter的常见excel格式…

Salesforce AI研究: 从奖励建模到在线RLHF工作流

摘要 该研究在本技术报告中介绍了在线迭代基于人类反馈的强化学习(Online Iterative Reinforcement Learning from Human Feedback, RLHF)的工作流程,在最近的大语言模型(Large Language Model, LLM)文献中,这被广泛报道为大幅优于其离线对应方法。然而,现有的开源RLHF项目仍然…

Android存储文件路径的区别

一、Android存储简介 Android系统分为内部存储和外部存储 从Android6.0开始不断在更新存储权限 外部存储路径的开头:storage/emulated/0 内部存储文件路径的开头:/data/user/0/应用的包名(packageName) 在设备上对应的目录为/data…

Linux的命名管道 共享内存

目录 命名管道 mkfifo函数 unlink函数 命名管道类 服务端 客户端 共享内存 shmget函数 ftok函数 key和shmid的区别 snprintf函数 ipcs指令 ipcrm指令 shmctl函数 shmat函数 void*做返回值 创建共享内存空间 服务端 客户端 命名管道 基本概念&#xff1…

笔记本黑屏,重新开机主板没有正常运作的解决办法

拆开笔记本后壳,打开看到主板,将主板上的这颗纽扣电池拆下来,如果是带连接线的(如下图),可以将接口处线头拔出,等1分钟再把线接上。 ------------- 以下是科普 首先,电脑主板上的这…

力扣例题(循环队列)

链接 . - 力扣(LeetCode) 描述 思路 我们使用数组来创建循环队列 数组的大小我们就额外对开辟一块空间 MyCircularQueue(k) 开辟一个结构体,存放队列的相关数据 分别为size,数组指针_a,起始位置head,结束位置tail 注意:我们…

移动端自动化测试工具 Appium 之持续集成

文章目录 一、背景二、前置条件三、代码部分1、pom.xml文件配置2、main入口代码 四、Jenkins 部分1、下载Jenkins2、安装插件3、job配置4、选择构建 五、工程目录六、报告示例七、总结 一、背景 持续集成是老生话谈的事情,用的好不好,看自己公司与使用场…

能播放SWF文件的FlashPlayer播放器

问题: 你是不是遇到了 flash 动画 放不了了? 以前的flash游戏玩不了了 在网上很难找到好用的,免费Flashplayer播放器, 找到的也没法保存.exe 以前买的课件放不了了 一打开就更新提示: 再不就是意外能打开了但【创建…

IBM Granite模型开源:推动软件开发领域的革新浪潮

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

高中数学:平面向量-加减运算

一、向量的加法运算 三角形法则(推荐) 两个或多个向量收尾相连的加法运算,用三角形法则 简便算法 首尾相连的多个向量,去掉中间点,就是最终的和。 也可以用三角形法则证明 向量加法交换律 向量加法结合律 平行四…

讲解SSM的xml文件

概述&#xff1a;这些配置文件很烦&#xff0c;建议直接复制粘贴 springMVC.xml文件 <?xml version"1.0" encoding"UTF-8"?> <beans xmlns"http://www.springframework.org/schema/beans"xmlns:xsi"http://www.w3.org/2001/XM…