pytorch中dataloader自定义数据集

前言

在深度学习中我们需要使用自己的数据集做训练,因此需要将自定义的数据和标签加载到pytorch里面的dataloader里,也就是自实现一个dataloader。

数据集处理

以花卉识别项目为例,我们分别做出图片的训练集和测试集,训练集的标签和测试集的标签

flower_data/
├── train_filelist/
│   ├── image_0001.jpg
│   └── ...
├── val_filelist/
│   ├── image_1001.jpg
│   └── ...
├── train.txt  # 格式:文件名 标签
└── val.txt

 数据目录的组织方式如上所示。

首先看图片的处理。图片只要做好编号放在同一个文件夹里就好了。

再看标签的处理。标签处理我们自己规定了一种形式,就是图像文件的名称+空格+分类标签。

可以看到前面第一列数据是图像名称,第二列数据是图像的分组,同样的数字为一组。比如分组为0的图像就是同一种花朵。

自定义dataset

源码

import os.path
import numpy as np
import torch
from PIL import Image  # 从PIL库导入Image类
from torch.utils.data import Datasetclass FlowerDataSet(Dataset):"""花朵分类任务数据集类,继承自torch的Dataset类"""def __init__(self, root_dir, ann_file, transform=None):"""初始化数据集实例Args:root_dir (str): 数据集根目录路径ann_file (str): 标注文件路径transform (callable, optional): 数据预处理变换函数"""self.ann_file = ann_fileself.root_dir = root_dir# 加载图片路径与标签的映射字典 {文件名: 标签}self.image_label = self.load_annotations()# 构建完整图片路径列表 [root_dir/文件名1, ...]self.image = [os.path.join(self.root_dir, img) for img in list(self.image_label.keys())]# 构建标签列表 [标签1, 标签2, ...]self.label = [lbl for lbl in list(self.image_label.values())]  # 重命名为lbl避免与导入的label冲突self.transform = transformdef __len__(self):"""返回数据集样本数量"""return len(self.image)def __getitem__(self, index):"""获取单个样本数据Args:index (int): 样本索引Returns:tuple: (预处理后的图像数据, 对应的标签)"""# 打开图片文件image = Image.open(self.image[index])# 获取对应标签label = self.label[index]# 应用数据预处理if self.transform:image = self.transform(image)# 将标签转换为torch张量label = torch.from_numpy(np.array(label))return image, labeldef load_annotations(self):"""加载标注文件,解析图片文件名和标签的映射关系Returns:dict: {图片文件名: 对应标签} 的字典"""data_infos = {}with open(self.ann_file) as f:# 读取所有行并分割,每行格式应为 "文件名 标签"samples = [x.strip().split(' ') for x in f.readlines()]for filename, label in samples:# 将标签转换为int64类型的numpy数组data_infos[filename] = np.array(label, dtype=np.int64)return data_infos

解析

1、将标签数据进行读取,组成一个哈希表,哈希表的键是图像的文件名称,哈希表的值是分组标签。

    def load_annotations(self):"""加载标注文件,解析图片文件名和标签的映射关系Returns:dict: {图片文件名: 对应标签} 的字典"""data_infos = {}with open(self.ann_file) as f:# 读取所有行并分割,每行格式应为 "文件名 标签"samples = [x.strip().split(' ') for x in f.readlines()]for filename, label in samples:# 将标签转换为int64类型的numpy数组data_infos[filename] = np.array(label, dtype=np.int64)return data_infos

上面的代码里,在录入标签的时候使用数组进行记录,这是为了兼容多标签的场景。如果不考虑兼容问题,仅考虑在单标签场景下的简单实现,可以用下面的代码:

def load_annotations(self):data_infos = {}with open(self.ann_file) as f:for line in f:filename, label = line.strip().split()  # 直接解包data_infos[filename] = int(label)        # 存为 Python 整数return data_infos# 在 __getitem__ 中直接转为张量
label = torch.tensor(self.labels[index], dtype=torch.long)

2、遍历哈希表,将文件名和标签分别存在两个数组里。这里注意,为了方便后面dataloader按照batch去读取图片,这里要将图片的全路径加到文件名里。

        # 构建完整图片路径列表 [root_dir/文件名1, ...]self.image = [os.path.join(self.root_dir, img) for img in list(self.image_label.keys())]# 构建标签列表 [标签1, 标签2, ...]self.label = [lbl for lbl in list(self.image_label.values())]  # 重命名为lbl避免与导入的label冲突

3、在dataloader向显卡/cpu加载数据的时候会调用getitem方法。比如一个batch里有64个数据,dataloader就会调用64次该方法,将64组图片和标签全部获取后交给运算单元去处理。

    def __getitem__(self, index):"""获取单个样本数据Args:index (int): 样本索引Returns:tuple: (预处理后的图像数据, 对应的标签)"""# 打开图片文件image = Image.open(self.image[index])# 获取对应标签label = self.label[index]# 应用数据预处理if self.transform:image = self.transform(image)# 将标签转换为torch张量label = torch.from_numpy(np.array(label))return image, label

测试dataloader

import os
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
from dataloader import FlowerDataSet  # 假设你的数据集类在dataloader.py中def denormalize(image_tensor):"""将归一化的图像张量转换为可显示的格式"""mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])image = image_tensor.numpy().transpose((1, 2, 0))  # 转换维度顺序image = std * image + mean  # 反归一化image = np.clip(image, 0, 1)  # 限制像素值范围return imagedef test_dataloader():# 定义数据预处理data_transforms = {'train': transforms.Compose([transforms.Resize(64),transforms.RandomRotation(45),transforms.CenterCrop(64),transforms.RandomHorizontalFlip(p=0.5),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'valid': transforms.Compose([transforms.Resize(64),transforms.CenterCrop(64),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 检查文件路径是否存在print("[1/5] 检查文件路径...")required_files = {'train_txt': './flower_data/train.txt','val_txt': './flower_data/val.txt','train_dir': './flower_data/train_filelist','val_dir': './flower_data/val_filelist'}for name, path in required_files.items():if not os.path.exists(path):print(f"❌ 文件/目录不存在: {path}")returnprint(f"✅ {name}: {path} 存在")# 初始化数据集print("\n[2/5] 加载数据集...")try:train_dataset = FlowerDataSet(root_dir=required_files['train_dir'],ann_file=required_files['train_txt'],transform=data_transforms['train'])val_dataset = FlowerDataSet(root_dir=required_files['val_dir'],ann_file=required_files['val_txt'],transform=data_transforms['valid'])print("✅ 数据集加载成功")except Exception as e:print(f"❌ 数据集加载失败: {str(e)}")return# 打印数据集信息print("\n[3/5] 数据集统计:")print(f"训练集样本数: {len(train_dataset)}")print(f"验证集样本数: {len(val_dataset)}")# 检查单个样本print("\n[4/5] 检查单个样本:")sample_idx = 0try:img, label = train_dataset[sample_idx]print(f"图像张量形状: {img.shape} (应接近 torch.Size([3, 64, 64]))")print(f"标签类型: {type(label)} (应为 torch.Tensor)")print(f"标签值: {label.item()} (应为整数)")except Exception as e:print(f"❌ 样本检查失败: {str(e)}")# 可视化样本print("\n[5/5] 可视化训练集样本...")try:plt.figure(figsize=(8, 8))img_show = denormalize(img)plt.imshow(img_show)plt.title(f"Label: {label.item()}")plt.axis('off')plt.show()except Exception as e:print(f"❌ 可视化失败: {str(e)}")# 检查DataLoaderprint("\n[附加] 检查DataLoader:")train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)for loader, name in [(train_loader, '训练集'), (val_loader, '验证集')]:print(f"\n{name} DataLoader测试:")try:batch = next(iter(loader))images, labels = batchprint(f"批次图像形状: {images.shape} (应接近 [batch, 3, 64, 64])")print(f"批次标签示例: {labels[:5].numpy()}")print(f"像素值范围: [{images.min():.3f}, {images.max():.3f}]")except Exception as e:print(f"❌ {name} DataLoader错误: {str(e)}")if __name__ == '__main__':test_dataloader()

在测试代码中,分别测试了文件路径,dataset是否正常创建,dataset样本数量,dataset样本格式,dataset数据可视化,dataloader数据样式。

在打印日志的时候需要注意,dataset和dataloader里面的变量都是张量形式的,所以需要转换成python标量再打印。比如从dataset里取出的标签label是一个一维张量,需要通过label.item()进行转换。

 在遍历的时候为了简化代码,将两个dataloader放在同一个循环语句中处理,并且通过增加name变量来区分两个dataloader。

for loader, name in [(train_loader, '训练集'), (val_loader, '验证集')]:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/899825.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Blender模型导入虚幻引擎设置

单位系统不一致 Blender默认单位是米(Meters),而虚幻引擎默认使用**厘米(Centimeters)**作为单位。 当模型从Blender导出为FBX或其他格式时,如果没有调整单位,虚幻引擎会将1米(Blen…

Docker基础详解

Docker 技术详解 一、概述 Docker官网:https://docs.docker.com/ 菜鸟教程:https://www.runoob.com/docker/docker-tutorial.html 1.1 什么是Docker? Docker 是一个开源的容器化平台,它允许开发者将应用程序和其依赖项打包到…

FastPillars:一种易于部署的基于支柱的 3D 探测器

FastPillars:一种易于部署的基于支柱的 3D 探测器Report issue for preceding element Sifan Zhou 1 , Zhi Tian 2 , Xiangxiang Chu 2 , Xinyu Zhang 2 , Bo Zhang 2 , Xiaobo Lu11{}^{1}start_FLOATSUPERSCRIPT 1 end_FLOATSUPERSCRIPT11footnotemark: 1 Chengji…

NLP语言模型训练里的特殊向量

1. CLS 向量和 DEC 向量的区别及训练方式 (1) CLS 向量与 DEC 向量是否都是特殊 token? CLS 向量([CLS] token)和 DEC 向量(Decoder Input token)都是特殊的 token,但它们出现在不同类型的 NLP 模型中&am…

字节跳动 UI-TARS 汇总整理报告

1. 摘要 UI-TARS 是字节跳动开发的一种原生图形用户界面(GUI)代理模型 。它将感知、行动、推理和记忆整合到一个统一的视觉语言模型(VLM)中 。UI-TARS 旨在跨桌面、移动和 Web 平台实现与 GUI 的无缝交互 。实验结果表明&#xf…

基于Python深度学习的鲨鱼识别分类系统

摘要:鲨鱼是海洋环境健康的指标,但受到过度捕捞和数据缺乏的挑战。传统的观察方法成本高昂且难以收集数据,特别是对于具有较大活动范围的物种。论文讨论了如何利用基于媒体的远程监测方法,结合机器学习和自动化技术,来…

【漫话机器学习系列】168.最大最小值缩放(Min-Max Scaling)

在机器学习和数据预处理中,特征缩放(Feature Scaling) 是一个至关重要的步骤,它可以使模型更稳定,提高训练速度,并优化收敛效果。最大最小值缩放(Min-Max Scaling) 是其中最常见的方…

开源测试用例管理平台

不可错过的10个开源测试用例管理平台: PingCode、TestLink、Kiwi TCMS、Squash TM、FitNesse、Tuleap、Robot Framework、SpecFlow、TestMaster、Nitrate。 开源测试用例管理工具提供了一种透明、灵活的解决方案,使团队能够在不受限的情况下适应具体的测…

鸿蒙阔折叠Pura X外屏开发适配

首先看下鸿蒙中断点分类 内外屏开合规则 Pura X开合连续规则: 外屏切换到内屏,界面可以直接接续。内屏(锁屏或非锁屏状态)切换到外屏,默认都显示为锁屏的亮屏状态。用户解锁后:对于应用已适配外屏的情况下,应用界面可以接续到外屏。折叠外屏显示展开内屏显示折叠状态…

DRM_CLIENT_CAP_UNIVERSAL_PLANES和DRM_CLIENT_CAP_ATOMIC

drmSetClientCap(fd, DRM_CLIENT_CAP_UNIVERSAL_PLANES, 1); drmSetClientCap(fd, DRM_CLIENT_CAP_ATOMIC, 1); 这两行代码用于启用 Linux DRM(Direct Rendering Manager)客户端的两个关键特性,具体作用如下: 1. drmSetClientCap…

敏捷开发10:精益软件开发和看板kanban开发方法的区别是什么

简介 精益生产起源于丰田生产系统,核心是消除浪费,而看板最初是由丰田用于物料管理的信号卡片,后来被引入软件开发。 Kanban 后来引入到敏捷开发中,强调持续交付和流程可视化。 精益软件开发原则是基于精益生产的原则&#xff0…

用matlab探索卷积神经网络(Convolutional Neural Networks)-3

5.GoogLeNet中的Filters 这里我们探索GoogLeNet中的Filters,首先你需要安装GoogLeNet.在Matlab的APPS里找到Deep Network Designer,然后找到GoogLeNet,安装后的网络是没有右下角的黄色感叹号的,没有安装的神经网络都有黄色感叹号。 一个层&a…

Verilog中X态的危险:仿真漏掉的bug

由于Verilog中X态的微妙语义,RTL仿真可能PASS,而网表仿真却会fail。 目前进行的网表仿真越来越少,这个问题尤其严重,主要是网表仿真比RTL仿真慢得多,因此对整个回归测试而言成本效益不高。 上面的例子中,用Verilog RTL中的case语句描述了一个简单的AND函数,它被综合成AN…

PyTorch中知识蒸馏浅讲

知识蒸馏 在 PyTorch 中,使用 teacher_model.eval() 和冻结教师模型参数是知识蒸馏(Knowledge Distillation)中的关键步骤。 ​1. teacher_model.eval() 的作用 目的: 将教师模型切换到评估模式,影响某些特定层(如 Dropout、BatchNorm)的行为。 ​具体影响: ​Dropo…

Odoo/OpenERP 和 psql 命令行的快速参考总结

Odoo/OpenERP 和 psql 命令行的快速参考总结 psql 命令行选项 选项意义-a从脚本中响应所有输入-A取消表数据输出的对齐模式-c <查询>仅运行一个简单的查询&#xff0c;然后退出-d <数据库名>指定连接的数据库名&#xff08;默认为当前登录用户名&#xff09;-e回显…

ChatGPT 迎来 4o模型:更强大的图像生成能力与潜在风险

OpenAI 对 ChatGPT 进行重大升级&#xff0c;图像生成功能即将迎来新的 4o 模型&#xff0c;并取代原本的 DALLE。此次更新不仅提升了图像生成质量&#xff0c;还增强了对话内容和上传文件的融合能力&#xff0c;使 AI 生成的图像更加智能化和精准化。 4o 模型带来的革新 Ope…

Python 实现的运筹优化系统代码详解(整数规划问题)

一、引言 在数学建模的广袤领域里&#xff0c;整数规划问题占据着极为重要的地位。它广泛应用于工业生产、资源分配、项目管理等诸多实际场景&#xff0c;旨在寻求在一系列约束条件下&#xff0c;使目标函数达到最优&#xff08;最大或最小&#xff09;且决策变量取整数值的解决…

Visual Studio Code配置自动规范代码格式

目录 前言1. 插件安装2. 配置个性化设置2.1 在左下角点击设置按钮 &#xff0c;点击命令面板&#xff08;或者也可以之间按快捷键CtrlShiftP&#xff09;2.2 在弹出的搜索框输入 settings.json&#xff0c;打开首选项&#xff1a;打开工作区设置&#xff1b;2.3 在settings.jso…

【分布式】Hystrix 的核心概念与工作原理​

熔断机制​ Hystrix 的熔断机制就像是电路中的保险丝。当某个服务的失败请求达到一定比例&#xff08;例如 50%&#xff09;或者在一定时间内&#xff08;如 20 秒&#xff09;失败请求数量超过一定阈值&#xff08;如 20 个&#xff09;时&#xff0c;熔断开关就会打开。此时…

TypeScript 中 await 的详解

TypeScript 中 await 的详解 1. 基本概念2. 语法要求3. 工作原理4. 与 Promise 的比较5. 实践中的注意事项总结 本文详细介绍了 TypeScript 中 await 的工作原理、语法要求、与 Promise 的关系以及实践中需要注意的问题&#xff0c;同时针对代码示例进行了优化和补充说明。 1.…