FlashSequence: SORA视频生成长序列任务训练解决方案

作者:黄奕桐、沈雯婷、艾宝乐、王昂、九丰

摘要

我们提出了长序列训练方案 FlashSequence 并集成在 PAI-TorchAcc (阿里云机器学习平台开发的Pytorch上的大模型训练加速框架)中,该方案能够支持SORA类超长序列模型的高效训练。在两机 16 卡 A100 上,FlashSequence 能够训练 1M 的长序列模型,并达到了 51.7%的 MFU,接近占据 E2E 95%时间的 FlashAttention 53.5%的 MFU。

一、横空出世的 SORA

SORA 介绍

SORA 是一个文生视频的模型,可以根据输入的文本生成对应的视频。

图1: SORA,这个图的核心部分来自:https://openai.com/research/video-generation-models-as-world-simulators。加上了text encoder和DiT blocks。

SORA 在训练时输入的视频可以看成是若干帧图像, 通过visual encoder得到spatial tempral patches 并 flatten成一维作为transformer tokens,同时输入的文本通过 text encoder 生成 embed,两者送入 diffusion transformer (DiT) 进行训练。

DiT 模型

图2: DiT模型的网络结构,来自:https://arxiv.org/abs/2212.09748

图 2 是 DiT 模型的网络结构,可以看到,DiT 模型和 LLAMA 等 LLM 模型的结构上大体上相同,DiT 模型多了对于输入 latent 的 patchify 处理(转换为 LLM 模型需要的 tokens 输入)、在 DecoderLayer 中增加了与文本输入的交互等。基本上 LLM 有的 Multi-Head Self-Attention、Pointwise Feadforward(MLP)结构,DiT 模型也有。

从整体上来看,DiT 模型与 LLM 模型的区别不大,从计算上来看,主要的计算量还是在 Attention 和 MLP 部分,区别在于 Attention 部分通常不会使用 casual mask,导致计算量较大;从显存上看,文本交互的部分会引入额外的显存使用。

训练需求

和LLM只有一个模型不同的是,文生视频模型由多个模型组成,包括:对文本进行编码的 text encoder、对视频进行编码的 visual encoder、DiT、用于推理的和 DiT 模型相同大小的 EMA 模型等。一般来说,text encoder为一个LM模型,预计在几 B 左右;visual encoder 通常为包含conv的VAE,参数量较小;DiT 为一个中小规模的模型,在 1B ~ 30B 左右。其中text encoder, visual encoder 通常为 pretrained 模型,在一些场景下无需进行训练。整体的模型参数量在 10B ~ 60B 左右,与一般的 LLM 模型差别不大。

同时,与一般的 LLM 模型不同的是,文生视频模型的输入 token 数在几十K到几M之间,因此文生视频模型训练的核心挑战是对长序列的高效支持

二、 workload 分析

计算量

对于 text encoder 和 DiT 模型,主要计算量为L(\alpha bsh^2+\beta bs^2h),其中 L 表示 decoder layer 层数、b 表示 micro batch size,s 表示 sequence length,h 表示 hidden dim size,\alpha表示 linear 层的系数,\beta表示 attention 层的系数(在一般的非 causal attention 上前向是 4,后向是 8)。在 s 比较大时,如几十K到几M 时,其他的算子如 element wise 算子在这里可以忽略。text encoder 的 s 通常比较固定的,且远小于视频输入的 token 数。对于 VAE 模型,其主要计算量为一系列卷积操作,在长序列场景,计算量一般小于 DiT 模型。

因此,在视频输入的 token 数 s 比较大时,如几十K到几M 时,整个训练的计算量绝大部分在 DiT 模型上。同时,在 s 比较大的时候,attention 部分的计算量会按照 token 数平方增长,而 linear 层的增长只是线性增长,因此,attention 部分的计算量会逐渐成为整个训练过程中的瓶颈。在 1M 场景,attention 部分的计算时间可以达到 E2E 时间的 95%。

显存使用

如前所述,文生视频整体的模型参数量在 10B ~ 60B 左右,与一般的 LLM 模型差别不大。在使用 7B 的 text encoder 和 7B 的 DiT 模型时,模型常驻的显存大概在 130GB(包含各个模型的参数和 DiT 部分的 optimizer state)。

在文生视频的训练过程中,通常只会训练 DiT 部分,而 text encoder 和 VAE 部分通常不会进行训练,所以显存使用主要在 DiT 的训练部分。text encoder 和 VAE 部分的显存主要考虑临时的 tensor 使用会不会导致 OOM。

在目前的 LLM 模型中,通常会使用 FlashAttention 进行性能和显存优化。在使用了 FlashAttention 之后,DiT 部分的显存使用为\gamma Lbsh,其中 \gamma为一个 Decoder Layer 的显存使用系数,取决于 DiT 模型的实现,这个值会有所变化,但是由于 text 的输入,通常会比普通的 LLM 模型大,例如,一般的 LLM 模型可以是 34,而 DiT 模型会达到 60 ~ 70。在 tokens 数为 1M 、micro batch size 为 1 的场景下,7B 的 DiT 模型总的 activation 显存使用可以达到 8000 GB 以上。

可以看到,相比于计算量按照 token 数 s 平方增长,显存量是按照 token 数 s 线性增长的,这也为后续显存优化提供了参考。

三、FlashSequence 长序列训练方案

FlashSequence

基于 workload 分析,我们提出了 FlashSequence 这一解决方案:

  • 分布式策略:

  • 为了切分中小规模模型的参数,FlashSequence 使用了 FSDP 这一分布式策略,同时,FlashSequence 在 FSDP 外面嵌套使用了 DP 提升多机拓展性。

  • 为了切分长序列训练场景下的 activation,FlashSequence 使用了 context parallel 对 sequence 维度进行切分,同时,FlashSequence 提出 2D context parallel 的方案减少 context parallel 跨机的通信开销。我们还去除了使用 context parallel 之后带来的冗余重复计算。

  • 显存优化策略:

  • FlashSequence 通过使用 CPU offloading 将 activation offload 到 CPU 内存上减少显存,同时极大减少了 gradient checkpoint(GC)带来的额外重算计算量。CPU offloading 的策略在长序列场景下数据传输时间能够和计算时间完全 overlap,相比 GC 能够在不影响 E2E 时间的情况下减少显存。

  • 为了避免 CPU 内存 OOM 和减少一部分 offloading 时间,FlashSequence 使用了 selective GC,selective GC 会优先选择显存计算比高的部分。

  • FlashSequence 还使用了 PyTorch expandable allocator 解决长序列场景下显存碎片过多的问题。

分布式策略

整体思路

在模型中存在两种类型的 tensor,一种是参数相关的包括模型参数、optimizer state、gradients,另一种是 activation。由于长序列场景下参数和 activation 都是不可忽视的,为了避免 OOM,我们需要同时切分参数和 activation。例如,常驻的参数和 optimizer state 可以达到 130GB,而 activation 在 1M 场景下可以达到 8000GB 以上。

参数切分

在参数的切分方面,我们存在多种选择,比如 TP、PP、FSDP 等,但是由于模型本身规模不是特别大,同时考虑到计算和通信的 overlap 情况,FlashSequence 选择了 FSDP 这种参数切分策略。不同于 TP 和 PP,FSDP 的通信除了第一个 layer 的 allgather 之外都能和计算 overlap,没有和计算 overlap 的通信时间在 FSDP 较小的情况下通常可以忽略。虽然 TP 也可以同时切分 activation,但是 TP 会引入无法 overlap 的通信,同时 PP 需要比较大的 gradient accumulation steps 才能掩盖 bubble。

activation 切分

对于 activation,DiT blocks 输入的 shape 为 [batch, sequence, hidden_dim]。由于长序列场景 activation 非常大,所以 micro batch size 通常为 1,这一维度无法切分。在 sequence 维度的切分目前存在 context parallel 如DeepSpeed-Ulysses 和 Ring Attention,以及 Megatron 的 sequence parallel。在 hidden_dim 的维度的切分主要是 Megatron 的 tensor parallel(以切分 weight 的方式实现对 activation 的 hidden dim 维度的切分)。纯粹的 tensor parallel 在 layer norm 等部分还是需要全量的 tensor,这一点在长序列场景是不可接受的。通常目前的主流做法是 Megatron 的 TP-SP 切分方式,这种切分方式和 context parallel 一样可以完整切分 layer 内的 activation。

对于 TP-SP 的切分方式,通信量为\frac{16Lsbh(t-1)}{t},其中t为 TP-SP 的数目,L 为 layer 数、s 为 sequence、b 为 micro batch size、h 为 hidden dim。对于 context parallel,以 DeepSpeed-Ulysses 为例,通信量为\frac{8\Psi(t-1)}{t}+\frac{16Lsbh(t-1)}{t^2},其中\Psi为模型参数量,前面一项是对模型参数的 all reduce 通信,后面一项是对 self attn 的 q、k、v、out 的 alltoall 通信。对于DeepSpeed-Ulysses,模型参数的 all reduce 可以被计算 overlap(类似 DDP),而后面不能 overlap 的通信小于 TP-SP 的切分方式。

从上面的对比可以看出,DeepSpeed-Ulysses 不能 overlap 的通信理论上是小于 TP-SP 的(即使考虑 TP-SP 后向通信可以 overlap)。同时,我们使用 FSDP 切分参数之后也不再需要 TP 对模型参数进行切分。在这种场景下面,FSDP 只是一种切分模型参数的分布式策略,其数据并行的含义被弱化了,不再是开启多少 FSDP 读取多少不同的数据样本,只需要保证 context parallel 的一个 group 内读取相同的数据即可。

综上所述,FSDP+context parallel 的方式优于 TP-SP 的切分方式。同时context parallel 还可以使用 Ring Attention 的方式进一步减少不能 overlap 的通信。

FSDP+DP

在文生视频这种中小规模模型的场景下,FSDP 不需要开很大就可以避免 OOM,在 7B 及以下规模,使用 FSDP=8 就足够满足显存使用需求,同时还能使用高速的机内带宽进行通信。

在更多的卡数下,FSDP 的拓展性会存在一些问题,为了避免这些问题,FlashSequence 进行了 DP 和 FSDP 的嵌套,在外层使用 DP,在内层使用 FSDP。虽然 FSDP 和 DP 的通信都能被计算 overlap,但是 DP 的通信量小于 FSDP,同时 DP 只在计算时间更长的后向进行通信,所以,DP 相比于 FSDP 拥有更好的多机拓展性。在使用了 DP+FSDP 的组合之后,不只能满足参数切分的需求,同时提升了多机的拓展性。

Context Parallel

context parallel 的好处是只在 attention 部分和 transformer 模型之后引入了额外通信,在其他的部分比如 MLP 均不需要额外的通信,而且 gradients 的同步使用 DP+FSDP 就可以完成。同时,在 context parallel 的作用域之内 activation 和计算可以被均匀切分。

目前的 context parallel 都是在一开始就对 sequence 维度进行切分。唯一的区别在于 attention 部分的处理,DeepSpeed-Ulysses 会将 sequence 维度的切分转换为 head 维度的切分再进行 attention 的计算,而 RingAttention 会依然保留 sequence 维度的切分对 attention 的计算进行特殊处理。

DeepSpeed-Ulysses

图3:DeepSpeed-Ulysses,来自:https://arxiv.org/abs/2309.14509

如图 3 所示,DeepSpeed-Ulysses 会对 q、k、v 分别进行 all to all,将 sequence 维度的切分转换为 head 维度的切分再进行 attention 的计算,然后再对 attention 的输出进行 all to all,将 head 维度的切分转换回 sequence 维度的切分。由于 attention 的计算在 head 维度是并行的,所以这样操作之后不需要对 attention 的计算进行额外处理。可以看到,DeepSpeed-Ulysses 切分的是 head 维度,所以这使得DeepSpeed-Ulysses 的并行数目最多开到 head 的大小。

单个 layer 内 DeepSpeed-Ulysses 的通信和计算对比为:\frac{16sbh(t-1)}{t^2B}:\frac{\alpha bsh^2+\beta bs^2h}{tF}= \frac{16(t-1)F}{t(\alpha h+\beta s)B},其中 F 为 GPU 计算 FLOPS,B 为 alltoall 通信带宽。可以看到,随着 s 的变大,DeepSpeed-Ulysses 的通信占比会逐渐降低,最终达到一个可以忽略的程度,在 seq len = 256K 单机 8 卡的场景下,DeepSpeed-Ulysses 的通信时间在 E2E 的时间占比已经低于 1%,在 seq len=64K 的场景下也只有 2%~ 3%。但是,在涉及到跨机通信时,DeepSpeed-Ulysses 的通信开销由于机间通信带宽较低会变得不可忽视。在 256K 的场景下 2 机 16 卡会达到 10%以上。

Ring Attention

图4:Ring Attention,来自:https://arxiv.org/abs/2310.01889
 

如图 4 所示,Ring Attention 的实现过程中会保持 sequence 维度的切分。Ring Attention 会以 ring 的方式发送和接收其他 device 上的 k 和 v,同时计算本地的 q、k、v 分块的 attention,对输出进行一些矫正保证正确性。这种方式可以使得计算和通信能够 overlap 起来。

Ring Attention 计算和通信 overlap 的理论条件是:考虑前向的一个小的 Attention,通信量为 k 和 v:4bsh,计算量为:4bs^2h,所以计算能够掩盖通信的条件为:4bs^2h/F \ge 4bsh/B \implies s \ge F/B,其中 F 为 GPU 计算 FLOPS,B 为 send/recv 通信带宽。在实际运行过程中,还需要考虑 Flash Attention 的计算利用率和 send/recv 的带宽利用率,根据机器和算子性能的不同,在涉及跨机通信时,在 A100 上面下单卡需要 24K 的序列长度才能 overlap。

可以看到,Ring Attention 的优势是通信能够和计算 overlap,但是需要保证 s 切分后单 GPU 卡上的句子长度满足 overlap 条件。

2D context parallel

对于 context parallel,由于只有 attention 部分存在通信,所以我们需要考虑的只是 attention 部分的处理。在 attention 部分,activation 的 shape 为 [batch, sequence, heads, head_dim],由于维度的大小关系,在这其中 sequence、heads 和 head_dim 是可以进行切分的 sequence 和 heads 的切分分别代表了 Ring Attention 和 Ulysses。head_dim 维度由于是矩阵乘的 contracted 维度,这种维度的切分一般不可避免会引入无法 overlap 的 allreduce 或者 allgather 等通信算子,这会使得通信量大于 Ulysses 的 alltoall。

除此之外,我们还可以同时切分 sequence 维度和 heads 维度。在这种情况下,我们只需要进行一部分通信量较少的 alltoall 通信将一部分 sequence 维度转换为 head 维度,同时,针对剩余的 sequence 维度的切分,可以使用可以 overlap 的 send/recv 通信进行处理。由于 alltoall 的跨机性能较差同时 send/recv 的通信时间可以被计算 overlap,FlashSequence 让外层 alltoall 的通信使用机内的 nvlink 进行通信,内层的 send/recv 使用机间带宽进行通信。我们称这种 context parallel 为 2D context parallel。

2D context parallel 相比 DeepSpeed-Ulysses 可以减少没有 overlap 的 alltoall 时间,相比 Ring Attention 可以在单机 tokens 数较小时减少 send/recv 的次数和 attention 的计算时间,使得 send/recv 和计算可以 overlap。这种设计在 context parallel 涉及跨机通信时会显著减少没有和计算 overlap 的通信时间在 E2E 中的占比,在 seqlen = 256K、2 机 16 卡的场景可以将 DeepSpeed-Ulysses 的通信时间从 10%以上减少到低于 1%。

分布式策略的冗余计算优化

在上面我们提到使用 context parallel 对 sequence 维度进行切分,但是这个切分是存在边界的,一般情况下我们会在 activation 的 shape 转换为 transformer 需要的 shape 之后(比如 DiT 模型的 patchify 之后)才对sequence 维度进行切分。由于 context parallel 需要 group 内的 device 读取相同的数据,这就会导致从 dataloader 读取样本到 sequence 维度切分之间在 group 内的 device 进行的是相同的计算。这一部分在 SORA 模型中通常是 visual encoder 和 text encoder 模型,分别负责对视频和文本进行编码。这些计算在中小长度的序列长度下占比比较高,取决于具体模型实现和序列长度,可以达到 20%甚至 70%。

为了为了去除这一部分的冗余计算,我们可以让 context parallel group 内的 device 读取不同的数据,在需要 sequence 维度切分时进行一个 context parallel 大小的 loop 遍历,依次对前面不同 device 读取的数据进行 broadcast ,使得 transformer 的部分输入的数据一样。这样处理之后,VAE+text encoder 的时间占比会减少到之前的 1/context parallel size,带来 E2E 性能提升。

显存优化策略

使用分布式策略可以进行模型参数和 activation 的切分以减少显存,但是分布式策略的切分会引入通信开销,在更多卡参与切分时,这些开销会逐渐变得不可忽视。例如 activation 在 1M 场景下可以达到 8000GB 以上,使用 80GB 的 GPU 就需要至少 100 张卡,这是不可接受的。因此,我们还需要一些显存优化策略来进一步减少显存。

在目前的实践中,gradient checkpoint(GC)是较为常见的策略,GC 的重点是选择合适的重算部分以减少额外的计算开销。CPU offloading 在 DeepSpeed 中通常是对参数进行 offload,但是在长序列场景,我们发现 CPU offloading 在 activation 上相比 GC 也能带来明显的性能提升。显存碎片在长序列场景也会经常遇到,经常会出现 PyTorch reserve 了 10 几 GB 的显存却无法分配一个几百 MB 的 tensor,进而导致 OOM。

Selective GC

gradient checkpoint(GC)的思想是在前向过程中不保留 activation,在后向时重新运行一次前向生成 activation。在使用 GC 的过程中,最主要的问题是选择好重算的部分。目前主流的做法是对整个 decoder layer 进行 GC(full GC)或者对 Attention 部分进行 GC(Megatron selective GC)。

但是,如前所述,在长序列场景,attention 部分占据了绝大部分的计算,重算 attention 的开销很大。同时,与较小序列不同的是,在长序列场景,MLP 部分也是可以考虑进行 GC 的,在 1M 场景,MLP 的 E2E 占比已经低于 5%,重算的开销较小。

FlashSequence 会优先选择显存计算比高的部分。按照模型中的算子 FLOPS以及算子节省的显存量,FlashSequence 会选择依次节省显存收益较大的部分。

CPU Offloading

CPU Offloading 的思想是将部分 tensor 从显存传输到 CPU 内存上,在需要时再 prefetch 回来。在 DeepSpeed 中,这一技术通常只在参数上使用,这是因为之前的场景 offload activation 会有比较大的 PCIe 传输开销。然而,在长序列场景,如上面所述,相比于计算量按照 token 数 s 平方增长,显存量是按照 token 数 s 线性增长的,这就使得在长序列场景,计算的时间会逐渐超过 offload activation 的 PCIe 传输时间。在 64K 场景,offload 一层 decoder layer activation 的 PCIe 传输时间可能需要 2 ~ 3 层 layer 的计算进行 overlap,而在超过 256K 的场景,offload 一层 decoder layer activation 的 PCIe 传输时间仅需一层 layer 的计算就可以 overlap。在不同的模型下,这个 overlap 的 layer 的数目会有所区别,但是随着序列长度 s 的增长,最终都会达到一个可用的状态,比如在 64K 上使用 offloading 就可以无损减少多层 decoder layer 的 activation 显存占用。

以一层 decoder layer 的 activation 作为 offload 的粒度,offload 一层 decoder layer 可以达到和 GC 一层 decoder layer 类似的显存减少量,同时在长序列场景,offloading 的传输时间能够被计算时间 overlap,相当于在 E2E 性能无损的情况下减少了显存,相比于 GC 能够减少额外的计算开销。

虽然 offloading 在长序列场景拥有比 GC 更好的性能表现,offloading 本身也存在一些问题:

  1. 较短的序列长度需要多层 layer 的计算才能 overlap 传输时间,当然这个在更长的序列长度上不是问题。

  2. offloading 需要使用 CPU 的 pinned 内存,而 CPU 的内存虽然有 1TB ~ 2TB,但是在长序列场景,8 张卡的 offloading 所需要使用的内存总量会很快超过 CPU 的内存。这可以通过结合部分 selective GC 进行解决。

  3. offloading 会和跨机通信(RDMA 也会使用一部分 PCIe 资源)竞争 PCIe,这种影响在前向计算中比较明显。但是在使用了 DP+FSDP 和 2D context parallel 的组合之后,大部分通信都是使用 nvlink,机间通信也能够被计算 overlap,所以对 E2E 的性能影响不大。

基于上述问题,FlashSequence 在优先 CPU offloading 的同时使用 selective GC,避免 CPU 内存 OOM 的同时减少重算的 FLOPS。

显存碎片

在长序列场景,一个 tensor 的显存使用可以达到几百 MB 甚至几 GB,在这种场景下,PyTorch 的 caching allocator 会导致比较多的显存碎片。

图5: caching allocator的显存分配情况

图6: expandable allocator的显存分配情况

图 5 是某个长序列场景 OOM 时的显存使用情况,其中空白部分是还没分配但是被 PyTorch reserve 的显存。这个 OOM 本来是不应该出现的,因为这个时候请求分配的 tensor 只需要 500 多 MB 的显存,而 PyTorch reserve 的未分配显存有 7.5GB。但是因为 PyTorch reserve 的未分配显存都是不连续的(大的空白是 200 多 MB),所以导致了 OOM。这个显存碎片问题在更长的序列场景会更加常见,有时候可以达到 10 ~ 20GB 的显存碎片。

在 PyTorch 的 2.2 及以上版本,引入了expandable 的 allocator,这个 allocator 可以在有更大显存分配请求的情况下拓展已有的空闲显存块,进而减少原始 caching allocator 的显存碎片。从图 6 中可以看到显存碎片低于 1GB。在大部分长序列场景下,expandable allocator 的显存碎片都比 caching allocator 的小,同时在我们场景下性能基本没有变化。

计算优化

FlashAttention 优化

FlashAttention是DiT 模型中attention部分的常用优化手段,FlashAttention的前向计算量为4bs^2hFLOPS,后向的计算量是10bs^2hFLOPS(FlashAttention 在后向存在部分重算),如前文对计算量的分析,随着序列长度的增长,attention部分在端到端的训练时间中甚至占比到95%以上。因此,FlashAttention的计算性能,也成为整个训练任务最为dominate的部分。

我们对TriDao版本的FlashAttention2在不同序列长度下的性能做了A100上的kernel的性能测试,以batch-size=1, hidden-dim=128为例。通过图 7 和图 8 的性能数字,可以看到:

1.  序列长度和N_HEADS同时太大(N_CTX>=512K and N_HEADS>=8)或太小(N_CTX<=4K 或 N_CTX<=32K and N_HEADS<=4),都会造成一定程度的性能损失。这个性能数字也可以指导对N_HEADS和N_CTX的分片;

2.  根据FlashAttention的性能,我们可以预估一次迭代的训练时长。如1M下,按照前向227TFLOPS/s,后向191TFLOPS/s计算,一个layer计算batch-size=1, hidden-size=4096的FlashAttention计算时间为4bs^2h \times 10^{-12}/227+10bs^2h\times 10^{-12}/191,约5min,L个layer的FlashAttention的用#n_gpus并行计算时间为5Lmin/n_gpus。

图7: FlashAttention在A100上面不同序列长度和HEAD大小下前向的性能

图8: FlashAttention在A100上面不同序列长度和HEAD大小下后向的性能

不同的 Attention 实现

前文中我们预估了1M序列长度下FlashAttention的计算时间,可以看到由于序列长度平方项的计算量的存在,一轮迭代的时间在分钟级别,导致模型训练的速度非常慢。有很多降低Attention二次方序列长度计算量的工作,提升Transformer效率的工作,主要包含以下几种:Linear Attention、Sparse Attention、Mamba、Compress Memory。这些算法由于计算量的减少,在模型的效果方面和原始Transformer存在差异,后续我们将对这方面的工作进行探索,并集成到系统中。

  • Linear Attention:

线性Attention的核心思想是用kernel函数代替softmax,然后通过矩阵乘法的结合律,将序列长度维度的两层循环减少为一层循环,从而将Attention的计算量从序列长度的平方项降低为线性项。线性Attention减少计算量也会造成模型的精度损失,取决于核函数的设计。

相关的工作有Transformers are RNNs, RMKW, Linformer, Lightning, DiJiang等。

  • Sparse Attention:

Sparse Attention通过让每个token对应的向量只跟部分token对应的向量(可见域)计算相关度,使Attention矩阵计算变得稀疏。在Sparse Attention中,如何选择有相关性的元素进行计算,成为影响模型精度的关键。相关工作探索了固定可见域,如OpenAI 2019的Sparse Transformers,与输入数据相关的可见域,如ICLR 2024的Transformer-VQ等多种算法。

  • Mamba

Mamba基于状态空间模型(SSM, State space models),结合了RNN(表达隐藏状态和输入关系,去掉非线性激活函数)和CNN(并行训练),并根据输入动态调整模型的选择性参数,包括当前输入和历史状态信息对输出的影响系数、无关信息的过滤参数等,并通过硬件感知算法来优化计算效率,将Transoformer的计算效率变成序列长度线性相关。

相关的工作包括Mamba,Mamba在视觉模型上的应用如ZigMA, ICLR 24 Diffusion SSM等。

  • Compress Memory

Compress Memory是一种将长序列切分为一个个segment,将历史segment的信息编码到一个固定大小的memory中,将当前segment的attention和memory信息concat到一起,从而将计算复杂度降低到1/n_segments,而memory占用的空间为固定大小,以此方式可以计算无限长度的序列的attention。相关工作如Infini-attention,ICAE等。

四、实验

图9: 不同sequence length和context parallel(CP)下的MFU

为了衡量我们提出的 FlashSequence 的解决方案,我们以纯 Ulysses+FSDP 并使用 full GC 作为 baseline,其中 full GC 的 layer 数依据显存使用决定,少于模型 layer 数。图 9 展示了在 A100 上不同sequence length和context parallel(CP)下的MFU,可以看到,FlashSequence 的 MFU 相比 baseline 平均提高了 11.75%,性能平均提升了 23.3%。同时,FlashSequence 的方案在长序列场景可以获得和 FlashAttention 接近的 MFU,比如在 1M、CP=16 的场景下,FlashSequence 的 MFU 为 51.7%,接近占据 E2E 95%时间的 FlashAttention 53.5%的 MFU。

五、总结与展望

PAI-TorchAcc(Torch Accelerator)是阿里云机器学习平台开发的Pytorch上的大模型训练加速框架。PAI-TorchAcc 通过进行分布式优化、计算优化、显存优化等,为包括 SORA 模型在内的Pytorch上的模型提供高效训练支持。

目前,FlashSequence 已经集成到 PAI-TorchAcc 中,并在开源的 DiT 模型上验证了效果。此外,由于 SORA 所使用的 DiT 类模型结构与 LLM 模型基本类似, FlashSequence 也可以应用在大部分长序列训练场景。后续我们会陆续开源这些工作。

同时,目前长序列的最主要瓶颈都在 FlashAttention 的计算上面,如何优化 FlashAttention 的计算将成为长序列场景下的主要问题。由于 FlashAttention 的计算量按照 token 数平方增长,未来更可能的优化方向是探索计算量更低的 Attention 实现比如线性的 Attention,同时低精度如 FP8 训练、稀疏训练等也都是一些可以探索的方向。

【招聘】最后,如果你对大模型训练加速技术感兴趣,欢迎加入到我们的团队中。目前研究型实习生和社招都在火热招聘中,欢迎投递简历到研究型实习生 - 基于负载与硬件特性协同的大模型训练加速技术研究 或邮箱 wenting.swt@alibaba-inc.com

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/26115.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CAS Server Restful接口实现后台认证

背景 对于一些比较复杂定制化登录页的情况下&#xff0c;之前提到过可以自定义修改使用CAS Server提供的登录页这种操作已经明显跟不上复杂定制场景了&#xff0c;所以CAS Server也提供了支持Restful接口&#xff0c;支持服务端后台登陆&#xff0c;对于复杂登陆场景时&#x…

无忧易售ERP - 助力您的沃尔玛平台刊登之旅

在跨境电商的广阔天地里&#xff0c;沃尔玛平台以其庞大的流量与高质量的顾客群体吸引了无数卖家的目光。但要想在这片蓝海中乘风破浪&#xff0c;高效、精准的产品刊登策略是关键。今天&#xff0c;我们将借助强大的无忧易售ERP系统&#xff0c;为您带来一站式沃尔玛平台产品刊…

何为屎山代码?

在编程界&#xff0c;有一种代码被称为"屎山代码"。这并非指某种编程语言或方法&#xff0c;而是对那些庞大而复杂的项目的一种形象称呼。屎山代码&#xff0c;也被称为"祖传代码"&#xff0c;是历史遗留问题&#xff0c;是前人留给我们的"宝藏"…

【Go语言精进之路】构建高效Go程序:了解map实现原理并高效使用

&#x1f525; 个人主页&#xff1a;空白诗 &#x1f525; 热门专栏&#xff1a;【Go语言精进之路】 文章目录 引言一、什么是map1.1 map的基本概念与特性1.2 map的初始化与零值问题1.3 map作为引用类型的行为 二、map的基本操作2.1 插入数据2.2 获取数据个数2.3 查找和数据读取…

前端开发部署:Visual Studio Code + vue

〇 说明 本教程全部采用默认安装路径&#xff0c;因为在进行自定义路径安装的时候&#xff0c;需要配置各种环境变量&#xff0c;在这个配置过程中&#xff0c;可能出现各种很混乱的问题。 一 安装Node.js 1 下载https://nodejs.org/en 2 按照默认NEXT执行 C:\Program Files…

文件传输系统主要用于哪些场景?要如何选型?

文件传输系统是一种用于在不同设备、网络或地理位置之间传输文件的产品解决方案&#xff0c;在各行各业中的应用还是很广泛的。 文件传输系统可以应用于多种场景&#xff0c;主要包括&#xff1a; 1、企业内部文件共享&#xff1a;在公司内部不同部门或团队之间共享文件&#…

9、编写业务逻辑

9、编写业务逻辑 9.1 编写博客接口(新增和查询一起编写了) 响应实体:(随便封装的,可以根据自己的想法封装) // entity/Response package com.example.fullstackblogback.commen;import lombok.Data;import java.util.List;@Data public class Response<T> {pri…

[经验] 梦见自己游泳是什么意思 周公解梦 #职场发展#微信#媒体

梦见自己游泳是什么意思 周公解梦 1、梦见自己游泳 梦见自己游泳是一种非常常见的梦境&#xff0c;而这个梦境通常代表着我们内心深处的渴望和憧憬。 游泳是一项需要技巧和勇气的运动&#xff0c;它需要我们在水中保持平衡和控制自己的呼吸。因此&#xff0c;梦见自己游泳通常…

代码随想录算法训练营第三十五天| 1005.K次取反后最大化的数组和、134. 加油站、135. 分发糖果

LeetCode 1005.K次取反后最大化的数组和 题目链接&#xff1a;https://leetcode.cn/problems/maximize-sum-of-array-after-k-negations/description/ 文章链接&#xff1a;https://programmercarl.com/1005.K%E6%AC%A1%E5%8F%96%E5%8F%8D%E5%90%8E%E6%9C%80%E5%A4%A7%E5%8C%9…

idea开发工具清除Git凭证(含Git凭证管理策略)

前言 网上很多人出现这个问题&#xff0c;也有很多文章或博客来说明这个问题&#xff0c;但是几乎都没有说到点子上&#xff0c;全网几乎都说清除credential.helper配置或者清空windows凭证管理器&#xff0c;还有一些文章说清除IDEA缓存&#xff0c;其实都是不对的。 creden…

黑龙江三级等保测评内容与等级划分

一、黑龙江等保三级测评内容 黑龙江等保三个层次&#xff0c;也就是三个级别的信息安全防护&#xff0c;这是我们国家的一项基础性的信息安全体系。在此基础上&#xff0c;提出了一种适用于非银行机构的最高级别的保障制度&#xff0c;即当该制度遭到破坏时&#xff0c;可能给…

Bankless:为什么 AI 需要 Crypto 的技术?

原文标题&#xff1a;《Why AI Needs Crypto’s Values》 撰文&#xff1a;Arjun Chand&#xff0c;Bankless 编译&#xff1a;Chris&#xff0c;Techub News 原文来自香港Web3媒体&#xff1a;Techub News 人工智能革命的梦想一直是一把双刃剑。 释放人工智能的潜力可以解…

springboot3一些听课笔记

文章目录 一、错误处理机制1.1 默认1.2 自定义 二、嵌入式容器 一、错误处理机制 1.1 默认 错误处理的自动配置都在ErrorMvcAutoConfiguration中&#xff0c;两大核心机制&#xff1a; ● 1. SpringBoot 会自适应处理错误&#xff0c;响应页面或JSON数据 ● 2. SpringMVC的错…

深入解析ETL与ELT架构:数据集成技术的演进与发展

摘要&#xff1a;随着大数据时代的到来&#xff0c;数据集成成为企业信息化建设的重要环节。本文将深入探讨ETL与ELT两种架构&#xff0c;分析它们在数据处理、性能、可扩展性等方面的差异&#xff0c;为企业数据集成提供技术指导。 一、引言 在大数据时代&#xff0c;企业需要…

13- 函数的定义与使用+形参实参区分

13- 函数的定义与使用形参实参区分 文章目录 13- 函数的定义与使用形参实参区分一、函数的定义与使用1.1 函数的结构1. 函数头2. 函数体 1.2 示例代码例子 1&#xff1a;无参数和无返回值的函数例子 2&#xff1a;带参数和返回值的函数 1.3 函数的基本语法1.4 函数的使用示例例…

Faster-RCNN基本思想和网络结构

简单来说&#xff0c;Faster RCNN RPN Fast RCNN RPN 是指 Region Proposal Network&#xff0c;建议区域生成网络。 Faster RCNN 中用 RPN 来代替了 Fast RCNN 中的SS算法。 算法流程&#xff1a; &#xff08;1&#xff09;将图像输入CNN网络得到相应的特征图。 &#x…

单机多卡分布式训练策略——MirroredStrategy

前言 分布式训练是一种用于在多个设备或机器上同时训练深度学习模型的技术&#xff0c;它有助于减少训练时间&#xff0c;允许使用更多数据更快训练大模型。分布式训练重点关注数据并行性&#xff0c;本次试验使用的是单机多卡的分布式训练策略&#xff0c;也就是 MirroredStr…

算法题目学习汇总

1、二叉树前中后序遍历:https://blog.csdn.net/cm15835106905/article/details/124699173 2、输入一棵二叉搜索树&#xff0c;将该二叉搜索树转换成一个排序的双向链表。要求不能创建任何新的结点&#xff0c;只能调整树中结点指针的指向。 public class Solution {private Tr…

多模态AI的挑战与早期壁垒的构建

伴随着Sora、GPT40的推出&#xff0c;多模态AI逐渐成为研究的热点和应用的趋势。然而&#xff0c;多模态AI的发展并非一帆风顺&#xff0c;它面临着诸多挑战和壁垒。 一、多模态AI的难点 多模态AI的核心在于将不同模态的信息&#xff08;如文本、图像、音频、视频等&#xff…

离线翻译器下载哪个好?这几个翻译器用过的人都说好

面对跨文化交流的挑战&#xff0c;如国际旅行或多元工作环境&#xff0c;语言障碍尤为突出。 特别是在信号弱或无网络覆盖的地区&#xff0c;翻译需求变得更加迫切。此时&#xff0c;一款优质的离线翻译app显得尤为重要。它能够在没有网络支持的情况下提供即时翻译服务&#x…