Pytorch的数据读取主要包含三个类:
- Dataset
- DataLoader
- DataLoaderIter
这三者是一个依次封装的关系: 1.被装进2., 2.被装进3.
Dataset类
Pytorch 读取数据,主要通过Dataset类,Dataset类是所有dataset类的基类,自定义的dataset类要继承它,并且实现它的两个最重要的方法 __getitem__()
和 __len__()
具体的使用:
from torch.utils.data import Datasetclass MyDataset(Dataset):def __init__(self, path): # 可以写一些文件的读取self.trainUserList = self.load_train_rating_as_list(path + ".train.rating")def __getitem__(self, index): # 根据index返回一条数据user= self.trainUserList[index]return userdef __len__(self): # 样本数据的长度return len(self.trainUserList)
注意:dataset中应尽量只包含只读对象
,避免修改任何可变对象。因为如果使用多进程,可变对象要加锁,但后面讲到的dataloader的设计使其难以加锁。
DataLoader类
Dataset 负责数据集,每次可以用 __getitem__()
返回一个样本,而 DataLoader 提供了对数据的批量处理。
Dataloader 的构造函数:
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
参数解释:
-
num_workers
:使用的子进程数,0为不使用多进程。 -
worker_init_fn
: 默认为None,如果不是None,这个函数将被每个子进程以子进程id([0, num_workers - 1]之间的数)调用 -
sample
:采样策略,若这个参数有定义,则shuffle必须为False -
pin_memory
:是否将tensor数据复制到CUDA pinned memory中,pin memory中的数据转到GPU中会快一些 -
drop_last
:当dataset中的数据数量不能整除batch size时,是否把最后不够batch size数据丢掉 -
collate_fn
:把一组samples打包成一个mini-batch的函数。可以自定义这个函数以处理损坏数据的情况(先在__getitem__函数中将这样的数据返回None,然后再在collate_fn中处理,如丢掉损坏数据or再从数据集里随机挑一张),但最好还是确保dataset里所有数据都能用。
具体的使用:
dataset = MyDataset('EPINION2/epinion2') # 初始化自定义类
dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=20) # 使用DataLoader对自定义类进行包装,使其能够批量获取数据for epoch in range(20):for data in dataloader: # data 是获取到的 batch_size 个 user# training...
DataLoaderIter
Dataset、Dataloader 和 DataLoaderIter 是层层封装的关系,最终在内部使用 DataLoaderIter 进行迭代。