将数据切分成N份,采用NCCL异步通信,让all_gather+matmul尽量Overlap

将数据切分成N份,采用NCCL异步通信,让all_gather+matmul尽量Overlap

  • 一.测试数据
  • 二.测试环境
  • 三.普通实现
  • 四.分块实现

本文演示了如何将数据切分成N份,采用NCCL异步通信,让all_gather+matmul尽量Overlap

一.测试数据

  • 1.测试规模:8192*8192 world_size=2
  • 2.单算子:all_gather:0.03508s matmul:0.05689s e2e:0.09197s。matmul耗时最长
  • 3.按输入和权值切分成8份,async_op=True。e2e:0.75ms
  • 4.e2e耗时从91ms缩短到75ms 缩短了17%。耗时为纯matmul算子的:1.34倍

二.测试环境

docker run --gpus all --shm-size=32g -ti -e NVIDIA_VISIBLE_DEVICES=all \--privileged --net=host -v $PWD:/home \-w /home --name all_gather_mm \nvcr.io/nvidia/pytorch:23.07-py3 /bin/bash

三.普通实现

tee all_gather_mm_native.py <<-'EOF'
import os
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
import time
import numpy as np
from torch.profiler import profile
import nvtxdev_type="cuda"
dist.init_process_group(backend='nccl')torch.manual_seed(1)
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
local_rank=int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
device = torch.device(dev_type,local_rank)
shape=(8192,8192)input_tensor=torch.rand((shape[0],shape[1]),dtype=torch.float).to(device)
weight=torch.rand((shape[1],8192),dtype=torch.float).to(device)
all_gather_buffer=torch.zeros((shape[0]*world_size,shape[1]),dtype=torch.float).to(device)for i in range(10):with nvtx.annotate(f"iter:{i}", color="blue"): dist.barrier()t0=time.time()torch.distributed._all_gather_base(all_gather_buffer, input_tensor)dist.barrier()torch.cuda.synchronize()t1=time.time()output = torch.matmul(all_gather_buffer, weight)torch.cuda.synchronize()t2=time.time()if rank==0:print(f"iter:{i} all_gather:{t1-t0:.5f} matmul:{t2-t1:.5f} e2e:{t2-t0:.5f} data:{output.mean()}")
EOF
export NCCL_DEBUG=error
export NCCL_IB_DISABLE=1
export CUDA_VISIBLE_DEVICES="1,3"
torchrun -m --nnodes=1 --nproc_per_node=2 all_gather_mm_nativensys profile --stats=true -o all_gather_mm_native.nsys-rep -f true -t cuda,nvtx --gpu-metrics-device=1,3 \torchrun -m --nnodes=1 --nproc_per_node=2 all_gather_mm_native

输出

iter:0 all_gather:0.03809 matmul:0.84971 e2e:0.88780 data:2047.62548828125
iter:1 all_gather:0.03327 matmul:0.06595 e2e:0.09922 data:2047.62548828125
iter:2 all_gather:0.03720 matmul:0.06082 e2e:0.09802 data:2047.62548828125
iter:3 all_gather:0.03682 matmul:0.05644 e2e:0.09326 data:2047.62548828125
iter:4 all_gather:0.03382 matmul:0.05648 e2e:0.09030 data:2047.62548828125
iter:5 all_gather:0.03404 matmul:0.05635 e2e:0.09039 data:2047.62548828125
iter:6 all_gather:0.03657 matmul:0.05701 e2e:0.09359 data:2047.62548828125
iter:7 all_gather:0.03840 matmul:0.05695 e2e:0.09535 data:2047.62548828125
iter:8 all_gather:0.03721 matmul:0.05685 e2e:0.09406 data:2047.62548828125
iter:9 all_gather:0.03508 matmul:0.05689 e2e:0.09197 data:2047.62548828125

在这里插入图片描述

四.分块实现

tee all_gather_mm_tiling.py <<-'EOF'
import os
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
import time
import numpy as np
import nvtx# 分几块
num_blocks = 8dev_type="cuda"
dist.init_process_group(backend='nccl')torch.manual_seed(1)
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
local_rank=int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
device = torch.device(dev_type,local_rank)streams = [torch.cuda.Stream(device=device) for _ in range(num_blocks)]def all_gather_matmul(rank, world_size, input, weight,gathered_buffer,output_buffer, num_blocks, device):input_chunk_size = input.size(0) // num_blocks  # 每块的大小weight_chunk_size = weight.size(1) // num_blockshandles = []for i in range(num_blocks):with torch.cuda.stream(streams[i]):# 划分块并进行 all_gatherinput_chunk = input[i * input_chunk_size: (i + 1) * input_chunk_size]gather_start_idx = i * input_chunk_size * world_size  # 起始索引handle = dist.all_gather_into_tensor(gathered_buffer[gather_start_idx:gather_start_idx + input_chunk_size * world_size], input_chunk, async_op=True)handles.append((handle, gather_start_idx))outputs = torch.zeros_like(output_buffer)for i in range(num_blocks):with torch.cuda.stream(streams[i]):handle, gather_start_idx = handles[i]handle.wait()  # 等待通信完成# 直接在通信结果上进行矩阵乘法gathered_input = gathered_buffer[gather_start_idx:gather_start_idx + input_chunk_size * world_size]for j in range(num_blocks):weight_chunk = weight[:, j * weight_chunk_size: (j + 1) * weight_chunk_size]output_chunk = outputs[i * input_chunk_size * world_size: (i + 1) * input_chunk_size * world_size, j * weight_chunk_size: (j + 1) * weight_chunk_size]             # 进行局部矩阵相乘output_chunk.add_(torch.matmul(gathered_input, weight_chunk))torch.cuda.synchronize(device)return outputs# 初始化
input = torch.rand((8192, 8192),dtype=torch.float).to(device) 
weight = torch.rand((8192, 8192),dtype=torch.float).to(device) 
all_gather_buffer = torch.zeros((8192 * world_size, 8192),dtype=torch.float).to(device)for i in range(10):output = torch.zeros(input.size(0) * world_size, weight.size(1),dtype=torch.float,device=device)dist.barrier()t0=time.time()with nvtx.annotate(f"iter:{i}", color="blue"):output = all_gather_matmul(rank, world_size, input, weight,all_gather_buffer,output,num_blocks,device)torch.cuda.synchronize()t1=time.time()if rank == 0:print(f"iter:{i} e2e:{t1-t0:.5f} data:{output.mean()}")
EOFexport NCCL_DEBUG=error
export NCCL_IB_DISABLE=1
torchrun -m --nnodes=1 --nproc_per_node=2 all_gather_mm_tilingnsys profile --stats=true -o all_gather_mm_tiling.nsys-rep -f true -t cuda,nvtx --gpu-metrics-device=1,3 \torchrun -m --nnodes=1 --nproc_per_node=2 all_gather_mm_tiling

输出

iter:0 e2e:0.13553 data:2047.62548828125
iter:1 e2e:0.07687 data:2047.62548828125
iter:2 e2e:0.07717 data:2047.62548828125
iter:3 e2e:0.07645 data:2047.62548828125
iter:4 e2e:0.07724 data:2047.62548828125
iter:5 e2e:0.07586 data:2047.62548828125
iter:6 e2e:0.07587 data:2047.62548828125
iter:7 e2e:0.07589 data:2047.62548828125
iter:8 e2e:0.07626 data:2047.62548828125
iter:9 e2e:0.07549 data:2047.62548828125

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/39480.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

代理IP的10大误区:区分事实与虚构

在当今的数字时代&#xff0c;代理已成为在线环境不可或缺的一部分。它们的用途广泛&#xff0c;从增强在线隐私到绕过地理限制。然而&#xff0c;尽管代理无处不在&#xff0c;但仍存在许多围绕代理的误解。在本博客中&#xff0c;我们将探讨和消除一些最常见的代理误解&#…

人脑网络的多层建模与分析

摘要 了解人类大脑的结构及其与功能的关系&#xff0c;对于各种应用至关重要&#xff0c;包括但不限于预防、处理和治疗脑部疾病(如阿尔茨海默病或帕金森病)&#xff0c;以及精神疾病(如精神分裂症)的新方法。结构和功能神经影像学方面的最新进展&#xff0c;以及计算机科学等…

OBS 免费的录屏软件

一、下载 obs 【OBS】OBS Studio 的安装、参数设置和录屏、摄像头使用教程-CSDN博客 二、使用 obs & 输出无黑屏 【OBS任意指定区域录屏的方法-哔哩哔哩】 https://b23.tv/aM0hj8A OBS任意指定区域录屏的方法_哔哩哔哩_bilibili 步骤&#xff1a; 1&#xff09;获取区域…

012-GeoGebra基础篇-构造圆的切线

前边文章对于基础内容已经悉数覆盖了&#xff0c;这一篇我就不放具体的细节&#xff0c;若有需要可以复刻一下 目录 一、成品展示二、算式内容三、正确性检查五、文章最后 一、成品展示 二、算式内容 A(0,0) B(3,0) c: Circle(A,B) C(5,4) sSegment(A,C) DMidpoint(s) d: Circ…

k8s部署单节点redis

一、configmap # cat redis-configmap.yaml apiVersion: v1 kind: ConfigMap metadata:name: redis-single-confignamespace: redis data:redis.conf: |daemonize nobind 0.0.0.0port 6379tcp-backlog 511timeout 0tcp-keepalive 300pidfile /data/redis-server.pidlogfile /d…

全网小视频去水印接口使用说明

一、请求地址&#xff1a; https://www.lytcreate.com/api/qsy/ 二、请求方式&#xff1a;POST 三、请求体&#xff1a;JSON body {"token": "个人中心的token","url": "视频分享地址"} token获取地址&#xff0c;访问&#xff…

uniapp微信小程序使用xr加载模型

1.在根目录与pages同级创建如下目录结构和文件&#xff1a; // index.js Component({properties: {modelPath: { // vue页面传过来的模型type: String,value: }},data: {},methods: {} }) { // index.json"component": true,"renderer": "xr-frame&q…

Element-plus点击当前行之后获取数据显示跟随行数据

要实现点击当前行后&#xff0c;在当前行的下方显示数据&#xff0c;可以通过以下步骤来实现&#xff1a; 在表格的行点击事件中获取当前点击行的位置信息。根据位置信息动态计算并设置需要显示数据区域的位置。 下面是一个更新后的示例代码&#xff0c;演示如何在 Element-P…

Unity 引擎收费模式变革:游戏开发者的挑战与机遇

Unity 引擎作为游戏开发领域中的重要工具&#xff0c;近日宣布将在 2024 年 1 月 1 日起根据游戏安装量对开发者进行收费。这一决定引起了业界的广泛关注和讨论。据 Unity 技术博客发布的《Unity 收费模式和配套服务更新》一文&#xff0c;他们选择这种计费方式是基于每次游戏被…

PHP和phpSpider:如何应对网站变动导致的数据爬取失败?

php和phpspider&#xff1a;如何应对网站变动导致的数据爬取失败&#xff1f; 导语&#xff1a; 网络爬虫是一种自动化程序&#xff0c;用于从网站上获取数据并进行处理。PHP是一种广泛使用的编程语言&#xff0c;而phpSpider是一个基于PHP的开源网络爬虫框架。然而&#xff0…

软降工程学系统实现

一、程序编码 程序编码是设计的继续&#xff0c;将软件设计的结果翻译成用某种程序设计语言描述的源代码。 程序编码涉及到方法、工具和过程。 程序设计风格和程序设计语言的特性会深刻地影响软件的质量和可维护性。 要求源程序具有良好的结构性和设计风格。 程序设计风格…

开启IT世界的探索之旅——致有志于踏入IT领域的高考少年们

高考已成过去&#xff0c;而前方是无限可能的未来。对于那些有志于进入IT领域的高考生来说&#xff0c;这个暑假是你们开启探索IT世界的绝佳时机。作为一名从事C#软件开发的专业人员&#xff0c;我希望能通过这篇文章&#xff0c;分享一些学习路线图和经验心得&#xff0c;帮助…

【web3】分享一个web入门学习平台-HackQuest

前言 一直想进入web3行业&#xff0c;但是没有什么途径&#xff0c;偶然在电鸭平台看到HackQuest的共学营&#xff0c;发现真的不错&#xff0c;并且还接触到了黑客松这种形式。 链接地址&#xff1a;HackQuest 平台功能 学习路径&#xff1a;平台有完整的学习路径&#xff…

【聊聊原子性,中断,以及nodejs中的具体示例】

什么是原子性 从一个例子说起&#xff0c; x &#xff0c;读和写 &#xff0c; 如图假设多线程&#xff0c;线程1和线程2同时操作变量x&#xff0c;进行x的操作&#xff0c;那么由于写的过程中&#xff0c;都会先读一份x数据到cpu的寄存器中&#xff0c;所以这个时候cpu1 和 c…

MyBatis-plus(下)

目录 静态工具 逻辑删除 枚举处理器 ​编辑​编辑JSON处理器 分页插件 案例 静态工具 只有save与update不需要传class字节码 UserController: MyServiceImpl: 改造根据id批量查询用户的接口&#xff0c;查询用户的同时&#xff0c;查询出用户对应的所有地址 Overrid…

容器内存

一、容器内存概述 容器本质上还是一个进程&#xff0c;是一个被隔离和限制的进程。因此容器内存和进程内存在表现形式上其实是一样的&#xff0c;这块主要涉及三部分内容&#xff1a;RSS&#xff0c;page cache和swap这三部分&#xff0c;容器基于memory Cgroup对内存进行限制…

用国内镜像安装docker 和 docker-compose (ubuntu)

替代方案&#xff0c;改用国内的镜像站(网易镜像&#xff09; 1.清除旧版本&#xff08;可选操作&#xff09; for pkg in docker.io docker-doc docker-compose podman-docker containerd runc; do apt-get remove $pkg; done 2.安装docker apt-get update 首先安装依赖 apt-g…

Linux驱动开发实战宝典:设备模型、模块编程、I2C/SPI/USB外设精讲

摘要: 本文将带你走进 Linux 驱动开发的世界,从设备驱动模型、内核模块开发基础开始,逐步深入 I2C、SPI、USB 等常用外设的驱动编写,结合实际案例,助你掌握 Linux 驱动开发技能。 关键词: Linux 驱动,设备驱动模型,内核模块,I2C,SPI,USB 一、Linux 设备驱动模型 Li…

mysql创建表的规范

名称 建表的时候&#xff0c;给表&#xff0c;字段和索引起个好名字 见名知意&#xff1a;好的名字能够降低沟通和维护的成本名字不宜过长&#xff0c;尽量控制在30个字符以内 大小写 名字尽量都用小写字母&#xff0c;因为从视觉上&#xff0c;小写字母更容易让人读懂全部大写…

Linux嵌入式中MQTT的使用

MQTT是什么&#xff1f; MQTT&#xff08;Message Queuing Telemetry Transport&#xff0c;消息队列遥测传输协议&#xff09;&#xff0c;是一种基于发布/订阅&#xff08;Publish/Subscribe&#xff09;模式的轻量级通讯协议&#xff0c;该协议构建于TCP/IP协议上&#xff0…