pytorch 演示 tensor并行

pytorch 演示 tensor并行

  • 一.原理
  • 二.实现代码

本文演示了tensor并行的原理。如何将二个mlp切分到多张GPU上分别计算自己的分块,最后做一次reduce。
1.为了避免中间数据产生集合通信,A矩阵只能列切分,只计算全部batch*seqlen的部分feature
2.因为上面的步骤每张GPU只有部分feature,只因B矩阵按行切分,可与之进行矩阵乘,生成部分和
3.最后把每张GPU上的部分和加起来,就是最张的结果
以下demo,先实现了非分块的模型,然后模拟nccl分块,最后是分布式的实现

一.原理

在这里插入图片描述

二.实现代码

# torch_tp_demo.py
import os
import torch
from torch import nn
import torch.nn.functional as F 
import numpy as np
import torch.distributed as dist
from torch.distributed import ReduceOpimport time
import argparseparser = argparse.ArgumentParser(description="")
parser.add_argument('--hidden_size', default=512, type=int, help='')
parser.add_argument('--ffn_size', default=1024, type=int, help='')
parser.add_argument('--seq_len', default=512, type=int, help='')
parser.add_argument('--batch_size', default=8, type=int, help='')
parser.add_argument('--world_size', default=4, type=int, help='')
parser.add_argument('--device', default="cuda", type=str, help='')class FeedForward(nn.Module): def __init__(self,hidden_size,ffn_size): super(FeedForward, self).__init__() self.fc1 = nn.Linear(hidden_size, ffn_size,bias=False)self.fc2 = nn.Linear(ffn_size, hidden_size,bias=False)def forward(self, input): return self.fc2(self.fc1(input))class FeedForwardTp(nn.Module):def __init__(self,hidden_size,ffn_size,tp_size,rank): super(FeedForwardTp, self).__init__() self.fc1 = nn.Linear(hidden_size, ffn_size//tp_size,bias=False)self.fc2 = nn.Linear(ffn_size//tp_size, hidden_size,bias=False)self.fc1.weight.data=torch.from_numpy(np.fromfile(f"fc1_{rank}.bin",dtype=np.float32)).reshape(self.fc1.weight.data.shape)self.fc2.weight.data=torch.from_numpy(np.fromfile(f"fc2_{rank}.bin",dtype=np.float32)).reshape(self.fc2.weight.data.shape)def forward(self, input): return self.fc2(self.fc1(input))args = parser.parse_args()
hidden_size = args.hidden_size
ffn_size = args.ffn_size
seq_len = args.seq_len
batch_size = args.batch_size
world_size = args.world_size
device = args.devicedef native_mode():print(args)torch.random.manual_seed(1)model = FeedForward(hidden_size,ffn_size)model.eval()input = torch.rand((batch_size, seq_len, hidden_size),dtype=torch.float32).half().to(device)for idx,chunk in enumerate(torch.split(model.fc1.weight, ffn_size//world_size, dim=0)):chunk.data.numpy().tofile(f"fc1_{idx}.bin")for idx,chunk in enumerate(torch.split(model.fc2.weight, ffn_size//world_size, dim=1)):chunk.data.numpy().tofile(f"fc2_{idx}.bin")model=model.half().to(device)usetime=[]for i in range(32):t0=time.time()    out = model(input)torch.cuda.synchronize()t1=time.time()if i>3:usetime.append(t1-t0)print("[INFO] native: shape:{},sum:{:.5f},mean:{:.5f},min:{:.5f},max:{:.5f}".format(out.shape,out.sum().item(),np.mean(usetime),np.min(usetime),np.max(usetime)))result=[]for rank in range(world_size):model = FeedForwardTp(hidden_size,ffn_size,world_size,rank).half().to(device)model.eval()out=model(input)torch.cuda.synchronize()result.append(out)sum_all=result[0]for t in result[1:]:sum_all=sum_all+tprint("[INFO] tp_simulate: shape:{},sum:{:.5f}".format(sum_all.shape,sum_all.sum().item()))def tp_mode():torch.random.manual_seed(1)dist.init_process_group(backend='nccl')world_size = torch.distributed.get_world_size()rank=rank = torch.distributed.get_rank()local_rank=int(os.environ['LOCAL_RANK'])torch.cuda.set_device(local_rank)device = torch.device("cuda",local_rank)input = torch.rand((batch_size, seq_len, hidden_size),dtype=torch.float32).half().to(device)  model = FeedForwardTp(hidden_size,ffn_size,world_size,rank).half().to(device)model.eval()if rank==0:print(args)usetime=[]for i in range(32):        dist.barrier()t0=time.time()out=model(input)#dist.reduce(out,0, op=ReduceOp.SUM) dist.all_reduce(out,op=ReduceOp.SUM)torch.cuda.synchronize()if rank==0:t1=time.time()if i>3:usetime.append(t1-t0)if rank==0:print("[INFO] tp: shape:{},sum:{:.5f},mean:{:.5f},min:{:.5f},max:{:.5f}".format(out.shape,out.sum().item(),np.mean(usetime),np.min(usetime),np.max(usetime)))if __name__ == "__main__":num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1is_distributed = num_gpus > 1if is_distributed:tp_mode()else:native_mode()

运行命令:

python3 torch_tp_demo.py --hidden_size 512 \--ffn_size 4096 --seq_len 512 \--batch_size 8 --world_size 4 --device "cuda"
torchrun -m --nnodes=1 --nproc_per_node=4 \torch_tp_demo --hidden_size 512 \--ffn_size 4096 --seq_len 512 \--batch_size 8 --world_size 4 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/798745.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用dotnet-dump 查找 .net core 3.0 占用CPU 100%的原因解析

这篇文章介绍了3个工具 •dotnet-counters: 实时统计runtime的状况, 包括 CPU、内存、GC、异常等 •dotnet-trace: 类似性能探测器 •dotnet-dump: 程序崩溃时使用该工具 这次使用的是dotnet-dump, 即使程序没有崩溃, 也可以dump程序快照, 用于分析 实验环境 ubuntu-16.04.5-…

「PHP系列」PHP 循环详解

文章目录 一、while - 只要指定的条件成立,则循环执行代码块二、do...while - 首先执行一次代码块,然后在指定的条件成立时重复这个循环三、for - 循环执行代码块指定的次数四、foreach - 根据数组中每个元素来循环代码块五、相关链接 一、while - 只要指…

2024 Tuxera NTFS for Mac功能介绍及如何安装使用

随着科技的发展,我们的日常生活和工作越来越依赖于电子设备。而在这些设备中,Mac由于其出色的稳定性和易用性,成为了许多用户的首选。然而,尽管Mac自带的文件系统已经足够强大,但仍有一些用户希望获得更加高效、稳定的…

Java学习笔记NO.30

1. ArrayList ArrayList是Java中最常用的动态数组实现。它可以自动扩展以容纳任意数量的元素,并提供了快速的随机访问能力。 import java.util.ArrayList; public class ArrayListExample { public static void main(String[] args) { // 创建 ArrayList Array…

【氮化镓】在轨实验研究辐射对GaN器件的影响

【Pioneering evaluation of GaN transistors in geostationary satellites】 摘要: 这篇论文介绍了一项为期6年的空间实验结果,该实验研究了在地球静止轨道上辐射对氮化镓(GaN)电子元件的影响。实验使用了四个GaN晶体管&#xf…

如何水出第一篇SCI:SCI发刊历程,从0到1全过程经验分享!!!

如何水出第一篇SCI:SCI发刊历程,从0到1全路程经验分享!!! 详细的改进教程以及源码,戳这!戳这!!戳这!!!B站:Ai学术叫叫兽e…

代码随想录算法训练营第四十天|leetcode139题

一、leetcode第139题 本题是完全背包问题&#xff0c;由于可以重复使用&#xff0c;因此需要先遍历背包再遍历物品&#xff0c;dp[i]的含义是在长度为i处能否从数组中找到元素组成。 具体代码如下&#xff1a; class Solution { public:bool wordBreak(string s, vector<…

WPS解决插入公式在正文带来行间距变大问题

问题描述 写论文解释公式时&#xff0c;插入对应的变量&#xff0c;导致行间距变大&#xff0c;如图 显然上文与下文行间距不等。但无法通过修改数值修改下文行间距。 解决办法

java - 读取配置文件

文章目录 1. properties2. XML(1) dom4j(2) XPath 1. properties // 创建properties对象用于读取properties文件Properties properties new Properties();properties.load(new FileReader("src/main/resources/test.properties"));String name properties.getPrope…

消息队列之RabbitMQ的安装配置

一&#xff0c;前言 RabbitMQ是由erlang语言开发&#xff0c;基于AMQP&#xff08;Advanced Message Queue 高级消息队列协议&#xff09;协议实现的消息队列&#xff0c;它是一种应用程序之间的通信方法&#xff0c;消息队列在分布式系统开发中应用非常广泛。点击跳转RabbitM…

3、计算机的执行过程

三、存储器 1、存储器的分类 按存储器介质份分类 半导体存储器&#xff08;TTL&#xff08;集成度低、功耗高、速度快&#xff09;、MOS&#xff08;功耗低&#xff0c;集成度高&#xff09;&#xff09;。U盘等&#xff1b;易失 磁表面存储器&#xff08;磁头、载磁体&#xf…

90天玩转Python—05—基础知识篇:Python基础知识扫盲,使用方法与注意事项

90天玩转Python系列文章目录 90天玩转Python—01—基础知识篇:C站最全Python标准库总结 90天玩转Python--02--基础知识篇:初识Python与PyCharm 90天玩转Python—03—基础知识篇:Python和PyCharm(语言特点、学习方法、工具安装) 90天玩转Python—04—基础知识篇:Pytho…

SSM整合----第一个SSM项目

文章目录 前言一、使用步骤1.引入库2.建表3 项目结构4 web.xml的配置5 配置数据源6 SpringMVC配置7 配置MyBatis Mapper8 书写控制类 总结 前言 提示&#xff1a;这里可以添加本文要记录的大概内容&#xff1a; SSM整合是指Spring、SpringMVC和MyBatis这三个框架的整合使用。…

MTK i500p AIoT解决方案

一、方案概述 i500p是一款强大而高效的AIoT平台&#xff0c;专为便携式、家用或商用物联网应用而设计&#xff0c;这些应用通常需要大量的边缘计算&#xff0c;需要强大的多媒体功能和多任务操作系统。该平台集成了Arm Cortex-A73 和 Cortex-A53 的四核集群&#xff0c;工作频…

【论文速读】| 大语言模型平台安全:将系统评估框架应用于OpenAI的ChatGPT插件

本次分享论文为&#xff1a;LLM Platform Security: Applying a Systematic Evaluation Framework to OpenAI’s ChatGPT Plugins 基本信息 原文作者&#xff1a;Umar Iqbal, Tadayoshi Kohno, Franziska Roesner 作者单位&#xff1a;华盛顿大学圣路易斯分校&#xff0c;华盛…

web安全学习笔记(7)

记一下第十一节课的内容。 这节课主要学习post传参和js弹窗与跳转 一、post传参 1.简单的post传参介绍 将index.php重命名为login.php&#xff0c;并将login.html从template文件夹下拿到根目录下&#xff0c;并删除template目录。 将login.html中内容改为如下所示&#xf…

Ubuntu下TexStudio如何兼容中文

怎么就想起来研究一下这个&#xff1f; 我使用大名鼎鼎的3Blue1Brown数学动画引擎Manim&#xff0c;制作了一个特别小的动画视频克里金插值。在视频中&#xff0c;绘制文字时&#xff0c;Manim使用到了texlive texlive-latex-extra这些库。专业的关系&#xff0c;当年的毕设没…

一个更难破解的加密算法 Bcrypt

BCrypt是由Niels Provos和David Mazires设计的密码哈希函数&#xff0c;他是基于Blowfish密码而来的&#xff0c;并于1999年在USENIX上提出。 除了加盐来抵御rainbow table 攻击之外&#xff0c;bcrypt的一个非常重要的特征就是自适应性&#xff0c;可以保证加密的速度在一个特…

设计模式:生活中的迭代器模式

迭代器模式可以通过日常生活中的餐厅菜单遍历来类比。想象一下&#xff0c;你走进一家餐厅&#xff0c;服务员给了你一本菜单。这本菜单就像是一个聚合对象&#xff0c;它包含了各种菜品。你可以一页一页地翻阅菜单&#xff0c;这个翻阅的过程就像是使用迭代器来遍历聚合对象的…

linux学习:gcc编译

编译.c gcc hello.c -o hello 用gcc 这个工具编译 hello.c&#xff0c;并且使之生成一个二进制文件 hello。 其中 –o 的意义是 output&#xff0c;指明要生成的文件的名称&#xff0c;如果不写 –o hello 的话会生成默 认的一个 a.out 文件 获得 C 源程序经过预处理之后的文…