Large Language Model系列之三:大模型并行训练(Parallel Training of Large Language Models)

Large Language Model系列之三:大模型并行训练(Parallel Training of Large Language Models)

1 各类并行算法

在这里插入图片描述

参考资料:
1 大模型并行训练

2 ZeRO(Zero Redundancy Optimizer)零冗余优化

ZeRO(Zero Redundancy Optimizer)是由微软研究院开发的一种内存优化技术,专门设计用于优化大规模深度学习模型的训练过程。ZeRO的核心原理是通过减少内存冗余来提高训练效率,使得可以在有限的硬件资源上训练更大的模型。

以常用的Adam优化器为例,
在这里插入图片描述

GPU显存存储内容主要分为两大块:Model StatesResidual States:
Model States指和模型本身息息相关的,必须存储的内容,具体包括:

  • optimizer states:Adam优化算法中的m(梯度的一阶矩)和v(梯度的二阶矩)
  • gradients:模型梯度(g)
  • parameters:模型参数 Θ \Theta Θ

Residual States指并非模型必须的,但在训练过程中会额外产生的内容,具体包括:

  • activation:激活值。在流水线并行中我们曾详细介绍过。在backward过程中使用链式法则计算梯度时会用到。有了它算梯度会更快,但它不是必须存储的,因为可以通过重新做Forward来算它。
  • temporary buffers: 临时存储。例如把梯度发送到某块GPU上做加总聚合时产生的存储。
  • unusable fragment memory:碎片化的存储空间。虽然总存储空间是够的,但是如果取不到连续的存储空间,相关的请求也会被fail掉。对这类空间浪费可以通过内存整理来解决。
2-1 优化模型状态内存

模型状态占用了主要的机器内存,针对该问题提出的数据并行方法ZeRO-DP在实现数据并行高效计算的同时,拥有模型并行的内存节省优势。如下图所示,ZeRO-DP主要有三个优化阶段,分别对应了模型状态中优化器状态、梯度、以及模型参数的切分,也就是通常所说的ZeRO-1/2/3。
在这里插入图片描述

  • 优化器状态分区(Optimizer State Partitioning) P o s P_{os} Pos :将optimizer states分成若干份,每块GPU上各自维护一份。在这个阶段每块GPU还是完整的存储一份参数,一个batch的数据被划分成n份,每块GPU上用一份数据计算出一个完整的梯度值,然后计算出这n个GPU上的一个梯度均值,参数的更新都用这个梯度均值,这里注意:参数的更新是由optimizer states和梯度值共同所决定的,由于我们在这个阶段已经对optimizer states进行了分割,分别存储在了不同的GPU上,所以这里的参数只能更新一部分。分区优化器状态到各个计算卡中,在享有与普通数据并行相同通信量的情况下,可降低4倍的内存占用。
  • 添加梯度分区(Gradient Partitioning) P o s + g P_{os+g} Pos+g :在这一步中除了将optimizer states分成若干份,梯度也分成若干份。在这个阶段每块GPU还是完整的存储一份参数,一个batch的数据被划分成n份,这里注意:每块GPU上用一份数据计算出完整的梯度,然后每个GPU汇总自己维护的那部分梯度值,把不是自己维护的那部分梯度值从显存移除。用部分梯度值,部分optimizer states更新全参数中对应的那部分参数,同样再相互通信获得完整的更新后的参数。这一步骤参数和梯度值都需要通信交互。在 P o s P_{os} Pos的基础上,进一步将模型梯度切分到各个计算卡中,在享有与普通数据并行相同通信量的情况下,拥有8倍的内存降低能力。
  • 添加参数分区(Parameter Partitioning) P o s + g + p P_{os+g+p} Pos+g+p :这一步除了将optimizer states、梯度分成若干份,参数也要分区。每块GPU上只保存部分参数,前向反向传播时需要用到完整的参数的话相互通信获取全参数,用完立马从显存移除。梯度计算时也是计算出完整的梯度,然后每个GPU汇总自己维护的那部分梯度值,把不是自己维护的那部分梯度值从显存移除。用部分梯度值,部分optimizer states更新自己维护的那部分参数。在 P o s + g + p P_{os+g+p} Pos+g+p 的基础上,将模型参数也切分到各个计算卡中,内存降低能力与并行数量 N d N_{d} Nd成线性比例,通信量大约有50%的增长。

典型的以时间换空间的优化思想,为了节省显存的空间,增加了通信的时间消耗。当三阶段的ZeRO-DP优化全部启动以后,使用混合精度和Adam优化器的千亿模型(总共占用约16T内存)可以成功基于1024卡上使用常规的32G显卡训练(每卡占用约16G内存)

2-2 优化剩余状态内存

除了ZeRO-DP外,作者还设计了ZeRO-R来解决剩余状态带来的内存瓶颈问题。
剩余状态的冗余主要集中在两个方面:

  • 临时缓冲区:在模型训练过程中,临时缓冲区会累积大量的中间数据,这些数据在不再需要时若未能及时清理,便会造成内存资源的浪费。
  • 内存碎片:由于PyTorch等深度学习框架在变量生命周期管理中的特性,频繁地分配和释放内存会导致内存碎片化,这不仅减少了可用内存空间,还增加了内存分配的时间成本。

针对上述问题,作者创新性地提出了ZeRO-R方法,该方法通过以下策略来优化内存使用效率:

  1. 激活值分区检查点(Partitioned Activation Checkpointing)
  • 问题描述:在模型并行训练中,为了支持跨设备的数据交换,activation数据需要被复制,这导致了冗余。
  • 解决方案:ZeRO通过激活值分区技术来减少这种冗余。具体地,它在正向传播完成后立即对activation进行切分,并在需要时(如在反向传播中)通过all-gather操作重新组合这些分片,从而避免了不必要的数据复制。
  1. 固定大小缓冲区(Constant Size Buffers)
  • 问题描述:某些运算的效率与输入数据的大小密切相关,大型all-gather等操作在处理大批量数据时更为高效。然而,随着模型并行复杂度的增加,对内存的需求也急剧上升。
  • 解决方案:ZeRO引入固定大小的缓冲区策略,以优化这类操作的内存使用。这一策略类似于网络通信中的窗口大小调整,通过预先分配固定大小的内存块来减少内存管理的开销,同时提高数据处理的效率。
  1. 内存碎片整理(Memory Defragmentation)
  • 问题描述:PyTorch等框架在变量生命周期管理中,由于频繁的内存分配和释放,导致内存碎片化问题严重,影响了内存的有效利用率和分配效率。
  • 解决方案:ZeRO通过为activation checkpoint和gradients等关键数据预先分配连续的内存块,有效避免了内存碎片的产生。这种策略不仅提高了内存的利用率,还减少了内存分配时的搜索时间,从而提升了整体训练性能。

ZeRO++:降低4倍网络通信,显著提高大模型及类ChatGPT模型训练效率,核心思路如下:
在这里插入图片描述
ZeRO-1/2/3 以及 ZeRO++的汇总:
在这里插入图片描述

参考资料
1 ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
2 大模型数据并行训练之DeepSpeed-ZeRO(零冗余优化)
3 ZeRO(零冗余优化器)

3 FlashAttention

Flash Attention 是一种针对 Transformer 模型中 Attention 机制的高效实现方式,旨在减少高带宽内存(如 HBM)的访问次数,同时利用 SRAM(静态随机存取存储器,通常指 GPU 上的 L1/L2 缓存)的高带宽特性来加速计算。

标准Attention算法执行过程如下:

在这里插入图片描述
在标准Attention算法的执行过程中,首先涉及从高带宽内存(HBM)中读取两个关键矩阵Q和K,它们各自具有N x d的维度。随后,这两个矩阵通过点积运算生成相似度得分矩阵S,其维度为N x N。这一步骤的HBM访问次数主要由读取Q和K矩阵以及写入S矩阵组成,总体上是O(Nd + N^2)次访问。

接下来,为了计算注意力权重P(同样为N x N矩阵),需要对相似度得分矩阵S进行softmax操作。softmax操作需要访问S矩阵中的所有元素以计算归一化权重,因此这一步也产生了O(N^2)次的HBM访问。

最后,在生成输出向量O(N x d矩阵)时,算法会将注意力权重P与值向量V(N x d矩阵)进行加权求和。此步骤的HBM访问主要集中在读取P和V矩阵以及写入O矩阵,共计O(Nd)次访问。

综上所述,标准Attention算法在整个执行过程中,其HBM访问的总次数达到了O(Nd + N^2)的复杂度。当处理的数据规模N非常大时,这种高频次的HBM访问会成为性能瓶颈,显著增加计算成本和时间消耗。因此,针对大规模数据,优化Attention算法以减少HBM访问次数成为了一个重要的研究方向。

FlashAttention的优化策略:

在Attention计算中,由于存在三个独立的核(kernel),每个核在处理时都要从HBM读取数据,并在计算后将结果写回HBM。通过将这三个核合并为一个,可以减少对HBM的访问次数。

在计算过程中,应优先利用SRAM进行计算,以减少对HBM的访问。尽管SRAM带宽较高,但其存储容量有限。采用分而治之的策略,通过Tiling将数据适配到SRAM容量。然而,当序列长度较大时,SRAM的限制可能导致序列被分割,这可能会干扰标准Softmax操作。

FlashAttention的优化策略如下:

Tiling(平铺):

采用“分治”策略,将大的注意力矩阵(如NxN的softmax/scores矩阵)分割成多个小得多的子矩阵。这些子矩阵的大小被精心设计,以确保它们能够完全存储在SRAM中,从而在计算过程中减少对HBM的依赖。

Block Softmax(分块Softmax):

然而,Attention机制中的softmax操作要求所有列(或行)的分数都必须参与归一化计算,这意味着子矩阵之间并非完全独立。为了解决这个问题,Flash Attention引入了分块SoftMax算法。这一算法在保持全局归一化的同时,对每个子矩阵独立进行softmax计算。通过一些巧妙的数学变换(如log-sum-exp技巧),能够确保分块SoftMax的结果与全局SoftMax高度一致,从而保证了Flash Attention的正确性。

Recomputation(重算):

为了进一步优化内存使用,Flash Attention还采用了Recomputation(重算)技术。这是一种在计算反向传播时减少内存占用的策略,通过避免存储所有正向传播的中间结果,并在需要时重新计算它们来节省内存。虽然这会增加一些计算成本,但相比于节省的内存和减少的HBM访问次数而言,这一代价通常是值得的。特别是在处理大规模数据集时,Recomputation技术能够显著提升训练效率和可扩展性。

FlashAttention的计算过程:

  1. 数据平铺(Tiling):
    将输入序列Q、K、V分割成较小的块,每块大小适合在快速访问的SRAM中处理。例如,如果序列长度为N,可以将其分割成t个大小为N/t的块。
  2. 分块计算相似度(Score Calculation):
    对于每个Q的块,计算与K的所有块的点积,得到局部相似度分数。这不是标准的自注意力计算,因为只计算了部分K对的相似度。
  3. 局部Softmax:
    对每个局部相似度分数块应用Softmax,得到局部注意力权重。这些权重是针对每个块内部的,可能不会反映整个序列的全局关系。
  4. 加权求和(Weighted Sum):
    使用局部注意力权重加权求和对应的V块,得到每个Q块的输出。
  5. 重算(Recomputation):
    在反向传播中,不存储所有中间状态。当需要计算梯度时,重新计算正向传播中的中间状态,从而减少内存占用。

示例说明:
假设有一个Transformer模型,输入序列长度为N=1024,特征维度为d=512。使用FlashAttention进行优化:

  1. Tiling:
    将Q、K、V矩阵分割成16个大小为64x512的块。
  2. 分块计算相似度:
    计算每个Q块与所有K块的点积,得到64x64的局部相似度矩阵。
  3. 局部Softmax:
    对每个64x64的局部相似度矩阵应用Softmax,得到局部注意力权重。
  4. 加权求和:
    使用局部注意力权重加权求和对应的V块,得到64x512的输出块。
  5. 拼接输出:
    将所有输出块拼接起来,形成最终的输出序列。
  6. 重算:
    在反向传播中,当需要计算某个Q块的梯度时,重新计算与该块相关的所有中间状态。

核心优势:
减少HBM访问:通过在SRAM中进行计算,FlashAttention减少了对HBM的访问次数。
内存效率:通过分块处理,FlashAttention适应了有限的内存资源,特别是对于大型序列。
灵活性:FlashAttention可以适应不同的硬件配置,通过调整块大小来优化性能。

FlashAttention通过这些策略在保持Transformer模型性能的同时,提高了模型的计算效率和内存效率。然而,这种优化可能需要特定的硬件支持,并且可能需要对模型架构进行调整以充分利用其优势。

参考资料
1 Fast and Memory-Efficient Exact Attention with IO-Awareness
2 通俗易懂聊flashAttention的加速原理

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/48019.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【.NET全栈】ASP.NET开发Web应用——计算器

文章目录 一、简单计算器二、复杂计算器 一、简单计算器 新建Web应用项目&#xff0c;窗体页面 窗体设计代码&#xff1a; <% Page Language"C#" AutoEventWireup"true" CodeBehind"Default.aspx.cs" Inherits"AdoDemo.Default"…

以Zookeeper为例 浅谈脑裂与奇数节点问题

一、脑裂现象的定义与影响 脑裂&#xff08;split-brain&#xff09;是指在分布式系统中&#xff0c;因网络分区或其他故障导致系统被切割成两个或多个相互独立的子系统&#xff0c;每个子系统可能独立选举出自己的领导节点。这一现象在依赖中心领导节点&#xff08;如Elastic…

亚信安全终端一体化解决方案入选应用创新典型案例

近日&#xff0c;由工业和信息化部信息中心主办的2024信息技术应用创新发展大会暨解决方案应用推广大会成功落幕&#xff0c;会上集中发布了一系列技术水平先进、应用效果突出、产业带动性强的信息技术创新工作成果。其中&#xff0c;亚信安全“终端一体化安全运营解决方案”在…

【漏洞复现】Next.js框架存在SSRF漏洞(CVE-2024-34351)

0x01 产品简介 ZEIT Next.js是ZEIT公司的一款基于Vue.js、Node.js、Webpack和Babel.js的开源Web应用框架。 0x02 漏洞概述 ZEIT Next.js 13.4版本至14.1.1之前版本存在代码问题漏洞&#xff0c;该漏洞源于存在服务器端请求伪造 (SSRF) 漏洞 0x03 搜索引擎 body"/_nex…

Keil开发IDE

Keil开发IDE 简述Keil C51Keil ARMMDK DFP安装 简述 Keil公司是一家业界领先的微控制器&#xff08;MCU&#xff09;软件开发工具的独立供应商。Keil公司由两家私人公司联合运营&#xff0c;分别是德国慕尼黑的Keil Elektronik GmbH和美国德克萨斯的Keil Software Inc。Keil公…

【06】LLaMA-Factory微调大模型——微调模型评估

上文【05】LLaMA-Factory微调大模型——初尝微调模型&#xff0c;对LLama-3与Qwen-2进行了指令微调&#xff0c;本文则介绍如何对微调后的模型进行评估分析。 一、部署微调后的LLama-3模型 激活虚拟环境&#xff0c;打开LLaMA-Factory的webui页面 conda activate GLM cd LLa…

elasticsearch, kibana, 6.8.18 版本下的创建索引,指定timestamp,java CRUD,maven版本等

ELK 这一套的版本更迭很快&#xff0c; 而且es常有不兼容的东西出现&#xff0c; 经常是搜一篇文章&#xff0c;看似能用&#xff0c;拿到我这边就不能用了。 很是烦恼。 我这边的ELK版本目前是 6.8.18&#xff0c;这次的操作记录一下。 &#xff08;涉密内容略有删改&#xf…

关闭 Linux 服务器上的 IPv6

虽然 IPv6 已经逐渐普及&#xff0c;但在某些 Linux 服务器上的业务系统仍然可能遇到一些奇怪的问题。特别是在集群场景中&#xff0c;因为集群各个节点之间需要互相通信&#xff0c;如果 IPv6 没有正确配置网络&#xff0c;可能导致一些未知问题&#xff0c;解决起来相当麻烦。…

YOLOV5学习记录

前言&#xff1a; 计算机视觉 什么是目标检测&#xff1f; 物体分类和目标检测的区别 目标检测&#xff0c;物体的类别和位置 学习选题&#xff0c;口罩检查&#xff0c;人脸识别 算法原理&#xff1a;知乎&#xff0c;csdn&#xff0c;目前还没到这种程度 大大滴崩溃&am…

Java文件管理

文件管理 Java中的对文件的管理&#xff0c;通过java.io包中的File类实现。Java中文件的管理&#xff0c;主要是针对文件或是目录路径名的管理&#xff0c;包括文件的属性信息&#xff0c;文件的检查&#xff0c;文件的删除等&#xff0c;但不包括文件的访问 file类 Java中的…

人工智能算法工程师(中级)课程17-模型的量化与部署之剪枝技巧与代码详解

大家好&#xff0c;我是微学AI&#xff0c;今天给大家介绍一下人工智能算法工程师(中级)课程17-模型的量化与部署之剪枝技巧与代码详解。模型剪枝是深度学习领域中一项关键的技术&#xff0c;旨在减少神经网络中的冗余权重&#xff0c;从而降低计算成本和内存占用&#xff0c;同…

Linux--实现线程池(万字详解)

目录 1.概念 2.封装原生线程方便使用 3.线程池工作日志 4.线程池需要处理的任务 5.进程池的实现 6.线程池运行测试 7.优化线程池&#xff08;单例模式 &#xff09; 单例模式概念 优化后的代码 8.测试单例模式 1.概念 线程池:* 一种线程使用模式。线程过多会带来调度…

FastAPI(六十五)实战开发《在线课程学习系统》基础架构的搭建

在之前三篇&#xff0c;我们分享的就是需求的分析&#xff0c;基本接口的整理&#xff0c;数据库链接的配置。这次我们分享项目的基本框架&#xff0c;目录结构大致如下&#xff1a; common目录&#xff1a; 通用目录&#xff0c;放一些通用的处理 models目录&#xf…

【基础】模拟题 角色授权类

3413. DHCP服务器 题目 提交记录 讨论 题解 视频讲解 动态主机配置协议&#xff08;Dynamic Host Configuration Protocol, DHCP&#xff09;是一种自动为网络客户端分配 IP 地址的网络协议。 当支持该协议的计算机刚刚接入网络时&#xff0c;它可以启动一个 DHCP 客户…

【Git远程操作】克隆远程仓库 https协议 | ssh协议

目录 前言 克隆远程仓库https协议 克隆远程仓库ssh协议 前言 这四个都是Git给我们提供的数据传输的协议&#xff0c;最常使用的还是https和ssh协议。本篇主要介绍还是这两种协议。 ssh协议&#xff1a;使用的公钥加密和公钥登录的机制&#xff08;体现的是实用性和安全性&am…

Nginx的HA高可用的搭建

1. 什么是高可用 高可用&#xff08;High Availability, HA&#xff09;是一种系统设计策略&#xff0c;旨在确保服务或应用在面对硬件故障、软件缺陷或任何其他异常情况时&#xff0c;仍能持续稳定地运行。它通过实现冗余性、故障转移、负载均衡、数据一致性、监控自动化、预防…

Java并发04之线程同步机制

文章目录 1 线程安全1.1 线程安全的变量1.2 Spring Bean1.3 如果保证线程安全 2 synchronized关键字2.1 Java对象头2.1.1 对象组成部分2.1.2 锁类型2.1.3 锁对象 2.2 synchronized底层实现2.2.1 无锁状态2.2.2 偏向锁状态2.2.3 轻量级锁状态2.2.4 重量级锁2.2.5 锁类型总结2.2.…

C++11 容器emplace方法刨析

如果是直接插入对象 push_back()和emplace_back()没有区别但如果直接传入构造函数所需参数&#xff0c;emplace_back()会直接在容器底层构造对象&#xff0c;省去了调用拷贝构造或者移动构造的过程 class Test { public:Test(int a){cout<<"Test(int)"<<…

链表(4) ----跳表

跳表&#xff08;Skip List&#xff09;是一种随机化的数据结构&#xff0c;用于替代平衡树&#xff08;如 AVL 树或红黑树&#xff09;。它是基于多层链表的&#xff0c;每一层都是上一层的子集。跳表可以提供与平衡树相似的搜索性能&#xff0c;即在最坏情况下&#xff0c;搜…

zlgcan,周立功Can设备,Qt中间件,QtCanBus插件,即插即用

新增zlgcan插件&#xff0c;需要请看下方视频回复联系&#xff01; 视频链接地址&#xff1a; Qt,canbus manager,周立功,zlgcan插件演示,需要请留言_哔哩哔哩_bilibili