1、dataset加载数据集
dataset_tranform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),])train_set = torchvision.datasets.CIFAR10(root="./train_dataset",train=True,transform=dataset_tranform,download=True)
test_set = torchvision.datasets.CIFAR10(root="./train_dataset",train=False,transform=dataset_tranform,download=True)print(test_set[0])writer = SummaryWriter('p10')for i in range(10):img,target = test_set[i]writer.add_image("test_set",img,i)writer.close()
下载这个CIFAR10这个数据集,通过tensorboard查看一下
2.dataloader从数据集中加载数据
test_data = torchvision.datasets.CIFAR10(root="./train_dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=False)writer = SummaryWriter("dataloader")
step = 0for data in test_loader:imgs ,targets = datawriter .add_images("test_data",imgs,step)step = step + 1writer.close()
我们从CIFAR10这个数据集中,每次加载64张图片