【深度学习】图像分类数据集

图像分类数据集

MNIST数据集是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。
我们将使用类似但更复杂的Fashion-MNIST数据集

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2ld2l.use_svg_display()#设置图表大小,具体实现过程及其底层逻辑见微积分一节

读取数据集

我们可以[通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中]。

# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,# 并除以255使得所有像素的数值均在0~1之间trans = transforms.ToTensor()mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

这段代码的主要目的是从 torchvision 库中下载并加载 Fashion - MNIST 数据集,同时对数据进行预处理,将图像转换为 PyTorch 张量。
代码主要分为三个部分:定义图像预处理操作、加载训练集数据、加载测试集数据。下面逐行进行详细解释。

1. 定义图像预处理操作

trans = transforms.ToTensor()

  • 功能:创建一个图像预处理的转换对象 transtransforms.ToTensor()torchvision.transforms 模块里的一个类,专门用于将 PIL(Python Imaging Library)图像或者 NumPy 数组(一般是 uint8 类型)转换为 torch.FloatTensor 类型的张量。
  • 转换细节
    - 在转换过程中,会把图像的像素值归一化到 [0.0, 1.0] 范围。例如,原始图像像素值范围是 [0, 255],经过该转换后,像素值会除以 255,变成 [0.0, 1.0] 之间的浮点数。
    - 同时,转换后张量的维度也会发生变化。对于单通道的灰度图像,会从 (H, W)(高度和宽度)变为 (1, H, W);对于三通道的彩色图像,会从 (H, W, C) 变为 (C, H, W),这里 C 代表通道数。

2. 加载训练集数据

mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)

  • 功能:创建一个 FashionMNIST 数据集对象 mnist_train,用于加载 Fashion - MNIST 数据集的训练集部分。
  • 参数解释
    - root="../data":指定数据集的存储路径。若该路径下没有数据集,下载的数据会存于此;若已存在,则直接从该路径加载数据。
    - train=True:表明要加载的是训练集数据。Fashion - MNIST 数据集包含 60,000 张训练图像和 10,000 张测试图像,通过此参数区分加载的是训练集还是测试集。
    - transform=trans:指定对图像数据进行的预处理操作。这里使用之前创建的 trans 对象,即对每个图像应用 ToTensor() 变换,将其转换为张量
    - download=True:如果指定路径下未找到数据集,会自动从网络下载 Fashion - MNIST 数据集。

3. 加载测试集数据

mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

  • 功能:创建一个 FashionMNIST 数据集对象 mnist_test,用于加载 Fashion - MNIST 数据集的测试集部分。
  • 参数解释:与加载训练集的代码基本相同,唯一区别在于 train=False,表示加载的是测试集数据。

Fashion-MNIST由10个类别的图像组成,每个类别由训练数据集(train dataset)中的6000张图像
测试数据集(test dataset)中的1000张图像组成。
因此,训练集和测试集分别包含60000和10000张图像。测试数据集不会用于训练,只用于评估模型性能。

len(mnist_train), len(mnist_test)

在这里插入图片描述
每个输入图像的高度和宽度均为28像素。
数据集由灰度图像组成,其通道数为1。
为了简洁起见,将高度 h h h像素、宽度 w w w像素图像的形状记为 h × w h \times w h×w或( h h h, w w w)。

mnist_train[0][0].shape

在这里插入图片描述
[两个可视化数据集的函数]

Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。
以下函数用于在数字标签索引及其文本名称之间进行转换。

def get_fashion_mnist_labels(labels):  #@save"""返回Fashion-MNIST数据集的文本标签"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]

列表推导式
[expression for item in iterable]

  • expression:对每个 item 进行操作后得到的结果,它将成为新列表中的一个元素。
  • item:从 iterable 中取出的单个元素。
  • iterable:一个可迭代对象,如列表、元组、字符串等。

示例代码

text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
labels = [0, 2, 4]
result = [text_labels[int(i)] for i in labels]
print(result)  # 输出: ['t-shirt', 'pullover', 'coat']

我们现在可以创建一个函数来可视化这些样本。

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save"""绘制图像列表"""figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):# 图片张量ax.imshow(img.numpy())else:# PIL图片ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes

子图坐标轴对象
在 matplotlib 中,一个图形(Figure)可以包含多个子图(Axes),每个子图就是一个独立的绘图区域,子图坐标轴对象(Axes 对象)就代表了这些独立的绘图区域。它可以被看作是一个 “画布”,你可以在这个 “画布” 上进行各种绘图操作,比如绘制线条、散点、柱状图等,还可以设置坐标轴的范围、标签、标题等。

以下是对 show_images 函数的详细解释:

  • def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
    • 定义了一个名为 show_images 的函数,用于将一组图像以网格形式展示出来。
    • imgs:是一个包含图像的列表,这些图像可以是 PyTorch 张量,也可以是 PIL(Python Imaging Library)图像对象。
    • num_rows:指定了要展示的图像网格的行数。
    • num_cols:指定了要展示的图像网格的列数。
    • titles:是一个可选参数,类型为列表,用于为每个图像设置对应的标题。如果不提供该参数,则默认不显示标题。
    • scale:同样是可选参数,是一个浮点数,用于调整图像显示的缩放比例,默认值为 1.5。
  • figsize = (num_cols * scale, num_rows * scale):
    • 这行代码根据 num_cols(列数)、num_rows(行数)和 scale(缩放比例)计算出整个图像展示窗口的大小。
    • figsize 是一个元组,第一个元素是窗口的宽度,由列数乘以缩放比例得到;第二个元素是窗口的高度,由行数乘以缩放比例得到。
  • _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    • num_rowsnum_cols 分别指定了子图的行数和列数,也就是图像网格的布局。
    • figsize=figsize 表示使用之前计算好的窗口大小。
    • subplots 函数返回两个值,第一个是 Figure 对象,这里用 _ 占位表示我们不关心这个返回值;第二个是一个包含所有子图坐标轴对象的数组,赋值给 axes
  • axes = axes.flatten()
    • axes 原本是一个二维数组,因为它对应着 num_rows 行和 num_cols 列的子图布局。
    • flatten 方法将这个二维数组转换为一维数组,这样在后续遍历图像和子图时会更加方便。
  • for i, (ax, img) in enumerate(zip(axes, imgs))
    • zip(axes, imgs)axes 数组(包含所有子图坐标轴对象)和 imgs 列表(包含所有要展示的图像)中的元素一一对应地组合起来。
    • enumerate 函数用于为组合后的元素添加索引,i 就是当前元素的索引。
    • 在每次循环中,ax 代表当前子图的坐标轴对象,img 代表当前要展示的图像。
        if torch.is_tensor(img):# 图片张量ax.imshow(img.numpy())else:# PIL图片ax.imshow(img)
  • torch.is_tensor(img) 用于判断当前的 img 是否为 PyTorch 张量。
  • 如果是张量,使用 img.numpy() 将其转换为 NumPy 数组,因为 matplotlibimshow 函数更适合处理 NumPy 数组。然后使用 ax.imshow 函数在当前子图上显示图像。
  • 如果不是张量,说明 img 可能是 PIL 图像对象,直接使用 ax.imshow 函数显示该图像。
        ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)
  • ax.axes.get_xaxis() 获取当前子图的 x 轴对象,set_visible(False) 方法将 x 轴设置为不可见。
  • 同理,ax.axes.get_yaxis() 获取当前子图的 y 轴对象,set_visible(False) 方法将 y 轴设置为不可见。这样可以使图像显示更加简洁,只专注于图像内容。
        if titles:ax.set_title(titles[i])
  • if titles: 检查是否提供了 titles 列表。
  • 如果提供了,使用 ax.set_title 方法为当前子图设置对应的标题,标题从 titles 列表中根据当前索引 i 取出。
    return axes
  • 最后,函数返回 axes 数组,这个数组包含了所有子图的坐标轴对象。返回它的目的是方便在调用该函数后,对图形进行进一步的操作,例如修改坐标轴属性等。

以下是训练数据集中前[几个样本的图像及其相应的标签]。

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

在这里插入图片描述

读取小批量

为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。
回顾一下,在每次迭代中,数据加载器每次都会[读取一小批量数据,大小为batch_size]。
通过内置数据迭代器,我们可以随机打乱了所有样本,从而无偏见地读取小批量。

batch_size = 256def get_dataloader_workers():  #@save"""使用4个进程来读取数据"""return 4
#shuffle表示在每个训练周期开始时,对数据集进行随机打乱
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())

我们看一下读取训练数据所需的时间。

timer = d2l.Timer()
for X, y in train_iter:continue
f'{timer.stop():.2f} sec'

整合所有组件

现在我们[定义load_data_fashion_mnist函数],用于获取和读取Fashion-MNIST数据集。
这个函数返回训练集和验证集的数据迭代器。
此外,这个函数还接受一个可选参数resize,用来将图像大小调整为另一种形状。

def load_data_fashion_mnist(batch_size, resize=None):  #@save"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]#trans初始化为一个包含transforms.ToTensor()的列表if resize:trans.insert(0, transforms.Resize(resize))#在 trans 列表的开头插入 transforms.Resize(resize) 操作trans = transforms.Compose(trans)#将 trans 列表中的所有变换操作组合成一个完整的变换序列 transmnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

下面,我们通过指定resize参数来测试load_data_fashion_mnist函数的图像大小调整功能。

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:print(X.shape, X.dtype, y.shape, y.dtype)#X.shape表示张量 X 的形状,X.dtype表示张量 X 中元素的数据类型break

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/68580.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【MySQL — 数据库增删改查操作】深入解析MySQL的 Retrieve 检索操作

Retrieve 检索 示例 1. 构造数据 创建表结构 create table exam1(id bigint, name varchar(20) comment同学姓名, Chinesedecimal(3,1) comment 语文成绩, Math decimal(3,1) comment 数学成绩, English decimal(3,1) comment 英语成绩 ); 插入测试数据 insert into ex…

Ansible自动化运维实战--通过role远程部署nginx并配置(8/8)

文章目录 1、准备工作2、创建角色结构3、编写任务4、准备配置文件(金甲模板)5、编写变量6、编写处理程序7、编写剧本8、执行剧本Playbook9、验证-游览器访问每台主机的nginx页面 在 Ansible 中,使用角色(Role)来远程部…

RNN实现阿尔茨海默症的诊断识别

本文为为🔗365天深度学习训练营内部文章 原作者:K同学啊 一 导入数据 import torch.nn as nn import torch.nn.functional as F import torchvision,torch from sklearn.preprocessing import StandardScaler from torch.utils.data import TensorDatase…

编程题-最长的回文子串(中等)

题目: 给你一个字符串 s,找到 s 中最长的回文子串。 示例 1: 输入:s "babad" 输出:"bab" 解释:"aba" 同样是符合题意的答案。示例 2: 输入:s &…

CNN-GRU卷积门控循环单元时间序列预测(Matlab完整源码和数据)

CNN-GRU卷积门控循环单元时间序列预测(Matlab完整源码和数据) 目录 CNN-GRU卷积门控循环单元时间序列预测(Matlab完整源码和数据)预测效果基本介绍CNN-GRU卷积门控循环单元时间序列预测一、引言1.1、研究背景与意义1.2、研究现状1…

HTML-新浪新闻-实现标题-样式1

用css进行样式控制 css引入方式: --行内样式:写在标签的style属性中(不推荐) --内嵌样式:写在style标签中(可以写在页面任何位置,但通常约定写在head标签中) --外联样式&#xf…

2024年终总结

回顾 今年过年没回老家,趁着有时间,总结一下24年吧。 我把23年看做是打基础的一年,而24年主要是忙于项目的一年,基本上大部分时间都是忙着交付软件,写的一些文章也大部分都是项目中遇到的问题和解决方案,虽…

[c语言日寄]越界访问:意外的死循环

【作者主页】siy2333 【专栏介绍】⌈c语言日寄⌋:这是一个专注于C语言刷题的专栏,精选题目,搭配详细题解、拓展算法。从基础语法到复杂算法,题目涉及的知识点全面覆盖,助力你系统提升。无论你是初学者,还是…

使用 KNN 搜索和 CLIP 嵌入构建多模态图像检索系统

作者:来自 Elastic James Gallagher 了解如何使用 Roboflow Inference 和 Elasticsearch 构建强大的语义图像搜索引擎。 在本指南中,我们将介绍如何使用 Elasticsearch 中的 KNN 聚类和使用计算机视觉推理服务器 Roboflow Inference 计算的 CLIP 嵌入构建…

maven的打包插件如何使用

默认的情况下,当直接执行maven项目的编译命令时,对于结果来说是不打第三方包的,只有一个单独的代码jar,想要打一个包含其他资源的完整包就需要用到maven编译插件,使用时分以下几种情况 第一种:当只是想单纯…

Golang Gin系列-7:认证和授权

在本章中,我们将探讨Gin框架中身份验证和授权的基本方面。这包括实现基本的和基于令牌的身份验证,使用基于角色的访问控制,应用中间件进行授权,以及使用HTTPS和漏洞防护保护应用程序。 实现身份认证 Basic 认证 Basic 认证是内置…

CTF-web: phar反序列化+数据库伪造 [DASCTF2024最后一战 strange_php]

step 1 如何触发反序列化? 漏洞入口在 welcome.php case delete: // 获取删除留言的路径,优先使用 POST 请求中的路径,否则使用会话中的路径 $message $_POST[message_path] ? $_POST[message_path] : $_SESSION[message_path]; $msg $userMes…

C语言自定义数据类型详解(一)——结构体类型(上)

什么是自定义数据类型呢?顾名思义,就是我们用户自己定义和设置的类型。 在C语言中,我们的自定义数据类型一共有三种,它们分别是:结构体(struct),枚举(enum),联合(union)。接下来,我…

SpringCloud系列教程:微服务的未来(十八)雪崩问题、服务保护方案、Sentinel快速入门

前言 在分布式系统中,雪崩效应(Avalanche Effect)是一种常见的故障现象,通常发生在系统中某个组件出现故障时,导致其他组件级联失败,最终引发整个系统的崩溃。为了有效应对雪崩效应,服务保护方…

升级到Mac15.1后pod install报错

升级Mac后,Flutter项目里的ios项目运行 pod install报错, 遇到这种问题,不要着急去百度,大概看一下报错信息,每个人遇到的问题都不一样。 别人的解决方法并不一定适合你; 下面是报错信息: #…

STM32 PWM驱动舵机

接线图: 这里将信号线连接到了开发板的PA1上 代码配置: 这里的PWM配置与呼吸灯一样,呼吸灯连接的是PA0引脚,输出比较单元用的是OC1通道,这里只需改为OC2通道即可。 完整代码: #include "servo.h&quo…

使用 concurrently 实现前后端一键启动

使用 concurrently 实现前后端一键启动 本文适合: 前后端分离项目(如 React Node.js),希望通过一条命令同时启动前端和后端服务。 工具链: Node.js、npm、concurrently。 耗时: 3 分钟。 文章目录 使用 c…

【NLP251】NLP RNN 系列网络

NLP251 系列主要记录从NLP基础网络结构到知识图谱的学习 1.原理及网络结构 1.1RNN 在Yoshua Bengio论文中( http://proceedings.mlr.press/v28/pascanu13.pdf )证明了梯度求导的一部分环节是一个指数模型…

OpenCV:在图像中添加噪声(瑞利、伽马、脉冲、泊松)

目录 简述 1. 瑞利噪声 2. 伽马噪声 3. 脉冲噪声 4. 泊松噪声 总结 相关阅读 OpenCV:在图像中添加高斯噪声、胡椒噪声-CSDN博客 OpenCV:高通滤波之索贝尔、沙尔和拉普拉斯-CSDN博客 OpenCV:图像处理中的低通滤波-CSDN博客 OpenCV&…

小智 AI 聊天机器人

小智 AI 聊天机器人 (XiaoZhi AI Chatbot) 👉参考源项目复现 👉 ESP32SenseVoiceQwen72B打造你的AI聊天伴侣!【bilibili】 👉 手工打造你的 AI 女友,新手入门教程【bilibili】 项目目的 本…