pytorch演示pipeline并行
1.单卡内存不够时,可以将网络切分成几段(stage),每个GPU负责一个stage。比如GPU0计算完之后将数据发送给GPU1算后续的stage
2.以上的方式,会导致GPU的利用率不高,可以将输入的batch切分成多份更小的batch,陆续送给GPU0,这样GPU0处理完micro batch0之后 可以处理micro batch1.如此便能提高GPU的利用率
tee pp_demo.py <<-'EOF'
import os
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import torch.distributed as dist
from torch.distributed import ReduceOp
import time
import argparseparser = argparse.ArgumentParser(description="")
parser.add_argument('--hidden_size', default=512, type=int, help='')
parser.add_argument('--ffn_size', default=1024, type=int, help='')
parser.add_argument('--seq_len', default=512, type=int, help='')
parser.add_argument('--batch_size', default=8, type=int, help='')
parser.add_argument('--world_size', default=4, type=int, help='')
parser.add_argument('--device', default="cuda", type=str, help='')
parser.add_argument('--chunk_size', default=1, type=int, help='')class FeedForward(nn.Module):def __init__(self,hidden_size,ffn_size):super(FeedForward, self).__init__()self.fc1 = nn.Linear(hidden_size, ffn_size,bias=False)self.fc2 = nn.Linear(ffn_size, hidden_size,bias=False)def forward(self, input):return self.fc2(self.fc1(input))args = parser.parse_args()
hidden_size = args.hidden_size
ffn_size = args.ffn_size
seq_len = args.seq_len
batch_size = args.batch_size
world_size = args.world_size
device = args.device
chunk_size = args.chunk_sizedef tp_mode():torch.random.manual_seed(1)dist.init_process_group(backend='nccl')world_size = torch.distributed.get_world_size()rank=rank = torch.distributed.get_rank()local_rank=int(os.environ['LOCAL_RANK'])torch.cuda.set_device(local_rank)device = torch.device("cuda",local_rank)model = FeedForward(hidden_size,ffn_size)model.eval()input = torch.rand((batch_size, seq_len, hidden_size),dtype=torch.float32).half().to(device)model=model.half().to(device)index=0count=0t0=0chunks=torch.split(input,chunk_size,dim=0)for epoch in range(32):index+=1if index>1:count+=1if t0==0:t0=time.time()if count%10==0 and rank==0:print("qps:{:.2f}".format(count/(time.time()-t0)))count=0t0=0all_output=[]snd_reqs=[]for chunk in chunks: if rank==0:out=model(chunk)else:torch.distributed.recv(chunk,rank-1)out=model(chunk)if rank==world_size-1:all_output.append(out.clone())else:snd_reqs = torch.distributed.send(out,rank+1)if rank==world_size-1:out=torch.cat(all_output,dim=0)if __name__ == "__main__":num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1is_distributed = num_gpus > 1if is_distributed:tp_mode()
EOFtorchrun -m --nnodes=1 --nproc_per_node=4 pp_demo \--hidden_size 512 --ffn_size 4096 --seq_len 512 \--batch_size 16 --world_size 4 --chunk_size 8