《动手学深度学习》-3.5-学习笔记
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()#用于将图像数据从 PIL 图像格式(Python Imaging Library,Python 的图像处理库)转换为 PyTorch 张量(Tensor)。
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)#加载训练数据集
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)#加载测试数据集
-
torchvision.datasets.FashionMNIST
是 PyTorch 提供的用于加载 FashionMNIST 数据集的类。 -
参数解释:
-
root="../data"
:指定数据集的存储路径。如果数据集不存在,PyTorch 会自动下载到这个路径。 -
train=True
:表示加载训练数据集。 -
transform=trans
:指定对图像数据应用的预处理操作,这里是transforms.ToTensor()
,即将图像转换为归一化的张量。 -
download=True
:如果指定路径下没有数据集,会自动从网络下载。 - 了解基础情况:在 PyTorch 中,
mnist_train
是一个torchvision.datasets.FashionMNIST
数据集对象,它是一个可迭代的集合,包含了所有训练样本的图像和标签。mnist_train[3]
表示获取数据集中的第四个样本(索引从 0 开始),包括第四个样本的图像和标签。 -
image.shape
输出torch.Size([1, 28, 28])
,表示图像是一个张量(Tensor),形状为:-
1:表示图像有 1 个通道(灰度图)。
-
28:图像的宽度为 28 像素。
-
28:图像的高度为 28 像素。
-
-
label
输出的是一个整数,表示图像的类别标签。FashionMNIST 数据集有 10 个类别,每个类别对应一个整数标签(从 0 到 9)。
-
-
打印出来看了一下
def get_fashion_mnist_labels(labels): """返回Fashion-MNIST数据集的文本标签"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]
-
这是一个列表推导式,用于将输入的整数标签列表
labels
转换为对应的文本标签列表。 -
对于
labels
中的每个元素i
:-
int(i)
确保i
是整数(虽然通常labels
已经是整数,但这里加了保险)。 -
text_labels[int(i)]
从text_labels
列表中获取对应的文本标签。
对text_labels -
列表的索引(从 0 到 9)对应于数据集中的整数标签。例如:
-
0
对应't-shirt'
-
1
对应'trouser'
-
9
对应'ankle boot'
下面这段 仅仅是 使用这个函数,应用场景 -
-
-
-
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): """绘制图像列表"""figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):# 图片张量ax.imshow(img.numpy())else:# PIL图片ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes
show_images
是一个用于批量显示图像的工具函数,
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
从 FashionMNIST 数据集中加载一批图像,使用 show_images
函数将图像以 2 行 9 列的网格形式显示,并为每张图像添加文本标签。
创建Dataloader
batch_size = 256def get_dataloader_workers(): """使用4个进程来读取数据"""return 4train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())
def load_data_fashion_mnist(batch_size, resize=None): """下载Fashion-MNIST数据集"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))
用于下载并加载 FashionMNIST 数据集,并将其转换为适合训练和测试的 DataLoader
对象。