一、概念
Pytorch的标准数据集包括很多种类型,如CIFAR,COCO,KITTI,MNIST等,我们可以在官网查看。当然我们也可以做数据集,但需要自己标注。
二、如何调用数据集
一、调用torchvision
在程序中调用torchvision.datasets,下面用程序示例如何下载CIFAR10数据集。
import torchvisiontrain_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
也可以复制路径通过其他方式下载,然后将下载文件放入py文件路径下,可运行程序可自动解压。
如果想显示数据集的图片,可以直接调用imshow方法。
img, target = test_set[0]
img.show()
如果想通过tensorboard显示图片,需要先将图片格式转化为tensor,然后调用SummaryWriter类。
二、调用dataset类
dataset类属于抽象类,需要通过创建子类来继承,从而创建数据集。
from torch.utils.data import Dataset
from PIL import Image
import osclass MyData(Dataset):def __init__(self, root_dir, label_dir):self.root_dir = root_dirself.label_dir = label_dirself.path = os.path.join(self.root_dir, self.label_dir)self.img_path = os.listdir(self.path)def __getitem__(self, idx):img_name = self.img_path[idx]img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)img = Image.open(img_item_path)label = self.label_dirreturn img, labeldef __len__(self):return len(self.img_path)
_ init _可以初始化子类的基础参数,可以自定义,相当于构造函数。
_ getitem _根据索引返回数据和标签。
_ len _返回数据大小
三、加载数据集
一般用DataLoader类来加载数据集。常见的参数包括:batch_size, shuffle num_workers。
这些参数的意义如下:
batch_size:指批大小,在训练时每次在训练集中取batchsize个样本。
epoch:指使用所有训练集的样本训练一次。
shuffle :指将训练集进行打乱的操作,一般生成数据集的时候要shuffle一下图片顺序,防止过拟合。
num_workers:设定DataLoader要使用多少个子进程进行加载。
drop_last:指训练集经过批处理后剩余的部分数据的处理模式。ture代表丢弃,false代表继续执行,只是batch_size会相对变小。
简单例子:
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWritertest_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())test_load = DataLoader(dataset=test_data, batch_size=64, shuffle=False, num_workers=0, drop_last=False)