TorchAcc:基于 TorchXLA 的分布式训练框架

演讲人:林伟,阿里云研究员,阿里云人工智能平台 PAI 技术负责人

本文旨在探讨阿里云 TorchAcc,这是一个基于 PyTorch/XLA 的大模型分布式训练框架。

过去十年 AI 领域的显著进步,关键在于训练技术的革新和模型规模的快速攀升。尽管大模型展现了堪比人类的理解力,但其训练却对算力提出了极高的要求。唯有配备充足的计算资源,方能在海量数据上有效训练大模型,确保其在有限时间内实现优质收敛。

图片

图片来源于 GTC 2024大会China AI Day 线上专场的演讲《TorchAcc:基于TorchXLA的分布式训练框架》

根据上图左侧图表显示,过去五年,大模型规模的增长态势尤为突出,平均每两年大小翻 15 倍;而对于 Transformer 为代表的语言模型以及多模态模型而言,其规模膨胀速度更加惊人,每隔两年以 750 倍剧增。对比之下,右侧图表揭示了一个明显的矛盾点:不论是单个 GPU 的计算能力抑或是 GPU 显存容量的发展速度,都无法跟上模型规模如此急剧的扩张步伐。这一现实状况直接催生了对分布式训练的迫切需求。分布式训练不再局限于以往单纯的数据并行模式,而是在此基础上,更加重视并采取模型并行策略,以弥补单个计算单元算力与存储提升速度相对于模型规模增长的滞后性。

在分布式训练实践中,开发人员普遍认同,构建模型并行的分布式训练系统相比数据并行更为复杂。数据并行从分布式角度来看,其逻辑相对直接和简洁,因为每个计算节点执行的任务本质上是对等且一致的。在这种情况下,只需在训练过程末尾插入 AllReduce 步骤,将各个工作节点(worker)独立计算出的梯度差异累加整合,然后求平均值,并将最终梯度结果广播至所有参与工作的节点,用以同步更新全局模型参数。

这类简单的分布式训练范式,确实呈现出类似单机计算的特点,主要涉及全局梯度同步的 AllReduce。然而步入大模型时代,由于模型规模过大,已无法容纳于单个 GPU 之内,我们就必须采用模型并行策略,其开发难度也就陡然上升了。

原因是,模型并行需要根据模型的规模和结构来决定如何恰当地“分割”模型,即将其分割为多个可以平衡计算负载的模块。在不同的分割策略下,模型在各个节点上算子的算法实现方式会发生变化,同时,不同分割方法还会引起节点间通信原语的差异,需要精心选择最优分割方案以及配套的通信原语。

在模型分割完成后,接下来的任务就是选用适合的通信原语,并精细地调度各个算子及其相关的通信操作,力求最大化计算与网络通信的重叠(overlap),以充分发挥底层计算资源的效率。正是由于存在多种可能的分割选项与调度决策,寻求最优模型并行策略的复杂性明显高于数据并行,对开发者的技巧和经验提出了更高的要求。

图片

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

本文将围绕四个核心方面展开。首个议题是如何在 TorchAcc 中实现多样化的并行策略,涵盖了常规的数据并行,以及当下备受关注的 FSDP(Fully Sharded Data Parallel,又称 ZeRO (Zero Redundancy Optimizer)) 。此外,还包括了模型并行的各种形态,诸如算子并行,即 Tensor Parallelism,以及流水线并行(Pipeline Parallelism)等。

TorchAcc 的一大亮点在于其能够自动探寻并有机整合各类并行策略,并为用户提供高度自动化的分布式策略配置方案;与此同时,为了满足高级开发者的定制化需求,TorchAcc 还提供了半自动化的控制接口,允许用户介入并调整自动探索并行策略的过程,从而在兼顾灵活性的同时,最大程度地提升训练效率和资源利用率。

通过上述方式,TorchAcc 有效地助力算法开发者将精力集中于模型自身的结构设计、训练方法的优化,以及追求模型收敛性能的提升上,而非花费精力在分布式训练的具体实现细节。TorchAcc 将智能化地协助开发者探寻并实现最佳的分布式训练方案,从而显著提升计算资源利用效率和算法迭代效率。

其次,模型并行技术的必要性是因为大模型尺寸超出单个 GPU 显存容量的限制。显存容量对于模型训练至关重要,如何打破显存瓶颈,对于提升分布式训练的整体效率来说至关重要。因此,TorchAcc 提供了一种显存智能分配器,通过对显存资源的精细化调度与地址分配策略,最大限度地提高模型并行训练时的效率,确保模型能充分利用现有的显存地址空间。

再者,随着模型结构日益复杂,且规模不断增大,用户对计算资源的需求也在持续攀升,因此,进一步优化模型在训练过程中的计算密集度及减少访存开销也非常关键。

最后,考虑到当前数据中心基础设施的发展趋势,大模型训练对网络条件的要求日渐严苛。现代数据中心服务器间的互联带宽已达到 TB 级别,以满足大规模模型并行训练对高速数据交换的需求。然而,模型并行所带来的复杂通信模式与高频次的数据交互亦会对整体训练效率构成挑战。因此,如何有效利用网络带宽,减少通信过程在迭代计算中占据的时间比例,也就成了训练效率提升的另一重要因素。

在具体实现上,TorchAcc 通过一系列技术手段,成功地将用户在前端,无论是基于 PyTorch 还是 TensorFlow 构建的模型训练过程转化为统一的中间表示层(Model IR)的 graph。其中,对于 TensorFlow 而言,因其自身就是一种计算图模型,转化过程相对直接,而对于 PyTorch,我们采用了符号式追踪(symbolic tracing)以及 LazyTensor 等技术捕获计算图,进而转化为 IR Graph。

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

基于中间表示层(IR Graph)的构建,TorchAcc 实施了一系列多元化的优化策略,涵盖计算优化、存储优化、通信优化以及分布式策略优化,IR Graph 以各类组合并反复执行这些优化的 Pass 后,最终得到一个最优的执行 Plan。然后交由底层 Backend 执行,以实现模型训练性能的最大化提升。

通过这一整套方案,TorchAcc 在多个模型的分布式训练场景中表现出了显著的性能优势。部分模型的训练过程得以实现高达 3 倍的性能提速,充分证明了 TorchAcc 在解决分布式训练难题上的高效性和实用性。

图片

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

这张图片主要展示了 TorchAcc 的框架总体架构。TorchAcc 以 Pytorch/XLA 为基础,并 TorchAcc 依托于 OpenXLA,构建了一套大模型训练加速框架。TorchAcc 在处理使用不同前端构建的模型时,会灵活采用适宜的图捕获技术,如 Symbolic Trace 和 LazyTensor,进而生成两种不同层级的图表示:FX Graph 和 HLO Graph。其中,FX Graph 位于较高抽象层次,而 HLO Graph 则更为底层。

基于捕获到的模型计算图,TorchAcc 即可进一步展开了四类优化工作,即前文提及的计算优化、存储优化、通信优化以及分布式策略优化。

图片

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

在分布式策略优化层面,TorchAcc 支持业界广泛使用的各种并行策略,并能够灵活地结合这些策略对给定模型进行有效的并行化处理。具体而言,对于数据并行 DP(Data Parallelism)、流水并行 PP(Pipeline Parallelism)以及 FSDP(Fully Sharded Data Parallel, 也称为 ZeRO)这三种分布式策略,其实现和优化都是在 FX Graph 这一较高抽象层次上完成的。

选择在 FX Graph 层面对并行策略进行操作的原因在于,这一层级所包含的关于计算图结构和操作的信息已足够丰富,足以支撑开发人员设计出适应不同并行策略的优化方案。相较于在更低层的 HLO Graph 上直接进行优化,由于 FX Graph 具有更高的抽象性和概括性,在这一层面上进行优化的成本通常较低,更容易实施高效且针对性强的分布式策略调整。

以流水并行作为例子,系统能够自动检测 FX Graph 层级上的不同阶段,并确定合适的分割点,从而有效地将模型分割为多个连续执行的阶段,实现流水线并行化。在此过程中,我们可以利用 FX Graph 提供的详细计算结构信息来进行智能分割。

至于 Tensor Parallelism (张量并行)和 Sequence Parallelism (序列并行)这两种更为复杂的并行策略,它们要求更为细致精确的信息以便进行决策。为了实现这一点,系统需要对前向传播和反向传播的整个计算图的执行计划来进行分析。这时的工作主要在 HLO 这一低级别表示层面上进行。

通过利用 PyTorch/XLA 提供的 mark sharding 接口,系统能够在模型参数上添加相应的拆分标记,然后将这些拆分信息传递给 OpenXLA 的 SPMD 优化 Pass,进而触发计算图的拆分、优化、推导和重写过程,最终实现自动的 Tensor Parallelism 和 Sequence Parallelism 功能。

图片

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

在算子优化层面,TorchAcc 引入 FlashAttention 技术来提升 Attention 模块的执行效率。首先,通过 XLA 的 custom call 功能,将 FlashAttention 的实现无缝地融入到了 OpenXLA 编译器和运行时框架中。这意味着 FlashAttention 可以直接在 XLA 内核层级被执行,从而充分利用硬件加速能力。

在整合过程中,要处理好在 PyTorch 与 XLA 之间 Tensor 数据的传递问题,确保在两个系统间转换时的数据一致性与性能优化,同时,还要妥善处理 FlashAttention内部参数传递等细节问题,保证在并行计算和优化的过程中,这些关键参数能够正确且高效地应用到计算中,进一步提升模型在执行注意力机制部分的运算速度和资源利用率。

为了用户能便捷地使用 FlashAttention 优化功能,我们提供了两种接口,用户也可以直接通过 Python 接口调用预先写好的 FlashAttention 算子,第三种方法是用户可以使用我们在 OpenXLA 上写好的 Pattern Match Pass,该 Pass 能够自动识别计算图中的 Attention Block,并将这部分计算结构提取出来,替换为FlashAttention 的 custom call。这样设计的优势在于,既能充分利用 XLA 原本就十分出色的 Kernel fusion 等算子优化功能,又能结合 FlashAttention 带来的先进计算优化技术。

图片

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

在 Llama 2-7B 模型的性能测试中,我们能够明显观察到上述计算优化带来的效果。通过利用 XLA 自身的优化技术,尤其是 kernel fusion,我们将大量的访存密集型算子做了有效合并,从而大幅减少其数量,在叠加 FlashAttention 后,优化性能进一步提升。

图片

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

在通信优化层面,我们主要完成了三项核心任务以提升分布式训练效率:首先,我们合并了一些零散的 collective 通讯算子,通过减少算子数量来降低通讯开销和调度复杂度,其次,我们将合并的 collective 通讯算子移至独立的 CUDA Stream 上执行,这样一来,就能够异步实现计算与通讯的重叠执行。最后,我们充分利用了 OpenXLA 的 Latency Hiding Scheduler 功能,对通讯算子的调度进行了精细优化,使其尽早启动和执行,从而增强通讯与计算之间的重叠效果。

图片

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

通过在 Llama2 -7B 模型上进行的端到端多机性能测试,我们发现,应用了通讯优化策略后,在 128 张 GPU 卡上进行分布式训练,优化后的加速比从原来的 88 提升到了 116,通过 timeline 图我们也可以直观地看到,优化后的通讯算子更加有序,并且能够更好地和计算重叠执行。

图片

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

本文最后一个章节绍 TorchAcc 的显存优化功能,该功能通过优化计算图中算子的执行顺序以及 Tensor 在显存中的地址分配,来降低显存开销。

如图举例说明,假设有一个包含四个算子 V0、V1、V2、V3 的计算图,如果不控制算子执行顺序,如左图所示按照 V0-V1-V2-V3 的顺序执行,若每个 Tensor 按照默认方式进行显存地址申请,则可能出现如 B 图左半部分所示的情况,即显存容量不足以容纳所有 Tensor,导致 out of memory 错误。

然而,如果我们能够预判并精细管理内存分配,即在分配地址时预知后续执行的算子序列,即可如 B 图右半部分所示进行更优的显存布局,使得整体计算可在有限显存内顺利完成。更进一步,通过精确控制执行顺序,比如按照 V0-V2-V1-V3 的方式执行,可以进一步压缩显存需求至原始需求的 70% 左右。

这一理念是基于 XLA 中间表示层已有的 scheduler 和 buffer 管理机制,我们在此基础上提出了更先进的显存优化方法。目前业界存在多种优化显存分配的方法,如启发式算法、约束求解等,但这些方法往往难以兼顾时效性和高效性,在实际生产环境的集群中应用时可能存在局限性。

图片

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

在训练场景中实现有效且高效的显存优化是一项极具挑战的任务,原因主要包括以下几个方面:

  1. NP-Hard 问题本质:由于模型的规模、算子的种类繁多,以及算子间显存分配的复杂性,显存优化问题成为一个典型的 NP-hard 问题,即找到全局最优解在计算上通常是不可行的。

  2. 算子执行灵活性:训练过程中,前向传播、反向传播和权重更新等操作具有很高的灵活性,特别是在权重更新方面,梯度产生后随时可以被用于权重更新,但不同的执行时机会影响显存的申请和释放,增加了优化难度。

  3. 显存复用复杂性:在训练过程中,前向和反向传播可以通过复用显存减少重新计算,但 Tensor 生命周期的多样性和尺寸的变化使得显存复用变得极为复杂,这对启发式算法等传统优化手段构成了严峻挑战。

为了解决上述难题,我们采取了一种分治策略:

  1. Memory-aware Weight Update Scheduler:引入了显存感知的权重更新调度器,它会根据梯度产生的时机、使用的优化器类型以及当前显存资源状况,选择合适的权重更新时间点,避免即时更新加重显存压力,特别是对于复杂的优化器如 Adam,需考虑动量和其他变量的存储。

  2. Graph 分割与局部优化:将大计算图根据关键节点 (memory insensitive operator) 分割成多个内存无关性的子图,子图间执行顺序固定,而子图内部的执行顺序则可以多样化。通过这种方式,可以将复杂的全局线性规划问题分解成多个局部问题,在子图范围内采用高效的优化方法,如线性规划求解最优执行顺序。

通过上述分治策略,最终我们能够聚合这些子图的求解结果,这也就是我们提出的 ROAM (Reorder Operators and Arrange Tensors Address to Reduce Memory Usage) 这一内存优化探索方式。

上述方法可以成功实现对显存优化问题的高效处理。实验结果显示,与原生 PyTorch、启发式算法以及 Facebook 近期基于整数线性规划的优化方法等 baseline 相比,ROAM 分别节省了约 16%、13% 和 27% 的显存开销,且在优化时长和可扩展性方面表现出色,证实了这种方法的有效性。

图片

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

图片

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

从另一个维度衡量效果,我们考察了算法求解的时间开销。实验证明,在常见的深度学习场景中,我们的优化算法能够在短短几分钟内得出优化结果。从右图所示对比中可以看出,相较于 Facebook 最近提出的 MODeL(一种基于线性规划的优化方法),我们的方法在求解时间上实现了显著的缩减。原因在于,MODeL 在处理大规模图时并未对其进行有效分割,而我们的方法通过引入 memory-aware weight update scheduler 和子图划分策略,有效地降低了优化问题的空间复杂度,从而提高了求解效率。

综上所述,TorchAcc 在显存优化、计算优化、通信优化以及并行策略优化等方面均取得显著成效,全方位提升了分布式训练的效率与性能。


以上内容来源于 GTC 2024 大会 China AI Day 线上中文演讲专场。扫描图片二维码或登录大会官网,即可观看演讲视频,并可下载讲义。

图片

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/768992.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

详细剖析多线程2----线程安全问题(面试高频考点)

文章目录 一、概念二、线程不安全的原因三、解决线程不安全问题--加锁(synchronized)synchronized的特性 四、死锁问题五、内存可见性导致的线程安全问题 一、概念 想给出⼀个线程安全的确切定义是复杂的,但我们可以这样认为: 在多…

立体统计图表绘制方法(凸显式环图)

立体统计图表绘制方法(凸显式环图) 记得我学统计学的时候,那些统计图表大都是平面的框框图,很呆板,就只是表现出统计的意义就好了。在网络科技发展进步的当下,原来一些传统的统计图表都有了进一步的创新。在…

RDGCN翻译

RDGCN翻译 Relation-Aware Entity Alignment for Heterogeneous Knowledge Graphs 面向异质知识图谱的关系感知实体对齐 阅读时间:2024.03.24 领域:知识图谱,知识对齐 作者:Yuting Wu等人 PKU 出处:IJCAI Abstract…

HarmonyOS NEXT应用开发之听歌识曲水波纹特效案例

介绍 在很多应用中,会出现点击按钮出现水波纹的特效。 效果图预览 使用说明 进入页面,点击按钮,触发水波纹动画。再次点击按钮,停止水波纹动画。 实现思路 本例涉及的关键特性和实现方案如下: 要实现存在两个连续…

C++ - 类和对象(上)

目录 一、类的定义 二、访问限定符 public(公有) protected(保护) private(私有) 三、类声明和定义分离 四、外部变量和成员变量的区别与注意 五、类的实例化 六、类对象的模型 七、类的this指针…

TCP详解

一、TCP报文段结构 1、源端口号和目的端口号都是16位,范围从(1-65535,0不可用) 2、序列号:在建立连接时由内核生成的随机数作为其初始值,通过 SYN 报文传给接收端主机,每发送一次数据&#xff0…

C语言数据结构易错知识点(5)(插入排序、选择排序)

插入排序:直接插入排序、希尔排序 选择排序:直接选择排序、堆排序 上述排序都是需要掌握的,但原理不会讲解,网上有很多详尽地解释,本文章主要分享一下代码实现上应当注意的事项 1.直接插入排序: 代码实…

拥抱C++的深度和复杂性,挖掘更多可能 !——《C++20高级编程(第5版)》

,C难以掌握,但其广泛的功能使其成为游戏和商业软件应用程序中最常用的语言。即使是有经验的用户通常也不熟悉许多高级特性,但C20的发布提供了探索该语言全部功能的绝佳机会。《C20高级编程(第5版)》为C的必要内容提供了一个代码密集型、面向解…

(AtCoder Beginner Contest 325) ---- D - Printing Machine -- 题解

目录 D - Printing Machine: 题目大意: 思路解析: 代码实现: D - Printing Machine: 题目大意: 思路解析: 打印一次后,需要充电一微秒后才能再次打印就可以看作每微妙只能打印一…

【文献阅读】AlphaFold touted as next big thing for drug discovery — but is it?

今天来精读2023年10月发在《Nature》上的一篇新闻:AlphaFold touted as next big thing for drug discovery — but is it? (nature.com)https://www.nature.com/articles/d41586-023-02984-w Questions remain about whether the AI tool for predicting protein …

蓝桥杯基础练习详细讲解二(具体代码、解题思路、Python)

试题 基础练习 回文数 提交此题 评测记录 资源限制 内存限制:512.0MB C/C时间限制:1.0s Java时间限制:3.0s Python时间限制:5.0s 问题描述 1221是一个非常特殊的数,它从左边读和从右边读是一样的&#x…

C语言从入门到实战----C语言中内存函数的使用和模拟实现

目录 前言 1.memcpy 使用和模拟实现 2. memmove 使用和模拟实现 3. memset 函数的使用 4. memcmp 函数的使用 前言 在编程领域,内存管理是至关重要的一环,它确保了程序能够高效、稳定地运行。 C语言作为一门底层的编程语言,提供了一系…

Redis 教程系列之Redis 集群配置(十三)

1.Redis集群方案比较 主从模式 在软件的架构中,主从模式(Master-Slave)是使用较多的一种架构。主(Master)和从(Slave)分别部署在不同的服务器上,当主节点服务器写入数据时,同时也会将数据同步至从节点服务器,通常情况下,主节点负责写入数据,而从节点负责读取数据。…

橘子疾病检测4种YOLOV8

橘子检测YOLOV8,检测4种疾病,采用YOLOV8-NANO,训练得到PT模型转换成ONNX,最后OPENCV调用,支持C/PYTHON/ANDROID 橘子检测YOLOV8,检测4种疾病

2025汤家凤考研数学视频,基础网课百度网盘课程+PDF讲义资料

2025汤家凤大神及数学全程 docs.qq.com/doc/DTmtOa0Fzc0V3WElI 复制粘贴到浏览器,可以见所有的Ke 第一轮 夯实基础 1.阅读大纲考查要求,明确每章的学习目标; 2.按节学习数学理论基础知识,吃透书中例题; 3.学习每章…

【C语言】数组(一维、二维数组的简单介绍)

数组(Array) 数组概念 数组是一组相同数据类型元素的集合,属于一种简单的数据结构,从中可以得到三个有效信息 数组元素是同一数据类型的变量数组存放一个或者多个数据,但是数组元素个数不能为0数组中各元素可独立作为…

【Web APIs】DOM节点

目录 1.节点操作 1.1DOM节点 1.2查找节点 1.2.1父节点查找 1.2.2子节点查找 1.2.3兄弟节点查找 1.3增加节点 1.4克隆节点 1.5删除节点 2.时间对象 2.1实例化 2.2时间对象方法 2.3时间戳 3.重绘和回流 1.节点操作 1.1DOM节点 DOM节点:DOM树中的每一个…

CHAT~(持续更新)

CHAT(持续更新) 实现一个ChatGPT创建API设计页面布局业务操作技术架构 编码其他 实现一个ChatGPT 创建API 最简单也最需要信息的一步 继续往下做的前提 此处省略,想要获取接口创建方式联系 设计 页面布局 按照官网布局 业务操作 注册登…

绝地求生:PUBG七周年庆典开启!参与周年话题投稿赢丰厚奖励

为庆祝七周年,闲游盒PUBG官方准备了众多活动与奖励,一起在庆典中创造难忘的回忆吧!七周年庆典期间游玩PUBG,参与 #乐在7中鸡味无穷# 周年话题投稿,即有机会赢取魔力甜心萨莉套装 2奖励。 参与方式 在小黑盒PUBG社区中…

贪心算法相关题目

文章目录 1. 什么是贪心?2. 分发饼干3. 摆动序列4. 最大子数组和5. 买卖股票的最佳时机 II6. 跳跃游戏7. 跳跃游戏 II8.K 次取反后最大化的数组和9.加油站10.分发糖果11.柠檬水找零 1. 什么是贪心? 贪心的本质是选择每一阶段的局部最优,从而…