CIFAR10
CIFAR-10 数据集由 10 个类的 60000 张 32x32 彩色图像组成,每个类有 6000 张图像。有 50000 张训练图像和 10000 张测试图像。
数据集分为 5 个训练批次和 1 个测试批次,每个批次有 10000 张图像。测试批次包含每个类中随机选择的 1000 张图像。训练批次包含按随机顺序排列的剩余图像,但某些训练批次可能包含来自一个类的图像多于另一个类的图像。在它们之间,训练批次包含来自每个类的 5000 张图像。
import torchvision
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True)
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)
print(test_set[0])
(<PIL.Image.Image image mode=RGB size=32x32 at 0x1F5B55DD5E0>, 3)
test_set[]存放两个数据,一个是图像本身,一个是标签
图片显示
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])#将图片都转为tensor数据类型
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=dataset_transform,download=True)
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_transform,download=True)
print(test_set[0])
writer=SummaryWriter("p10")
for i in range(10):img,target=test_set[i]writer.add_image("test_set",img,i)
writer.close()
dataloader
参数
dataset (Dataset) – 从中加载数据的数据集。
batch_size (int, optional) – 每批要加载的样本数 (默认值:)。1
shuffle (bool, optional) – 设置为重新洗牌数据 在每个 epoch (默认值: )。TrueFalse
sampler (Sampler 或 Iterable,可选) – 定义绘制的策略 数据集中的样本。可以是任何已实施的。如果指定,则不得指定。Iterable__len__shuffle
batch_sampler (Sampler 或 Iterable,可选) – 类似于 ,但 一次返回一批索引。与 、 、 互斥 和。batch_sizeshuffledrop_last
num_workers (int, optional) – 用于数据的子进程数 装载。 表示数据将在主进程中加载。 (默认:00)
collate_fn (Callable, optional) – 合并样本列表以形成 小批量的 Tensor 中。当使用 batch loading from 地图样式数据集。
pin_memory (bool, optional) – 如果 ,数据加载器将复制 Tensor 放入 device/CUDA 固定内存中。如果您的数据元素 是自定义类型,或者您返回的批次是自定义类型, 请参阅下面的示例。Truecollate_fn
drop_last (bool, optional) – 设置为 以删除最后一个未完成的批次, 如果数据集大小不能被批量大小整除。If 和 数据集的大小不能被批次大小整除,然后是最后一个批次 会更小。(默认:TrueFalseFalse)
timeout (numeric, optional) – 如果为正数,则为收集批次的超时值 从工人。应始终为非负数。(默认:0)
worker_init_fn (Callable, optional) – 如果不是 ,则将在每个 worker 子进程,其中 worker id ( int in ) 为 input、seeding 之后和 data loading 之前。(默认:None[0, num_workers - 1]None)
multiprocessing_context (str 或 multiprocessing.context.BaseContext,可选) – 如果 ,则操作系统的默认多处理上下文将 被使用。(默认:NoneNone)
发电机(Torch.生成器,可选) – 如果没有,将使用此 RNG 通过 RandomSampler 生成随机索引,并通过 multiprocessing 为 worker 生成。(默认:Nonebase_seedNone)
prefetch_factor (int, optional, keyword-only arg) – 加载的批次数 由每个 worker 提前完成。 表示总共会有 2 * num_workers 个批次,在所有工作程序中预取。(默认值取决于 在 num_workers 的 Set 值上。如果值 num_workers=0,则默认值为 。 否则,如果 default 的值为 )。2Nonenum_workers > 02
persistent_workers (bool, optional) – 如果 ,则数据加载器不会关闭 工作程序在 dataset 被使用一次后进行处理。这允许 保持 worker Dataset 实例处于活动状态。(默认:TrueFalse)
pin_memory_device (str, optional) – 如果设备为 。pin_memorypin_memoryTrue
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWritertest_data=torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
test_loader=DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
#测试数据集第一张图片
img,target=test_data[0]
#writer=SummaryWriter("dataloader")
i=0
for data in test_loader:imgs,target=datawriter.add_images("testdata",imgs,i)i=i+1
writer.close()