【深度学习笔记】3_5 图像分类数据集fashion-mnist

注:本文为《动手学深度学习》开源内容,仅为个人学习记录,无抄袭搬运意图

3.5 图像分类数据集(Fashion-MNIST)

在介绍softmax回归的实现前我们先引入一个多类图像分类数据集。它将在后面的章节中被多次使用,以方便我们观察比较算法之间在模型精度和计算效率上的区别。图像分类数据集中最常用的是手写数字识别数据集MNIST[1]。但大部分模型在MNIST上的分类精度都超过了95%。为了更直观地观察算法之间的差异,我们将使用一个图像内容更加复杂的数据集Fashion-MNIST[2](这个数据集也比较小,只有几十M,没有GPU的电脑也能吃得消)。

本节我们将使用torchvision包,它是服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。torchvision主要由以下几部分构成:

  1. torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
  2. torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
  3. torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
  4. torchvision.utils: 其他的一些有用的方法。

3.5.1 获取数据集

首先导入本节需要的包或模块。

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append("..") # 为了导入上层目录的d2lzh_pytorch
import d2lzh_pytorch as d2l

下面,我们通过torchvision的torchvision.datasets来下载这个数据集。第一次调用时会自动从网上获取数据。我们通过参数train来指定获取训练数据集或测试数据集(testing data set)。测试数据集也叫测试集(testing set),只用来评价模型的表现,并不用来训练模型。

另外我们还指定了参数transform = transforms.ToTensor()使所有数据转换为Tensor,如果不进行转换则返回的是PIL图片。transforms.ToTensor()将尺寸为 (H x W x C) 且数据位于[0, 255]的PIL图片或者数据类型为np.uint8的NumPy数组转换为尺寸为(C x H x W)且数据类型为torch.float32且位于[0.0, 1.0]的Tensor

注意: 由于像素值为0到255的整数,所以刚好是uint8所能表示的范围,包括transforms.ToTensor()在内的一些关于图片的函数就默认输入的是uint8型,若不是,可能不会报错但可能得不到想要的结果。所以,如果用像素值(0-255整数)表示图片数据,那么一律将其类型设置成uint8,避免不必要的bug。 详见传送门2.2.4节。

mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

上面的mnist_trainmnist_test都是torch.utils.data.Dataset的子类,所以我们可以用len()来获取该数据集的大小,还可以用下标来获取具体的一个样本。训练集中和测试集中的每个类别的图像数分别为6,000和1,000。因为有10个类别,所以训练集和测试集的样本数分别为60,000和10,000。

print(type(mnist_train))
print(len(mnist_train), len(mnist_test))

输出:

<class 'torchvision.datasets.mnist.FashionMNIST'>
60000 10000

我们可以通过下标来访问任意一个样本:

feature, label = mnist_train[0]
print(feature.shape, label)  # Channel x Height x Width

输出:

torch.Size([1, 28, 28]) tensor(9)

变量feature对应高和宽均为28像素的图像。由于我们使用了transforms.ToTensor(),所以每个像素的数值为[0.0, 1.0]的32位浮点数。需要注意的是,feature的尺寸是 (C x H x W) 的,而不是 (H x W x C)。第一维是通道数,因为数据集中是灰度图像,所以通道数为1。后面两维分别是图像的高和宽。

Fashion-MNIST中一共包括了10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。以下函数可以将数值标签转成相应的文本标签。

# 本函数已保存在d2lzh包中方便以后使用
def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]

下面定义一个可以在一行里画出多张图像和对应标签的函数。

# 本函数已保存在d2lzh包中方便以后使用
def show_fashion_mnist(images, labels):d2l.use_svg_display()# 这里的_表示我们忽略(不使用)的变量_, figs = plt.subplots(1, len(images), figsize=(12, 12))for f, img, lbl in zip(figs, images, labels):f.imshow(img.view((28, 28)).numpy())f.set_title(lbl)f.axes.get_xaxis().set_visible(False)f.axes.get_yaxis().set_visible(False)plt.show()

现在,我们看一下训练数据集中前10个样本的图像内容和文本标签。

X, y = [], []
for i in range(10):X.append(mnist_train[i][0])y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))

在这里插入图片描述

3.5.2 读取小批量

我们将在训练数据集上训练模型,并将训练好的模型在测试数据集上评价模型的表现。前面说过,mnist_traintorch.utils.data.Dataset的子类,所以我们可以将其传入torch.utils.data.DataLoader来创建一个读取小批量数据样本的DataLoader实例。

在实践中,数据读取经常是训练的性能瓶颈,特别当模型较简单或者计算硬件性能较高时。PyTorch的DataLoader中一个很方便的功能是允许使用多进程来加速数据读取。这里我们通过参数num_workers来设置4个进程读取数据。

batch_size = 256
if sys.platform.startswith('win'):num_workers = 0  # 0表示不用额外的进程来加速读取数据
else:num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

我们将获取并读取Fashion-MNIST数据集的逻辑封装在d2lzh_pytorch.load_data_fashion_mnist函数中供后面章节调用。该函数将返回train_itertest_iter两个变量。随着本书内容的不断深入,我们会进一步改进该函数。它的完整实现将在5.6节中描述。

最后我们查看读取一遍训练数据需要的时间。

start = time.time()
for X, y in train_iter:continue
print('%.2f sec' % (time.time() - start))

输出:

1.57 sec

小结

  • Fashion-MNIST是一个10类服饰分类数据集,之后章节里将使用它来检验不同算法的表现。
  • 我们将高和宽分别为 h h h w w w像素的图像的形状记为 h × w h \times w h×w(h,w)

参考文献

[1] LeCun, Y., Cortes, C., & Burges, C. http://yann.lecun.com/exdb/mnist/

[2] Xiao, H., Rasul, K., & Vollgraf, R. (2017). Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. arXiv preprint arXiv:1708.07747.


注:本节除了代码之外与原书基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/700700.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《Docker 简易速速上手小册》第1章 Docker 基础入门(2024 最新版)

文章目录 1.1 Docker 简介与历史1.1.1 Docker 基础知识1.1.2 重点案例&#xff1a;Python Web 应用的 Docker 化1.1.3 拓展案例 1&#xff1a;使用 Docker 进行 Python 数据分析1.1.4 拓展案例 2&#xff1a;Docker 中的 Python 机器学习环境 1.2 安装与配置 Docker1.2.1 重点基…

消息队列-RabbitMQ:发布确认—发布确认逻辑和发布确认的策略

九、发布确认 1、发布确认逻辑 生产者将信道设置成 confirm 模式&#xff0c;一旦信道进入 confirm 模式&#xff0c;所有在该信道上面发布的消息都将会被指派一个唯一的 ID (从 1 开始)&#xff0c;一旦消息被投递到所有匹配的队列之后&#xff0c;broker 就会发送一个确认给…

Python基础教程——17个工作必备的Python自动化代码

您是否厌倦了在日常工作中做那些重复性的任务&#xff1f;简单但多功能的Python脚本可以解决您的问题。 引言 Python是一种流行的编程语言&#xff0c;以其简单性和可读性而闻名。因其能够提供大量的库和模块&#xff0c;它成为了自动化各种任务的绝佳选择。让我们进入自动化…

K8s环境搭建

一、基础环境准备 VMware虚拟机&#xff0c;安装三台CentOS&#xff0c;网络环境选择NAT模式&#xff0c;推荐配置如下&#xff08;具体安装步骤省略&#xff0c;网上很多虚拟机安装CentOS7的教程&#xff09; 二、网络环境说明 使用NAT模式&#xff0c;我的IP分别是&#xf…

Promise相关理解记录

一、Promise基础定义相关 Promise是一个构造函数&#xff0c;调用时需要使用new关键字 Promise是解决回调地狱的一种异步解决方式 Promise有三个状态&#xff1a;pending(进行中)、fulfilled(成功)、rejected(失败) Promise的状态只会从 pending→fulfilled 或者 pending→…

300分钟吃透分布式缓存-13讲:如何完整学习MC协议及优化client访问?

协议分析 异常错误响应 接下来&#xff0c;我们来完整学习 Mc 协议。在学习 Mc 协议之前&#xff0c;首先来看看 Mc 处理协议指令&#xff0c;如果发现异常&#xff0c;如何进行异常错误响应的。Mc 在处理所有 client 端指令时&#xff0c;如果遇到错误&#xff0c;就会返回 …

信号系统之线性图像处理

1 卷积 图像卷积的工作原理与一维卷积相同。例如&#xff0c;图像可以被视为脉冲的总和&#xff0c;即缩放和移位的delta函数。同样&#xff0c;线性系统的特征在于它们如何响应脉冲。也就是说&#xff0c;通过它们的脉冲响应。系统的输出图像等于输入图像与系统脉冲响应的卷积…

pclpy 半径滤波实现

pclpy 半径滤波实现 一、算法原理背景 二、代码1.pclpy 官方给与RadiusOutlierRemoval2.手写的半径滤波&#xff08;速度太慢了&#xff0c;用官方的吧&#xff09; 三、结果1.左边为原始点云&#xff0c;右边为半径滤波后点云 四、相关数据 一、算法原理 背景 RadiusOutlier…

Linux——进程概念

目录 冯诺依曼体系结构 操作系统 管理 系统调用和库函数 进程的概念 进程控制块——PCB 查看进程 通过系统调用获取进程标示符 通过系统调用创建进程 进程状态 运行状态-R ​编辑 浅度睡眠状态-S 深度睡眠状态-D 暂停状态-T 死亡状态-X 僵尸状态-Z 僵尸进程…

AD24-PCB的DRC电气性能检查

1、 2、如果报错器件选中&#xff0c;不能跳转时&#xff0c;按下图设置 3、开始出现以下提示时处理 4、到后期&#xff0c;错误改得差不多的时候&#xff1b;出现以下的处理步骤 ①将顶层和底层铜皮选中&#xff0c;移动200mm ②执行以下操作 ③将铜皮在移动回来&#xff0c;进…

STM32_IIC_AT24C02_1_芯片简介即管脚配置

STM32的IIC总线是存在bug&#xff0c;感兴趣的可以上网搜一搜。我们可以使用两个I/O口和软件的方式来模拟stm32的iic总线的控制&#xff0c;所以就不需要使用stm32的硬件控制器了&#xff0c;同理数据手册中的I2C库函数也没有用了。 ROM&#xff08;只读存储器&#xff09;和…

黄仁勋最新专访:机器人基础模型可能即将出现,新一代GPU性能超乎想象

最近&#xff0c;《连线》的记者采访了英伟达CEO黄仁勋。 记者表示&#xff0c;与Jensen Huang交流应该带有警告标签&#xff0c;因为这位Nvidia首席执行官对人工智能的发展方向如此投入&#xff0c;以至于在经过近 90 分钟的热烈交谈后&#xff0c;我&#xff08;指代本采访的…

杰发科技AC7801——SRAM 错误检测纠正

0.概述 7801暂时无错误注入&#xff0c;无法直接进中断看错误情况&#xff0c;具体效果后续看7840的带错误注入的测试情况。 1.简介 2.特性 3.功能 4.调试 可以看到在库文件里面有ecc_sram的库。 在官方GPIO代码里面写了点测试代码 成功打开2bit中断 因为没有错误注入&#x…

Netdata:实时高分辨率监控工具 | 开源日报 No.173

netdata/netdata Stars: 63.9k License: GPL-3.0 Netdata 是一个监控工具&#xff0c;可以实时高分辨率地监视服务器、容器和应用程序。 以下是该项目的主要功能&#xff1a; 收集来自 800 多个整合方案的指标&#xff1a;操作系统指标、容器指标、虚拟机、硬件传感器等。实…

软件常见设计模式

设计模式 设计模式是为了解决在软件开发过程中遇到的某些问题而形成的思想。同一场景有多种设计模式可以应用&#xff0c;不同的模式有各自的优缺点&#xff0c;开发者可以基于自身需求选择合适的设计模式&#xff0c;去解决相应的工程难题。 良好的软件设计和架构&#xff0…

k8s的svc流量通过iptables和ipvs转发到pod的流程解析

文章目录 1. k8s的svc流量转发1.1 service 说明1.2 endpoints说明1.3 pod 说明1.4 svc流量转发的主要工作 2. iptables规则解析2.1 svc涉及的iptables链流程说明2.2 svc涉及的iptables规则实例2.2.1 KUBE-SERVICES规则链2.2.2 KUBE-SVC-EFPSQH5654KMWHJ5规则链2.2.3 KUBE-SEP-L…

css复习

盒模型相关&#xff1a; border&#xff1a;1px solid red (没有顺序) 单元格的border会发生重叠&#xff0c;如果不想要重叠设置 border-collapse:collapse (表示相邻边框合并在一起) padding padding影响盒子大小的好处使用 margin应用&#xff1a; 行内或行内块元素水…

windows Server下Let‘s Encrypt的SSL证书续期

一、手动续期方法&#xff1a; 暂停IIS服务器 --> 暂时关闭防火墙 --> 执行certbot renew --> 打开防火墙 --> 用OpenSSL将证书转换为PFX格式-->pfx文件导入到IIS --> IIS对应网站中绑定新证书 --> 重新启动IIS -->完成 1、暂停IIS服务器 2、暂时关闭…

【LeetCode每日一题】 单调栈的案例 42. 接雨水

这道题是困难&#xff0c;但是可以使用单调栈&#xff0c;非常简洁通俗。 关于单调栈可以参考单调栈总结以及Leetcode案例解读与复盘 42. 接雨水 给定 n 个非负整数表示每个宽度为 1 的柱子的高度图&#xff0c;计算按此排列的柱子&#xff0c;下雨之后能接多少雨水。 示例 …

浅析SpringBoot框架常见未授权访问漏洞

文章目录 前言Swagger未授权访问RESTful API 设计风格swagger-ui 未授权访问swagger 接口批量探测 Springboot Actuator未授权访问数据利用未授权访问防御手段漏洞自动化检测工具 CVE-2022-22947 RCE漏洞原理分析与复现漏洞自动化利用工具 其他常见未授权访问Druid未授权访问漏…