pytorch 演示 tensor并行
- 一.原理
- 二.实现代码
本文演示了tensor并行的原理。如何将二个mlp切分到多张GPU上分别计算自己的分块,最后做一次reduce。
1.为了避免中间数据产生集合通信,A矩阵只能列切分,只计算全部batch*seqlen的部分feature
2.因为上面的步骤每张GPU只有部分feature,只因B矩阵按行切分,可与之进行矩阵乘,生成部分和
3.最后把每张GPU上的部分和加起来,就是最张的结果
以下demo,先实现了非分块的模型,然后模拟nccl分块,最后是分布式的实现
一.原理
二.实现代码
# torch_tp_demo.py
import os
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import torch.distributed as dist
from torch.distributed import ReduceOpimport time
import argparseparser = argparse.ArgumentParser(description="")
parser.add_argument('--hidden_size', default=512, type=int, help='')
parser.add_argument('--ffn_size', default=1024, type=int, help='')
parser.add_argument('--seq_len', default=512, type=int, help='')
parser.add_argument('--batch_size', default=8, type=int, help='')
parser.add_argument('--world_size', default=4, type=int, help='')
parser.add_argument('--device', default="cuda", type=str, help='')class FeedForward(nn.Module): def __init__(self,hidden_size,ffn_size): super(FeedForward, self).__init__() self.fc1 = nn.Linear(hidden_size, ffn_size,bias=False)self.fc2 = nn.Linear(ffn_size, hidden_size,bias=False)def forward(self, input): return self.fc2(self.fc1(input))class FeedForwardTp(nn.Module):def __init__(self,hidden_size,ffn_size,tp_size,rank): super(FeedForwardTp, self).__init__() self.fc1 = nn.Linear(hidden_size, ffn_size//tp_size,bias=False)self.fc2 = nn.Linear(ffn_size//tp_size, hidden_size,bias=False)self.fc1.weight.data=torch.from_numpy(np.fromfile(f"fc1_{rank}.bin",dtype=np.float32)).reshape(self.fc1.weight.data.shape)self.fc2.weight.data=torch.from_numpy(np.fromfile(f"fc2_{rank}.bin",dtype=np.float32)).reshape(self.fc2.weight.data.shape)def forward(self, input): return self.fc2(self.fc1(input))args = parser.parse_args()
hidden_size = args.hidden_size
ffn_size = args.ffn_size
seq_len = args.seq_len
batch_size = args.batch_size
world_size = args.world_size
device = args.devicedef native_mode():print(args)torch.random.manual_seed(1)model = FeedForward(hidden_size,ffn_size)model.eval()input = torch.rand((batch_size, seq_len, hidden_size),dtype=torch.float32).half().to(device)for idx,chunk in enumerate(torch.split(model.fc1.weight, ffn_size//world_size, dim=0)):chunk.data.numpy().tofile(f"fc1_{idx}.bin")for idx,chunk in enumerate(torch.split(model.fc2.weight, ffn_size//world_size, dim=1)):chunk.data.numpy().tofile(f"fc2_{idx}.bin")model=model.half().to(device)usetime=[]for i in range(32):t0=time.time() out = model(input)torch.cuda.synchronize()t1=time.time()if i>3:usetime.append(t1-t0)print("[INFO] native: shape:{},sum:{:.5f},mean:{:.5f},min:{:.5f},max:{:.5f}".format(out.shape,out.sum().item(),np.mean(usetime),np.min(usetime),np.max(usetime)))result=[]for rank in range(world_size):model = FeedForwardTp(hidden_size,ffn_size,world_size,rank).half().to(device)model.eval()out=model(input)torch.cuda.synchronize()result.append(out)sum_all=result[0]for t in result[1:]:sum_all=sum_all+tprint("[INFO] tp_simulate: shape:{},sum:{:.5f}".format(sum_all.shape,sum_all.sum().item()))def tp_mode():torch.random.manual_seed(1)dist.init_process_group(backend='nccl')world_size = torch.distributed.get_world_size()rank=rank = torch.distributed.get_rank()local_rank=int(os.environ['LOCAL_RANK'])torch.cuda.set_device(local_rank)device = torch.device("cuda",local_rank)input = torch.rand((batch_size, seq_len, hidden_size),dtype=torch.float32).half().to(device) model = FeedForwardTp(hidden_size,ffn_size,world_size,rank).half().to(device)model.eval()if rank==0:print(args)usetime=[]for i in range(32): dist.barrier()t0=time.time()out=model(input)#dist.reduce(out,0, op=ReduceOp.SUM) dist.all_reduce(out,op=ReduceOp.SUM)torch.cuda.synchronize()if rank==0:t1=time.time()if i>3:usetime.append(t1-t0)if rank==0:print("[INFO] tp: shape:{},sum:{:.5f},mean:{:.5f},min:{:.5f},max:{:.5f}".format(out.shape,out.sum().item(),np.mean(usetime),np.min(usetime),np.max(usetime)))if __name__ == "__main__":num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1is_distributed = num_gpus > 1if is_distributed:tp_mode()else:native_mode()
运行命令:
python3 torch_tp_demo.py --hidden_size 512 \--ffn_size 4096 --seq_len 512 \--batch_size 8 --world_size 4 --device "cuda"
torchrun -m --nnodes=1 --nproc_per_node=4 \torch_tp_demo --hidden_size 512 \--ffn_size 4096 --seq_len 512 \--batch_size 8 --world_size 4