pytorch 演示 tensor并行

pytorch 演示 tensor并行

  • 一.原理
  • 二.实现代码

本文演示了tensor并行的原理。如何将二个mlp切分到多张GPU上分别计算自己的分块,最后做一次reduce。
1.为了避免中间数据产生集合通信,A矩阵只能列切分,只计算全部batch*seqlen的部分feature
2.因为上面的步骤每张GPU只有部分feature,只因B矩阵按行切分,可与之进行矩阵乘,生成部分和
3.最后把每张GPU上的部分和加起来,就是最张的结果
以下demo,先实现了非分块的模型,然后模拟nccl分块,最后是分布式的实现

一.原理

在这里插入图片描述

二.实现代码

# torch_tp_demo.py
import os
import torch
from torch import nn
import torch.nn.functional as F 
import numpy as np
import torch.distributed as dist
from torch.distributed import ReduceOpimport time
import argparseparser = argparse.ArgumentParser(description="")
parser.add_argument('--hidden_size', default=512, type=int, help='')
parser.add_argument('--ffn_size', default=1024, type=int, help='')
parser.add_argument('--seq_len', default=512, type=int, help='')
parser.add_argument('--batch_size', default=8, type=int, help='')
parser.add_argument('--world_size', default=4, type=int, help='')
parser.add_argument('--device', default="cuda", type=str, help='')class FeedForward(nn.Module): def __init__(self,hidden_size,ffn_size): super(FeedForward, self).__init__() self.fc1 = nn.Linear(hidden_size, ffn_size,bias=False)self.fc2 = nn.Linear(ffn_size, hidden_size,bias=False)def forward(self, input): return self.fc2(self.fc1(input))class FeedForwardTp(nn.Module):def __init__(self,hidden_size,ffn_size,tp_size,rank): super(FeedForwardTp, self).__init__() self.fc1 = nn.Linear(hidden_size, ffn_size//tp_size,bias=False)self.fc2 = nn.Linear(ffn_size//tp_size, hidden_size,bias=False)self.fc1.weight.data=torch.from_numpy(np.fromfile(f"fc1_{rank}.bin",dtype=np.float32)).reshape(self.fc1.weight.data.shape)self.fc2.weight.data=torch.from_numpy(np.fromfile(f"fc2_{rank}.bin",dtype=np.float32)).reshape(self.fc2.weight.data.shape)def forward(self, input): return self.fc2(self.fc1(input))args = parser.parse_args()
hidden_size = args.hidden_size
ffn_size = args.ffn_size
seq_len = args.seq_len
batch_size = args.batch_size
world_size = args.world_size
device = args.devicedef native_mode():print(args)torch.random.manual_seed(1)model = FeedForward(hidden_size,ffn_size)model.eval()input = torch.rand((batch_size, seq_len, hidden_size),dtype=torch.float32).half().to(device)for idx,chunk in enumerate(torch.split(model.fc1.weight, ffn_size//world_size, dim=0)):chunk.data.numpy().tofile(f"fc1_{idx}.bin")for idx,chunk in enumerate(torch.split(model.fc2.weight, ffn_size//world_size, dim=1)):chunk.data.numpy().tofile(f"fc2_{idx}.bin")model=model.half().to(device)usetime=[]for i in range(32):t0=time.time()    out = model(input)torch.cuda.synchronize()t1=time.time()if i>3:usetime.append(t1-t0)print("[INFO] native: shape:{},sum:{:.5f},mean:{:.5f},min:{:.5f},max:{:.5f}".format(out.shape,out.sum().item(),np.mean(usetime),np.min(usetime),np.max(usetime)))result=[]for rank in range(world_size):model = FeedForwardTp(hidden_size,ffn_size,world_size,rank).half().to(device)model.eval()out=model(input)torch.cuda.synchronize()result.append(out)sum_all=result[0]for t in result[1:]:sum_all=sum_all+tprint("[INFO] tp_simulate: shape:{},sum:{:.5f}".format(sum_all.shape,sum_all.sum().item()))def tp_mode():torch.random.manual_seed(1)dist.init_process_group(backend='nccl')world_size = torch.distributed.get_world_size()rank=rank = torch.distributed.get_rank()local_rank=int(os.environ['LOCAL_RANK'])torch.cuda.set_device(local_rank)device = torch.device("cuda",local_rank)input = torch.rand((batch_size, seq_len, hidden_size),dtype=torch.float32).half().to(device)  model = FeedForwardTp(hidden_size,ffn_size,world_size,rank).half().to(device)model.eval()if rank==0:print(args)usetime=[]for i in range(32):        dist.barrier()t0=time.time()out=model(input)#dist.reduce(out,0, op=ReduceOp.SUM) dist.all_reduce(out,op=ReduceOp.SUM)torch.cuda.synchronize()if rank==0:t1=time.time()if i>3:usetime.append(t1-t0)if rank==0:print("[INFO] tp: shape:{},sum:{:.5f},mean:{:.5f},min:{:.5f},max:{:.5f}".format(out.shape,out.sum().item(),np.mean(usetime),np.min(usetime),np.max(usetime)))if __name__ == "__main__":num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1is_distributed = num_gpus > 1if is_distributed:tp_mode()else:native_mode()

运行命令:

python3 torch_tp_demo.py --hidden_size 512 \--ffn_size 4096 --seq_len 512 \--batch_size 8 --world_size 4 --device "cuda"
torchrun -m --nnodes=1 --nproc_per_node=4 \torch_tp_demo --hidden_size 512 \--ffn_size 4096 --seq_len 512 \--batch_size 8 --world_size 4 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/798745.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024 Tuxera NTFS for Mac功能介绍及如何安装使用

随着科技的发展,我们的日常生活和工作越来越依赖于电子设备。而在这些设备中,Mac由于其出色的稳定性和易用性,成为了许多用户的首选。然而,尽管Mac自带的文件系统已经足够强大,但仍有一些用户希望获得更加高效、稳定的…

【氮化镓】在轨实验研究辐射对GaN器件的影响

【Pioneering evaluation of GaN transistors in geostationary satellites】 摘要: 这篇论文介绍了一项为期6年的空间实验结果,该实验研究了在地球静止轨道上辐射对氮化镓(GaN)电子元件的影响。实验使用了四个GaN晶体管&#xf…

如何水出第一篇SCI:SCI发刊历程,从0到1全过程经验分享!!!

如何水出第一篇SCI:SCI发刊历程,从0到1全路程经验分享!!! 详细的改进教程以及源码,戳这!戳这!!戳这!!!B站:Ai学术叫叫兽e…

WPS解决插入公式在正文带来行间距变大问题

问题描述 写论文解释公式时,插入对应的变量,导致行间距变大,如图 显然上文与下文行间距不等。但无法通过修改数值修改下文行间距。 解决办法

消息队列之RabbitMQ的安装配置

一,前言 RabbitMQ是由erlang语言开发,基于AMQP(Advanced Message Queue 高级消息队列协议)协议实现的消息队列,它是一种应用程序之间的通信方法,消息队列在分布式系统开发中应用非常广泛。点击跳转RabbitM…

90天玩转Python—05—基础知识篇:Python基础知识扫盲,使用方法与注意事项

90天玩转Python系列文章目录 90天玩转Python—01—基础知识篇:C站最全Python标准库总结 90天玩转Python--02--基础知识篇:初识Python与PyCharm 90天玩转Python—03—基础知识篇:Python和PyCharm(语言特点、学习方法、工具安装) 90天玩转Python—04—基础知识篇:Pytho…

SSM整合----第一个SSM项目

文章目录 前言一、使用步骤1.引入库2.建表3 项目结构4 web.xml的配置5 配置数据源6 SpringMVC配置7 配置MyBatis Mapper8 书写控制类 总结 前言 提示:这里可以添加本文要记录的大概内容: SSM整合是指Spring、SpringMVC和MyBatis这三个框架的整合使用。…

MTK i500p AIoT解决方案

一、方案概述 i500p是一款强大而高效的AIoT平台,专为便携式、家用或商用物联网应用而设计,这些应用通常需要大量的边缘计算,需要强大的多媒体功能和多任务操作系统。该平台集成了Arm Cortex-A73 和 Cortex-A53 的四核集群,工作频…

【论文速读】| 大语言模型平台安全:将系统评估框架应用于OpenAI的ChatGPT插件

本次分享论文为:LLM Platform Security: Applying a Systematic Evaluation Framework to OpenAI’s ChatGPT Plugins 基本信息 原文作者:Umar Iqbal, Tadayoshi Kohno, Franziska Roesner 作者单位:华盛顿大学圣路易斯分校,华盛…

web安全学习笔记(7)

记一下第十一节课的内容。 这节课主要学习post传参和js弹窗与跳转 一、post传参 1.简单的post传参介绍 将index.php重命名为login.php,并将login.html从template文件夹下拿到根目录下,并删除template目录。 将login.html中内容改为如下所示&#xf…

Ubuntu下TexStudio如何兼容中文

怎么就想起来研究一下这个? 我使用大名鼎鼎的3Blue1Brown数学动画引擎Manim,制作了一个特别小的动画视频克里金插值。在视频中,绘制文字时,Manim使用到了texlive texlive-latex-extra这些库。专业的关系,当年的毕设没…

一个更难破解的加密算法 Bcrypt

BCrypt是由Niels Provos和David Mazires设计的密码哈希函数,他是基于Blowfish密码而来的,并于1999年在USENIX上提出。 除了加盐来抵御rainbow table 攻击之外,bcrypt的一个非常重要的特征就是自适应性,可以保证加密的速度在一个特…

linux学习:gcc编译

编译.c gcc hello.c -o hello 用gcc 这个工具编译 hello.c,并且使之生成一个二进制文件 hello。 其中 –o 的意义是 output,指明要生成的文件的名称,如果不写 –o hello 的话会生成默 认的一个 a.out 文件 获得 C 源程序经过预处理之后的文…

书生·浦语训练营二期第三次笔记-茴香豆:搭建你的 RAG 智能助理

RAG学习文档1: https://paragshah.medium.com/unlock-the-power-of-your-knowledge-base-with-openai-gpt-apis-db9a1138cac4 RAG学习文档2: https://blog.demir.io/hands-on-with-rag-step-by-step-guide-to-integrating-retrieval-augmented-generation-in-llms-a…

C#/.NET/.NET Core推荐学习书籍(24年4月更新,已分类)

前言 古人云:“书中自有黄金屋,书中自有颜如玉”,说明了书籍的重要性。作为程序员,我们需要不断学习以提升自己的核心竞争力。以下是一些优秀的C#/.NET/.NET Core相关学习书籍(包含了C#、.NET、.NET Core、Linq、EF/E…

云原生安全当前的挑战与解决办法

云原生安全作为一种新兴的安全理念,不仅解决云计算普及带来的安全问题,更强调以原生的思维构建云上安全建设、部署与应用,推动安全与云计算深度融合。所以现在云原生安全在云安全领域越来受到重视,云安全厂商在这块的投入也是越来…

34-4 CSRF漏洞 - CSRF跨站点请求伪造

一、漏洞定义 CSRF(跨站请求伪造)是一种客户端攻击,又称为“一键式攻击”。该漏洞利用了Web应用程序与受害用户之间的信任关系,通过滥用同源策略,使受害者在不知情的情况下代表攻击者执行操作。与XSS攻击不同,XSS利用用户对特定网站的信任,而CSRF则利用了网站对用户网页…

HiveSQL如何生成连续日期剖析

HiveSQL如何生成连续日期剖析 情景假设: 有一结果表,表中有start_dt和end_dt两个字段,,想要根据开始和结束时间生成连续日期的多条数据,应该怎么做?直接上结果sql。(为了便于演示和测试这里通过…

C++:MySQL的事务概念与使用(四)

1、事务的概念 定义:事务是构成单一逻辑工作单元的操作集合,要么完整的执行,要么完全不执行。无论发生何种情况,DBS必须保证事务能正确、完整的执行。 性质:事务的四大ACID性质。 原子性(Atomicity):一个事…

2.网络编程-HTTP和HTTPS

目录 HTTP介绍 HTTP协议主要组成部分 GET 和 POST有什么区别 常见的 HTTP 状态码有哪些 http状态码100 HTTP1.1 和 HTTP1.0 的区别有哪些 HTTPS 和 HTTP 的区别是什么 HTTP2 和 HTTP1.1 的区别是什么 HTTP3 和 HTTP2 的区别是什么 HTTPS的请求过程 对称加密和非对称…