【计算机视觉基础CV】03-深度学习图像分类实战:鲜花数据集加载与预处理详解

本文将深入介绍鲜花分类数据集的加载与处理方式,同时详细解释代码的每一步骤并给出更丰富的实践建议和拓展思路。以实用为导向,为读者提供从数据组织、预处理、加载到可视化展示的完整过程,并为后续模型训练打下基础。


前言

在计算机视觉的深度学习实践中,数据加载和预处理是至关重要的一步。无论你是初学者,还是有一定经验的从业者,都需要深刻理解如何将原始数据转化为神经网络可接受的输入。PyTorch中的torchvision.datasetstorchvision.transforms为我们提供了极大的便利,使图像数据的加载和处理更加高效与简洁。

本文将以“鲜花分类数据集”(一个包含5种不同花卉类别的图像数据集)为例,详细讲述如何使用ImageFolder类进行数据加载,并通过transforms对图像进行预处理和数据增强。我们还会深入讨论数据集结构、训练/验证集划分、代码注释和实践建议,并给出详细说明。


数据集简介与结构

本例使用的鲜花分类数据集共包含5种花:雏菊(daisy)、蒲公英(dandelion)、玫瑰(roses)、向日葵(sunflowers)和郁金香(tulips)。数据量约为:

  • 训练集(train):3306张图像

  • 验证集(val):364张图像

数据已按类别分好目录,每个类别对应一个文件夹,文件夹中存放若干图片文件。结构示意如下:

dataset/flower_datas/├─ train/│   ├─ daisy/       # 雏菊类图像若干张│   ├─ dandelion/   # 蒲公英类图像若干张│   ├─ roses/       # 玫瑰类图像若干张│   ├─ sunflowers/   # 向日葵类图像若干张│   └─ tulips/       # 郁金香类图像若干张└─ val/├─ daisy/├─ dandelion/├─ roses/├─ sunflowers/└─ tulips/

这种目录结构非常适合ImageFolder数据集类,它会根据子文件夹的名称自动分配类别标签,从0开始编号。例如:

  • daisy -> 0

  • dandelion -> 1

  • roses -> 2

  • sunflowers -> 3

  • tulips -> 4

这样无需手动编码类别映射,简化了流程。


ImageFolder和transform

ImageFolder简介

ImageFoldertorchvision.datasets中的一个实用数据类,它假设数据按如下规则组织:

  • root/class_x/xxx.png

  • root/class_x/xxy.png

  • root/class_y/xxz.png

  • ...

其中class_xclass_y是类名(字符串),ImageFolder会根据这些类名自动生成类别索引。加载后,每个样本是一个(image, label)二元组,image通常会通过transform转换为Tensorlabel为整数索引。


transforms的数据预处理功能

torchvision.transforms提供多种图像处理方法,用来改变图像格式、尺寸、颜色空间和进行数据增强。例如:

  • ToTensor():将PIL图像或Numpy数组转换为(C,H,W)格式的张量,并将像素值归一化到[0,1]之间。

  • Resize((224,224)):将图像缩放到224x224大小,这通常是预训练模型如ResNet、VGG的标准输入尺寸。

  • RandomHorizontalFlip():随机水平翻转图像,用于数据增强,提高模型对翻转不敏感。

  • Normalize(mean, std):对图像的每个通道进行归一化,使训练更稳定。

你可以根据需求灵活组合多个变换操作,使用transforms.Compose将其串联成流水线。


加载鲜花分类数据集的示例代码

下面的代码示例中,我将详细注释每个步骤,为读者提供清晰的思路。该示例以最基本的ToTensor和Resize为主,读者可按需添加更多transform。

import torch
import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms
import matplotlib.pyplot as plt# 数据集存放路径,根据实际情况修改
flowers_train_path = '../01.图像分类/dataset/flower_datas/train/'
flowers_val_path = '../01.图像分类/dataset/flower_datas/val/'# 定义数据预处理
# 这里的transforms主要包括:
# 1. ToTensor():将PIL图片或numpy数组转为Tensor,并将像素值归一化到[0,1]区间。
# 2. Resize((224,224)):将所有图片大小统一为224x224,以匹配后续卷积神经网络的输入要求。
# 对于实际训练,更建议加入数据增强手段(如随机裁剪、翻转、归一化等),
# 但本例先展示基本流程。
dataset_transform = transforms.Compose([transforms.ToTensor(),transforms.Resize((224,224))
])# 使用ImageFolder加载训练集和验证集
# ImageFolder会扫描指定目录下的子文件夹,并以子文件夹名称作为类别。
flowers_train = ImageFolder(root=flowers_train_path, transform=dataset_transform)
flowers_val = ImageFolder(root=flowers_val_path, transform=dataset_transform)# 打印样本数量
print("训练集样本数:", len(flowers_train))
print("验证集样本数:", len(flowers_val))# flowers_train.classes属性包含类别名称列表,如['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
print("类别名称列表:", flowers_train.classes)# 获取单个样本进行查看
# __getitem__(index)返回(img, label),img是Tensor,label是int
sample_index = 3000
sample_img, sample_label = flowers_train[sample_index]print("样本索引:", sample_index)
print("类别标签索引:", sample_label, "类别名称:", flowers_train.classes[sample_label])
print("图像Tensor尺寸:", sample_img.shape)  # 期望为[3,224,224]# 可视化图像
# Matplotlib的imshow要求图像为(H,W,C),而Tensor是(C,H,W),需要permute调整维度顺序。
plt.imshow(sample_img.permute(1,2,0))
plt.title(flowers_train.classes[sample_label])
plt.show()

代码输出: 


关于训练集、验证集和测试集的说明

本数据集中已提前将数据分为trainval两个目录:

  • train/:训练集,用于模型训练过程中反向传播和参数更新。

  • val/:验证集,用于在训练中间进行性能评估,不参与参数更新,仅用于选择超参数或判断训练是否过拟合。

有些数据集还会提供test/测试集,用于最终评估模型在未知数据上的表现,但本例中未提供,如有需要可自行分割数据或从其他来源获取。


DataLoader的引入

仅有ImageFolder还不够,为了在训练时批量读取数据并进行迭代,我们通常会将数据集对象传入DataLoader中。

DataLoader的作用是:

  • 按指定的batch_size从Dataset中抽取样本构成mini-batch。

  • 可设置shuffle=True来随机打乱样本顺序,防止模型记住样本顺序。

  • 使用num_workers参数并行加速数据加载。

示例(可选代码):

from torch.utils.data import DataLoaderbatch_size = 32
# 定义训练集和验证集的DataLoader
train_loader = DataLoader(flowers_train, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(flowers_val, batch_size=batch_size, shuffle=False, num_workers=2)# 测试一下加载结果
images, labels = next(iter(train_loader))
print("一个batch的图像尺寸:", images.shape)  # [batch_size, 3, 224, 224]
print("对应的标签:", labels)  # 张量形式,如tensor([0, 1, 3, ...])


有了DataLoader,我们在训练模型时,就可以轻松迭代数据:

for epoch in range(1):for batch_images, batch_labels in train_loader:# 在这里将batch_images, batch_labels输入模型进行训练print("一个batch的图像尺寸:", batch_images.shape)  # [batch_size, 3, 224, 224]print("对应的标签:", batch_labels)  # 张量形式,如tensor([0, 1, 3, ...])passbreak


我们可以打印一下第一个batch 和最后一个batch的标签

batch_count = 0
first_batch_images, first_batch_labels = None, None
last_batch_images, last_batch_labels = None, Nonefor epoch in range(1):for batch_images, batch_labels in train_loader:batch_count += 1# 保存第一个batchif batch_count == 1:first_batch_images, first_batch_labels = batch_images, batch_labelsprint("第一个batch的图像尺寸:", batch_images.shape)print("第一个batch的标签:", batch_labels)# 每次循环都会更新last_batchlast_batch_images, last_batch_labels = batch_images, batch_labelsbreak  # 只进行一次epoch的训练,移除这行会进行多个epoch的训练# 打印最后一个batch
print("最后一个batch的图像尺寸:", last_batch_images.shape)
print("最后一个batch的标签:", last_batch_labels)# 打印总共的batch数量
print("总共的batch数量:", batch_count)


数据增强策略的拓展

实际训练中,为提高模型的泛化能力,我们常加入数据增强操作。这些操作对训练集图像进行随机变换,如随机剪裁、翻转、颜色抖动、归一化等。这样模型不会过度记忆特定图像的像素分布,而会学习更有泛化性的特征。

一个常用的transform示例:

# 定义训练集的图像预处理流程
train_transform = transforms.Compose([# 随机裁剪并缩放图像到224x224的尺寸,裁剪的区域大小是随机的transforms.RandomResizedCrop(224),  # 随机进行水平翻转,用于数据增强,提升模型的泛化能力transforms.RandomHorizontalFlip(),# 将图像转换为Tensor类型,PyTorch要求输入为Tensor格式transforms.ToTensor(),# 进行图像的标准化处理。根据ImageNet数据集的均值和标准差进行归一化,# 使得不同的通道(RGB)具有相同的尺度,便于训练。transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])# 定义验证集的图像预处理流程
val_transform = transforms.Compose([# 将图像的最短边缩放到256像素,保持长宽比例不变transforms.Resize(256),  # 从缩放后的图像中进行中心裁剪,裁剪出224x224的区域,这样图像的尺寸就一致了transforms.CenterCrop(224),# 将图像转换为Tensor类型,PyTorch要求输入为Tensor格式transforms.ToTensor(),# 进行图像的标准化处理。根据ImageNet数据集的均值和标准差进行归一化,# 使得不同的通道(RGB)具有相同的尺度,便于训练。transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])# 使用定义的transform对训练集和验证集进行图像预处理
# flowers_train_path和flowers_val_path是训练集和验证集图像所在的路径
flowers_train = ImageFolder(flowers_train_path, transform=train_transform)  # 训练集
flowers_val = ImageFolder(flowers_val_path, transform=val_transform)  # 验证集

在此示例中,Normalize的参数是使用ImageNet数据集的均值和标准差,这在使用ImageNet预训练模型时是常规操作。对于自定义数据集,你也可以先统计本数据集的均值和方差,再进行归一化。


我们可以打印一下变化前后的图像区别

import os
import random
import numpy as np  # 需要导入numpy
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from torchvision.datasets import ImageFolder# 定义训练集的图像预处理流程
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 定义图像数据集路径
train_image_folder = '/Users/coyi/PycharmProjects/coyi_pythonProject/01.图像分类/dataset/flower_datas/train/'# 使用ImageFolder加载数据集
dataset = ImageFolder(train_image_folder, transform=None)# 随机选取一张图片
random_idx = random.randint(0, len(dataset) - 1)
image, label = dataset[random_idx]# 显示原始图像
plt.figure(figsize=(5,5))
plt.title("Original Image")
plt.imshow(image)
plt.axis('off')  # 不显示坐标轴
plt.show()# 应用train_transform变换
transformed_image = train_transform(image)# 反标准化(Undo normalization)以恢复图片的原始视觉效果,因为训练的时候需要标准化
inv_normalize = transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1/0.229, 1/0.224, 1/0.225])
unnormalized_image = inv_normalize(transformed_image)# 将Tensor转回PIL图像进行显示
unnormalized_image = unnormalized_image.permute(1, 2, 0).numpy()  # 转换为HWC格式
unnormalized_image = np.clip(unnormalized_image, 0, 1)  # 限制值在[0, 1]之间,以符合视觉输出# 显示变换后的图像
plt.figure(figsize=(5,5))
plt.title("Transformed Image")
plt.imshow(unnormalized_image)
plt.axis('off')  # 不显示坐标轴
plt.show()

输出: 

备注: 为了显示图片,我对处理后的图片进行了反标准化,实际上训练的时候是不需要反标准化的


为什么要反标准化?

标准化是一个常见的预处理步骤,目的是让模型训练时更稳定,通常是将像素值转换到均值为0、标准差为1的范围。这可以帮助模型更好地收敛,并且消除不同通道(例如RGB)的尺度差异。

然而,标准化后的图像不适合直接用于可视化,因为它们的像素值已经不在[0, 1]的范围内,可能会变成负数或大于1。反标准化的目的是恢复图像的原始视觉效果,让它们的像素值回到原始的视觉范围。

不反标准化可以吗?

在可视化时不反标准化是可以的,但你会看到经过标准化后的图像没有直观的可视化效果,因为图像的像素值会偏离 [0, 1] 的可视化范围。这会导致显示的图像看起来可能是“失真”的,例如图像会变得非常暗、非常亮,或者有一些不自然的颜色。

简而言之:

反标准化是为了恢复图像的原始视觉效果,使得图像显示更符合人类的感知。

• **np.clip()**是为了确保图像的像素值在[0, 1]范围内,符合图像显示的要求。

示例:

假设标准化之后,你得到了一个像素值为 -0.5 或 1.5 的图像像素。这时,如果不进行 np.clip(),直接用 matplotlib 显示,可能会看到图像出现异常的颜色或显示不出来。而通过 np.clip(),将这些像素值限制在[0, 1]的范围内,可以确保图像能正确显示。


类别分布与标签可解释性

flowers_train.classesflowers_val.classes可以查看类名列表。例如:

这意味着模型预测结果中的label=0代表daisy,label=1代表dandelion,以此类推。当我们预测模型输出为label=3时,就可以将其解释为sunflowers。这种可读性非常有助于后期分析和调试。

如果想查看具体每类样本数量,可手动统计,例如:

 

通过查看类别分布,我们可了解数据是否偏斜(某些类样本过多或过少),从而采取相应措施(如类均衡采样、权重平衡等)。


实战建议和下一步计划

  1. 数据准备完成后做什么? 通常下一步就是定义和加载模型(如预训练的ResNet18),然后编写训练循环对模型进行微调或从头训练。在训练循环中,train_loader提供批数据,val_loader则用于评估模型在验证集上的表现。

  2. 调试DataLoader是否正确工作: 在正式训练前,尝试可视化几个batch的数据样本,确保图像大小、颜色正确,标签映射无误。如果出现图像显示不正确或标签偏移,及时检查目录结构和transform流程。

  3. 善用数据增强: 当验证集精度停滞不前或出现过拟合时,尝试加入更多数据增强手段(如RandomRotationColorJitterRandomGrayscale等)提升泛化性能。

  4. 硬件加速: 在加载大规模数据时,合理增加num_workers可以提高数据读取速度(依赖操作系统和硬件条件)。同时,如果是分布式训练,也需考虑分布式Sampler和合适的数据划分策略。

  5. 定制Dataset: 如果你的数据不遵循ImageFolder的结构,也可以自行定义Dataset类,通过实现__len____getitem__方法来自定义数据加载流程。但对像本例这样已按类分文件夹的数据集,ImageFolder无疑是最简单高效的方案。


小结

在本文中,我们从零出发,详细介绍了如何使用PyTorch的ImageFoldertransforms加载和预处理鲜花分类数据集。主要点包括:

  • 数据集组织结构:子文件夹命名为类名,便于ImageFolder自动识别类别。

  • 使用transforms对图像进行ToTensor和Resize等变换,以满足神经网络输入要求。

  • 通过可视化样本和打印类别信息确认数据加载的正确性。

  • 引入DataLoader批量采样和迭代数据,为后续训练循环奠定基础。

  • 展望数据增强、Normalize以及预训练模型迁移学习等实战技巧。

数据加载与预处理是深度学习项目不可或缺的步骤。掌握这些技能,能够让你在模型开发和实验中更加得心应手。未来你可以尝试更多高级技巧,如自定义transforms、对数据集进行统计分析、探索更复杂的增强策略和分布式数据加载方法。

达成这些基础后,你就可以开始定义模型(如使用torchvision.models.resnet18(pretrained=True)加载预训练模型)、设置损失函数(如CrossEntropyLoss)、选择优化器(如Adam或SGD),并在训练循环中快速迭代提升模型性能。

希望本文介绍,能为你对CV数据加载与预处理的理解添砖加瓦,帮助你在图像分类任务中迈出稳健的一步。


如果你遇到了什么问题,或者想了解某些方面的知识,欢迎在评论区留言

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/890113.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据结构之线性表1

2.1 线性表的定义和基本操作 1.线性结构的特点是:在数据元素的非空有限集中, (1)存在惟一的一个被称做“第一个”的数据元素; (2) 存在惟一的一个被称做“最后一个”的数据元素; &a…

信息安全实训室网络攻防靶场实战核心平台解决方案

一、引言 网络安全靶场,作为一种融合了虚拟与现实环境的综合性平台,专为基础设施、应用程序及物理系统等目标设计,旨在向系统用户提供全方位的安全服务,涵盖教学、研究、训练及测试等多个维度。随着网络空间对抗态势的日益复杂化…

关于分页的样式问题

在最近写网页的时候遇到了一个关于样式的问题,今天我来跟大家来说一下。像是分页中的颜色效果,斑马纹颜色要注意颜色不要过于深。 这种的颜色就有一点深看着很不舒服,应将当前的颜色改为淡一点的,也可以利用rgba调整透明度&#x…

一分钟快速了解什么是AEO海关认证

一分钟快速了解什么是AEO海关认证——这一术语,对于国际贸易领域的从业者而言,无疑是一个充满分量与价值的标签。AEO,即“Authorized Economic Operator”,中文译为“经认证的经营者”,是海关对信用状况、守法程度和安…

Python图注意力神经网络GAT与蛋白质相互作用数据模型构建、可视化及熵直方图分析...

全文链接:https://tecdat.cn/?p38617 本文聚焦于图注意力网络GAT在蛋白质 - 蛋白质相互作用数据集中的应用。首先介绍了研究背景与目的,阐述了相关概念如归纳设置与转导设置的差异。接着详细描述了数据加载与可视化的过程,包括代码实现与分析…

Java学习笔记(13)——面向对象编程

面向对象基础 目录 面向对象基础 方法重载 练习: 继承 继承树 protected super 阻止继承 向上转型 向下转型 区分继承和组合 练习 小结: 方法重载 如果有一系列方法,功能类似,只是参数有所不同,就可以把…

Facebook 与数字社交的未来走向

随着数字技术的飞速发展,社交平台的角色和形式也在不断演变。作为全球最大社交平台之一,Facebook(现Meta)在推动数字社交的进程中扮演了至关重要的角色。然而,随着互联网的去中心化趋势和新技术的崛起,Face…

QT:QDEBUG输出重定向和命令行参数QCommandLineParser

qInstallMessageHandler函数简介 QtMessageHandler qInstallMessageHandler(QtMessageHandler handler) qInstallMessageHandler 是 Qt 框架中的一个函数,用于安装一个全局的消息处理函数,以替代默认的消息输出机制。这个函数允许开发者自定义 Qt 应用…

穷举vs暴搜vs深搜vs回溯vs剪枝专题一>全排列II

题目&#xff1a; 解析&#xff1a; 这题设计递归函数&#xff0c;主要把看如何剪枝 代码&#xff1a; class Solution {private List<List<Integer>> ret;private List<Integer> path;private boolean[] check;public List<List<Integer>> p…

Python如何正确解决reCaptcha验证码(9)

前言 本文是该专栏的第73篇,后面会持续分享python爬虫干货知识,记得关注。 我们在处理某些国内外平台项目的时候,相信很多同学或多或少都见过,如下图所示的reCaptcha验证码。 而本文,笔者将重点来介绍在实战项目中,遇到上述中的“reCaptcha验证码”,如何正确去处理并解…

java_零钱通项目

SmallChangeSysOOP.java package com.hspedu.smallchange.oop;import java.text.SimpleDateFormat; import java.util.Date; import java.util.Scanner;/*** 该类是完成零钱通的各个功能的类* 使用OOP(面向对象编程&#xff09;*/ public class SmallChangeSysOOP {// 定义相关…

Mamba安装环境和使用,anaconda环境打包

什么是mamba Mamba是一个极速版本的conda&#xff0c;它是conda的C重新实现&#xff0c;使用多线程并行处理来加速包和依赖项的下载。 Mamba旨在提高安装、更新和卸载Python包的速度&#xff0c;同时保持与conda相同的兼容性和命令行接口。 Mamba的核心部分使用C实现&#xff…

网络多层的协议详述

网络层 1&#xff09;地址管理&#xff1a;制定一系列的规则&#xff0c;通过地址&#xff0c;在网络上描述出一个设备的位置 2&#xff09;路由选择&#xff1a;网络环境比较复杂&#xff0c;从一个节点到另一个节点&#xff0c;存在很多条不同的路径&#xff0c;需要规划出…

《算法ZUC》题目

判断题 ZUC算法LFSR部分产生的二元序列具有很低的线性复杂度。 A.正确 B.错误 正确答案A 单项选择题 ZUC算法驱动部分LFSR的抽头位置不包括&#xff08; &#xff09;。 A.s15 B.s10 C.s7 D.s0 正确答案C 单项选择题 ZUC算法比特重组BR层主要使用了软件实现友好的…

Flink SQL 从一个SOURCE 写入多个Sink端实例

一. 背景 FLINK 任务从一个数据源读取数据, 写入多个sink端. 二. 官方实例 写入多个Sink语句时&#xff0c;需要以BEGIN STATEMENT SET;开头&#xff0c;以END;结尾。--源表 CREATE TEMPORARY TABLE datagen_source (name VARCHAR,score BIGINT ) WITH (connector datagen …

.vscode配置文件备份

vscode插件 位于&#xff1a;C:\Users\用户名\AppData\Roaming\Code\User\settings.json settings.json {// "C_Cpp.intelliSenseEngine": "default",//智能查找默认值"C_Cpp.intelliSenseEngineFallback": "enabled", //需要添加的…

关于Buildroot如何配置qtwebengine [未能成功编译]

目录 前言 下载Buildroot 如何添加qtwebengine 开始make编译 编译过程中到了这些问题 前言 问题的开始就在于学习QT的过程中遇到了一个问题… Unknown module(s) in QT: webenginewidgets 我想要把qt的一个项目编译并发送到我的开发板上&#xff0c;但是qmake识别不到这…

SNP与Scheer合作助力Warsteiner Brauerei成功升级至SAP S/4HANA

德国软件和咨询公司SNP是SAP环境中数字化转型、自动化数据迁移和数据管理软件的知名提供商&#xff0c;再次与德国Scheer公司合作&#xff0c;Scheer公司是一家专门从事业务流程管理和SAP咨询的咨询公司。他们为家族企业Warsteiner Brauerei Haus Cramer KG向SAP S/4HANA升级转…

【Super Tilemap Editor使用详解】(五):图块调色板

1、图块调色板&#xff08;Tile Palette&#xff09;可以在以下位置找到&#xff1a; Tileset Inspector检视面板 STETilemap Inspector检视面板&#xff0c;并选择 "Paint" 选项卡 Tile Palette 窗口&#xff1a;"SuperTilemapEditor/Window/Tile Palette Win…

LNMP+discuz论坛

0.准备 文章目录 0.准备1.nginx2.mysql2.1 mysql82.2 mysql5.7 3.php4.测试php访问mysql5.部署 Discuz6.其他 yum源&#xff1a; # 没有wget&#xff0c;用这个 # curl -o /etc/yum.repos.d/CentOS-Base.repo https://mirrors.aliyun.com/repo/Centos-7.repo[rootlocalhost ~]#…