使用 PyTorch 自定义数据集并划分训练、验证与测试集

使用 PyTorch 自定义数据集并划分训练、验证与测试集

在图像分类等任务中,通常需要将原始训练数据进一步划分为训练集和验证集,以便在训练过程中评估模型的性能。下面将详细介绍如何组织数据与注释文件、如何分割训练集和验证集,以及如何基于自定义 Dataset 类构建 DataLoader 以加速模型训练与评估。

一、数据准备

1.1 文件结构

假设你的数据目录结构如下所示:

data/
├── train_data/
│   ├── img1.png
│   ├── img2.png
│   ├── img3.png
│   └── ...
├── test_data/
│   ├── img101.png
│   ├── img102.png
│   ├── img103.png
│   └── ...
├── train_annotations.csv
└── test_annotations.csv

注意:这里将 train_annotations.csvtest_annotations.csv 文件单独放在 data/ 目录下,而不放在各自图片的子文件夹中。这样当图片数量非常多时,我们也能快速找到并管理这两个 CSV 文件。

1.2 注释文件(CSV)格式示例

train_annotations.csvtest_annotations.csv 中,一般会包含两列或更多列信息,但最关键的通常是 图片文件名(filename)和 标签(label)。格式示例如下:

train_annotations.csv

filename,label
img1.png,0
img2.png,1
img3.png,0
...

test_annotations.csv

filename,label
img101.png,0
img102.png,1
img103.png,0
...
  • filename 列表示图像的文件名,需要与 train_data/test_data/ 文件夹下的文件一一对应。
  • label 列表示图像所对应的类别或标签,可以是整数,也可以是字符串,比如 catdog 等。训练时通常会将字符串映射到整数标签或独热编码。

二、将训练数据划分为训练集和验证集

在进行模型训练前,往往需要将原始训练数据(以下简称 “总训练集”)拆分成 训练集(train) 和 验证集(val)。这里我们使用 scikit-learn 提供的 train_test_split 函数来完成这一步骤。

import pandas as pd
from sklearn.model_selection import train_test_split# 读取原始训练集的注释文件(此时还未拆分)
train_annotations = pd.read_csv('data/train_annotations.csv')# 按 80%:20% 的比例拆分为 新的训练集(train_df) 和 验证集(val_df)
train_df, val_df = train_test_split(train_annotations, test_size=0.2, random_state=42, stratify=train_annotations['label']
)# 将拆分后的注释文件保存为新的 CSV 文件
train_df.to_csv('data/train_split.csv', index=False)
val_df.to_csv('data/val_split.csv', index=False)

关键参数说明:

  • test_size=0.2:表示将 20% 的样本作为验证集,其余 80% 作为新的训练集。
  • random_state=42:让划分结果可复现,方便后续对比不同实验结果。
  • stratify=train_annotations['label']:在划分时保持各类别在训练和验证集中相同比例,这在分类任务中尤为重要。

执行完以上步骤后,你的 data 目录下会多出两个新的注释文件:

data/
├── train_data/
│   ├── ...
├── test_data/
│   ├── ...
├── train_annotations.csv   # 原始,总训练集注释
├── train_split.csv         # 新的,训练集注释
└── val_split.csv           # 新的,验证集注释

三、自定义 Dataset

PyTorch 提供了 torch.utils.data.Dataset 作为数据集的抽象基类。我们可以通过继承并重写其中的方法,来实现灵活的数据加载逻辑。

下面的 CustomImageDataset 类支持通过 CSV 文件(包括你在上一步生成的 train_split.csv, val_split.csv 等)来读取图像与标签,并在取样本时进行必要的预处理操作。

import os
import pandas as pd
from torch.utils.data import Dataset
from PIL import Imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):"""初始化数据集。参数:annotations_file (str): CSV 文件路径,包含 (filename, label) 等信息img_dir (str): 存放图像的文件夹路径transform (callable, optional): 对图像进行转换和增强的函数或 transforms 组合target_transform (callable, optional): 对标签进行转换的函数"""self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):"""返回整个数据集的样本数量。"""return len(self.img_labels)def __getitem__(self, idx):"""根据索引 idx 获取单个样本。返回:(image, label) 其中 image 可以是一个 PIL 图像或 Tensor,label 可以是整数或字符串"""# 1. 获取图像文件名与对应的标签img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])label = self.img_labels.iloc[idx, 1]# 2. 读取图像并转换为 RGB 模式(如果是灰度则可用 'L')image = Image.open(img_path).convert('RGB')# 3. 对图像和标签进行必要的变换if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

四、创建训练集、验证集、测试集对应的 DataLoader

有了自定义 Dataset 后,就可以利用 PyTorch 自带的 DataLoader 来进行批量数据加载、随机打乱以及多线程读取数据等工作。以下示例展示了如何分别实例化 训练集验证集测试集Dataset 对象,并为每个对象创建 DataLoader

from torchvision import transforms
from torch.utils.data import DataLoader# 定义训练、验证/测试时所需的数据变换
train_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),  # 数据增强transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])val_test_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])# 实例化训练集 (train_dataset)
train_dataset = CustomImageDataset(annotations_file='data/train_split.csv',  # 注意这里不再是 data/train_annotations.csvimg_dir='data/train_data',transform=train_transform
)# 实例化验证集 (val_dataset)
val_dataset = CustomImageDataset(annotations_file='data/val_split.csv',img_dir='data/train_data',transform=val_test_transform
)# 实例化测试集 (test_dataset)
test_dataset = CustomImageDataset(annotations_file='data/test_annotations.csv',img_dir='data/test_data',transform=val_test_transform
)# 构建 DataLoader
train_loader = DataLoader(train_dataset,batch_size=64,shuffle=True,      # 训练时常使用 shuffle=True 来打乱顺序num_workers=4,     # 根据 CPU 核心数进行调整drop_last=True     # 避免最后一个 batch 样本数不足时带来的问题
)val_loader = DataLoader(val_dataset,batch_size=64,shuffle=False,num_workers=4,drop_last=False
)test_loader = DataLoader(test_dataset,batch_size=64,shuffle=False,num_workers=4,drop_last=False
)

通过使用 DataLoader,你就可以在训练和验证过程中以 (batch)为单位获取数据,从而显著提升训练速度,并方便进行数据增强、随机打乱等操作。

五、完整示例脚本

下面给出一个相对完整的示例脚本,整合了数据拆分、自定义数据集加载以及构建 DataLoader 的主要流程。如果你愿意,可以将这些步骤拆分到不同的 Python 文件中,以保持项目结构清晰。

import os
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms# ========== 1. 数据集拆分函数 ========== #
def split_train_val(annotations_file, output_train_file, output_val_file, test_size=0.2, random_state=42):df = pd.read_csv(annotations_file)train_df, val_df = train_test_split(df, test_size=test_size, random_state=random_state, stratify=df['label'])train_df.to_csv(output_train_file, index=False)val_df.to_csv(output_val_file, index=False)# ========== 2. 定义自定义 Dataset 类 ========== #
class CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])label = self.img_labels.iloc[idx, 1]image = Image.open(img_path).convert('RGB')if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label# ========== 3. 执行划分并创建训练/验证/测试集 ========== #
# 假设原始的训练集标注文件位于 data/train_annotations.csv
split_train_val(annotations_file='data/train_annotations.csv',output_train_file='data/train_split.csv',output_val_file='data/val_split.csv',test_size=0.2,random_state=42
)train_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])val_test_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])train_dataset = CustomImageDataset(annotations_file='data/train_split.csv',img_dir='data/train_data',transform=train_transform
)val_dataset = CustomImageDataset(annotations_file='data/val_split.csv',img_dir='data/train_data',transform=val_test_transform
)test_dataset = CustomImageDataset(annotations_file='data/test_annotations.csv',img_dir='data/test_data',transform=val_test_transform
)train_loader = DataLoader(train_dataset,batch_size=64,shuffle=True,num_workers=4,drop_last=True
)val_loader = DataLoader(val_dataset,batch_size=64,shuffle=False,num_workers=4,drop_last=False
)test_loader = DataLoader(test_dataset,batch_size=64,shuffle=False,num_workers=4,drop_last=False
)# ========== 4. 简单测试:读取一个 batch ========== #
for images, labels in train_loader:print(images.shape, labels.shape)break

六、在训练循环中使用验证集

构建好训练、验证和测试集的 DataLoader 之后,你就可以在模型训练过程中使用验证集来评估模型性能;并在完全训练结束后,对测试集进行最终评估。以下是一个最简化的示例,演示如何在每个 epoch 后进行验证:

import torch
import torch.nn as nn
import torch.optim as optim# 定义简单的神经网络
class SimpleNN(nn.Module):def __init__(self, num_classes=10):super(SimpleNN, self).__init__()self.flatten = nn.Flatten()self.fc1 = nn.Linear(224*224*3, 128)  # 根据输入图像大小进行调整self.relu = nn.ReLU()self.fc2 = nn.Linear(128, num_classes)def forward(self, x):x = self.flatten(x)x = self.relu(self.fc1(x))x = self.fc2(x)return x# 初始化模型、损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleNN(num_classes=2).to(device)  # 假设有 2 个类别
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练过程
num_epochs = 5
for epoch in range(num_epochs):# 1. 训练阶段model.train()running_loss = 0.0for images, labels in train_loader:images = images.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()avg_train_loss = running_loss / len(train_loader)# 2. 验证阶段model.eval()correct = 0total = 0val_loss = 0.0with torch.no_grad():for images, labels in val_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)loss = criterion(outputs, labels)val_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()avg_val_loss = val_loss / len(val_loader)val_accuracy = 100.0 * correct / totalprint(f'Epoch [{epoch+1}/{num_epochs}], 'f'Train Loss: {avg_train_loss:.4f}, 'f'Val Loss: {avg_val_loss:.4f}, 'f'Val Accuracy: {val_accuracy:.2f}%')

输出示例

Epoch [1/5], Train Loss: 1.2034, Val Loss: 0.4567, Val Accuracy: 85.32%
Epoch [2/5], Train Loss: 0.9876, Val Loss: 0.3987, Val Accuracy: 88.45%
...

总结

  1. 数据组织:将大量图片与注释文件分开存储(如 train_annotations.csvtest_annotations.csv 单独放在 data/ 目录下),可以在图片数量庞大时更方便地管理和检索。
  2. 数据集拆分:使用 train_test_split 将原始训练集拆分为训练集与验证集,以便在训练过程中监控模型的过拟合情况。
  3. 自定义 Dataset:通过继承 Dataset 并重写 __getitem____len__,可以灵活处理任意格式的数据,并在读入时执行预处理/增强操作。
  4. 构建 DataLoader:使用 PyTorch 的 DataLoader 可以轻松实现批量读取、并行加速、随机打乱等功能,大幅提升训练效率。
  5. 验证与测试:在每个 epoch 后对验证集进行评估可以及时发现过拟合和调参问题;最终对测试集进行评估可以获得模型的实际泛化性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/892098.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

TCP通信原理学习

TCP三次握手和四次挥手以及为什么_哔哩哔哩_bilibili

unity学习13:gameobject的组件component以及tag, layer 归类

目录 1 gameobject component 是unity的基础 1.1 类比 1.2 为什么要这么设计? 2 从空物体开始 2.1 创建2个物体 2.2 给 empty gameobject添加组件 3 各种组件和新建组件 3.1 点击 add component可以添加各种组件 3.2 新建组件 3.3 组件的操作 3.4 特别的…

MYSQL--------什么是存储过程和函数

什么是存储过程和函数 存储过程: 存储过程是一组预编译的 SQL 语句集合,存储在数据库服务器中,可通过名称调用执行。它可以包含数据操作语言(DML)、数据定义语言(DDL)、控制流语句等。存储过程主…

计算机网络 (30)多协议标签交换MPLS

前言 多协议标签交换(Multi-Protocol Label Switching,MPLS)是一种在开放的通信网上利用标签引导数据高速、高效传输的新技术。 一、基本概念 MPLS是一种第三代网络架构技术,旨在提供高速、可靠的IP骨干网络交换。它通过将IP地址映…

探索Facebook的区块链计划:未来社交网络的变革

随着区块链技术的迅速发展,社交网络领域正面临一场深刻的变革。Facebook,作为全球最大且最具影响力的社交平台之一,正在积极探索区块链技术的应用。本文将深入探讨Facebook的区块链计划,分析其潜在的变革性影响,并展望…

十年后LabVIEW编程知识是否会过时?

在考虑LabVIEW编程知识在未来十年内的有效性时,我们可以从几个角度进行分析: ​ 1. 技术发展与软件更新 随着技术的快速发展,许多编程工具和平台不断更新和改进,LabVIEW也不例外。十年后,可能会有新的编程语言或平台…

C# async和await

第一种: 多个异步任务按照顺序执行先让一个异步任务start 然后通过ContinueWith方法 在参数函数的表达式里面开启第二个任务如果要有第三个任务 需要在第二个任务ContinueWith方法中开启第三个任务 以此类推 可以实现多个异步任务顺序执行 上面这种方式绘出现地狱回…

Excel 技巧03 - 如何对齐小数位数? (★)如何去掉小数点?如何不四舍五入去掉小数点?

这几个有点儿关联,我都给放到一起了,不影响大家分别使用。 目录 1,如何对齐小数位数? 2,如何去掉小数点? 3,如何不四舍五入去掉小数点? 1,如何对齐小数位数&#xff…

node.js|浏览器插件|Open-Multiple-URLs的部署和使用,实现一键打开多个URL的强大工具

前言: 在整理各类资源的时候,可能会面临资源非常多的情况,这个时候我们就需要一款能够一键打开多个URL的浏览器插件了 说简单点,其实,迅雷就是这样的,但是迅雷是基于内置nginx浏览器实现的,并…

“AI 视频图像识别系统,开启智能新视界

咱老百姓现在的生活啊,那是越来越离不开高科技了,就说这 AI 视频图像识别系统,听起来挺高大上,实际上已经悄无声息地融入到咱们日常的方方面面,给咱带来了超多便利。 先讲讲安防领域吧,这可是 AI 图像识别的…

C语言 游动的小球

代码如下&#xff1a; 在这里插入代码片#include<stdio.h> #include<stdlib.h> #include<windows.h>int main() {int i,j;int x 5;int y 10;int height 20;int velocity_x 1;int velocity_y 1;int left 0;int right 20;int top 0;int bottom 10;while(1){…

基于SpringBoot实现的保障性住房管理系统

&#x1f942;(❁◡❁)您的点赞&#x1f44d;➕评论&#x1f4dd;➕收藏⭐是作者创作的最大动力&#x1f91e; &#x1f496;&#x1f4d5;&#x1f389;&#x1f525; 支持我&#xff1a;点赞&#x1f44d;收藏⭐️留言&#x1f4dd;欢迎留言讨论 &#x1f525;&#x1f525;&…

安卓触摸对焦

1. 相机坐标说明 触摸对焦需要通过setFocusAreas()设置对焦区域&#xff0c;而该方法的参数的坐标&#xff0c;与屏幕坐标并不相同&#xff0c;需要做一个转换。 对Camera&#xff08;旧版相机API&#xff09;来说&#xff0c;相机的坐标区域是一个2000*2000&#xff0c;原点…

湖南引力:低代码技术助力军工企业实现设备管理系统创新

背景介绍 在核工业相关生产领域&#xff0c;随着技术的持续进步&#xff0c;生产活动对设备的依赖性日益增强。随着企业规模的不断扩大&#xff0c;所涉及的设备数量和种类也在急剧增长&#xff0c;这使得传统的设备管理模式逐渐显得力不从心。企业当前的设备管理主要依赖人工…

【701. 二叉搜索树中的插入操作 中等】

题目&#xff1a; 给定二叉搜索树&#xff08;BST&#xff09;的根节点 root 和要插入树中的值 value &#xff0c;将值插入二叉搜索树。 返回插入后二叉搜索树的根节点。 输入数据 保证 &#xff0c;新值和原始二叉搜索树中的任意节点值都不同。 注意&#xff0c;可能存在多…

VR+智慧消防一体化决策平台

随着科技的飞速发展&#xff0c;虚拟现实&#xff08;VR&#xff09;技术与智慧城市建设的结合越来越紧密。在消防安全领域&#xff0c;VR技术的应用不仅能够提升消防训练的效率和安全性&#xff0c;还能在智慧消防一体化决策平台中发挥重要作用。本文将探讨“VR智慧消防一体化…

nginx http反向代理

系统&#xff1a;Ubuntu_24.0.4 1、安装nginx sudo apt-get update sudo apt-get install nginx sudo systemctl start nginx 2、配置nginx.conf文件 /etc/nginx/nginx.conf&#xff0c;但可以在 /etc/nginx/sites-available/ 目录下创建一个新的配置文件&#xff0c;并在…

arcgisPro加载CGCS2000天地图后,如何转成米单位

1、导入加载的天地图影像服务&#xff0c;一开始是经纬度显示的。 2、右键地图&#xff0c;选择需要调整的投影坐标&#xff0c;这里选择坐标如下&#xff1a; 3、点击确定后&#xff0c;就可以调整成米单位的了。 4、切换后结果如下&#xff1a; 如有需要&#xff0c;可调整成…

计算机的错误计算(二百零四)

摘要 利用两个大模型判断&#xff1a;在(0, ) 范围内&#xff0c; 和 等价吗&#xff1f;实验表明&#xff0c;两个大模型&#xff08;其中一个是数学大模型&#xff09;均在输出幻觉&#xff0c;均说等价&#xff01; 例1. 在(0, ) 范围内&#xff0c; 和 等价吗&#xf…

简单的jmeter数据请求学习

简单的jmeter数据请求学习 1.需求 我们的流程服务由原来的workflow-server调用wfms进行了优化&#xff0c;将wfms服务操作并入了workflow-server中&#xff0c;去除了原来的webservice服务调用形式&#xff0c;增加了并发处理&#xff0c;现在想测试模拟一下&#xff0c;在一…