torch分布式通信基础
- 1. 点到点通信
- 2. 集群通信
官网文档:WRITING DISTRIBUTED APPLICATIONS WITH PYTORCH
1. 点到点通信
# 同步,peer-2-peer数据传递
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mpdef test_send_recv_sync(rank, size):tensor = torch.zeros(1)if rank == 0:tensor += 1dist.send(tensor=tensor, dst=1) # 需要指定dst,发送的目标else:dist.recv(tensor=tensor, src=0) # 需要指定src,从哪儿接收print('Rank ', rank, ' has data ', tensor[0])# 异步
def test_send_recv_async(rank, size):tensor = torch.zeros(1)req = Noneif rank == 0:tensor += 1req = dist.isend(tensor=tensor, dst=1)else:req = dist.irecv(tensor=tensor, src=0)req.wait()print('Rank ', rank, ' has data ', tensor[0])def init_process(rank, size, backend='gloo'):""" 这里初始化分布式环境,设定Master机器以及端口号 """os.environ['MASTER_ADDR'] = '127.0.0.1'os.environ['MASTER_PORT'] = '29598'dist.init_process_group(backend, rank=rank, world_size=size)#test_send_recv_sync(rank, size)test_send_recv_async(rank, size)if __name__ == "__main__":size = 2processes = []mp.set_start_method("spawn")for rank in range(size):p = mp.Process(target=init_process, args=(rank, size))p.start()processes.append(p)for p in processes:p.join()
2. 集群通信
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mpdef test_broadcast(rank, size):tensor = torch.zeros(1)if rank == 0:tensor += 2else:tensor += 1dist.broadcast(tensor=tensor,src=0) # src指定broad_cast的源print("******test_broadcast******")print('Rank ', rank, ' has data ', tensor) # 结果都是 2def test_scatter(rank, size):tensor = torch.zeros(1)if rank == 0:tensor_list = [torch.tensor([1.0]), torch.tensor([2.0]), torch.tensor([3.0]), torch.tensor([4.0])]dist.scatter(tensor, scatter_list = tensor_list, src = 0)else:dist.scatter(tensor, scatter_list = [], src = 0)print("******test_scatter******")print('Rank ', rank, ' has data ', tensor) # 结果是[[1], [2], [3], [4]]def test_reduce(rank, size):tensor = torch.ones(1)dist.reduce(tensor=tensor, dst=0) # dst指定哪个进程进行reduce, 默认操作是加法print("******test_reduce******")print('Rank ', rank, ' has data ', tensor)def test_all_reduce(rank, size):tensor = torch.ones(1)dist.all_reduce(tensor=tensor,op=dist.ReduceOp.SUM)print("******test_all_reduce******")print('Rank ', rank, ' has data ', tensor) # 结果都是 4def test_gather(rank, size):tensor = torch.ones(1)if rank == 0:output = [torch.zeros(1) for _ in range(size)]dist.gather(tensor, gather_list=output, dst=0)else:dist.gather(tensor, gather_list=[], dst=0)if rank == 0:print("******test_gather******")print('Rank ', rank, ' has data ', output) # 结果是 [[1,1,1,1]]def test_all_gather(rank, size):output = [torch.zeros(1) for _ in range(size)]tensor = torch.ones(1)dist.all_gather(output, tensor)print("******test_all_gather******")print('Rank ', rank, ' has data ', output) # 结果都是 [1,1,1,1]def init_process(rank, size, backend='gloo'):""" 这里初始化分布式环境,设定Master机器以及端口号 """os.environ['MASTER_ADDR'] = '127.0.0.1'os.environ['MASTER_PORT'] = '29596'dist.init_process_group(backend, rank=rank, world_size=size)test_reduce(rank, size)test_all_reduce(rank, size)test_gather(rank, size)test_all_gather(rank, size)test_broadcast(rank, size)test_scatter(rank, size)if __name__ == "__main__":size = 4processes = []mp.set_start_method("spawn")for rank in range(size):p = mp.Process(target=init_process, args=(rank, size))p.start()processes.append(p)for p in processes:p.join()
需要注意的一点是:
这里面的调用都是同步的,可以理解为,每个进程都调用到通信api时,真正的有效数据传输才开始,然后通信完成之后,代码继续往下跑。实际上有些通信进程并不获取数据,这些进程可能并不会被阻塞。
文档最后,提供了一个简单的类似 DDP 的实现,里面核心的部分就是:
这也进一步阐释了DDP的核心逻辑:
反向计算完成之后,汇总梯度信息(求均值),然后再更新参数