深度学习pytorch--MNIST数据集

图像分类数据集(Fashion-MNIST)

在介绍softmax回归的实现前我们先引入一个多类图像分类数据集。它将在后面的章节中被多次使用,以方便我们观察比较算法之间在模型精度和计算效率上的区别。图像分类数据集中最常用的是手写数字识别数据集MNIST[1]。但大部分模型在MNIST上的分类精度都超过了95%。为了更直观地观察算法之间的差异,我们将使用一个图像内容更加复杂的数据集Fashion-MNIST[2](这个数据集也比较小,只有几十M,没有GPU的电脑也能吃得消)。

本节我们将使用torchvision包,它是服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。torchvision主要由以下几部分构成:

  1. torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
  2. torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
  3. torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
  4. torchvision.utils: 其他的一些有用的方法。

获取数据集

首先导入本节需要的包或模块。

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append("..") # 为了导入上层目录的d2lzh_pytorch
import d2lzh_pytorch as d2l

下面,我们通过torchvision的torchvision.datasets来下载这个数据集。第一次调用时会自动从网上获取数据。我们通过参数train来指定获取训练数据集或测试数据集(testing data set)。测试数据集也叫测试集(testing set),只用来评价模型的表现,并不用来训练模型。

另外我们还指定了参数transform = transforms.ToTensor()使所有数据转换为Tensor,如果不进行转换则返回的是PIL图片。transforms.ToTensor()将尺寸为 (H x W x C) 且数据位于[0, 255]的PIL图片或者数据类型为np.uint8的NumPy数组转换为尺寸为(C x H x W)且数据类型为torch.float32且位于[0.0, 1.0]的Tensor

注意: 由于像素值为0到255的整数,所以刚好是uint8所能表示的范围,包括transforms.ToTensor()在内的一些关于图片的函数就默认输入的是uint8型,若不是,可能不会报错但可能得不到想要的结果。所以,如果用像素值(0-255整数)表示图片数据,那么一律将其类型设置成uint8,避免不必要的bug。 本人就被这点坑过,详见这个博客2.2.4节。

mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

上面的mnist_trainmnist_test都是torch.utils.data.Dataset的子类,所以我们可以用len()来获取该数据集的大小,还可以用下标来获取具体的一个样本。训练集中和测试集中的每个类别的图像数分别为6,000和1,000。因为有10个类别,所以训练集和测试集的样本数分别为60,000和10,000。

print(type(mnist_train))
print(len(mnist_train), len(mnist_test))

输出:

<class 'torchvision.datasets.mnist.FashionMNIST'>
60000 10000

我们可以通过下标来访问任意一个样本:

feature, label = mnist_train[0]
print(feature.shape, label)  # Channel x Height x Width

输出:

torch.Size([1, 28, 28]) tensor(9)

变量feature对应高和宽均为28像素的图像。由于我们使用了transforms.ToTensor(),所以每个像素的数值为[0.0, 1.0]的32位浮点数。需要注意的是,feature的尺寸是 (C x H x W) 的,而不是 (H x W x C)。第一维是通道数,因为数据集中是灰度图像,所以通道数为1。后面两维分别是图像的高和宽。

Fashion-MNIST中一共包括了10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。以下函数可以将数值标签转成相应的文本标签。

# 本函数已保存在d2lzh包中方便以后使用
def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]

下面定义一个可以在一行里画出多张图像和对应标签的函数。

# 本函数已保存在d2lzh包中方便以后使用
def show_fashion_mnist(images, labels):d2l.use_svg_display()# 这里的_表示我们忽略(不使用)的变量_, figs = plt.subplots(1, len(images), figsize=(12, 12))for f, img, lbl in zip(figs, images, labels):f.imshow(img.view((28, 28)).numpy())f.set_title(lbl)f.axes.get_xaxis().set_visible(False)f.axes.get_yaxis().set_visible(False)plt.show()

在这里插入图片描述

现在,我们看一下训练数据集中前10个样本的图像内容和文本标签。

X, y = [], []
for i in range(10):X.append(mnist_train[i][0])y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))

读取小批量

我们将在训练数据集上训练模型,并将训练好的模型在测试数据集上评价模型的表现。前面说过,mnist_traintorch.utils.data.Dataset的子类,所以我们可以将其传入torch.utils.data.DataLoader来创建一个读取小批量数据样本的DataLoader实例。

在实践中,数据读取经常是训练的性能瓶颈,特别当模型较简单或者计算硬件性能较高时。PyTorch的DataLoader中一个很方便的功能是允许使用多进程来加速数据读取。这里我们通过参数num_workers来设置4个进程读取数据。

batch_size = 256
if sys.platform.startswith('win'):num_workers = 0  # 0表示不用额外的进程来加速读取数据
else:num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

我们将获取并读取Fashion-MNIST数据集的逻辑封装在d2lzh_pytorch.load_data_fashion_mnist函数中供后面章节调用。该函数将返回train_itertest_iter两个变量。随着本书内容的不断深入,我们会进一步改进该函数。

最后我们查看读取一遍训练数据需要的时间。

start = time.time()
for X, y in train_iter:continue
print('%.2f sec' % (time.time() - start))

输出:

1.57 sec

小结

  • Fashion-MNIST是一个10类服饰分类数据集,之后章节里将使用它来检验不同算法的表现。
  • 我们将高和宽分别为hhhwww像素的图像的形状记为h×wh \times wh×w(h,w)

参考文献

[1] LeCun, Y., Cortes, C., & Burges, C. http://yann.lecun.com/exdb/mnist/

[2] Xiao, H., Rasul, K., & Vollgraf, R. (2017). Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. arXiv preprint arXiv:1708.07747.


注:本节除了代码之外与原书基本相同,原书传送门

转载至:动手学习深度学习pytorch

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/333771.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

html 元素的属性

全局属性 全局属性是可与所有 HTML 元素一起使用的属性。 事件属性 用来定义某个事件的操作的属性叫事件属性&#xff0c;例如&#xff0c;οnclick“script”&#xff0c;元素上发生鼠标点击时触发 click 事件&#xff0c;click 事件被触发就会执行对应的脚本代码。事件属性…

nosql和rdnms_用于SaaS和NoSQL的Jdbi

nosql和rdnms一个自然的接口&#xff0c;用于与CRM&#xff0c;ERP&#xff0c;会计&#xff0c;营销自动化&#xff0c;NoSQL&#xff0c;平面文件等基于Java的数据集成 Jdbi是用于JavaSQL便利库&#xff0c;它为JDBC提供更自然的Java数据库接口&#xff0c;该接口易于绑定到…

matlab 功率谱密度 汉宁窗_如何理解随机振动的功率谱密度?

一、随机信号和正太分布有什么关系&#xff1f; 二、时域、频域之间功率守恒&#xff1f; 三、自相关又是个什么玩意&#xff1f;作为一个工程师&#xff0c;很多人对随机振动看着熟悉&#xff0c;却又实际陌生。熟悉是因为几乎每个产品在出厂时都要求要做随机振动试验&#xf…

深度学习pytorch--softmax回归(二)

softmax回归的从零开始实现实验前思考获取和读取数据获取数据集查看数据集查看下载后的.pt文件查看mnist_train和mnist_test读取数据集查看数据迭代器内容初始化模型参数定义softmax函数定义模型定义损失函数计算分类准确率模型评价--准确率开始训练可视化总结完整代码实验前思…

HTML块级元素/块标签/块元素

文章目录块元素的特点块元素清单block level element. 块级元素想在同一行显示需浮动或者 display:inline。 块元素的特点 每个块级元素都是独自占一行&#xff0c;其后的元素也只能另起一行&#xff0c;并不能两个元素共用一行&#xff1b; 元素的高度、宽度、行高、顶边距、…

物联卡查询流量_电信物联卡官网是多少?如何快速查询流量信息?

高速率设备的使用场景需要用到5G&#xff0c;中速率LET-Cat1应用范围更广&#xff0c;而低速率则要靠窄带物联网NB-IOT去维护了。这三种网络制式全都与物联网息息相关&#xff0c;这就能知道为什么国家层面对物联网基础设施建设这么重视了。电信物联卡在智能化硬件中有优秀表现…

java8日期转时间戳_Java 8日期和时间

java8日期转时间戳如今&#xff0c;一些应用程序仍在使用java.util.Date和java.util.Calendar API&#xff0c;包括使我们的生活更轻松地使用这些类型的库&#xff0c;例如JodaTime。 但是&#xff0c;Java 8引入了新的API来处理日期和时间&#xff0c;这使我们可以对日期和时间…

HTML行内元素/行级元素/内联元素/行标签/内联标签/行内标签/行元素

文章目录行内元素的特点行内元素清单可变元素列表inline element. 也叫行级元素、内联元素。行内元素默认设置宽度是不起作用&#xff0c;需设置 display:inline-block 或者 block 才行。 行内元素的特点 可以和其他元素处于一行&#xff0c;不用必须另起一行&#xff1b; 元…

深度学习pytorch--softmax回归(三)

softmax回归的简洁实现获取和读取数据定义和初始化模型softmax和交叉熵损失函数定义优化算法模型评价训练模型小结完整代码前两篇链接: 深度学习pytorch–softmax回归(一) 深度学习pytorch–softmax回归(二) 本文使用框架来实现模型。 获取和读取数据 我们仍然使用Fashion-M…

正则表达式的分类

文章目录一、正则表达式引擎二、正则表达式分类三、正则表达式比较四、Linux/OS X 下常用命令与正则表达式的关系一、正则表达式引擎 正则引擎大体上可分为不同的两类&#xff1a;DFA 和 NFA&#xff0c;而 NFA 又基本上可以分为传统型 NFA 和 POSIX NFA。 DFA(Deterministic …

spock测试_使用Spock测试您的代码

spock测试Spock是针对Java和Groovy应用程序的测试和规范框架。 Spock是&#xff1a; 极富表现力 简化测试的“给定/何时/然后” 语法 与大多数IDE和CI服务器兼容。 听起来不错&#xff1f; 通过快速访问Spock Web控制台&#xff0c;您可以非常快速地开始使用Spock。 当您有…

深度学习pytorch--多层感知机(一)

多层感知机隐藏层激活函数ReLU函数sigmoid函数tanh函数多层感知机小结我们已经介绍了包括线性回归和softmax回归在内的单层神经网络。然而深度学习主要关注多层模型。在本节中&#xff0c;我们将以多层感知机&#xff08;multilayer perceptron&#xff0c;MLP&#xff09;为例…

太阳能板如何串联_光伏板清洁专用的清洁毛刷

光伏发电是利用半导体界面的光生伏特效应将光能直接转变为电能的一种技术。主要由太阳电池板&#xff08;组件&#xff09;、控制器和逆变器三大部分组成。主要部件由电子元器件构成。太阳能电池经过串联后进行封装保护可形成大面积的太阳电池组件&#xff0c;再配合上功率控制…

java 异步等待_Java中的异步等待

java 异步等待编写异步代码很困难。 试图了解异步代码应该做什么的难度更大。 承诺是尝试描述延迟执行流程的一种常见方式&#xff1a;首先做一件事&#xff0c;然后再做另一件事&#xff0c;以防万一出错时再做其他事情。 在许多语言中&#xff0c;承诺已成为协调异步行为的实…

cass生成曲线要素_干货在线 | CASS入门指南——道路断面计算土方

CASS操作指南——道路断面计算土方法小伙伴们赶紧学起来&#xff01;道路类的土方工程&#xff0c;主要用CASS的断面法土方计算之道路断面来计算。整个计算过程主要分为以下四步&#xff1a;菜单截图第一步&#xff1a;绘制道路中心线道路的中心线&#xff0c;一般由直线段和缓…

正则表达式的捕获性分组/反向引用

文章目录分组捕获性分组和反向引用分组 正则的分组主要通过小括号来实现&#xff0c;括号包裹的子表达式作为一个分组&#xff0c;括号后可以紧跟限定词表示重复次数。如下&#xff0c;小括号内包裹的 abc 便是一个分组: // (abc) 表示匹配一个或多个"abc"&#xf…

深度学习pytorch--多层感知机(二)

多层感知机的从零开始实现获取和读取数据定义模型参数定义激活函数定义模型定义损失函数训练模型小结我们已经从上一节里了解了多层感知机的原理。下面&#xff0c;我们一起来动手实现一个多层感知机。首先导入实现所需的包或模块。 import torch import numpy as np获取和读取…

jwt同一会话_在会话中使用JWT

jwt同一会话这个话题已经在黑客新闻&#xff0c;reddit和博客上讨论了很多次。 共识是–请勿使用JWT&#xff08;用于用户会话&#xff09;。 而且我在很大程度上同意对JWT的典型论点 &#xff0c; 典型的“但我可以使其工作……”的解释以及JWT标准的缺陷的批评 。 。 我不会…

表必须要有主键吗_玄关隔断什么材质好?玄关隔断必须要做吗?

为了避免一到门口就能够看到全部室内的东西&#xff0c;为了更好的保护家居的隐私&#xff0c;目前有很多人都会在玄关的位置加一个隔断&#xff0c;而玄关隔断什么材质好?在做玄关隔断的时候&#xff0c;有些人觉得做了隔断会太浪费空间了&#xff0c;而玄关隔断必须要做吗?…

深度学习pytorch--多层感知机(三)

使用pytorch框架实现多层感知机和实现softmax回归唯一的不同在于我们多加了一个全连接层作为隐藏层。它的隐藏单元个数为256&#xff0c;并使用ReLU函数作为激活函数。#模型的核心代码为:nn.Linear(num_inputs, num_hiddens),nn.ReLU(),nn.Linear(num_hiddens, num_outputs),