多线程加载
- 在 datalaoder中指定
num_works > 0
,多线程加载数据集,最大可设置为 cpu 核数 - 设置
pin_memory = True
, 固定内存访问单元,节约内存调度时间 - 示例如下:
loader = DataLoader(dataset,batch_size=batch_size * group_size,shuffle=True,collate_fn=dataset.collate_fn,num_workers=2,pin_memory=True,)
预加载数据集
说别的都没大用,还得是预加载
- 原理:将整个数据集预先 load 到内存单元中,读取则直接访问内存,不存在与磁盘的I/O问题
- 构建自己的dataset类
- 示例如下:
class My_Dataset(Dataset):def __init__(self, filename, preprocess_config, train_config, sort=False, drop_last=False):self.dataset_name = preprocess_config["dataset"]self.preprocessed_path = preprocess_config["path"]["preprocessed_path"]self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]self.batch_size = train_config["optimizer"]["batch_size"]self.basename, self.speaker, self.text, self.raw_text = self.process_meta(filename)with open(os.path.join(self.preprocessed_path, "speakers.json")) as f:self.speaker_map = json.load(f)self.sort = sortself.drop_last = drop_last# addself.mel_list = []self.pitch_list = []self.energy_list = []self.duration_list = []for idx in range(len(self.text)):basename = self.basename[idx]speaker = self.speaker[idx]mel_path = os.path.join(self.preprocessed_path,"mel","{}-mel-{}.npy".format(speaker, basename),)mel = np.load(mel_path)pitch_path = os.path.join(self.preprocessed_path,"pitch","{}-pitch-{}.npy".format(speaker, basename),)pitch = np.load(pitch_path)energy_path = os.path.join(self.preprocessed_path,"energy","{}-energy-{}.npy".format(speaker, basename),)energy = np.load(energy_path)duration_path = os.path.join(self.preprocessed_path,"duration","{}-duration-{}.npy".format(speaker, basename),)duration = np.load(duration_path)self.mel_list.append(mel)self.pitch_list.append(pitch)self.energy_list.append(energy)self.duration_list.append(duration)def __len__(self):return len(self.text)def __getitem__(self, idx):basename = self.basename[idx]speaker = self.speaker[idx]speaker_id = self.speaker_map[speaker]raw_text = self.raw_text[idx]phone = np.array(text_to_sequence(self.text[idx], self.cleaners))mel = self.mel_list[idx]pitch = self.pitch_list[idx] energy = self.energy_list[idx] duration = self.duration_list[idx]sample = {"id": basename,"speaker": speaker_id,"text": phone,"raw_text": raw_text,"mel": mel,"pitch": pitch,"energy": energy,"duration": duration,}return sample
- 在
__init__
函数里,即将所有数据load进内存 __getitem__(self, idx):
函数,则直接通过列表idx访问每一条数据