我自己的原文哦~ https://blog.51cto.com/whaosoft/12459590
#大模型~微调~用带反馈的自训练
面对当前微调大模型主要依赖人类生成数据的普遍做法,谷歌 DeepMind 探索出了一种减少这种依赖的更高效方法。大模型微调非得依赖人类数据吗?用带反馈的自训练更好
如你我所见,大语言模型(LLM)正在改变深度学习的格局,在生成人类质量的文本和解决各种语言任务方面展现出了卓越的能力。虽然业界通过对人类收集的数据进行监督微调进一步提升了在具体任务上的性能,但获取高质量人类数据却面临着重大瓶颈。这对于要解决复杂问题的任务来说尤为明显,需要大量资源和专业知识。
怎么解决呢?模型生成得合成数据是一种有潜力的替代方案,只要能保证数据的质量,就能实现可扩展性和成本效益。
虽然 LLM 能够自我评估生成的数据,但在本文中,谷歌 DeepMind 探索了一种更简单的设置,将外部标量反馈信号用作每个生成样本的质量指标。
论文地址:https://arxiv.org/pdf/2312.06585.pdf
为了研究在模型生成数据上的训练,研究者考虑了一种简单但强大的语言模型自训练方法,仅需要两项功能,一是基于模型生成样本,二是利用评分机制对这些样本进行评估。
为了确保清晰度和一致性,研究者采用了一种强化自训练方法 ReST^𝐸𝑀,并证明该方法可以将期望最大化(expectation-maximization,EM)用于强化学习。具体来讲,ReST^𝐸𝑀在期望和最大化步骤之间交替进行。
- 生成(E-step):语言模型为每个输入上下文生成多个输出样本,然后使用二元奖励过滤这些样本以收集训练数据集。
- 改进(M-step):原始语言模型在来自前一个 E-step 的训练数据集上进行监督微调,然后在下一个 E-step 中使用。
研究者证实,ReST^𝐸𝑀及变体在增强各个领域的语言模型方面取得了成功,包括机器翻译、语义分析、偏好对齐和基础推理。
此外,以往工作主要将 ReST^𝐸𝑀用于相对较小的模型(最高 70 亿参数),对于较大模型的可扩展性受限。因此,本文旨在探究模型生成的合成数据与人类生成的数据在以下两个具有挑战性但研究较少领域的有效性和可扩展性,这两个领域分别是竞争水平数学解题(MATH)和代码生成(APPS)。
实证结果表明,当将 ReST^𝐸𝑀用于不同规模的 PaLM 2 模型时,在数学推理和代码生成任务中实现了显著的能力改进。与在人类编写数据上训练的模型相比,在模型生成的合成数据上微调的模型取得了更大的性能增益。有趣的是,超过了一定数量的 ReST^𝐸𝑀 迭代后,性能会降低,这表明了在少量训练问题上可能会出现过拟合。
此外,使用 ReST^𝐸𝑀微调的模型提升了 pass@k 指标和多数投票性能。这些微调后的模型在相关但 held-out 的基准上也表现出了性能增强,包括数学题(GSM8K 和 Hungarian HS finals)、编码(HumanEval)和 Big-Bench Hard 任务。
总之,本文研究结果表明,具有反馈的自训练是减少对人类数据依赖的一种有潜力的方法。
用于强化自训练的期望最大值(EM)
首先,该研究基于 Dayan 和 Hinton 之前的研究,用语言模型描述了基于 EM 的强化学习框架。具体而言,他们先是定义了一个二进制最优变量 O,使得𝑝(𝑂= 1|𝒙,𝒚)∝𝑓(𝑟(𝒙,𝒚));然后对非递减函数 𝑓 : ℝ → ℝ+ ,实现最大化观察𝑂= 1(获得高奖励),得到如下公式:
然而,求解上式中的序列 𝒚 的和很棘手。因而本文考虑相对于参数 𝜃 和变分分布 𝑞( 𝑦|𝑥) 最大化其 ELBO 𝐿( 𝑝𝜃, 𝑞),而不是最大化 log 𝑝(𝑂 = 1; 𝒙)。具体来说:
公式(2)中的 EM 算法在 E-step(Expectation) 和 M-step(Maximization)之间交替进行。
ReST^𝐸𝑀:受 EM 框架的启发,接下来论文讨论了 Gulcehre 等人提出的 ReST 方法的简化版本。为了清楚起见,本文将这种方法称为 ReST^𝐸𝑀,它将 RL pipeline 中的数据收集 (E-step) 和策略优化 (M-step) 进行解耦。如算法 1 所示:
实验和分析
本文进行实验的主要目标是回答以下问题:
- 与人类生成的数据进行微调相比,ReST^𝐸𝑀的效果如何?
- 需要多少次迭代才能获得最佳性能?ReST^𝐸𝑀多长时间会导致训练集过度拟合?
- ReST^𝐸𝑀如何影响 pass@k 和多数投票表现?
- 如果用户在特定任务上使用模型生成的数据进行微调,是否会迁移到其他任务上?在广泛的任务中评估本文的微调模型时,与基本模型相比,性能是否会下降?
- 大约需要多少输入数据才能从 ReST^𝐸𝑀 获得大部分性能提升?ReST^𝐸𝑀的一次迭代是否足够?
该研究使用 PaLM 2 模型和 Google Cloud 上的公共 API 进行实验,包括 PaLM 2-S (Bison)、PaLM 2-S* (Codey) 和 PaLM 2-L (Unicorn)。训练数据集采用 MATH 数据集和 APPS 数据集。
图 2 和图 3 分别显示了 ReST^𝐸𝑀在 MATH 和 APPS 数据集上训练的性能。可以得出 MATH 受益于 ReST^𝐸𝑀 的多次迭代,无论是在 MATH 测试集上的性能还是迁移到 GSM8K 方面。另一方面可以看到 APPS 的大部分收益来自第一次迭代,而执行更多次迭代会导致 APPS 和 HumanEval 的性能下降。
训练和测试性能的差距。图 4 显示,虽然训练集性能随着 ReST^𝐸𝑀迭代次数线性增加,但测试集性能却没有。对于 MATH,第一次迭代后测试性能改进很小,而对于 APPS,在第二次迭代中观察到性能回归。该研究猜测性能的回归可能是由于过度拟合造成的。由于 APPS 数据集的大小约为 MATH 数据集的三分之一,因此它更容易受到此问题的影响。
图 5 显示了 Palm-2-L 模型在 pass@K 指标上的性能。结果显示,微调后获得的 ReST^𝐸𝑀 模型对于所有 K 值都更强,其中性能差距通常在 K=1 时最大。
#大模型训练~数据并行
大模型场景里巨大的存储和GPU间通讯量是系统设计时需要考虑的重点,本文递进介绍了三种主流数据并行的实现方法:DP、DD皮、ZeRo。
当模型太大,一块GPU放不下时,流水线并行将模型的不同层放到不同的GPU上,通过切割mini-batch实现对训练数据的流水线处理,提升GPU计算通讯比。同时通过re-materialization机制降低显存消耗。
但在实际应用中,流水线并行并不特别流行,主要原因是模型能否均匀切割,影响了整体计算效率,这就需要算法工程师做手调。因此,今天我们来介绍一种应用最广泛,最易于理解的并行范式:数据并行。
数据并行的核心思想是:在各个GPU上都拷贝一份完整模型,各自吃一份数据,算一份梯度,最后对梯度进行累加来更新整体模型。理念不复杂,但到了大模型场景,巨大的存储和GPU间的通讯量,就是系统设计要考虑的重点了。在本文中,我们将递进介绍三种主流数据并行的实现方式:
DP(Data Parallelism):最早的数据并行模式,一般采用参数服务器(Parameters Server)这一编程框架。实际中多用于单机多卡
DDP(Distributed Data Parallelism):分布式数据并行,采用Ring AllReduce的通讯方式,实际中多用于多机场景
ZeRO:零冗余优化器。由微软推出并应用于其DeepSpeed框架中。严格来讲ZeRO采用数据并行+张量并行的方式,旨在降低存储。
数据并行(DP)整体架构
一个经典数据并行的过程如下:
若干块计算GPU,如图中GPU0~GPU2;1块梯度收集GPU,如图中AllReduce操作所在GPU。
在每块计算GPU上都拷贝一份完整的模型参数。
把一份数据X(例如一个batch)均匀分给不同的计算GPU。
每块计算GPU做一轮FWD和BWD后,算得一份梯度G。
每块计算GPU将自己的梯度push给梯度收集GPU,做聚合操作。这里的聚合操作一般指梯度累加。当然也支持用户自定义。
梯度收集GPU聚合完毕后,计算GPU从它那pull下完整的梯度结果,用于更新模型参数W。更新完毕后,计算GPU上的模型参数依然保持一致。
聚合再下发梯度的操作,称为AllReduce。
前文说过,实现DP的一种经典编程框架叫“参数服务器”,在这个框架里,计算GPU称为Worker,梯度聚合GPU称为Server。在实际应用中,为了尽量减少通讯量,一般可选择一个Worker同时作为Server。比如可把梯度全发到GPU0上做聚合。需要再额外说明几点:
1个Worker或者Server下可以不止1块GPU。
Server可以只做梯度聚合,也可以梯度聚合+全量参数更新一起做
在参数服务器的语言体系下,DP的过程又可以被描述下图:
通讯瓶颈与梯度异步更新
DP的框架理解起来不难,但实战中确有两个主要问题:
存储开销大。每块GPU上都存了一份完整的模型,造成冗余。关于这一点的优化,我们将在后文ZeRO部分做讲解。
通讯开销大。Server需要和每一个Worker进行梯度传输。当Server和Worker不在一台机器上时,Server的带宽将会成为整个系统的计算效率瓶颈。
我们对通讯开销再做详细说明。如果将传输比作一条马路,带宽就是马路的宽度,它决定每次并排行驶的数据量。例如带宽是100G/s,但每秒却推给Server 1000G的数据,消化肯定需要时间。那么当Server在搬运数据,计算梯度的时候,Worker们在干嘛呢?当然是在:
人类老板不愿意了:“打工系统里不允许有串行存在的任务!”,于是梯度异步更新这一管理层略诞生了。
上图刻画了在梯度异步更新的场景下,某个Worker的计算顺序为:
在第10轮计算中,该Worker正常计算梯度,并向Server发送push&pull梯度请求。
但是,该Worker并不会实际等到把聚合梯度拿回来,更新完参数W后再做计算。而是直接拿旧的W,吃新的数据,继续第11轮的计算。这样就保证在通讯的时间里,Worker也在马不停蹄做计算,提升计算通讯比。
当然,异步也不能太过份。只计算梯度,不更新权重,那模型就无法收敛。图中刻画的是延迟为1的异步更新,也就是在开始第12轮对的计算时,必须保证W已经用第10、11轮的梯度做完2次更新了。
参数服务器的框架下,延迟的步数也可以由用户自己决定,下图分别刻划了几种延迟情况:
(a) 无延迟
(b) 延迟但不指定延迟步数。也即在迭代2时,用的可能是老权重,也可能是新权重,听天由命。
(c) 延迟且指定延迟步数为1。例如做迭代3时,可以不拿回迭代2的梯度,但必须保证迭代0、1的梯度都已拿回且用于参数更新。
总结一下,异步很香,但对一个Worker来说,只是等于W不变,batch的数量增加了而已,在SGD下,会减慢模型的整体收敛速度。异步的整体思想是,比起让Worker闲着,倒不如让它多吃点数据,虽然反馈延迟了,但只要它在干活在学习就行。
batch就像活,异步就像画出去的饼,且往往不指定延迟步数,每个Worker干越来越多的活,但模型却没收敛取效,这又是刺伤了哪些打工仔们的心(狗头
分布式数据并行(DDP)
受通讯负载不均的影响,DP一般用于单机多卡场景。因此,DDP作为一种更通用的解决方案出现了,既能多机,也能单机。DDP首先要解决的就是通讯问题:将Server上的通讯压力均衡转到各个Worker上。实现这一点后,可以进一步去Server,留Worker。
前文我们说过,聚合梯度 + 下发梯度这一轮操作,称为AllReduce。接下来我们介绍目前最通用的AllReduce方法:Ring-AllReduce。它由百度最先提出,非常有效地解决了数据并行中通讯负载不均的问题,使得DDP得以实现。
Ring-AllReduce
如下图,假设有4块GPU,每块GPU上的数据也对应被切成4份。AllReduce的最终目标,就是让每块GPU上的数据都变成箭头右边汇总的样子。
Ring-ALLReduce则分两大步骤实现该目标:Reduce-Scatter和All-Gather。
- Reduce-Scatter
定义网络拓扑关系,使得每个GPU只和其相邻的两块GPU通讯。每次发送对应位置的数据进行累加。每一次累加更新都形成一个拓扑环,因此被称为Ring。看到这觉得困惑不要紧,我们用图例把详细步骤画出来。
一次累加完毕后,蓝色位置的数据块被更新,被更新的数据块将成为下一次更新的起点,继续做累加操作。
3次更新之后,每块GPU上都有一块数据拥有了对应位置完整的聚合(图中红色)。此时,Reduce-Scatter阶段结束。进入All-Gather阶段。目标是把红色块的数据广播到其余GPU对应的位置上。
- All-Gather
如名字里Gather所述的一样,这操作里依然按照“相邻GPU对应位置进行通讯”的原则,但对应位置数据不再做相加,而是直接替换。All-Gather以红色块作为起点。
以此类推,同样经过3轮迭代后,使得每块GPU上都汇总到了完整的数据,变成如下形式:
建议读者们手动推一次,加深理解。
Ring-AllReduce通讯量分析
在多Server的模式下,进一步,每个Server可以只负责维护和更新某一块梯度(也可以某块梯度+参数一起维护),此时虽然每个Server仍然需要和所有Worker通讯,但它的带宽压力会小非常多。经过调整设计后,依然可以用来做DDP。虽然这篇文章是用递进式的方式来介绍两者,但不代表两者间一定要决出优劣。我想表达的观点是,方法是多样性的。 对参数服务器有兴趣的朋友,可以阅读参考的第1个链接。
最后,请大家记住Ring-AllReduce的方法,因为在之后的ZeRO,Megatron-LM中,它将频繁地出现,是分布式训练系统中重要的算子。
总结
1、在DP中,每个GPU上都拷贝一份完整的模型,每个GPU上处理batch的一部分数据,所有GPU算出来的梯度进行累加后,再传回各GPU用于更新参数
2、DP多采用参数服务器这一编程框架,一般由若个计算Worker和1个梯度聚合Server组成。Server与每个Worker通讯,Worker间并不通讯。因此Server承担了系统所有的通讯压力。基于此DP常用于单机多卡场景。
3、异步梯度更新是提升计算通讯比的一种方法,延迟更新的步数大小决定了模型的收敛速度。
4、Ring-AllReduce通过定义网络环拓扑的方式,将通讯压力均衡地分到每个GPU上,使得跨机器的数据并行(DDP)得以高效实现。
5、DP和DDP的总通讯量相同,但因负载不均的原因,DP需要耗费更多的时间搬运数据
由微软开发的ZeRO(零冗余优化),它是DeepSpeed这一分布式训练框架的核心,被用来解决大模型训练中的显存开销问题。ZeRO的思想就是用通讯换显存。 如果初读ZeRO,觉得它逻辑跳跃,晦涩难懂,那么下文或许可以帮到你~
存储消耗
存储分类
首先,我们来看在大模型训练的过程中,GPU都需要存什么内容。
存储主要分为两大块:Model States和Residual StatesModel States指和模型本身息息相关的,必须存储的内容,具体包括:
- optimizer states:Adam优化算法中的momentum和variance
- gradients:模型梯度
- parameters:模型参数W
Residual States指并非模型必须的,但在训练过程中会额外产生的内容,具体包括:
- activation:激活值。在流水线并行中我们曾详细介绍过。在backward过程中使用链式法则计算梯度时会用到。有了它算梯度会更快,但它不是必须存储的,因为可以通过重新做Forward来算它。
- temporary buffers: 临时存储。例如把梯度发送到某块GPU上做加总聚合时产生的存储。
- unusable fragment memory:碎片化的存储空间。虽然总存储空间是够的,但是如果取不到连续的存储空间,相关的请求也会被fail掉。对这类空间浪费可以通过内存整理来解决。
精度混合训练
在分析这个问题前,我们需要来了解精度混合训练。
对于模型,我们肯定希望其参数越精准越好,也即我们用fp32(单精度浮点数,存储占4byte)来表示参数W。但是在forward和backward的过程中,fp32的计算开销也是庞大的。那么能否在计算的过程中,引入fp16或bf16(半精度浮点数,存储占2byte),来减轻计算压力呢?于是,混合精度训练就产生了,它的步骤如下图:
- 存储一份fp32的parameter,momentum和variance(统称model states)
- 在forward开始之前,额外开辟一块存储空间,将fp32 parameter减半到fp16 parameter。
- 正常做forward和backward,在此之间产生的activation和gradients,都用fp16进行存储。
- 用fp16 gradients去更新fp32下的model states。
- 当模型收敛后,fp32的parameter就是最终的参数输出。
通过这种方式,混合精度训练在计算开销和模型精度上做了权衡。如果不了解fp32,fp16和bf16的细节也没关系,不影响下文的阅读。只要记住它们所占的存储空间和精度表达上的差异即可。
存储大小
另外,这里暂不将activation纳入统计范围,原因是:
- activation不仅与模型参数相关,还与batch size相关
- activation的存储不是必须的。存储activation只是为了在用链式法则做backward的过程中,计算梯度更快一些。但你永远可以通过只保留最初的输入X,重新做forward来得到每一层的activation(虽然实际中并不会这么极端)。
- 因为activation的这种灵活性,纳入它后不方便衡量系统性能随模型增大的真实变动情况。因此在这里不考虑它,在后面会单开一块说明对activation的优化。
ZeRO-DP
知道了什么东西会占存储,以及它们占了多大的存储之后,我们就可以来谈如何优化存储了。
注意到,在整个训练中,有很多states并不会每时每刻都用到,举例来说;
- Adam优化下的optimizer states只在最终做update时才用到
- 数据并行中,gradients只在最后做AllReduce和updates时才用到
- 参数W只在做forward和backward的那一刻才用到
- 诸如此类
所以,ZeRO想了一个简单粗暴的办法:如果数据算完即废,等需要的时候,我再想办法从个什么地方拿回来,那不就省了一笔存储空间吗?
沿着这个思路,我们逐一来看ZeRO是如何递进做存储优化的。
(3)得到完整梯度G,就可以对W做更新。我们知道W的更新由optimizer states和梯度共同决定。由于每块GPU上只保管部分optimizer states,因此只能将相应的W(蓝色部分)进行更新。(2)和(3)可以用下图表示:
此时,数据并行的整体流程如下:
(1)每块GPU上存一份完整的参数W。将一个batch的数据分成3份,每块GPU各吃一份,做完一轮foward和backward后,算得一份完整的梯度(下图中绿色+白色)。
再次比对下显存和通讯量:
和朴素DP相比,存储降了8倍,单卡通讯量持平,好像更牛皮了呢!那么,还可以优化吗?
数据并行的流程如下:
(1)每块GPU上只保存部分参数W。将一个batch的数据分成3份,每块GPU各吃一份。
到这一步,我们用1.5倍的通讯开销,换回近120倍的显存。只要梯度计算和异步更新做的好,通讯时间大部分可以被计算时间隐藏,因此这样的额外通讯开销,也是划算的。
到这里,我们可以放出原始论文中的说明图了,经过以上分析,这张说明图是不是瞬间就能看懂了。不得不吐槽下,虽然ZeRO的设计不复杂,但对应论文写得真是逻辑跳跃,晦涩难懂....
仔细一想,ZeRO其实掌握了降本增效的精髓:用完即弃,需要再补。反正我补一个和你差不多的,也不会花费很多通(找)讯(人)时间,还大大降低了我的成本。模型的每一层多算(造)几(轮)遍(子)有啥关系呢,反正在我的预算里每个人都一刻不停地干活,就行啦!
ZeRO VS 模型并行
知道模型并行的朋友,可能会想,既然ZeRO都把参数W给切了,那它应该是个模型并行呀?为什么要归到数据并行里呢?
其实ZeRO是模型并行的形式,数据并行的实质。
模型并行,是指在forward和backward的过程中,我只需要用自己维护的那块W来计算就行。即同样的输入X,每块GPU上各算模型的一部分,最后通过某些方式聚合结果。
但对ZeRO来说,它做forward和backward的时候,是需要把各GPU上维护的W聚合起来的,即本质上还是用完整的W进行计算。它是不同的输入X,完整的参数W,最终再做聚合。
因为下一篇要写模型并行Megatron-LM,因此现在这里罗列一下两者的对比。
ZeRO-R
说完了以上对model states的显存优化,现在来看对residual states的优化。
在前文提过,设置机制,对碎片化的存储空间进行重新整合,整出连续的存储空间。防止出现总存储足够,但连续存储不够而引起的存储请求fail
ZeRO-Offload与ZeRO-Infinity
最后,简单介绍一下ZeRO-Offload。它的核心思想是:显存不够,内存来凑。如果我把要存储的大头卸载(offload)到CPU上,而把计算部分放到GPU上,这样比起跨机,是不是能既降显存,也能减少一些通讯压力呢? ZeRO-Offload的做法是:
- forward和backward计算量高,因此和它们相关的部分,例如参数W(fp16),activation,就全放入GPU。
- update的部分计算量低,因此和它相关的部分,全部放入CPU中。例如W(fp32),optimizer states(fp32)和gradients(fp16)等。
具体切分如下图:
ZeRO-infinity也是同理,它们在解决的事情都是:找个除GPU之外的地方,存数据。感兴趣的朋友可以深入研究,这里就不展开了。
#大模型训练~显卡
为什么用 A100 不用 4090? 结论,大模型的训练用 4090 是不行的,但推理(inference/serving)用 4090 不仅可行,在性价比上还能比 H100 稍高。4090 如果极致优化,性价比甚至可以达到 H100 的 2 倍。
事实上,H100/A100 和 4090 最大的区别就在通信和内存上,算力差距不大。
NVIDIA 的算力表里面油水很多,比如 H100 TF16 算力写的是 1979 Tflops,但那是加了 sparsity(稀疏)的,稠密的算力只有一半;4090 官方宣传 Tensor Core 算力高达 1321 Tflops,但那是 int8 的,FP16 直只有 330 Tflops。这篇文章的第一版就是用了错的数据,H100 和 4090 的数据都用错了,得到的结论非常离谱。
H100 这个售价其实是有 10 倍以上油水的。
2016 年我在 MSRA 的时候,见证了微软给每块服务器部署了 FPGA,把 FPGA 打到了沙子的价格,甚至成为了供应商 Altera 被 Intel 收购的重要推手。2017 年我还自己挖过矿,知道什么显卡最划算。后来在华为,我也是鲲鹏、昇腾生态软件研发的核心参与者。因此,一个芯片成本多少,我心里大概是有数的。
鲲鹏的首席架构师夏 Core 有一篇知名文章《谈一下英伟达帝国的破腚》,很好的分析了 H100 的成本:
据说微软和 OpenAI 包下了 H100 2024 年产能的一半,猜猜他们会不会发挥当年跟 Altera 砍价的传统艺能?会真的花 $40,000 * 500,000 = 200 亿美金去买卡?
可以说,H100 就像是中国一线城市的房子,本身钢筋水泥不值多少钱,房价完全是被供求关系吹起来的。我在 LA 已经住了两周,公司租的房子使用面积是我北京房子的 4 倍,但售价只贵了 30%,还带个小院,相当于单位面积的房价是北京的 1/3。我跟本地的老外聊天,他们都很吃惊,你们的平均收入水平比 LA 低这么多,怎么买得起北京的房子的?
问题来了,如果 4090 这么香的话,为啥大家还要争着买 H100,搞得 H100 都断货了?甚至 H100 都要对华禁售,搞出个 H800 的阉割版?
大模型训练为什么不能用 4090
GPU 训练性能和成本对比
LambdaLabs 有个很好的 GPU 单机训练性能和成本对比,在此摘录如下。
首先看吞吐量,看起来没有什么违和的,在单卡能放下模型的情况下,确实是 H100 的吞吐量最高,达到 4090 的两倍。看算力和内存也能看出来,H100 的 FP16 算力大约是 4090 的 3 倍,内存带宽是 3.35 倍,训练过程中由于 batch size 比较大,大多数算子是 compute bound(计算密集型),少数算子是 memory bound(内存密集型),这个结果是不意外的。
LambdaLabs PyTorch 单卡训练吞吐量对比图
LambdaLabs PyTorch 单卡训练吞吐量对比表
然后看性价比,就有意思了,原来排在榜首的 H100 现在几乎垫底了,而且 4090 和 H100 的差距高达接近 10 倍。这就是因为 H100 比 4090 贵太多了。
由于 H100 货源紧张,云厂商的 H100 租用价格就更黑了,按照标价大约 7 个月就可以回本。就算大客户价能便宜一半,一年半也足够回本了。
在价格战中过惯了苦日子的 IaaS 云服务商看到这样的 H100 回本速度,估计要感叹,这真是比区块链回本还快呐。
LambdaLabs PyTorch 单卡训练单位成本吞吐量对比图
LambdaLabs PyTorch 单卡训练单位成本吞吐量对比表
大模型训练的算力需求
既然 4090 单卡训练的性价比这么高,为啥不能用来做大模型训练呢?抛开不允许游戏显卡用于数据中心这样的许可证约束不谈,从技术上讲,根本原因是大模型训练需要高性能的通信,但 4090 的通信效率太低。
大模型训练需要多少算力?训练总算力(Flops)= 6 * 模型的参数量 * 训练数据的 token 数。
我今年初第一次看到有人煞有介事地讲这个公式的时候,觉得这不是显然的吗?又看到 OpenAI 的高级工程师能拿 90 多万美金的年薪,顿时整个人都不好了,还是 AI 香呀。之前我也面试过一些做 AI 的工程师,包括一些做 AI 系统优化的专家,连 Q、K、V 是啥都说不清楚,LLaMA 每个 tensor 的大小也算不出来,就这样还能拿到 offer。
APNet 2023 panel 的主题是 Network, AI, and Foundational Models: Opportunties and Challenges。前面几个问题都中规中矩的,panelists 有点放不开,我就提了一个问题,网络历史上的重要成就基本上都基于对应用场景深刻的理解,但我们现在做网络的很多都不了解 AI,甚至连每个 tensor 的大小和每个 step 传输的数据量都不知道,如何让 network community 更了解 AI 呢?
这下热闹了,台下的谭博首先发言,说我在华为肯定能知道所有这些东西;然后传雄老师也跟了一句,要是做网络的懂了太多 AI,那可能他就变成一个 AI guy 了。接着主持人陈凯教授问,你们有谁真的训练过大模型?沉默了一会儿,阿里的兄弟先说,我算是半个训练过大模型的,我们做的东西是支撑阿里大模型 infra 的。后面又有 panelist 说,做 AI 系统的网络优化是否有必要自己懂 AI 呢,是不是只要会做 profiling 就行了?
我个人观点仍然是,AI 并不难学,要想做好 AI 系统优化,可以不懂 attention 的 softmax 里面为什么要除以 sqrt(d_k),但不能不会计算模型所需的算力、内存带宽、内存容量和通信数据量。Jeff Dean 就有个很有名的 Numbers Every Programmer Should Know,数量级的估算对任何系统优化来说都很关键,不然根本不知道瓶颈在哪里。
回到大模型训练所需的总算力,其实很简单,6 * 模型的参数量 * 训练数据的 token 数就是所有训练数据过一遍所需的算力。这里的 6 就是每个 token 在模型正向传播和反向传播的时候所需的乘法、加法计算次数。
一堆矩阵相乘,简单来想就是左边若干个神经元,右边若干个神经元,组成一个完全二分图。选出其中任意一个左边的神经元 l 和右边的神经元 r,正向传播的时候:
- l 把它的输出乘上 l 和 r 之间的权重 w,发给 r;
- r 不可能只连一个神经元吧,总要把多个 l 的加到一起,这就是 reduce,需要一次加法。
反向传播的时候:
- r 把它收到的梯度乘上 l 和 r 之间的权重 w,发给 l;
- l 也不可能只连一个 r,需要把梯度 reduce 一下,做个加法;
- 别忘了权重 w 需要更新,那就要计算 w 的梯度,把 r 收到的梯度乘上 l 正向传播的输出(activation);
- 一个 batch 一般有多个 sample,权重 w 的更新需要把这些 sample 的梯度加到一起。
一共 3 次乘法,3 次加法,不管 Transformer 多复杂,矩阵计算就是这么简单,其他的向量计算、softmax 之类的都不是占算力的主要因素,估算的时候可以忽略。
想起来我 2019 年刚加入 MindSpore 团队的时候,领导让我开发一个正向算子的反向版本,我求导给求错了,搞得算子的计算结果总是不对,还以为是我们的编译器出 bug 了。当发现求导求错的时候,领导像以为我没学过微积分一样看着我,确实我的微积分学的不好,这也是我从数学专业转到计算机专业的原因之一。
在 MindSpore 的时候,自动微分一共就不到 1000 行代码,按照微分公式递归计算下去就行了,但自动微分作为一个重要特性被吹了半天,我都感觉不好意思了。
模型的参数量和训练数据的 token 数之间也有个比例关系,这也很容易理解,只要把模型想象成数据的压缩版本就行了,压缩比总是有极限的。模型的参数量太小,就吃不下训练数据里面所有的知识;模型的参数量如果大于训练数据的 token 数,那又浪费,还容易导致 over-fitting。
训练 LLaMA-2 70B 需要多少张卡
有了模型训练所需的总算力,除以每个 GPU 的理论算力,再除以 GPU 的有效算力利用比例,就得到了所需的 GPU-hours,这块已经有很多开源数据。LLaMA 2 70B 训练需要 1.7M GPU hours(A100),要是用 1 个 GPU,那得算 200 年。要在一个月这种比较能接受的时间周期内训练出来,就得至少有 2400 块 A100。
如果用 4090,单卡 FP16 算力是跟 A100 差不多(330 vs 312 Tflops),但是内存带宽比 A100 低一半(1 vs 2 TB/s),内存容量更是差好几倍(24 vs 80 GB),计算梯度时需要使用的 TF32 算力也低一半(83 vs 156 Tflops),综合起来 4090 单卡的训练速度还比 A100 稍低(参考前面 LambdaLabs 的评测)。
就按照 2048 块 4090 算吧,这 2048 块 4090 之间的通信就成了最大的问题。
为什么?一般有 tensor parallelism、pipeline parallelism、data parallelism 几种并行方式,分别在模型的层内、模型的层间、训练数据三个维度上对 GPU 进行划分。三个并行度乘起来,就是这个训练任务总的 GPU 数量。
三种并行方式从三个维度划分计算空间的示意图,来源:DeepSpeed
Data parallelism(数据并行)
数据并行是最容易想到的并行方式。每个 GPU 分别计算不同的输入数据,计算各自的梯度(也就是模型参数的改变量),再把梯度汇总起来,取个平均值,广播给各个 GPU 分别更新。
Data Parallelism 示意图,来源:Colossal AI
但只用数据并行是肯定不行的,因为一块 GPU 放不下整个 LLaMA 70B 模型。
就模型训练需要多少 GPU 内存,我发现能算清楚的人就不多。有的人甚至以为只需要把模型的参数和反向传播的梯度存下来就够了。事实上,训练需要的内存包括模型参数、反向传播的梯度、优化器所用的内存、正向传播的中间状态(activation)。
优化器所用的内存其实也很简单,如果用最经典的 Adam 优化器,它需要用 32 位浮点来计算,否则单纯使用 16 位浮点来计算的误差太大,模型容易不收敛。因此,每个参数需要存 4 字节的 32 位版本(正向传播时用 16 位版本,优化时用 32 位版本,这叫做 mixed-precision),还需要存 4 字节的 momentum 和 4 字节的 variance,一共 12 字节。如果是用类似 SGD 的优化器,可以不存 variance,只需要 8 字节。
正向传播的中间状态(activation)是反向传播时计算梯度必需的,而且跟 batch size 成正比。Batch size 越大,每次读取模型参数内存能做的计算就越多,这样对 GPU 内存带宽的压力就越小。可是不要忘了,正向传播的中间状态数量是跟 batch size 成正比的,GPU 内存容量又会成为瓶颈。
大家也发现正向传播中间状态占的内存太多了,可以玩一个用算力换内存的把戏,就是不要存储那么多梯度和每一层的正向传播的中间状态,而是在计算到某一层的时候再临时从头开始重算正向传播的中间状态,这样这层的正向传播中间状态就不用保存了。如果每一层都这么干,那么就只要 2 个字节来存这一层的梯度。但是计算中间状态的算力开销会很大。因此实际中一般是把整个 Transformer 分成若干组,一组有若干层,只保存每组第一层的中间状态,后面的层就从该组第一层开始重新计算,这样就平衡了算力和内存的开销。
如果还是算不清楚,可以读读这篇论文:Reducing Activation Recomputation in Large Transformer Models。
当然有人说,GPU 内存放不下可以换出到 CPU 内存,但是就目前的 PCIe 速度,换出到 CPU 内存的代价有时候还不如在 GPU 内存里重算。如果是像 Grace Hopper 那种极高带宽的统一内存,那么换入换出倒是一个不错的主意,不管训练的正向传播中间状态还是 KV Cache,都有很多优化的空间。
Pipeline parallelism(流水线并行)
既然一块 GPU 放不下,用多块 GPU 总行了吧?这就是 model parallelism(模型并行),可以大致分为 pipeline parallelism 和 tensor parallelism。
大家最容易想到的并行方式就是 pipeline parallelism,模型不是有很多层吗,那就分成几组,每组算连续的几层,穿成一条链。
Pipeline Parallelism 示意图,来源:Colossal AI
这样就有个问题,一条链上只有一个 GPU 在干活,剩下的都在干等。当然聪明的你一定也想到了,既然叫 pipeline,那就可以流水线处理,可以把一个 batch 分为若干个 mini-batch,每个 mini-batch 分别计算。
Pipeline Parallelism 示意图,来源:GPipe
这可好,是不是把 pipeline 搞的越深越好,每个 GPU 只算一层?
首先,正向传播中间状态(activation)的存储容量会成倍增加,加剧内存容量不足的问题。比如流水线的第一级算出了正向传播的中间状态,如果有 N 个流水级,那就要正向流过后面的 N - 1 个流水级,再等反向传播 N - 1 个流水级,也就是 2N - 2 轮之后才能用到这个正向传播的中间状态。不要忘了每一轮都会产生这么多中间状态,因此一共是保存了 2N - 1 个中间状态。如果 N 比较大,这个存储容量是非常恐怖的。
其次,pipeline 的相邻流水级(pipeline stage)之间是要通信的,级数越多,通信的总数据量和总时延就越高。
最后,要让这样的 pipeline 流起来,batch size 需要等于 Transformer 里面的层数,一般是几十,再乘以 data parallelism 的并行数,batch size 会很大,影响模型收敛的速度或模型收敛后的精度。
因此,在内存容量足够的情况下,最好还是少划分一些流水级。
对于 LLaMA-2 70B 模型,模型参数需要 140 GB,反向传播的梯度需要 140 GB,优化器的状态(如果用 Adam)需要 840 GB。
正向传播的中间状态跟 batch size 和选择性重新计算的配置有关,我们在算力和内存之间取一个折中,那么正向传播的中间状态需要 token 长度 * batch size * hidden layer 的神经元数量 * 层数 * (10 + 24/张量并行度) 字节。假设 batch size = 8,不用张量并行,那么 LLaMA-2 70B 模型的正向传播中间状态需要 4096 * 8 * 8192 * 80 * (10 + 24) byte = 730 GB,是不是很大?
总共需要 140 + 140 + 840 + 730 = 1850 GB,这可比单放模型参数的 140 GB 大多了。一张 A100/H100 卡也只有 80 GB 内存,这就至少要 24 张卡;如果用 4090,一张卡 24 GB 内存,就至少需要 78 张卡。
LLaMA-2 模型一共就只有 80 层,一张卡放一层,是不是正好?这样就有 80 个流水级,单是流水线并行就有 80 个并行的 batch 才能填满流水线。
这样,正向传播的中间状态存储就会大到无法忍受,这可是 80 * 2 = 160 轮的中间状态,翻了 160 倍。就算是使用选择性重新计算,比如把 80 层分成 8 组,每组 10 层,中间状态存储仍然是翻了 16 倍。
除非是用最极端的完全重新计算,反向传播到每一层都重新从头开始计算正向传播的中间结果,但这样计算开销可是随模型层数平方级别的增长,第 1 层算 1 层,第 2 层算 2 层,一直到第 80 层算 80 层,一共算了 3240 层,计算开销可是比正常算一次 80 层翻了 40 倍,这还能忍?
中间状态存储的问题就已经够大了,再看这 2048 张卡之间的通信开销。按照一张卡放一层,并且用不同的输入数据让它完全流水起来的做法,这 2048 张卡分别在计算自己的 mini-batch,可以认为是独立参与到 data parallelism 里面了。前面讲过,在数据并行中,每一轮需要传输的是它计算出的梯度和全局平均后的梯度,梯度的数据量就等于模型的参数数量。
把 70B 模型分成 80 层,每一层大约有 1B 参数,由于优化器用的是 32 bit 浮点数,这就需要传输 4 GB 数据。那么一轮计算需要多久呢?总的计算量 = batch size * token 数量 * 6 * 参数量 = 8 * 4096 * 6 * 1B = 196 Tflops,在 4090 上如果假定算力利用率 100%,只需要 0.6 秒。而通过 PCIe Gen4 传输这 4 GB 数据就已经至少需要 0.12 秒了,还需要传两遍,也就是先传梯度,再把平均梯度传过来,这 0.24 秒的时间相比 0.6 秒来说,是占了比较大的比例。
当然我们也可以做个优化,让每个 GPU 在 pipeline parallelism 中处理的 80 组梯度数据首先在内部做个聚合,这样理论上一个 training step 就需要 48 秒,通信占用的时间不到 1 秒,通信开销就可以接受了。当然,通信占用时间不到 1 秒的前提是机器上插了足够多的网卡,能够把 PCIe Gen4 的带宽都通过网络吐出去,否则网卡就成了瓶颈。假如一台机器上插了 8 块 GPU,这基本上需要 8 块 ConnectX-6 200 Gbps RDMA 网卡才能满足我们的需求。
最后再看 batch size,整个 2048 张卡的集群跑起来,每个 GPU 的 mini-batch 我们刚才设置为 8,那可真是 batch size = 16384,已经是大规模训练中比较大的 batch size 了,如果再大,可能就影响模型的收敛速度或收敛后的精度了。
因此,单纯使用流水线并行和数据并行训练大模型的最大问题在于流水线并行级数过多,导致正向传播中间状态(activation)存储容量不足。
Tensor parallelism(张量并行)
那就没办法了吗?我们还有最后一招,就是 Tensor parallelism(张量并行)。它也是模型并行的一种,但不像流水线并行那样是在模型的层间划分,而是在模型的层内划分,也就是把一层内的 attention 计算和 Feed Forward Network 划分到多个 GPU 上处理。
有了张量并行,就可以缓解 GPU 放不下模型导致的流水级太多的问题。分到 80 个 GPU 才能放下的模型,如果用单机 8 卡张量并行,就只需要划分 10 个流水级。同时,张量并行还可以降低 batch size,因为张量并行的几个 GPU 是在算同一个输入数据。
Tensor、Pipeline、Data 三种并行方式从模型层内、模型层间、训练数据三个维度上划分计算空间,来源:DeepSpeed
Attention 的计算过程是比较容易并行的,因为有多个 head,用来关注输入序列中的不同位置的,那么把这些 head 分别拆开就行了。
Attention 的计算过程,来源:The Illustrated Transformer
但是我们做任何并行计算的时候都不要忘记通信开销。
每个 head 里面的 Q、K 两个矩阵的大小是 batch size * token 长度 * key 的大小,V 矩阵的大小是 batch size * token 长度 * value 的大小。key/value 的大小一般等于 embedding size / heads 数量,例如在 LLaMA-2 70B 中就是 8192 / 64 = 128,矩阵大小是 batch size * 4096 * 8192 / 64(注意,这只是一个 head 的)。而 Q、K、V 参数矩阵在每个 head 上的大小是 embedding size * embedding size / heads num = 8192 * 8192 / 64。
我们前面推导过,正向的计算量基本上就是每个 token 过一遍所有参数的计算量,2 * 3 (Q, K, V) * batch size * token 长度 * 参数个数 = 2 * 3 * batch size * 4096 * 8192 * 8192 / 64。可以跟矩阵的大小对一下,看看有没有算错。
那么通信量是多少呢?输出矩阵 Z 是由每个 head 拼起来的,每个 head 的大小是 batch size * token 长度 * embedding size / heads num = batch size * 4096 * 8192 / 64。输入矩阵 X 的大小是 batch size * token 长度 * embedding size = batch size * 4096 * 8192。注意这里的 X 大小跟所有 heads 合并在一起后的 Z 大小是一致的,而我们在这里算的是每个 head 的 Z 大小。这里的单位是参数数量,如果按照字节算,还要乘以每个参数的大小。
如果我们采用最极端的方式,每个 head 交给一个 GPU 去算,那么计算量和通信量的比例是多少?大概是 2 * 3 * embedding size / heads num / bytes per param = 2 * 3 * 8192 / 64 / 2 = 384。代入 4090 的 330 Tflops,如果想让通信不成为瓶颈,那么通信带宽至少需要是 330T / 384 = 859 GB/s,发送接收双向还得乘以 2,就是 1.7 TB/s。太大了,远远超过 PCIe Gen4 x16 的 64 GB/s,就算 NVLink 的 900 GB/s 都撑不住。
所以,tensor parallelism 不能切得太细,每个 GPU 需要多算几个 heads。如果每个 GPU 多算几个 attention heads,输入矩阵 X 就是这些 heads 共享的了,因此输入矩阵的通信开销就被多个 heads 平摊了,计算量和通信量的比例就可以提高。
还是按照 4090 的算力 / 单向通信带宽 = 330T / (64GB/s / 2) 来算,计算量和通信量的比例最少需要是 10000,也就是 2 * 3 * (embedding size / 张量并行 GPU 数量) / bytes per param = 2 * 3 * 8192 / 张量并行 GPU 数量 / 2 >= 10000,解得:张量并行 GPU 数量 <= 2.4。也就是告诉你,要是用了张量并行,最多用 2 个 GPU,如果用更多的 GPU,算力就肯定跑不满理论值。这让我怎么玩?
但是,如果把 H100 的参数代入进去,马上就不一样了。H100 的峰值算力是 989 Tflops,NVLink 双向带宽是 900 GB/s,计算量和通信量的比例最少需要是 1100,也就是 2 * 3 * (embedding size / 张量并行 GPU 数量) / bytes per param = 2 * 3 * 8192 / 张量并行 GPU 数量 / 2 >= 1100,解得:张量并行 GPU 数量 <= 11,也就是单机 8 卡做张量并行,对于 embedding size = 8192 的模型,刚刚好,通信不会成为瓶颈!
阉割版的 H800 相比 H100 卡的就是网络带宽,把网络带宽从 900 GB/s 降到 400 GB/s 了。我们再代入一次,计算量和通信量比例最少需要是 5000,那么张量并行 GPU 数量 <= 4.8。这样单机 8 卡做张量并行,就会导致网络成为瓶颈。当然,计算量 989 Tflops 是理论值,并行切分方式也可以优化,因此实际训练 70B 的模型 8 卡 H800 网络不一定真的是瓶颈。这就是 H800 精准打击大模型训练,让张量并行过得不舒服。
Feed Forward Network 的计算过程,虽然这是 encoder 的,但 decoder 也差不多,来源:Step-by-Step Illustrated Explanations of Transformer
如果在 Feed Forward Network 这里做张量并行,也是可以做类似的推导,在这里就不赘述了。大凡神经网络里的矩阵乘法,M*N 的矩阵乘上 N*K 的矩阵,总的计算量是 M*N*K,输入输出的总大小是 (M*N + N*K),多摞几个矩阵那也是常数(就像 Q、K、V),也就是计算和通信的比例跟矩阵的边长(dimension)是一个量级的。
这么分析完了,如果你是要做大规模大模型训练,你还会买 A100/H100/H800 的 PCIe 版吗?PCIe Gen5 虽然比 Gen 4 快一倍,但对 H100 而言,计算量和通信量的比例仍然最少需要是 989T / (128G / 2) = 15000,解出来张量并行 GPU 数量 <= 1.6,也就是只要用了张量并行,就是损失算力的!
等到 H100 的下一代出来了,比如 GH200,算力又翻了一倍,NVLink 还是 900 GB/s,这时候 NVLink 就也开始有点吃力了。所以 GH200 不失时机的推出了统一大内存,号称 144 TB,就是为了更好的做换入换出,用内存换网络通信。如果禁令保持不变,国内版本还是卡住 400 GB/s 的通信,那性能差距会有多大?
上面的推导当然都是简化的,实际上可能不会这么夸张,但数量级是差不多的。
训练部分小结
4090 不容易做大模型训练的原因除了前面分析的内存小,通信慢,license 不支持数据中心,还有很多其他问题。
比如,A100/H100 支持 ECC 显存容错,据说 4090 也支持 ECC,但是不知道故障率会不会比 A100/H100 更高。不要小看了容错,2048 张卡的集群就算每张卡 1 个月出一次故障,平均 20 分钟就会有一张卡出故障!要是没有自动化的故障恢复方式,炼丹师就别想睡觉了。
就算是自动从上一个 checkpoint 恢复,这可是要时间的,如果不考虑丢弃故障 GPU 梯度这种比较暴力的方式,当前这个 step 就算是白算了,还要从上一个 checkpoint 加载梯度,一般需要 10 来分钟的时间才能搞定。这样,每 20 分钟就浪费 10 分钟,这 10 分钟恢复过程中可能又有新的卡故障,总的算下来要浪费掉一半的有效算力。
因此,保持大规模训练集群的低故障率是非常重要的,这些 GPU 卡都非常金贵,可不能像挖矿机房那样,动不动就过热死机了。
据说 3090 是支持 NVLink 的,但 4090 就把 NVLink 给砍掉了。更老的卡,甚至还有支持 PCIe P2P 的,现在也都被砍掉了。谁感兴趣可以测一测 3090 的 NVLink 性能怎么样,是不是真的能达到标称的 600 GB/s,如果真的能达到的话,是否又可以用来做大模型训练了呢。
我们年会的时候,海哥讲了个段子,我们找老婆都希望又漂亮,又能挣钱,还一心一意爱自己。可同时满足这三个条件的老婆就很难找到了。类似的,在分布式系统中,我们都希望性能又高,通用性又强,成本还低。这三个条件的交集也很小。海哥讲到这里,谭博补充了一句,同时满足这三个条件的分布式系统根本就不存在。
Tensor、Pipeline、Data Parallelism 就像是这样的不可能三角,相互牵制,只要集群规模够大,模型结构仍然是 Transformer,就很难逃出内存容量和网络带宽的魔爪。
大模型推理为什么 4090 很香
推理和训练有什么区别?
首先,训练不仅需要存储模型参数,还需要存储梯度、优化器状态、正向传播每一层的中间状态(activation),后面几个比参数更大,对模型内存的需求量也更大。
其次,训练任务是一个整体,流水线并行的正向传播中间结果是需要存下来给反向传播用的。为了节约内存而使用流水线并行,流水级越多,要存储的中间状态也就更多,反而加剧内存的不足。而推理任务中的各个输入数据之间并没有关系,正向传播每一层的中间状态也不需要保存下来,因此流水线并行不需要存储很多中间状态。
首先我们需要计算一下推理需要多少算力。前面针对训练算力的估算,为了简单起见,忽略了两个事情,首先是没有考虑 KV Cache,其次是没有考虑内存带宽。
KV Cache
什么是 KV Cache?对于每个输入的 prompt,在计算第一个 token 输出的时候,每个 token 的 attention 肯定是都要从头计算。但是在后续 token 的生成中,都需要计算 self-attention,也就是输入 prompt 以及前面输出的 token 的 attention。这是就需要用到前面每一个 token 的 K 和 V,由于每一层的参数矩阵是不变的,此时只有刚生成的那个 token 的 K 和 V 需要从头计算,输入 prompt 和之前生成的 token 的 K 和 V 其实是跟上一轮一样的。
这时,我们就可以把每一层的 K、V 矩阵缓存起来,生成下一个 token 的时候不再需要重新计算,这就是所谓的 KV Cache。Q 矩阵每次都不一样,没有缓存的价值。前面讲的训练中的选择性保存正向 activation 是个拿计算换内存的把戏,这里的 KV Cache 就是一个拿内存换计算的把戏。
KV Cache 需要多少存储容量呢?每一层,每个 token 的 K、V 矩阵都是 embedding size 这么大,再乘上 token 数量和 batch size,就是这一层的 KV Cache 所需的存储容量了。一定要记住 batch size,在正向和反向传播的几乎所有阶段,都不会涉及到对 batch size 中各个 sample 的合并处理,因此它始终是存储量和计算量计算中的一个系数。
例如,如果 batch size = 8,在 LLaMA 2 70B 中,假设输入和输出的 token 数量达到了模型的极限 4096,80 层的 KV Cache 一共需要 2 (K, V) * 80 * 8192 * 4096 * 8 * 2B = 80 GB。如果 batch size 更大,那么 KV Cache 占据的空间将超过参数本身占的 140 GB。
KV Cache 能省下来多少计算量?每一层计算 K、V 矩阵一共需要 2 (K, V) * 2 (mult, add) * embedding size * embedding size = 4 * 8192 * 8192 这么多计算量,乘以之前输入过的 token 数量、层数和 batch size,就是 4096 * 80 * 8 * 4 * 8192 * 8192 = 640 Tflops。相当于每存储 1 个字节,节约了 16K 次计算,还是很划算的。
事实上,KV Cache 节约的远远不止这些。计算 K、V 矩阵的过程是个典型的内存密集型过程,它需要加载每一层的 K、V 参数矩阵。也就是如果不做任何缓存,假设 prompt 长度很短而输出长度接近 token 的最大长度 4096,到了最后一个 token 的时候,单是重复计算前面每个 token 的 K、V 矩阵,就需要读取内存 4096 * 80 * 2 * 8192 * 8192 = 40T 次,每次 2 个字节,要知道 H100 的内存带宽只有 3.35 TB/s,4090 更是只有 1 TB/s,这单是最后一个 token 就得耗掉一张卡几十秒的时间来做重复计算。这样,token 的输出就会越来越慢,整个输出时间是输出长度平方级别的,根本没法用。
推理是计算密集还是存储密集
接下来我们就可以计算推理所需的计算量了。总的算力很好算,前面讲过,大概就是 2 * 输出 token 数量 * 参数数量 flops。如果想看细节,可以看下面这张图,[来源:https://le.qun.ch/en/blog/2023/05/13/transformer-batching/]。
Transformer 推理过程中每一步的矩阵形状、所需算力和内存访问量,来源:Lequn Chen,Dissecting Batching Effects in GPT Inference
但算力并不能说明一切,模型还需要访问 GPU 内存,内存带宽也可能成为瓶颈。至少需要把参数从内存里面读出来吧?事实上,内存带宽的估算就这么简单,内存访问量 = 参数数量 * 2 bytes。中间结果有一部分是可以放在缓存里面的,缓存放不下的部分也需要占内存带宽,我们先不算。
如果不做任何批量输入,也就是模型专门服务一个 prompt,batch size = 1,整个 context 的长度很短(例如只有 128),那么整个推理过程中,每载入一个参数(2 字节),就只进行 128 次乘法和加法计算,那么计算 flops 和访问内存 bytes 的比例就只有 128。基本上任何 GPU 在这种情况下都会变成 memory bound,时间都耗在加载内存上了。
对于 4090 来说,计算 flops 和内存带宽之比是 330 / 1 = 330;对于 H100 来说,计算 flops 和内存带宽之比是 989 / 3.35 = 295。也就是说,如果 context 中的 token 数量小于 330 或者 295,那么内存访问就会成为瓶颈。
虽然 LLaMA 2 的理论上限是 4096 个 token,但很多输入 prompt 用不了这么多,因此内存访问是有可能成为瓶颈的。此时,就需要靠 batch size 来补足了。推理中的批量处理,就是把几乎同时到达后端服务的 prompt 放到一起处理。不用担心,batch 里面的不同 prompt 的处理是完全独立的,不用担心会互相干扰。但这些 prompt 的输出是步调整齐划一的,每一轮整个 batch 中的每个 prompt 都会输出一个 token,因此如果有的 prompt 先输出完了,那就只能等其他的输出结束,造成一定的算力浪费。
有的人问,批量处理所需的算力跟分别单独处理所需的算力是一样的呀,那推理时为什么需要批量处理?答案就在访问内存的带宽上。
如果同时到达服务器的 prompt 很多,是不是 batch size 越大越好?也不是,因为 KV Cache 的大小可是正比于 batch size 的,batch size 大了,KV Cache 占据的 GPU 内存容量就很可观,比如在 LLaMA-2 70B 中,每个 prompt 都要占据 5 GB 的 KV Cache,如果 batch size 搞到 32,那么 KV Cache 就会占掉 160 GB 的 GPU 内存,比参数都大了。
70B 推理需要多少张卡?
总的存储容量也很好算,推理的时候最主要占内存的就是参数、KV Cache 和当前层的中间结果。当 batch size = 8 时,中间结果所需的大小是 batch size * token length * embedding size = 8 * 4096 * 8192 * 2B = 0.5 GB,相对来说是很小的。
70B 模型的参数是 140 GB,不管 A100/H100 还是 4090 都是单卡放不下的。那么 2 张 H100 够吗?看起来 160 GB 是够了,但是剩下的 20 GB 如果用来放 KV Cache,要么把 batch size 压缩一半,要么把 token 最大长度压缩一半,听起来是不太明智。因此,至少需要 3 张 H100。
对于 4090,140 GB 参数 + 40 GB KV Cache = 180 GB,每张卡 24 GB,8 张卡刚好可以放下。
推理用流水线并行可以吗?
推理使用流水线并行,最主要的问题是串行处理的推理延迟,网络延迟倒是小问题。
首先是推理延迟。虽然流水线的不同阶段可以塞进不同的 prompt,但同一个 prompt 的处理仍然永远在单个 GPU 上轮转,这样相比 Tensor parallelism 而言,单个 prompt 的延迟就增大了。
对于很小的 batch size,GPU 内存带宽是瓶颈,此时每张卡计算每个 token 的时延就是 2 byte * 参数量 / 卡的数量 / 内存带宽,例如 8 卡 4090 跑 LLaMA-2 70B,就是 2 * 70G / 8 / 1 TB/s = 0.0175 秒。这里没有考虑 KV Cache 带来的节约。注意,8 张卡是串行处理的,因此每个 token 的时延还要乘以 8,也就是 0.14 秒。每秒只能输出 7 个 token,对于 70B 这么小的模型来说是有点慢了。
对于很大的 batch size,GPU 算力是瓶颈,此时每张卡计算每个 token 的时延就是 batch size * 2 * 参数量 / 卡的数量 / 算力,例如 batch size = 1024,同样的 8 卡例子,就是 1024 * 2 * 70G / 8 / 330 Tflops = 0.0543 秒。事实上,对于这么大的 batch size,KV Cache 和正向传播的中间结果先把 GPU 内存给吃满了。
那么要平衡利用 GPU 算力和内存带宽,batch size 需要是多少呢?这就是 2 byte * 参数量 / 卡的数量 / 内存带宽 = batch size * 2 * 参数量 / 卡的数量 / 算力,左右两边参数量和卡的数量互相抵消,得到 batch size = 算力 / 内存带宽。对于 4090,就是 330 / 1 = 330;对于 H100,就是 989 / 3.35 = 295。也就是说,对 4090 而言,batch size 小于 330 的时候 GPU 内存带宽是瓶颈,大于 330 的时候 GPU 算力是瓶颈。当 batch size = 330 的时候,理想情况下,内存带宽和算力恰好都打满,每张卡处理每个 token 的时间就是 17.5 ms。
其次是网络延迟。流水线并行相比张量并行的优点就是网络传输量小,流水级之间只需要传输 batch size * embedding size 这么多数据。例如 batch size = 8,embedding size = 8192,只需要传输 128 KB 数据,在 32 GB/s 的 PCIe Gen4 x16 上,只需要 4 us 就可以传输完成。当然,还需要考虑到通信库本身的开销,加上 4090 不支持 GPU 之间 P2P 传输,需要通过 CPU 中转,实际上需要几十 us 的时间,相比计算部分动辄几十 ms 的时延,可以忽略不计。
即使 batch size = 330,这 5.28 MB 数据在 PCIe 上也只需要传输 0.16 ms,相比计算部分的 17.5 ms 仍然可以忽略不计。
如果可以忍受流水线并行的推理延迟,甚至可以用多台主机来做流水线并行。我们假设主机间只有 1 Gbps 的普通以太网络,每台主机只有一张 4090。对于 batch size = 1,16 KB 数据需要 0.25 ms 才能传输完成,再加上 0.25 ms 两端网络协议栈的处理时间,每个流水级就需要 0.5 ms 的时延,8 张卡花在通信上的时间只有 4 ms,相比整体计算时延 140 ms 来说可以忽略,不会显著影响系统的推理延迟。
当 batch size 很小时,流水线推理中的网络流量是突发性(bursty)的,每过 18 ms 只会进行 0.25 ms 数据传输,只有 1/72 的占空比,不用担心流水线推理把局域网全部给占满了,搞得没法正常上网了。
如果为了充分利用算力,把 batch size 设置得很大,比如 330,那么 16 KB * 330 = 5.28 MB 数据需要传输 41 ms,8 张卡花在通信上的时间高达 0.33 秒,这样就只有 3 token/s 的输出速度了,难以忍受。因此,如果用主机间通信来做流水线并行,主机间又没有很高的通信带宽,就势必需要牺牲一定的吞吐量。
例如,我们设置输出速度不小于 5 token/s,这时留给通信的时间是 60 ms,每个流水级至多 7.5 ms,1 Gbps 网络可以传输 960 KB 数据,这时 batch size 至多设置为 60,也就是这 8 张 4090 的总吞吐量是 2400 token/s。此时的有效算力利用率只有不到 20%。
最近有一个比较火的 [Petals 开源项目:https://github.com/bigscience-workshop/petals],就是利用流水线并行,把 GPU 做成了一个类似 BitTorrent 的分布式网络。虽然推理延迟确实比较高,但至少说明了分布式 GPU 推理的可行性。
推理用张量并行怎么样?
前面讲到,流水线并行的最大缺点是 GPU 串行处理,延迟较高,导致输出 token 比较慢。而张量并行的最大缺点是传输数据量大,网络带宽低的设备不一定 hold 得住。
但是推理要传输的数据量跟训练要传输的数据量可不是一回事啊!推理只需要传输正向传播的中间结果(activation),而训练还需要传输所有参数的梯度,梯度才是数据量的大头。
在推理中,如果使用张量并行,Transformer 的每一层都需要传输把自己负责的结果向量(大小为 batch size * embedding size / num GPUs)广播给其他所有 GPU,并接受来自所有其他 GPU 广播来的数据。计算 attention 的时候需要传输一次,计算 feed-forward network 的时候又需要传输一次,也就是总共需要传输 2 * 层数这么多次。
每次发送就是 batch size * embedding size(发送和接收是不同的方向,不能算两次),对于 batch size = 1, embedding size = 8192,只需要传输 16 KB 数据,在 32 GB/s 的 PCIe Gen4 上传输只需要 1 us。当然,考虑到前面讨论的 CPU 中转开销,还是需要大约 30 us 的。一共 160 次传输,需要 4.8 ms。
我们再考虑计算的开销。还是考虑 batch size = 1 的情形,GPU 内存带宽是瓶颈,此时每张卡计算每个 token 的时延就是 2 byte * 参数量 / 卡的数量 / 内存带宽,代入我们前面的数值,仍然是 17.5 ms。但是这里 8 张卡是并行处理的,因此总的处理时长就是计算时间 + 通信时间 = 17.5 ms + 4.8 ms = 22.3 ms。这就意味着每秒可以生成 45 个 token,这个 token 生成速度已经很不错了,至少人类的阅读速度是很难赶上生成的速度了。
如果 batch size 更大会怎样?例如 batch size = 330,把 GPU 算力和内存带宽都充分利用起来,每次需要传输的数据量是 330 * 8192 * 2 = 5.4 MB,在 32 GB/s 的 PCIe Gen4 上需要 0.17 ms。一共 160 次传输,就是 27 ms。这下网络通信开销成了延迟的大头,总处理时长为 27 + 17.5 = 44.5 ms,每秒只能生成 22 个 token 了,但也不算慢。
注意,不管用多少个 GPU 做并行推理,只要用的是张量并行,网络传输的总数据量是相同的,因此增加 GPU 的数量只能加速计算,不能加速通信。
因此,A100/H100 的 NVLink 在降低推理延迟方面还是有很大作用的。如果用 H100,取 batch size = 295 达到算力和带宽的平衡利用,这 4.72 MB 数据只需要 4.72 MB / 450 GB/s = 0.01 ms。一共 160 次传输,也只有 1.6 ms。由于内存带宽大了,计算时间也可以大幅缩短,例如 H100 的计算时间为 2 * 70G / 8 / 3.35 TB/s = 5.2 ms。总处理时长只有 5.2 ms + 1.6 ms = 6.8 ms,每秒可以生成 147 个 token,非常棒!
可以说,如果论单个 prompt 的 token 生成速度,无论用多少块 4090 也追不上 8 卡 H100。
用 4090 做推理的成本怎么样?
对于推理,不管用流水线并行还是张量并行,batch size 不算高到太离谱的情况下内存带宽都是瓶颈。
假如 batch size 能够高到把算力 100% 利用起来,并且还能解决 KV Cache 不够大的问题,能解决中间结果占用内存过多的问题,那么这 8 张 4090 可以达到多少吞吐量?
当然,这两个问题都不好解决,因此推理优化才是一个热门的研究领域,存在很多的 trade-off 和奇技淫巧。如果只是用标准的 PyTorch,那推理性能距离把算力 100% 利用起来还远得很呐。
假设都解决了,在张量并行的通信过程中我们可以利用 double buffer 做另外一个 batch 的计算,也就是计算和通信并行,进一步提高吞吐量。通信和计算分别是 27 ms 和 17.5 ms,传输的 27 ms 是瓶颈,也就是每 27 ms 输出一组 token,一个 batch 330 个 prompt,那这 8 张 4090 真是可以达到每秒 330 / 0.027 = 12.2K token 的吞吐量。
8 张 4090 的成本是 12800 美金,8 卡 PCIe Gen4 服务器本身要 2 万美金,加上网络设备,平均每台 4 万美金的设备成本。固定资产按照 3 年摊销,每小时 1.52 美元。整机功耗大约 400W * 8 + 2 kW = 5 kW,按照 0.1 美元一度电算,每小时 0.5 美元。一个机架可以放 4 台这样的 8 卡服务器,数据中心机柜租用成本(不含电费)一个月 1500 美元算贵的了,合每小时 0.5 美元。这 2.5 美元一小时的机器,满打满算能生成 12.2K * 3600 = 44M tokens,也就是说 1 美元能生成 17.6M tokens。
是不是比 GPT-3.5 Turbo 的 $0.002 / 1K tokens,也就是 1 美元 0.5M tokens 便宜 35 倍?当然,账不能这么算。
- 首先,GPU 的算力利用率到不了 100%;
- 其次,如同所有 SaaS 服务一样,用户的请求数量有波峰有波谷,用户是按量付费的,平台提供方可是不管有没有人用都在烧钱的;
- 此外,每个 batch 中不同 prompt 的长度和响应 token 数量都不同,消耗的算力是 batch 中最大的那个,但收的钱是用户实际用的 token 数;
- 再次,GPT-3.5 是 175B 的模型,比 70B 的 LLaMA 很可能推理成本更高;
- 最后,OpenAI 开发 GPT-3.5 是烧了不知道多少钱的,人家至少要赚回训练成本和研发人员的工资吧。
其实 GPT-3.5 Turbo 的 $0.002 / 1K tokens 真的挺良心的,有的卖 API 的,LLaMA-2 70B 都敢比 GPT-3.5 Turbo 卖得贵。
如果换成用 H100 做推理,重新算一下这笔账。一张 H100 至少要 3 万美金,一台 8 卡 H100 高配服务器加上配套的 IB 网络,起码要 30 万美金,同样按照 3 年摊销,每小时 11.4 美元。10 kW 功耗,电费每小时 1 美元。一个普通供电和散热的机架只能放 2 台 8 卡 H100,机柜租用成本(不含电费)还按 1500 美元算,合每小时 1 美元。一共 13.4 美元一小时。
这其实已经是非常良心的价格了,你在任何云服务商都不可能租得到这么便宜的 8 卡 H100。所以说从云服务商租卡卖没有护城河的 SaaS 服务,比如开源模型的推理 API,除非有一种提高推理性能的独门绝技,很难赚得了什么大钱,二房东的生意不是这么好做的。
再算算这台 8 卡 H100 机器的吞吐量,张量并行也采用传输和计算并行,H100 的通信比较快,因此计算是瓶颈,每 5.2 ms 可以输出一组 token,一个 batch 295 个 prompt,满打满算可以达到每秒 295 / 0.0052 = 56K token 的吞吐量。理想情况下,一小时能生成 204M tokens,也就是 1 美元能生成 15.2M tokens,H100 单位 token 的成本比 4090 仅仅高 16%,可以算打个平手吧。
为什么 8 卡 H100 机器是 4090 机器生命周期价格的 5 倍,性价比却跟 4090 差不多?因为一张 H100 的算力是 4090 的 3 倍,内存带宽是 4090 的 3.35 倍,不管按延迟还是按带宽算,单卡的性能就基本上是 3 倍。而且,H100 比 4090 的网络带宽强太多了,导致 4090 在张量并行中网络通信成了瓶颈,浪费了有效算力。因此,同样的 8 卡机器吞吐量可以达到 4090 的 4.6 倍。虽然一张 H100 卡的价格是 4090 的 20 倍以上,但算上服务器本身的成本、电费和数据中心托管费用,整机的成本只是 5 倍左右。
用最便宜的设备搞出最高的推理性能
我们发现在 8 卡 4090 机器中,3 万美金的设备成本,GPU 卡只占了 1.28 万美金,不像 H100 机器那样 GPU 成本占了大头。还有办法进一步降低吗?
如果我们可以忍受 5 token/s 的输出速度,甚至可以利用流水线并行,用家用台式机和 4090 攒出个推理集群来。
遥想我当年在 MSRA 的时候,在一台只用 1000 美金攒出来的机器上插了 10 块 FPGA ,做出个世界最快的 Key-Value Store。其实如果让我去设计一个性价比最高的 4090 推理集群,有很多种方案可以尝试:
- 用流水线并行,台式机 + 10 Gbps 网卡,足够在 5 ms 内传输 batch size = 330 的 5.28 MB 数据了,通信 40 ms,计算 140 ms,达到 5 token/s 的单 prompt 输出速度,同时又能充分利用 4090 的算力。10 Gbps 的网卡和交换机都很便宜,Intel X710 网卡只要 150 美金,20 口交换机只要 1500 美金(每 8 个口 750 美金),一台家用台式机 700 美金,这只要 2 万美金就可以搞定原本需要 4 万美金的设备。
- 用张量并行,台式机 + 200 Gbps ConnectX-6 网卡,上 RoCE,可以把 batch size = 330 的 5.28 MB 数据在 0.22 ms 内传完,160 次传输是 35 ms,加上计算的 17.5 ms,一个 token 52.5 ms,可以达到 19 token/s 的单 prompt 输出速度,这个速度已经不错了。网卡 1000 美金,200G 交换机 2 万美金 40 个端口,平均每 8 个端口 4000 美金,一台家用台式机 700 美金,这只要 3 万美金就能搞定原本 4 万美金的设备。
- 主机内用张量并行,主机间用流水线并行,4 卡 PCIe Gen4 服务器主板只要 1000 美金而且能跑满 PCIe 带宽(因为 8 卡就需要 PCIe switch 了,价格会贵很多),两台主机之间用 25 Gbps 网卡直连,主机内张量并行的时延是 27 ms,主机间流水线并行只需 2 次 8 ms 的传输(注意 25G 的网络带宽是 4 张 GPU 卡共享的),加上两次流水线计算各 17.5 ms,总共 78 ms,可以达到 13 token/s 的单 prompt 输出速度。网卡 300 美金 * 2,服务器 3000 美金 * 2,这只要 1.95 万美金就可以搞定原本需要 4 万美金的设备。
2 万美金按照 3 年摊销是每小时 0.76 美元。按照 0.1 美元/度的电价,每小时的电费都要 0.5 美元,接近设备成本了,这有点挖矿的味道了。矿场里面可没有中央空调和 UPS,只有暴力风扇,托管费用比数据中心低很多,整机的成本是有可能压到 1.5 美元/小时的。如果跑满了 44M tokens 的吞吐量,1 美元能生成 30M tokens,正好是 8 卡 H100 的 15M token per dollar 的 2 倍。
为什么 H100 以 20 倍于 4090 的 GPU 价格,性价比却只差一倍?首先是因为能耗成本更低,8 卡 H100 的功耗是 10 kW,但 9 台 8 卡 4090 的功耗是 45 kW;其次是因为主机和网络设备成本更低,一台 8 卡 H100 准系统虽然贵,但只占整机价格的 20% 左右;但 4090 因为卡多,除非像 GPU 矿机那样压成本,只要还是用数据中心级的设备,准系统价格就要占到 35% 以上。
其实,这个世界上不止有 A100/H100 和 4090,还有 A10、A40 等计算卡和 3090 等游戏卡,还有 AMD 的 GPU 和很多其他厂商的 AI 芯片。H100 和 4090 大概率都不是性价比的最优解,例如 A10、A40 和 AMD GPU 的性价比有可能就更高。
我都想搞一个推理性价比挑战赛,看谁能用最便宜的设备搞出最强的推理吞吐量,同时延迟不能太高;或者用最便宜的设备搞出最低的推理延迟,同时吞吐量不能太低。
这一切都是在假设使用 LLaMA-2 70B 模型,没有做量化压缩的前提下。如果做了量化压缩,那性能就更高,甚至在 Unified Memory 够大的 MacBook Pro 上都能单机跑了。
License 问题怎么办?
我把这个问题放到最后。[NVIDIA Geforce driver 的 License:https://www.nvidia.com/en-us/drivers/geforce-license/] 里写道:
No Datacenter Deployment. The SOFTWARE is not licensed for datacenter deployment, except that blockchain processing in a datacenter is permitted.
既然机器都是用台式机攒起来的,这能叫 data center 吗?还是叫矿场比较合适吧。人家也说了,4090 用来做区块链是允许的。
我有一个大胆的想法,如果未来的区块链不再用挖矿来做 proof of work,而是用大模型推理来做 proof of work,这是不是很有意思?每个人买几块显卡,接到矿池上,既可以自己用来玩游戏,闲时又可以贡献算力。矿池直接就是个卖大模型推理 SaaS 服务的公司,提供前所未有的低价 API。甚至需要大模型推理服务的人可以在区块链里自己 P2P 玩起来,谁要用大模型就付点 gas。
当然,目前的 proof of work 都是计算很复杂,验证很简单的。如果真用大模型推理做 proof of work,必须防止用户随意编造一个结果交上去。当然这也是有解决方案的,就像 BitTorrent 和其他一些去中心化网络一样,采用信用机制,新人只能做验证别人计算结果的工作,积攒信用;老人每次算错了,都有比较严厉的惩罚。
从另一个角度看,家庭局域网络的速度也越来越快,比如我家就自己部署了 10 Gbps 的网络。家中的智能设备越来越多,算力越来越强。光纤入户也越来越普遍,小区和城市的运营商机房里部署了越来越多的边缘计算节点。前面我们用 1 Gbps 的网络就足以把多台主机上的 GPU 组成流水线并行,那么在未来的家庭高速网络中,流水线并行甚至张量并行都将成为可能。
大多数搞 AI 推理的都只关心数据中心,忽略了家中的分布式算力。只要解决了安全、隐私和经济动机问题,我家的 Siri,也许就跑在邻居家里的 GPU 上。
很多人都在说要 democratize AI。但现在大模型平民化的最大障碍就是成本,而成本最大的来源又是 GPU 市场上计算卡和游戏卡价格的剪刀差。这并不是指责某家公司,其他做 AI 芯片的公司,AI 芯片的算力也并不便宜。毕竟芯片、软件和生态的研发都是白花花的银子。
就像本文开头提到的微软给每台服务器部署 FPGA 一样,大规模量产的芯片价格就像沙子一样。到时候,能限制大模型推理算力的就只有能源了,就像区块链挖矿和通用 CPU 的云计算一样,都在找最便宜的电力供应。我在之前的一个采访中就表示,长期来看,能源和材料可能是制约大模型发展的关键。让我们期待廉价的大模型走进千家万户,真正改变人们的生活。
#大模型~简化开发过程
史蒂夫・乔布斯曾经把计算机称作 “心灵之自行车”。不过,人们对他这个比喻的背景知之甚少,他是在谈及地球上所有物种移动效率的时候提到的。
由 DALL·E 3 生成的图片,提示 “将计算机想象成心灵的自行车”
秃鹫赢了,位居榜首,超过了其他所有物种。人类排在榜单大约三分之一的位置…… 但是,一旦人类骑上自行车,就能远远超越秃鹫,登顶榜首。这让我深受启发,人类是工具制造者,我们可以制造出将这些固有能力放大到惊人程度的工具。对我来说,计算机一直是思维的自行车,它让我们远远超越了固有的能力。我认为我们只是处于这个工具的早期阶段,非常早期的阶段。我们只走了很短的一段距离,它仍处于形成阶段,但我们已经看到了巨大的变化。我认为,与未来 100 年发生的事情相比,这算不了什么。
—— 史蒂夫・乔布斯(1990)
#01 谨慎乐观
LLM 在加速软件开发方面的作用引发了广泛讨论。有人认为,自动生成的代码质量过低,以至于使用这些代码产生的是负面效果。而另一方面,许多人声称编程的时代已经结束。已经有众多研究试图客观评估 LLM 在诸如 HumanEval 或 MBPP 这样的代码质量基准数据集上的表现。这些评估对于该领域的发展至关重要,但这并非本文的焦点。
本文旨在为开发者,尤其是那些对利用这些模型持保留态度的开发者,提供实用的建议。要注意的是,我相信通过恰当地使用这项技术,一个人的工作效率确实能够实现质的飞跃。
需要明确的是,使用 LLM 并不意味着一定能产出高性能的代码。即便是最先进的模型也会犯下严重的错误。但重要的是要记住,最优秀的人类工程师也难免犯错。这就是我们为什么不会在没有设置防护措施的情况下发布代码,比如进行自动化测试和同行评审。LLM 没有改变最佳实践的本质,但可能改变了实现这些最佳实践的容易度。
编写高性能代码的速度真的那么重要吗?
如果大多数专业软件工程师的大部分时间都花在了键盘敲击上,我会觉得非常意外。如果真是这样,我们或许会更认真地把 “代码行数” 当作衡量工作表现的一个标准。我认为,工程师实际上更多的时间是用来决定首先需要编写哪些代码。目前,LLM 还没有能力进行这种高级的抽象思维 ——OpenAI 和许多其他机构正在努力使这些步骤成为可能,但这很可能需要的不仅仅是增加模型的参数或扩大训练数据,而是需要一种全新的思维方式。然而,LLM 已经能够在很大程度上加速我们的工作中的很大一部分了。
#02 编写高效的函数
让我举个例子来说明,当 LLM 生成了一小段代码后,我们需要采取哪些步骤来优化它,使之变得实用。在地球上的某个点和另一组坐标之间测量距离是地理信息分析中的一个常见任务。想象一下这样的情境。
首先看到的代码输出是这样的:
import mathdef haversine(lat1, lon1, lat2, lon2):# Radius of the Earth in kilometersR = 6371.0# Convert degrees to radianslat1_rad = math.radians(lat1)lon1_rad = math.radians(lon1)lat2_rad = math.radians(lat2)lon2_rad = math.radians(lon2)# Differences in coordinatesdlat = lat2_rad - lat1_raddlon = lon2_rad - lon1_rad# Haversine formulaa = math.sin(dlat / 2)**2 + math.cos(lat1_rad) * math.cos(lat2_rad) * math.sin(dlon / 2)**2c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))# Distance in kilometersdistance = R * creturn distancedef find_closest_oil_platform(vessel_lat, vessel_lon, platforms):closest_platform = Nonemin_distance = float('inf')for platform in platforms:distance = haversine(vessel_lat, vessel_lon, platform['latitude'], platform['longitude'])if distance < min_distance:min_distance = distanceclosest_platform = platformreturn closest_platform, min_distance
那么,这段代码的质量怎么样呢?
这段代码的质量并不理想。它的问题不少,最明显的是它难以阅读,因为它使用了难以理解和不一致的变量命名,包括大小写的单字符命名。此外,它缺少了文档说明和类型提示,使用了一个不必要且低效的 for 循环,也没有处理一些特殊情况,比如相等距离的问题。
但经过修改之后:
from typing import Tuple, List
import numpy as npdef haversine_vectorized(vessel_lat: float, vessel_lon: float, platform_lats: np.ndarray, platform_lons: np.ndarray) -> np.ndarray:"""Calculate the Haversine distance between a vessel and multiple platforms in a vectorized manner.Parameters:- vessel_lat: Latitude of the vessel in decimal degrees.- vessel_lon: Longitude of the vessel in decimal degrees.- platform_lats: Numpy array of latitudes of oil platforms in decimal degrees.- platform_lons: Numpy array of longitudes of oil platforms in decimal degrees.Returns:- distances: Numpy array of distances from the vessel to each platform in kilometers."""# Convert decimal degrees to radianslat1, lon1, lat2, lon2 = map(np.radians, [vessel_lat, vessel_lon, platform_lats, platform_lons])# Haversine formuladlat = lat2 - lat1dlon = lon2 - lon1a = np.sin(dlat/2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2)**2c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1-a))r = 6371 # Radius of Earth in kilometersreturn c * rdef find_closest_platform(vessel_lat: float, vessel_lon: float, platform_lats: np.ndarray, platform_lons: np.ndarray) -> Tuple[List[int], float]:"""Finds the closest oil platform(s) to a vessel given arrays of platform latitudes and longitudes, handling equidistant platforms.Parameters:- vessel_lat: Latitude of the vessel in decimal degrees.- vessel_lon: Longitude of the vessel in decimal degrees.- platform_lats: Numpy array of latitudes for oil platforms.- platform_lons: Numpy array of longitudes for oil platforms.Returns:- A tuple containing a list of indices of the closest platforms and the distance to them in kilometers."""# Calculate distances to all platformsdistances = haversine_vectorized(vessel_lat, vessel_lon, platform_lats, platform_lons)# Find the minimum distancemin_distance = np.min(distances)# Find all indices with the minimum distanceclosest_indices = np.where(distances == min_distance)[0].tolist()# Return the indices of all closest platforms and the minimum distancereturn closest_indices, min_distance
改进后的代码有了明显提升。它变得更容易阅读了,增加了文档说明和类型提示,并且用更高效的向量计算方式替换了原有的 for 循环。
但是,代码的 “好坏”,更重要的是,它是否满足需求这些都取决于代码将要运行的具体环境。要知道,我们无法仅凭几行代码就能有效评估其质量,这一点对人类如此,对 LLM 也是如此。
比如说,这段代码的准确度是否满足用户的预期?它会被频繁运行吗?是一年一次,还是每微秒一次?使用的硬件条件如何?预期的使用量和规模是否值得我们去追求那些细小的优化?在考虑到你的薪资之后,这样做是否划算?
让我们在上述因素的基础上来评估这段代码。
在准确性方面,虽然半正矢公式(haversine formula)表现不错,但并非最佳选择,因为它将地球视为一个完美的球体,而实际上地球更接近于一个扁球体。在需要跨越巨大距离进行毫米级精确测量时,这种差异变得非常重要。如果真的需要这样的精确度,虽然有更精确的公式(如 Vincenty 公式)可用,但这会带来性能上的折中。因为对于这段代码的用户而言,毫米级的精确度并不是必须的(事实上,由于卫星图像导出的船舶坐标的误差,这种精度也并不相关),所以在准确性方面,半正弦函数是一个合理的选择。
代码运行得够快吗?考虑到只需要对几千个海上石油平台计算距离,特别是通过向量计算方法,这种计算是非常高效的。但如果应用场景变成了计算与岸边任意点的距离(岸线上有数以亿计的点),那么采用 “分而治之” 的策略可能会更加合适。在实际应用中,考虑到节约计算成本的需要,这个函数设计为在一个尽可能配置低的虚拟机上每天运行约 1 亿次。
基于这些详细的背景信息,我们可以认为上面的代码实现是合理的。这也意味着,在代码最终合并前,它应该先经过测试(我通常不推荐仅依赖 LLM 进行测试)和人工同行评审。
#03 加速前进
像之前那样利用 LLM 自动生成实用的函数不仅可以节省时间,而且当你开始利用它们来生成整套的库、处理模块间的依赖、撰写文档、实现可视化(通过多模态能力)、编写 README 文件、开发命令行接口等时,它们带来的价值将会成倍增长。
我们来试着从零开始,借助 LLM 的广泛辅助,创建、训练、评估并推断一个全新的计算机视觉模型。以一篇最近发表的论文为例,“通过深度学习识别 Sentinel-2 图像中船舶尾迹组件的关键点方法”(Del Prete 等人,IEEE GRSL,2023),这篇论文就是我们前进的动力和灵感来源。
论文链接:https://www.semanticscholar.org/paper/Keypoints-Method-for-Recognition-of-Ship-Wake-in-by-Prete-Graziano/a38d19b5ebaa2441e1bef2af0ecf24332bd6ca5b
为什么我们需要关心船舶在卫星图像中的行进方向,这项任务有什么难点呢?
通过静态图像识别船只的航行方向,对于那些需要监控水域中人类活动的组织来说,是极其宝贵的信息。比如,如果一艘船正朝向一个海洋保护区行进,这可能意味着需要警觉或者采取拦截措施。通常,全球范围内公开的卫星图像的分辨率不足以精确判断一艘船的朝向,尤其是那些在图像上只占据几个像素的小型船只(例如,Sentinel-2 的图像分辨率为 10 米 / 像素)。然而,即便是小型船只留下的水波纹也可能相当明显,这就为我们提供了一个判断船只朝向和行进方向的线索,即使船的尾部无法直接识别。
这项研究之所以引人注目,是因为它采用的模型基于 EfficientNetB0,这是一个足够小的模型,能够在不花费太多计算资源的情况下进行大规模应用。虽然我没有找到具体的代码实现,但作者公开了包括标注在内的数据集,这是值得赞赏的一步。
https://zenodo.org/records/7947694
开始我们的探索吧!
如同启动任何新的机器学习项目一样,首先对数据进行可视化是极富启发性的一步。
import os
import json
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import seaborn as sns# Define the path to your data directory
data_dir = "/path/to/your/data" # Adjust this to the path of your data directory
annotations_dir = os.path.join(data_dir, "annotations")
images_dir = os.path.join(data_dir, "imgs")# Initialize Seaborn for better visual aesthetics
sns.set(style="whitegrid", palette="muted")# Create a list to hold file paths for images and their corresponding annotations
image_files = []
annotation_files = []# Loop through the annotations directory to get the list of annotation files
for annotation_file in os.listdir(annotations_dir):if annotation_file.endswith(".json"):annotation_path = os.path.join(annotations_dir, annotation_file)image_file = annotation_file.replace(".json", ".png") # Assuming image file names match annotation file namesimage_path = os.path.join(images_dir, image_file)# Check if the corresponding image file existsif os.path.exists(image_path):annotation_files.append(annotation_path)image_files.append(image_path)# Plotting
num_examples = min(len(image_files), 10) # Limiting to 10 examples for visualization
fig, axes = plt.subplots(2, 5, figsize=(20, 8))for idx, (image_path, annotation_path) in enumerate(zip(image_files[:num_examples], annotation_files[:num_examples])):# Load the imageimg = Image.open(image_path).convert("RGB") # Ensure the image is treated as RGBdraw = ImageDraw.Draw(img)# Load the corresponding annotations and draw keypointswith open(annotation_path, 'r') as f:annotations = json.load(f)for point in annotations["tooltips"]:x, y = point["x"], point["y"]# Draw keypoints in cyan for visibilitydraw.ellipse([(x-10, y-10), (x+10, y+10)], fill='cyan', outline='black')# Plot the image with keypointsax = axes[idx // 5, idx % 5]ax.imshow(img)ax.axis('off')ax.set_title(f"Image {idx+1}")plt.tight_layout()
plt.show()
结果如下:
Sentinel-2 卫星图像中的船只和尾流关键点注释(青色)样本。
结果展示了几个关键的发现:首先,不同图片中标注的关键点(青色点)数量不一;其次,这些图像都是灰度的,并且形状各异;第三,尽管样本量不大,但其内部差异性却相当大。在明确了这些基础之后,我们接下来将编写用于加载数据集、定义神经网络、进行训练和验证的代码。在请求代码之前,我们会先向 LLM 提供一些背景信息。
这导致了以下输出:
import os
import json
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.io import read_image
from typing import Tuple, Listclass WakeDataset(Dataset):"""Custom Dataset for loading wake detection imagery and annotations, with preprocessing for EfficientNet.Images are grayscale and will be converted to three-channel images for model compatibility."""def __init__(self, annotations_dir: str, images_dir: str, transform=None):"""Initializes the dataset with paths to annotations and images directories, and a transform.Parameters:- annotations_dir: Path to the directory containing annotation JSON files.- images_dir: Path to the directory containing grayscale image files.- transform: A torchvision.transforms transform to be applied on a sample."""self.annotations_dir = annotations_dirself.images_dir = images_dirself.transform = transform if transform else self.default_transform()self.annotations = self._load_annotations()def default_transform(self):"""Defines the default transform to resize images, convert them to three-channel format, and to tensors.Returns:- A torchvision.transforms composition for image preprocessing."""return transforms.Compose([transforms.Resize((224, 224)), # Resize to the input size for EfficientNettransforms.Grayscale(num_output_channels=3), # Convert grayscale to three-channeltransforms.ConvertImageDtype(torch.float), # Ensures image is a float tensor])def _load_annotations(self) -> List[Tuple[str, List[int]]]:"""Loads and filters annotations with at least 2 keypoints.Returns:- A list of tuples, where each tuple contains the image filename and its keypoints."""annotations = []for filename in os.listdir(self.annotations_dir):if not filename.endswith('.json'):continuefilepath = os.path.join(self.annotations_dir, filename)with open(filepath, 'r') as f:data = json.load(f)if len(data['tooltips']) >= 2: # Filter annotations with at least 2 keypointskeypoints = [item for sublist in [[kp['x'], kp['y']] for kp in data['tooltips']] for item in sublist]annotations.append((filename.replace('.json', '.png'), keypoints))return annotationsdef __len__(self) -> int:"""Returns the number of items in the dataset."""return len(self.annotations)def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:"""Retrieves an item from the dataset at the specified index, with preprocessing for EfficientNet.Parameters:- idx: Index of the item to retrieve.Returns:- A tuple containing the preprocessed image tensor and its corresponding keypoints tensor."""img_name, keypoints = self.annotations[idx]img_path = os.path.join(self.images_dir, img_name)image = read_image(img_path) # Load imagekeypoints = torch.tensor(keypoints + [0] * (6 - len(keypoints)), dtype=torch.float) # Pad keypointsif self.transform:image = self.transform(image)return image, keypoints
最终得到的代码看起来还算合理,并且似乎注意到了问题的一些细节。虽然不是完美的(例如,选择在 default_transform 中将图像从灰度转为 RGB 颜色模式是个奇怪的决定,我个人更倾向于使用 pathlib 而非 os。listdir),但我们现在可以先放下这些细节不谈。
import torch
import torch.nn as nn
from torchvision import models
from typing import Tupleclass EfficientNetB0KeypointDetector(nn.Module):"""A custom EfficientNet-B0 model for keypoint detection in grayscale images.The network modifies the input layer to accept single-channel images and adds a custom output layer for predicting a fixed number of keypoints."""def __init__(self, num_keypoints: int = 3):"""Initializes the model.Parameters:- num_keypoints: The number of keypoints to predict. Default is 3, corresponding to the maximum number of keypoints in the dataset."""super(EfficientNetB0KeypointDetector, self).__init__()self.num_keypoints = num_keypoints# Load a pre-trained EfficientNet-B0self.efficientnet_b0 = models.efficientnet_b0(pretrained=True)# Modify the first convolutional layer to accept single-channel (grayscale) imagesoriginal_first_conv = self.efficientnet_b0.features[0][0]self.efficientnet_b0.features[0][0] = nn.Conv2d(in_channels=1, out_channels=original_first_conv.out_channels, kernel_size=original_first_conv.kernel_size, stride=original_first_conv.stride, padding=original_first_conv.padding, bias=False)# Modify the classifier to predict keypoints (num_keypoints * 2 because x, y for each keypoint)self.efficientnet_b0.classifier = nn.Linear(in_features=self.efficientnet_b0.classifier[1].in_features, out_features=num_keypoints * 2)def forward(self, x: torch.Tensor) -> torch.Tensor:"""Forward pass of the network.Parameters:- x: Input tensor of shape (batch_size, 1, 224, 224)Returns:- A tensor of shape (batch_size, num_keypoints * 2) representing the predicted keypoints."""return self.efficientnet_b0(x)# Example usage
if __name__ == "__main__":model = EfficientNetB0KeypointDetector()print(model)# Example input tensor (batch_size, channels, height, width)example_input = torch.randn(1, 1, 224, 224)output = model(example_input)print(output.shape) # Expected shape: (1, 6) for 3 keypoints (each with x, y coordinates)
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset import WakeDataset, get_dataloader # Assuming dataset.py is in the same directory
from net import EfficientNetB0KeypointDetector
from torchvision import transforms
import wandb# Initialize Weights & Biases
wandb.init(project="wake_model_llm_assist", entity="your_wandb_entity", mode="offline")def custom_collate_fn(batch):"""Custom collate function to pad keypoints to a fixed size.Parameters:- batch: A list of tuples (image, keypoints).Returns:- Batched images and keypoints with padding as necessary."""images, keypoints = zip(*batch)images = torch.stack(images, 0)max_keypoints = 3 * 2 # Maximum number of keypoints (3) times 2 for (x, y)padded_keypoints = [torch.cat((kp, torch.zeros(max_keypoints - len(kp)))) for kp in keypoints]keypoints = torch.stack(padded_keypoints, 0)return images, keypointsdef train_one_epoch(model: nn.Module, dataloader: DataLoader, optimizer: optim.Optimizer, device: torch.device):"""Trains the model for one epoch.Parameters:- model: The neural network model.- dataloader: DataLoader providing the training data.- optimizer: Optimizer used for model training.- device: The device to train on."""model.train()total_loss = 0.0for images, keypoints in dataloader:images, keypoints = images.to(device), keypoints.to(device)optimizer.zero_grad()outputs = model(images)loss = nn.MSELoss()(outputs, keypoints)loss.backward()optimizer.step()total_loss += loss.item()average_loss = total_loss / len(dataloader)wandb.log({"train_loss": average_loss})def evaluate(model: nn.Module, dataloader: DataLoader, device: torch.device):"""Evaluates the model on the validation set.Parameters:- model: The neural network model.- dataloader: DataLoader providing the validation data.- device: The device to evaluate on."""model.eval()total_loss = 0.0with torch.no_grad():for images, keypoints in dataloader:images, keypoints = images.to(device), keypoints.to(device)outputs = model(images)loss = nn.MSELoss()(outputs, keypoints)total_loss += loss.item()average_loss = total_loss / len(dataloader)wandb.log({"val_loss": average_loss})def main():# Setupdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = EfficientNetB0KeypointDetector().to(device)optimizer = optim.Adam(model.parameters(), lr=1e-3)# DataLoaderstrain_dataloader = get_dataloader('path/to/train/annotations', 'path/to/train/images', batch_size=4, transform=transforms.ToTensor(), collate_fn=custom_collate_fn)val_dataloader = get_dataloader('path/to/val/annotations', 'path/to/val/images', batch_size=4, transform=transforms.ToTensor(), collate_fn=custom_collate_fn)# Training loopfor epoch in range(num_epochs):train_one_epoch(model, train_dataloader, optimizer, device)evaluate(model, val_dataloader, device)# Log model checkpoints, images, or other metrics as neededwandb.log({})if __name__ == "__main__":main()
当我开始尝试运行代码时,遇到了一系列与张量形状不匹配相关的问题(想起了之前 RGB 与灰度图像的差异以及自定义 collate 函数的问题)。我尝试了几分钟的调试,但最终选择将所有模块代码复制到提示框里,让它帮我一起找出问题所在。
这样做有效解决了所有挂起的问题,我随即开始了模型的训练…… 但训练速度慢得出奇。这时我才意识到我没有利用 Apple 芯片上的 Metal 性能着色器(MPS)后端。虽然我平时不太在个人电脑上训练模型,MPS 对我还比较新,但我还是决定添加一个条件判断来使用它。
鉴于训练数据量适中(共 581 张图片),加上 EfficientNet 已经在 ImageNet 上进行了预训练,我决定进行 1000 个周期的训练。
500 个周期后,train_loss 仍在减少,但验证损失似乎已经收敛(至少对于快速评估而言足够收敛)。图像是根据权重和偏差复制的。
值得注意的是,尽管存在一些已经在卫星图像上进行预训练的基础模型(相对于 ImageNet),这些模型很可能对这项任务更有效,但这些网络体积比 EfficientNet 大得多,因此训练速度更慢(而且太新,还没被 LLM 训练数据集包括)。
现在,让我们编写一个推理脚本和一个命令行界面(CLI),来测试我们刚训练好的模型。
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import click# Assuming net.py contains the EfficientNetB0KeypointDetector class
# and dataset.py contains the transforms used during training
from net import EfficientNetB0KeypointDetector
from dataset import WakeDataset # Modify as necessary to import transformsdef load_model(model_path):"""Load the trained model from a file."""model = EfficientNetB0KeypointDetector()model.load_state_dict(torch.load(model_path))model.eval() # Set the model to inference modereturn modeldef preprocess_image(image_path, transform):"""Load and preprocess an image."""image = Image.open(image_path).convert("L") # Assuming grayscale conversion as in your datasetimage = transform(image)# Add batch dimension (BxCxHxW)image = image.unsqueeze(0)return imagedef plot_keypoints(image, keypoints):"""Plot keypoints on the image."""plt.imshow(image.squeeze(), cmap='gray') # Remove batch dimension and show imageplt.scatter(keypoints[:, 0], keypoints[:, 1], s=50, marker='.', c='red')plt.show()@click.command()
@click.argument('model_path', type=click.Path(exists=True))
@click.argument('image_path', type=click.Path(exists=True))
def run_inference(model_path, image_path):"""Run inference on an image using a trained model."""# Use the same transforms as during trainingtransform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Grayscale(num_output_channels=3),])model = load_model(model_path)image = preprocess_image(image_path, transform)# Perform inferencewith torch.no_grad():keypoints = model(image)keypoints = keypoints.view(-1, 2).cpu().numpy() # Reshape and convert to numpy for plotting# Load original image for plottingoriginal_image = Image.open(image_path).convert("L")plot_keypoints(original_image, keypoints)if __name__ == '__main__':run_inference()
让我们开始吧!
虽不完美,但对于第一次通过来说是合理的。
你可以在 GitHub 上找到包括所有模块、模型及权重(第 500 周期的)和一个 readme 的完整代码。我花了不到一个小时就生成了整个库,这个过程比写这篇文章花费的时间要少得多。所有这些工作都是在我的个人开发环境中完成的:MacBook Air M2 + VS Code + Copilot + 保存时自动格式化(使用 black、isort 等)+ 一个 Python 3.9.6 的虚拟环境(.venv)。
GitHub:https://github.com/pbeukema/wakemodel_llmassist
学到的教训
- 向模型提供尽可能多的相关上下文,帮助其解决任务。要记住,模型缺少许多你可能认为理所当然的假设。
- LLM 生成的代码通常远非完美,预测其失败的方式也颇具挑战。因此,在 IDE 中有一个辅助工具(比如 Copilot)非常有帮助。
- 当你的代码高度依赖 LLM 时,要记得编写代码的速度往往是限制因素。避免请求重复且不需要任何改动的代码,这不仅浪费能源,也会拖慢你的进度。
- LLM 很难 “记住” 它们输出的每一行代码,经常需要提醒它们当前的状态(特别是当存在跨多个模块的依赖时)。
- 对 LLM 生成的代码保持怀疑态度。尽可能多地进行验证,使用测试、可视化等手段。并且在重要的地方投入时间。相比于神经网络部分,我在 haversine 函数上花费了更多的时间(因为预期规模对性能的要求较高),对于神经网络,我更关注的是快速发现失败。
#04 LLM 与工程领域的未来
唯有变化是永恒的。
—— 赫拉克利特
在 LLM 引发的热潮和巨额资金流动的背景下,人们很容易一开始就期待完美。然而,有效利用这些工具,需要我们勇于尝试、学习并做出调整。
LLM 是否会改变软件工程团队的根本结构呢?可能吧,我们现在只是新世界的门前小道。但 LLM 已经使代码的获取变得更加民主化了。即使是没有编程经验的人,也能快速而容易地构建出功能性原型。如果你有严格的需求,将 LLM 应用在你已经熟悉的领域或许更为明智。根据我个人的经验,LLM 能够使得编写高效代码所需的时间缩短约 90%。如果你发现它们一直输出低质量的代码,那么也许是时候重新审视你的输入了。
原文链接:https://towardsdatascience.com/accelerating-engineering-with-llms-e83a524a5a13