图像分类数据集
MNIST数据集是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。
我们将使用类似但更复杂的Fashion-MNIST数据集。
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2ld2l.use_svg_display()#设置图表大小,具体实现过程及其底层逻辑见微积分一节
读取数据集
我们可以[通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中]。
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,# 并除以255使得所有像素的数值均在0~1之间trans = transforms.ToTensor()mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
这段代码的主要目的是从 torchvision
库中下载并加载 Fashion - MNIST 数据集,同时对数据进行预处理,将图像转换为 PyTorch 张量。
代码主要分为三个部分:定义图像预处理操作、加载训练集数据、加载测试集数据。下面逐行进行详细解释。
1. 定义图像预处理操作
trans = transforms.ToTensor()
- 功能:创建一个图像预处理的转换对象
trans
。transforms.ToTensor()
是torchvision.transforms
模块里的一个类,专门用于将 PIL(Python Imaging Library)图像或者 NumPy 数组(一般是uint8
类型)转换为torch.FloatTensor
类型的张量。- 转换细节:
- 在转换过程中,会把图像的像素值归一化到[0.0, 1.0]
范围。例如,原始图像像素值范围是[0, 255]
,经过该转换后,像素值会除以 255,变成[0.0, 1.0]
之间的浮点数。
- 同时,转换后张量的维度也会发生变化。对于单通道的灰度图像,会从(H, W)
(高度和宽度)变为(1, H, W)
;对于三通道的彩色图像,会从(H, W, C)
变为(C, H, W)
,这里C
代表通道数。2. 加载训练集数据
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
- 功能:创建一个
FashionMNIST
数据集对象mnist_train
,用于加载 Fashion - MNIST 数据集的训练集部分。- 参数解释:
-root="../data"
:指定数据集的存储路径。若该路径下没有数据集,下载的数据会存于此;若已存在,则直接从该路径加载数据。
-train=True
:表明要加载的是训练集数据。Fashion - MNIST 数据集包含 60,000 张训练图像和 10,000 张测试图像,通过此参数区分加载的是训练集还是测试集。
-transform=trans
:指定对图像数据进行的预处理操作。这里使用之前创建的trans
对象,即对每个图像应用ToTensor()
变换,将其转换为张量。
-download=True
:如果指定路径下未找到数据集,会自动从网络下载 Fashion - MNIST 数据集。3. 加载测试集数据
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
- 功能:创建一个
FashionMNIST
数据集对象mnist_test
,用于加载 Fashion - MNIST 数据集的测试集部分。- 参数解释:与加载训练集的代码基本相同,唯一区别在于
train=False
,表示加载的是测试集数据。
Fashion-MNIST由10个类别的图像组成,每个类别由训练数据集(train dataset)中的6000张图像
和测试数据集(test dataset)中的1000张图像组成。
因此,训练集和测试集分别包含60000和10000张图像。测试数据集不会用于训练,只用于评估模型性能。
len(mnist_train), len(mnist_test)
每个输入图像的高度和宽度均为28像素。
数据集由灰度图像组成,其通道数为1。
为了简洁起见,将高度 h h h像素、宽度 w w w像素图像的形状记为 h × w h \times w h×w或( h h h, w w w)。
mnist_train[0][0].shape
[两个可视化数据集的函数]
Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。
以下函数用于在数字标签索引及其文本名称之间进行转换。
def get_fashion_mnist_labels(labels): #@save"""返回Fashion-MNIST数据集的文本标签"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]
列表推导式
[expression for item in iterable]
expression
:对每个 item 进行操作后得到的结果,它将成为新列表中的一个元素。item
:从 iterable 中取出的单个元素。iterable
:一个可迭代对象,如列表、元组、字符串等。
示例代码
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
labels = [0, 2, 4]
result = [text_labels[int(i)] for i in labels]
print(result) # 输出: ['t-shirt', 'pullover', 'coat']
我们现在可以创建一个函数来可视化这些样本。
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save"""绘制图像列表"""figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):# 图片张量ax.imshow(img.numpy())else:# PIL图片ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes
子图坐标轴对象
在 matplotlib 中,一个图形(Figure)可以包含多个子图(Axes),每个子图就是一个独立的绘图区域,子图坐标轴对象(Axes 对象)就代表了这些独立的绘图区域。它可以被看作是一个 “画布”,你可以在这个 “画布” 上进行各种绘图操作,比如绘制线条、散点、柱状图等,还可以设置坐标轴的范围、标签、标题等。
以下是对
show_images
函数的详细解释:
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
:
- 定义了一个名为
show_images
的函数,用于将一组图像以网格形式展示出来。imgs
:是一个包含图像的列表,这些图像可以是 PyTorch 张量,也可以是 PIL(Python Imaging Library)图像对象。num_rows
:指定了要展示的图像网格的行数。num_cols
:指定了要展示的图像网格的列数。titles
:是一个可选参数,类型为列表,用于为每个图像设置对应的标题。如果不提供该参数,则默认不显示标题。scale
:同样是可选参数,是一个浮点数,用于调整图像显示的缩放比例,默认值为 1.5。figsize = (num_cols * scale, num_rows * scale):
:
- 这行代码根据
num_cols
(列数)、num_rows
(行数)和scale
(缩放比例)计算出整个图像展示窗口的大小。figsize
是一个元组,第一个元素是窗口的宽度,由列数乘以缩放比例得到;第二个元素是窗口的高度,由行数乘以缩放比例得到。_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
:
num_rows
和num_cols
分别指定了子图的行数和列数,也就是图像网格的布局。figsize=figsize
表示使用之前计算好的窗口大小。subplots
函数返回两个值,第一个是Figure
对象,这里用_
占位表示我们不关心这个返回值;第二个是一个包含所有子图坐标轴对象的数组,赋值给axes
。axes = axes.flatten()
:
axes
原本是一个二维数组,因为它对应着num_rows
行和num_cols
列的子图布局。flatten
方法将这个二维数组转换为一维数组,这样在后续遍历图像和子图时会更加方便。for i, (ax, img) in enumerate(zip(axes, imgs))
:
zip(axes, imgs)
将axes
数组(包含所有子图坐标轴对象)和imgs
列表(包含所有要展示的图像)中的元素一一对应地组合起来。enumerate
函数用于为组合后的元素添加索引,i
就是当前元素的索引。- 在每次循环中,
ax
代表当前子图的坐标轴对象,img
代表当前要展示的图像。
if torch.is_tensor(img):# 图片张量ax.imshow(img.numpy())else:# PIL图片ax.imshow(img)
torch.is_tensor(img)
用于判断当前的img
是否为 PyTorch 张量。- 如果是张量,使用
img.numpy()
将其转换为 NumPy 数组,因为matplotlib
的imshow
函数更适合处理 NumPy 数组。然后使用ax.imshow
函数在当前子图上显示图像。 - 如果不是张量,说明
img
可能是 PIL 图像对象,直接使用ax.imshow
函数显示该图像。
ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)
ax.axes.get_xaxis()
获取当前子图的 x 轴对象,set_visible(False)
方法将 x 轴设置为不可见。- 同理,
ax.axes.get_yaxis()
获取当前子图的 y 轴对象,set_visible(False)
方法将 y 轴设置为不可见。这样可以使图像显示更加简洁,只专注于图像内容。
if titles:ax.set_title(titles[i])
if titles:
检查是否提供了titles
列表。- 如果提供了,使用
ax.set_title
方法为当前子图设置对应的标题,标题从titles
列表中根据当前索引i
取出。
return axes
- 最后,函数返回
axes
数组,这个数组包含了所有子图的坐标轴对象。返回它的目的是方便在调用该函数后,对图形进行进一步的操作,例如修改坐标轴属性等。
以下是训练数据集中前[几个样本的图像及其相应的标签]。
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
读取小批量
为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。
回顾一下,在每次迭代中,数据加载器每次都会[读取一小批量数据,大小为batch_size
]。
通过内置数据迭代器,我们可以随机打乱了所有样本,从而无偏见地读取小批量。
batch_size = 256def get_dataloader_workers(): #@save"""使用4个进程来读取数据"""return 4
#shuffle表示在每个训练周期开始时,对数据集进行随机打乱
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())
我们看一下读取训练数据所需的时间。
timer = d2l.Timer()
for X, y in train_iter:continue
f'{timer.stop():.2f} sec'
整合所有组件
现在我们[定义load_data_fashion_mnist
函数],用于获取和读取Fashion-MNIST数据集。
这个函数返回训练集和验证集的数据迭代器。
此外,这个函数还接受一个可选参数resize
,用来将图像大小调整为另一种形状。
def load_data_fashion_mnist(batch_size, resize=None): #@save"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]#trans初始化为一个包含transforms.ToTensor()的列表if resize:trans.insert(0, transforms.Resize(resize))#在 trans 列表的开头插入 transforms.Resize(resize) 操作trans = transforms.Compose(trans)#将 trans 列表中的所有变换操作组合成一个完整的变换序列 transmnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))
下面,我们通过指定resize
参数来测试load_data_fashion_mnist
函数的图像大小调整功能。
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:print(X.shape, X.dtype, y.shape, y.dtype)#X.shape表示张量 X 的形状,X.dtype表示张量 X 中元素的数据类型break