Megatron-LM源码系列(八): Context Parallel并行

1. Context Parallel并行原理介绍

megatron中的context并行(简称CP)与sequence并行(简称SP)不同点在于,SP只针对LayernormDropout输出的activation在sequence维度上进行切分,CP则是对所有的input输入和所有的输出activation在sequence维度上进行切分,可以看成是增强版的SP。除了Attention模块以外,其他的模块(Layernorm、Dropout)由于没有多token的处理,在CP并行时都不用任何修改。

为什么Attention模块是个例外? 因为Attention计算过程中每个token的Q(query)要跟同一个sequence中其他token的K(key)和V(value)一起进行计算,存在计算上的依赖,所以通过CP并行后,在计算Attention前要通过allgather通信拿到所有token的KV向量,在反向计算时对应需要通过reduce_scatter分发gradient梯度。

为了减少显存占用,在前向时每个gpu只用保存一部分KV块,反向时通过allgather通信拿到所有的KV数据。KV的通信发生在相邻TP通信组相同位置的rank之间。allgather和reduce_scatter在ring拓扑架构实现时,底层会通过send和recv来进行实现。
在这里插入图片描述

以上图TP2-CP2的transformer网络为例,在Attention前的是CP的通信算子,其他都是TP的通信算子。AG表示allgather, RS表示reduce_scatter, AG/RS表示前向allgather反向reduce_scatter, RS/AG表示前向reduce_scatter反向allgather。

这里TP2对应为[GPU0, GPU1], [GPU2, GPU3], CP2对应为TP组相同位置的rank号,也就是[GPU0, GPU2], [GPU1, GPU3]。CP并行与Ring Attention类似,但是提供了新的OSS与FlashAttention版本,也去除了low-triangle causal masking的冗余计算。

LLM经常由于sequence长度过长导致显存OOM,这时之前的一种方式是通过重计算的方式保存中间的activation产出,全量重计算的劣势会带来30%的计算代价;另外一种方式是扩大TP(tensor parallel)的大小,扩大TP的劣势在于会对tensor切的更小,从而导致linear fc的计算时间变少,从而与通信很难进行计算的掩盖。

通过CP可以更好解决OOM的问题,每个GPU只用处理一部分的sequence, 同时减少CP倍的通信和计算,但保持TP不变,同时activation也会减少CP倍。CP优化的性能参考如下图,在Megatron中(Megatron-Core>=0.5.0 && Transformer Engine >=1.1)通过指定--context-parallel-size可以进行使用。 t o t a l _ g p u _ c o u n t = T P × C P × P P × D P total\_gpu\_count = TP \times CP \times PP \times DP total_gpu_count=TP×CP×PP×DP

在这里插入图片描述

2. 源码

以Megatron-Core 0.5.0为例进行介绍

  • 首先在megatron/arguments.py中定义了--context-parallel-size参数, 同时也要求了world_size能要整除TP*PP*CP
    group.add_argument('--context-parallel-size', type=int, default=1,help='Degree of context parallelism.')....model_parallel_size = args.pipeline_model_parallel_size * \args.tensor_model_parallel_sizeassert args.world_size % (model_parallel_size * args.context_parallel_size) == 0, \'world size ({}) is not divisible by tensor parallel size ({}) times ' \'pipeline parallel size ({}) times context parallel size ({})'.format(args.world_size, args.tensor_model_parallel_size,args.pipeline_model_parallel_size, args.context_parallel_size)args.data_parallel_size = args.world_size // (model_parallel_size * args.context_parallel_size)
  • megatron/core/parallel_state.py中初始化通信组时会初始化相关CP通信组, 以TP-PP-DP-CP=8-1-1-2为例,TP通信组为[0,1,2,3,4,5,6,7],[8,9,10,11,12,13,14,15], CP通信组为[0,8],[1,9],[2,10],[3,11],[4,12],[5,13],[6,14],[7,15]
def initialize_model_parallel(...):...for i in range(pipeline_model_parallel_size):for j in range(data_parallel_size):start_rank = (i * num_pipeline_model_parallel_groups+ j * tensor_model_parallel_size * context_parallel_size)end_rank = (i * num_pipeline_model_parallel_groups+ (j + 1) * tensor_model_parallel_size * context_parallel_size)for k in range(tensor_model_parallel_size):ranks = range(start_rank + k, end_rank, tensor_model_parallel_size)group = torch.distributed.new_group(ranks, pg_options=get_nccl_options('cp', nccl_comm_cfgs))if rank in ranks:_CONTEXT_PARALLEL_GROUP = group_CONTEXT_PARALLEL_GLOBAL_RANKS = ranks
  • megatron/core/transformer/custom_layers/transformer_engine.pyTEDotProductAttention会初始化相关CP通信组相关参数,TEDotProductAttention继承自te.pytorch.DotProductAttention,在前向中直接调用父类的的forward函数。
class TEDotProductAttention(te.pytorch.DotProductAttention):def __init__(...):...if te_version >= packaging.version.Version("1.0.0"):if getattr(TEDotProductAttention, "cp_stream") is None:TEDotProductAttention.cp_stream = torch.cuda.Stream()extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(check_initialized=False)extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream...def forward(...):...core_attn_out = super().forward(query,key,value,attention_mask,attn_mask_type=attn_mask_type.name,**packed_seq_kwargs,)...
  • Transformer Engine中DotProductAttention定义在transformer_engine/pytorch/attention.py中,CP相关参数通过attn_kwargs进行传入。接下来会调用FlashAttention的前反向。
class DotProductAttention(torch.nn.Module):def __init__(...):...if self.use_flash_attention:self.flash_attention = FlashAttention(norm_factor,attention_type=attention_type,layer_number=layer_number,deterministic=self.deterministic,**attn_kwargs)...class FlashAttention(torch.nn.Module):def forward(...):...if context_parallel:with self.attention_dropout_ctx():output = attn_forward_func_with_cp(self.training, query_layer, key_layer, value_layer,cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,self.attention_dropout if self.training else 0.0,cp_group, cp_global_ranks, cp_stream,softmax_scale=1.0/self.norm_factor,qkv_format="bshd" if qkv_format=="sbhd" else qkv_format,attn_mask_type=attn_mask_type,deterministic=self.deterministic)
  • 在FlashAttention中会通过函数attn_forward_func_with_cp进行调用,最终Attn前的all_gather通信是在AttnFuncWithCP中通过send、recv通信来实现的, 执行完通信就执行对应的flash_attention算子的调用。
def attn_forward_func_with_cp(...):out = AttnFuncWithCP.apply(is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, qkv_format,attn_mask_type, attn_bias_type, attn_bias, deterministic, use_fused_attention)return outclass AttnFuncWithCP(torch.autograd.Function):def forward(...):for i in range(cp_size+1):if i < cp_size:with torch.cuda.stream(flash_attn_streams[i%2]):# wait until KV is receivedfor req in send_recv_reqs[(i+1)%2]:req.wait()if i < (cp_size-1):p2p_comm_buffers[i+1] = torch.empty_like(p2p_comm_buffers[i])send_recv_reqs[i%2] = flash_attn_p2p_communicate(rank,p2p_comm_buffers[i],send_dst,p2p_comm_buffers[i+1],recv_src,cp_group,batch_p2p_comm)...fused_attn_fwd(is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q,cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0],kv_inputs[i%2][1], TE_DType[q.dtype],tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,attn_scale=softmax_scale, dropout=dropout_p,qkv_layout=qkv_layout, attn_mask_type="causal",attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2],)

3. 参考

  • Context parallelism overview

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/17159.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

M00238-固定翼无人机集群飞行仿真平台MATLAB完整代码含效果

一个小型无人机集群仿真演示平台&#xff0c;使用matlab和simulink搭建。 给出的例子是5架的&#xff0c;当然如果你愿意花时间&#xff0c;也可以把它扩展到10架&#xff0c;20架甚至更多。 输入&#xff1a;5架飞机的规划路径 输出&#xff1a;每架无人机每个时刻的13个状态量…

Docker环境安装并使用Elasticsearch

1、拉取es docker pull elasticsearch:7.10.12、查看镜像 docker images3、启动es docker run -d --name esearch -p 9200:9200 -p 9300:9300 elasticsearch:7.10.14、如果启动ES时出现一下问题 Unable to find image docker.elastic.co/elasticsearch/elasticsearch:7.10.…

python max_min标准化

python max_min标准化 max_min标准化sklearn实现max_min标准化手动实现max_min标准化 max_min标准化 Max-Min标准化&#xff08;也称为归一化或Min-Max Scaling&#xff09;是一种将数据缩放到特定范围&#xff08;通常是0到1&#xff09;的标准化方法。这种方法通过线性变换将…

用PhpStudy在本地电脑搭建WordPress网站教程(2024版)

对新手来说&#xff0c;明白了建站3要素后&#xff0c;如果直接购买域名、空间去建站&#xff0c;因为不熟练&#xff0c;反复测试主题、框架、插件等费时费力&#xff0c;等网站建成可能要两三个月&#xff0c;白白损失这段时间的建站费用。那么新手怎么建测试网站来练手呢&am…

06.部署jpress

安装mariadb数据 yum -y install mariadb-server #启动并设置开启自启动 systemctl start mariadb.service systemctl enable mariadb.service数据库准备 [rootweb01 ~]# mysql Welcome to the MariaDB monitor. Commands end with ; or \g. Your MariaDB connection id…

OpenAI 再次刷新认知边界:GPT-4 颠覆语音助手市场,流畅度直逼真人互动?

前言 近日&#xff0c;美国人工智能研究公司 OpenAI 发布了其最新旗舰模型 GPT-4o&#xff0c;这一革命性的进展不仅标志着人工智能领域的新突破&#xff0c;更预示着即将步入一个全新的交互时代&#xff1f;GPT-4o 的发布&#xff0c;对于我们来说&#xff0c;意味着人工智能…

冯喜运:5.28黄金今日走势分析及黄金原油操作策略

【黄金消息面分析】&#xff1a;周一&#xff08;5月27日&#xff09;美盘时段&#xff0c;现货黄金止跌回稳&#xff0c;缓慢回升&#xff0c;盘中最高触及2358.4美元。美国商品期货交易委员会(Commodity Futures Trading Commission)的最新交易数据显示&#xff0c;对黄金的投…

空压机的热回收原理介绍

空压机运行时会产生大量的压缩热&#xff0c;通常这部分能量通过机组的风冷或水冷系统释放到大气当中。压缩机的热回收是持续降低空气系统损耗&#xff0c;提高客户生产力的必要手段。 余热回收的节能技术目前研究很多&#xff0c;但大多只针对喷油螺杆式空压机的油路改造而言…

Eureka全面解析:轻松实现高效服务发现与治理!

一、引言 Eureka是Netflix开源的一款服务发现框架&#xff0c;它提供了一种高效的服务注册和发现机制&#xff0c;适用于大规模分布式系统。本文将详细介绍Eureka的相关知识。 二、Eureka简介 Eureka是一个基于REST的服务发现框架&#xff0c;它提供了一种简单的服务注册和发…

如果创办Google

本文是一篇演讲稿&#xff0c;来自于《黑客与画家》一书的作者保罗*格雷厄姆&#xff0c;被称为硅谷创业之父。这是他为14至15岁的孩子们做的一次演讲&#xff0c;内容是关于如果他们将来想创立一家创业公司&#xff0c;现在应该做些什么。很多学校认为应该向学生们传授一些有关…

ADS基础教程15 - 设计加密保护IP

设计加密保护IP 一、引言二、IP的生成与调用1.IP生成2.IP的调用 一、引言 介绍如何ADS中如何对设计好的原理图进行加密形成IP&#xff0c;然偶进行调用的过程。 二、IP的生成与调用 1.IP生成 (1)选择一个已经调试好的原理图&#xff0c;在菜单栏中选择Tools–>Encode De…

python中import的搜索路径

文章目录 前言 一 python中import的搜索路径1. python中import的搜索路径先判断是否内置模块根据sys.path查找1.1 脚本当前目录和所属项目目录1.2 环境变量1.3 标准库1.4 .pth 文件1.5 第三方库 2. 解决ModuleNotFoundError 前言 码python时经常会遇到找不到包或者找不到模块的…

OpenWrt 23.05 安装之后默认空间小 磁盘扩容 教程 软路由实测 系列六

1 安装fdisk opkg update opkg install fdisk #查看磁盘 rootOpenWrt:~# fdisk -l GPT PMBR size mismatch (246303 ! 250069679) will be corrected by write. The backup GPT table is not on the end of the device. Disk /dev/sda: 119.24 GiB, 128035676160 bytes, 25006…

开源远程协助:分享屏幕,隔空协助!

&#x1f5a5;️ 星控远程协助系统 &#x1f5b1;️ 一个使用Java GUI技术实现的远程控制软件&#xff0c;你现在就可以远程查看和控制你的伙伴的桌面&#xff0c;接受星星的指引吧&#xff01; 支持系统&#xff1a;Windows / Mac / Linux &#x1f31f; 功能导览 &#x1f…

【Flutter】KeyAnimatedList组件

&#x1f525; 本文由 程序喵正在路上 原创&#xff0c;CSDN首发&#xff01; &#x1f496; 系列专栏&#xff1a;Flutter学习 &#x1f320; 首发时间&#xff1a;2024年5月28日 &#x1f98b; 欢迎关注&#x1f5b1;点赞&#x1f44d;收藏&#x1f31f;留言&#x1f43e; 目…

10个最佳人物素材网站推荐,免费获取第一个PNG文件!

人物素材是设计中应用最广泛的元素之一。无论是网页设计还是移动终端设计&#xff0c;人物素材的插画设计都比文字信息更容易吸引用户的注意力。作为内容呈现&#xff0c;还可以增加设计的艺术属性。为了节省大家寻找人物素材的时间成本&#xff0c;本文立即为大家整理了10个宝…

Java 实验12 线程同步与通信

&#xff08;一&#xff09;实验目的 1、掌握JAVA中多线程的实现方法&#xff1b; 2、重点掌握多线程的同步与通信机制&#xff1b; 3、熟悉JAVA中有关多线程同步与通信的方法 &#xff1b; 4、能使用多线程机制解决实际应用中的线程同步与通信问题。 &#xff08;二&…

行为设计模式之职责链模式

文章目录 概述原理代码实现小结 概述 职责链模式(chain of responsibility pattern) 定义: 避免将一个请求的发送者与接收者耦合在一起,让多个对象都有机会处理请求.将接收请求的对象连接成一条链,并且沿着这条链传递请求,直到有一个对象能够处理它为止. 在职责链模式中&…

宝塔:如何在宝塔面板做301重定向

如何在宝塔面板做301重定向?301重定向对于网站来说非常重要。如果你的网站以www开头&#xff0c;我们应该把没有www的域名重定向到有www的域名&#xff0c;反之亦然。 1、我们进入宝塔管理后台 2、登录面板并单击添加站点。既然要把xxx.com 301发到www.xxx.com&#xff0c;我…

JS 中怎么删除数组元素?有哪几种方法?

正文开始之前推荐一位宝藏博主免费分享的学习教程,学起来! 编号学习链接1Cesium: 保姆级教程+源码示例2openlayers: 保姆级教程+源码示例3Leaflet: 保姆级教程+源码示例4MapboxGL: 保姆级教程+源码示例splice() JavaScript中的splice()方法是一个内置的数组对象函数, 用于…