深度学习:Pytorch分布式训练
- 简介
- 模型并行
- 数据并行
- 参考文献
简介
在深度学习领域,模型越来越庞大、数据量不断增加,训练这些大型模型越来越耗时。通过在多个GPU或多个节点上并行地训练模型,我们可以显著减少训练时间。此外,某些模型因为巨大的参数量,单个设备可能无法容纳其整个模型和数据。在这种情况下,分布式训练不仅能提高训练速度,更是必要的手段来训练大模型。为此,PyTorch 分布式训练提供了两种基本的并行方法:
-
模型并行(Model Parallel):模型并行是指将模型的不同部分放到不同的设备上。这种方式通常用于当一个单独的模型太大而无法放到单个GPU上时。
-
数据并行(Data Parallel):数据并行是将训练数据分割并在多个设备上同时训练的方法。PyTorch提供了
torch.nn.DataParallel
和torch.nn.parallel.DistributedDataParallel
用于在多个GPU上并行化模型训练。
模型并行
模型并行主要利用to(device)
函数将模型和数据(Tensor张量)放置在适当设备上,其余代码基本无需额外改动。
以下是一个简单的模型并行的代码示例:
import torch
import torch.nn as nn
import torch.optim as optimclass DemoModel(nn.Module):def __init__(self):super(DemoModel, self).__init__()self.net1 = torch.nn.Linear(10, 10).to('cuda:0')self.relu = torch.nn.ReLU()self.net2 = torch.nn.Linear(10, 5).to('cuda:1')def forward(self, x):x = self.relu(self.net1(x.to('cuda:0')))return self.net2(x.to('cuda:1'))model = DemoModel()
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)optimizer.zero_grad()
outputs = model(torch.randn(20, 10))
labels = torch.randn(20, 5).to('cuda:1')
loss_fn(outputs, labels).backward()
optimizer.step()
注意调用损失函数时,您只需要确保标签与输出位于同一设备上。不难看出,此模型并行的方法效率相对较低,因为在任何时间点,两个 GPU 中只有一个在工作,而另一个则处于闲置状态。而且中间过程变量从cuda:0
复制到cuda:1
,又会需要额外的开销。因此可以引入流水线并行来进行加速。
在以下代码示例中,采取将输入数据批次划分为 20 组。由于 PyTorch 异步启动 CUDA 操作,因此可以不需要生成多个线程来实现并发。值得注意的是,使用较小的结果split_size会导致许多微小的 CUDA 内核启动,而使用较大的split_size
会导致在第一次和最后一次数据划分期间存在相对较长的空闲时间。因此split_size
对于特定实验可能有一个最佳配置,可以多次尝试最佳的超参数。
class PipelineParallelResNet50(ModelParallelResNet50):def __init__(self, split_size=20, *args, **kwargs):super(PipelineParallelResNet50, self).__init__(*args, **kwargs)self.split_size = split_sizedef forward(self, x):splits = iter(x.split(self.split_size, dim=0))s_next = next(splits)s_prev = self.seq1(s_next).to('cuda:1')ret = []for s_next in splits:# A. ``s_prev`` runs on ``cuda:1``s_prev = self.seq2(s_prev)ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))# B. ``s_next`` runs on ``cuda:0``, which can run concurrently with As_prev = self.seq1(s_next).to('cuda:1')s_prev = self.seq2(s_prev)ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))return torch.cat(ret)
数据并行
DataParallel是单进程、多线程,仅适用于单机,而是DistributedDataParallel多进程,适用于单机和多机训练。由于跨线程的 GIL 争用、每次迭代复制模型以及分散输入和收集输出带来的额外开销,DataParallel通常比DistributedDataParallel在单台机器上更慢。
一般地,数据并行的流程为:
- 在使用 distributed 包的任何其他函数之前,需要使用 init_process_group 初始化进程组,同时初始化 distributed 包。
- 如果需要进行组内集体通信,用 new_group 创建子分组
- 创建分布式并行模型 DDP(model, device_ids=device_ids)
- 为数据集创建 Sampler
- 使用启动工具 torch.distributed.launch 在每个主机上执行一次脚本,开始训练
- 使用 destory_process_group() 销毁进程组
以下是一个简单的数据并行的代码示例:
# demo_ddp.py
# 在init_process_group()时,一般可设置为Gloo、NCCL或mpi后端,Gloo目前在GPU上运行速度比 NCCL慢。所以经验法则是:
# 分布式GPU训练使用 NCCL 后端
# 分布式CPU训练使用 Gloo 后端import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optimfrom torch.nn.parallel import DistributedDataParallel as DDPclass DemoModel(nn.Module):def __init__(self):super(DemoModel, self).__init__()self.net1 = nn.Linear(10, 10)self.relu = nn.ReLU()self.net2 = nn.Linear(10, 5)def forward(self, x):return self.net2(self.relu(self.net1(x)))def demo_basic():dist.init_process_group("nccl")rank = dist.get_rank()print(f"Start running basic DDP example on rank {rank}.")# create model and move it to GPU with id rankdevice_id = rank % torch.cuda.device_count()model = DemoModel().to(device_id)ddp_model = DDP(model, device_ids=[device_id])loss_fn = nn.MSELoss()optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)optimizer.zero_grad()outputs = ddp_model(torch.randn(20, 10))labels = torch.randn(20, 5).to(device_id)loss_fn(outputs, labels).backward()optimizer.step()dist.destroy_process_group()if __name__ == "__main__":demo_basic()
然后使用torchrun
命令进行启动,其中,nnodes表示总节点数,nproc_per_node表示每个节点运行的进程数,rdzv_id表示用户定义的ID,唯一标识作业的工作组, rdzv_backend表示集合点的后端,rdzv_endpoint表示rendezvous后端运行的地址
# 需要应用 slurm 等集群管理工具来实际在 2 个节点上运行此命令。
export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1)
torchrun --nnodes=2 --nproc_per_node=8 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR:29400 demo_ddp.py
此命令表示在两台服务器上运行 DDP 脚本,每台服务器运行 8 个进程,即在 16 个 GPU 上运行。
参考文献
- https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html
- https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
- https://medium.com/deelvin-machine-learning/model-parallelism-vs-data-parallelism-in-unet-speedup-1341bc74ff9e