PyTorch框架学习八——PyTorch数据读取机制(简述)
- 一、数据
- 二、DataLoader与Dataset
- 1.torch.utils.data.DataLoader
- 2.torch.utils.data.Dataset
- 三、数据读取整体流程
琢磨了一段时间,终于对PyTorch的数据读取机制有了一点理解,并自己实现了简单数据集(猫狗分类数据集)的读入和训练,这里简单写一写自己的理解,以备日后回顾。
一、数据
简单来说,一个机器学习或深度学习问题可以拆解为五个主要的部分:数据、模型、损失函数、优化器和迭代过程,这五部分每个都可以详细展开,都有非常多的知识点,而一切的开始,都源于数据。
一般数据部分可以分为四个主要的内容去学习:
- 数据收集:即获取Img和相应的Label。
- 数据划分:划分为训练集、验证集和测试集。
- 数据读取:DataLoader。
- 数据预处理:transforms。
在PyTorch框架的学习中,前两个不是重点,它们是机器学习基础和Python基础的事。而PyTorch的数据预处理transforms方法在前几次笔记进行了很详细地介绍,这次笔记重点是写一点对数据读取机制的理解,这也是最折磨的一部分,经过了很多次的步进演示,终于对整个数据读取过程有了一个较为完整的印象。
总的来说,DataLoader里比较重要的是Sampler和Dataset,前者负责获取要读取的数据的索引,即读哪些数据,后者决定数据从哪里读取以及如何读取。
二、DataLoader与Dataset
1.torch.utils.data.DataLoader
功能:构建可迭代的数据装载器。
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None)
参数比较多,如下所示:
介绍几个主要的:
- dataset:Dataset类,决定数据从哪读取以及如何读取。
- batch_size:批大小,默认为1。
- num_works:是否多进程读取数据。
- shuffle:每个epoch是否乱序。
- drop_last:当样本数不能被batch_size整除时,是否舍弃最后一批数据。
上面涉及到一个小知识点,顺带介绍一下,即Epoch、Iteration、Batchsize之间的关系:
- Epoch:所有训练样本都输入到模型中,称为一个epoch。
- Iteration:一个Batch的样本输入到模型中,称为一个Iteration。
- Batchsize:批大小,决定一个epoch有多少个iteration。
举个栗子:
若样本总数:80,Batchsize:8,则 1 Epoch = 10 Iterations。
若样本总数:87,Batchsize:8,且 drop_last = True,则1 Epoch = 10 Iterations;而drop_last = False时,1 Epoch = 11 Iterations。
2.torch.utils.data.Dataset
功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__()函数。
这里__getitem__()函数的功能是:接收一个索引,返回一个样本。
三、数据读取整体流程
经过上面简单的介绍,下面来看一下数据读取的整体流程:
- 从DataLoader这个命令开始。
- 然后进入到DataLoaderIter里,判断是单进程还是多进程。
- 然后进入到Sampler里进行采样,获得一批一批的索引,这些索引就指引了要读取哪些数据。
- 然后进入到DatasetFetcher中要依据Sampler获得的Index对数据进行获取。
- 在DatasetFetcher调用Dataset类,这里是我们自定义的数据集,数据集一般放在硬盘中,Dataset里面一般都有数据的路径,所以也就能知道了从哪读取数据。
- 自定义的Dataset类里再调用__getitem__函数,这里有我们编写的如何读取数据的代码,依据这里的代码读取数据。
- 读取出来后可能需要进行图像预处理或数据增强,所以紧接着是transforms方法。
- 经过上述的读取,已经得到了图像及其标签,但是还需要将它们组合成batch,就是下面的collate_fn,最后得到了一个batch一个batch的数据。
这个过程中的三个主要问题:
- 读哪些数据:Sampler输出要读取的数据的Index。
- 从哪读数据:Dataset类中的data_dir,即数据的存放路径。
- 怎么读数据:Dataset类中编写的__getitem__()函数。
精力有限,就不在这里写一个具体读取数据的代码了,这里有很多有价值的课程和资料可以学习:深度之眼PyTorch框架