在 PyTorch 中,tensor.shape
返回一个包含张量各维度大小的元组。
所以,当你执行 print(img.shape)
,你看到的 (3, 32, 32)
实际上是在告诉你:
- 这是一个三维张量
- 第一维(通道)的大小是 3
- 第二维(高度)的大小是 32
- 第三维(宽度)的大小是 32
-
import torchvision from torch.utils.data import DataLoadertest_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True,transform=torchvision.transforms.ToTensor()) test_loader = DataLoader(dataset=test_set,batch_size=4,shuffle=True,num_workers=0,drop_last=False) img,targert = test_set[0] print(img.shape) print(targert)
参数含义
dataset=test_set
: 指定要加载的数据集batch_size=4
: 每批加载 4 个样本shuffle=True
: 随机打乱数据顺序num_workers=0
: 不使用多进程加载数据drop_last=False
: 不丢弃最后一个不完整的批次
import torchvision
from torch.utils.data import DataLoadertest_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True,transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_set,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
img,targert = test_set[0]print(img.shape)
print(targert)for data in test_loader:imgs,targerts = dataprint(imgs.shape)print(targerts)