Transformer推理性能优化技术很重要的一个就是K V cache,能否通俗分析,可以结合代码?

原文:Transformer推理性能优化技术很重要的一个就是K V cache,能否通俗分析,可以结合代码? - 知乎

为什么要研究KV cache?


设输入序列的长度为 s ,输出序列的长度为 n ,模型深度为l,维度为h,以 FP16 来保存KV cache,那么KV cache的峰值显存占用大小为 b(s+n)h∗l∗2∗2=4blh(s+n) 。这里第一个2表示K/V cache,第二个2表示 FP16 占2个bytes。
以 GPT3 (175B) 为例,对比 KV cache 与模型参数占用显存的大小。GPT3 模型weight占用显存大小为350GB (FP16),层数 l为96,维度h为12888。

batch sizes+nKV cache(GB)KV cache/weight
4409675.50.22
1640963020.86
64409612083.45


参考上图,随着 batch size 和 长度的增大,KV cache 占用的显存开销快速增大,甚至会超过模型本身。从LLM的趋势上而讲,主要有三个方面来说明kv cache优化的必要性:
1、总体趋势上LLM 的窗口长度在不断增大,因此就出现一组主要矛盾,即:对不断增长的 LLM 的窗口长度的需要与有限的 GPU 显存之间的矛盾。因此优化 KV cache 非常必要。
OpenAI API场景,API最烧钱的是输入而非输出,输入包括prefill prompt 和conversation,长度动辄数十K token。虽说每输入token比每输出token便宜,但能够降低kv重新计算的开销,无论是硬件资源门槛,还是模型推理降本,都有着极为积极的作用。
2、对于消费级显卡这种性价比较高的显卡而言,显存容量相对较小,KV cache从一定程度上降低了模型的batch size,因而KV cache优化在工程落地中更显重要。

框架模型input/output机器配置tokens/s最佳batch_sizegpu/cpu负载
TGIllama2 70B128/5124090*8,32c29113cpu25%,gpu95%
TGIllama2 70B128/512A800,32c112243cpu25%,gpu95%


从上表能够看出,类似4090等消费级显卡,其主要的瓶颈时batch_size,而影响batch_size的主要因素就是显存容量,而KV cache在推理过程中占用大量的显存。如果能通过KV cache降低显存占用,从一定程度上就能提升消费级显卡的性价比,带来非常高的商业收益。
3、sora/sd3等文生视频或者文生图的模型,纷纷放弃u-net架构,转而支持DIF(diffusion transformer)架构。对此类AIGC模型而言, KV cache同样能起到类似LLM上的加速效果。
根据资料,Sora类训练任务的特点是模型本体不大(10B以下),但是由于视频复杂性带来的序列长度特别长(接近1000kpatches的长度),可以对模型推理进行简易测算:

  • 按照batch size = 1 进行测算,kv cache和模型权重对显卡占比能达到10:1(例如4090的24G显存,2G分给模型,22G分给kv cache)左右,这个场景的显存分配占比与LLM差异性还是非常的大。
  • 按照batch size = 4 进行测算,kv cache和模型权重对显卡占比能达到40:1(batch size越大,kv cache的显存越大)

由此可见,KV cache会成为视频生成领域的一个重要瓶颈,但不排除有替代kv cache的加速方案。


KV cache的作用


解释kv cache之前,先看一组对话:
paki:What is the apples?
llama:Apples are a boring fruit.
上述对话中,paki是自然人,llama是模型。如果对上述对话进行分析,实际上需要将llama的推理步骤分成两个阶段,即prefill和decode。
prefill阶段:输入为Q,即‘What is the apples?’,返回了第一个token,即‘Apples’,同时初始化了kv cache。
decode阶段:输入为单个词或者说q,通过自回归的方式,生成‘Apples are a boring fruit’这个句子。需要注意的是,decode计算的过程中,q的长度为1,即当前词,返回下一个词,例如通过‘Apples’生成‘are’,同时更新kv cache。
暂时无法在飞书文档外展示此内容


如上图所示,kv cache是attention计算中的全量kv 缓存,主要作用在decode阶段,目的是将输入Q优化成输入q。
我们举例说明,假设通过decode阶段通过自回归生成‘Apples are a boring fruit’这句话,当生成到‘fruit’这个词的时候,如果没有KV cache,输入为Q(‘Apples are a boring’),进行attention计算。反之,如果有了kv cache之后,输入只需要q(‘boring’这个词),即可完成attention计算。
为什么会这样,主要跟下一个token的生成给当前token的q和全量KV有关,具体attention的计算公式不再粘贴。
从这里也能看出,为什么 KV cache那么吃显存,其实主要因为随着seq长度变长和batch size增大,KV cache需要存储历史全量KV,从而跟着增大。
那就有这么一个思路,KV cacha本质上是attention计算中的一部分,如果对其进行压缩或者优化,是不是能起到推理的加速效果?
答案是肯定的,下面介绍一下主要的优化方法。


基于KV cache的加速策略


从整体上来讲,KV cache主要分成5个方向的优化,即Sparse、Quantization、Allocator、Window、share,我们逐个对5个方向的最新技术,做一些探讨。


Window--窗口


多轮对话场景的 LLMs 有两个难点:1. 解码阶段缓存 KV 需要耗费大量的内存;2. 流行的 LLMs 不能拓展到训练长度之外。
基于window方向的技术,解决上述问题主要有StreamingLLM[8]和LM-Infinite[14]两种方案,我们基于StreamingLLM进行介绍。
首先,自回归LLMs的一个有趣现象:无论它们与语言建模任务的相关性如何,初始tokens都被分配了惊人的大量注意力得分,并将tokens称为“attention sinks”,参考下图:

Visualization of the average attention logits in Llama-2-7B over 256 sentences, each with a length of 16.


根据attention sinks特性,我们参考之前多轮对话场景的解决方法,给出StreamingLLM的解决方案,即只保留attention sink tokens的KV(只需4个初始tokens)以及滑动窗口的KV,以锚定注意力计算并稳定模型性能的方法。

Illustration of StreamingLLM vs. existing methods. The language model, pre-trained on texts of length L, predicts the Tth token (T ≫ L).

上图的四个模块,分别对应下面的描述:

(a) 密集注意力(Dense Attention)具有 $$O(T^2$$ 的时间复杂度和不断增加的缓存大小。当文本长度超过预训练文本长度时,其性能会下降。
(b) 窗口注意力(Window Attention)缓存最近的 $$$$ 个tokens的$KV$。虽然在推理中效率高,但一旦开始tokens的键和值被删除,性能就会急剧下降。 这种方法相当于只在最近的tokens的KV状态上维护一个固定大小的滑动窗口,太过简单粗暴,没有实用价值。
(c) 带重计算的滑动窗口(Sliding Window with Re-computation)为每个新token重建来自 $L$个最近tokens的$KV$状态。虽然它在长文本上表现良好,但由于在上下文重计算中的二次注意力,其的$O(TL^2 ) $ 复杂度使其相当缓慢,使得这种方法不适用于实际的流式应用。
(d) StreamingLLM 保留了用于稳定注意力计算的attention sink(几个初始tokens),并结合了最近的tokens。它高效并且在扩展文本上提供稳定的性能。
粘贴一段StreamingLLM的代码,说明StreamingLLM的运行机制

def streaming_inference(model, tokenizer, prompts, kv_cache=None, max_gen_len=1000):past_key_values = Nonefor idx, prompt in enumerate(prompts):prompt = "USER: " + prompt + "\n\nASSISTANT: "print("\n" + prompt, end="")input_ids = tokenizer(prompt, return_tensors="pt").input_idsinput_ids = input_ids.to(model.device)seq_len = input_ids.shape[1]if kv_cache is not None:space_needed = seq_len + max_gen_len#通过每个prompt进行kv cache清理,并不是decode过程中进行kv cache清理。past_key_values = kv_cache.evict_for_space(past_key_values, space_needed)#通过prefill和decode进行模型推理past_key_values = greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len)


通过上述代码逻辑,我们可以得到一个结论:StreamingLLM通过添加attention sinks和最近的tokens,可以将LLMs应用于远远超出预训练窗口大小的文本,甚至可能是无限长度的文本,但同时具有较强的场景局限性。


Sparse--稀疏化


通过对StreamingLLM的分析,其实可以发现基于window的方法,可以看作是一种相对粗糙的方法,那有没有更优的方法?答案就是Sparse,通过稀疏对KV cache压缩是一种相对模型性能更高的方法。
目前基于Sparse对kv cache压缩的方法在学术上有较多的进展,但在主流的推理框架上,例如TGI/vLLM等框架,尚未得到有效支持。主要介绍H2O、SubGen、LESS三个项目,来对KV cache稀疏化方案做个总体介绍。


H2O
简要介绍:基于attention的观察,即在计算attention分数时,一小部分token贡献了大部分价值。我们将这些token称为重击者(H2 )。通过全面的调查,我们发现H2的出现是自然的,并且与文本中标记的频繁共现密切相关,( ii )删除它们会导致性能显着下降。基于这些见解,我们提出了 Heavy Hitter Oracle (H 2 O),这是一种 KV 缓存驱逐策略(贪婪算法),可动态保持最近token和 H2 token的平衡。我们使用 OPT、LLaMA 和 GPT-NeoX 在各种任务中验证了算法的准确性。我们采用20%重量级的 H2O实施方案,在 OPT-6.7B 和 OPT- 上,与三个领先的推理系统 DeepSpeed Zero-Inference、Hugging Face Accelerate 和 FlexGen 相比,吞吐量提高了高达29 倍、 29 倍和3 倍。
H2O对比StreamingLLM而言,从一种的window方法演进成一种基于H2和最新token的驱逐策略,模型性能更加优秀,具体参考如下数据:

Comparison results of StreamLLM [52] and our H2O on generization tasks. The number in each method represents the KV Cache budget of the start/heavy-hitter tokens and the local tokens, respectively. For example, H2O-256-256 means maintaining 256 Heavy-Hitters and 256 local tokens.

通过H2O的KV cache逐出策略,能够实现20%缓存接近于全量KV cache的效果,这对KV cache的意义比较有意义,能极大的提升推理并发并同时降低延迟,参考下表:

Quantatively comparison between H2O with Full methods of different number of shots

Generation throughput and latency on an A100 GPU. In the sequence length row, we use “7000 + 1024” to denote a prompt length of 7000 and a generation length of 1024. “OOM” means out-of-memory

综合来看,H2O是一种相对高效的KV cache压缩框架,值得进行深入研究。


SubGen
SubGen是一种为 KV cache开发的高效压缩技术。经验证据表明,attention模块中的关键嵌入存在显著的聚类趋势。基于这一关键见解,设计了一种具有亚线性复杂度的新颖缓存方法,对关键标记采用在线聚类并对值进行在线ℓ (2)采样。
如下图所示,使用 MT Bench 数据集从 Llama-2-7B收集key和value嵌入,同时模型生成 1024 个标记的序列。然后,使用 t-SNE [ 24 ] 跨各个层和头可视化嵌入,通过贪婪 k 中心算法 识别聚类中心点,根据观察结果表明,与所有随机选择的层和头的value嵌入相比,key嵌入(第一行)表现出更高程度的可聚类性。

A t-SNE plot of cached keys (first row) and values (second row) embeddings over 1024 timesteps from Llama2-7B using MT Bench dataset

如下表所示,基于聚类的方法在所有序列长度上始终优于其他算法。例如,仅利用一半长度为 9k token 的缓存 KV 嵌入,就实现了 44% 的准确率,而 H2O 和 Sink 的准确率都低了 10%。这一发现表明,与注意力分数和位置信息相比,维护嵌入信息对于维持法学硕士的表现具有更大的意义。

Results on accuracy of line retrieval from LongEval [13] dataset with context length 5k-9k. Under the sublinear budgets on cache size, the proposed algorithm based on k-center algorithm outperforms other methods over all sequence lengths.

同时,根据SubGen的实验数据,可以观测到随着seq的长度变长,基于稀疏的Sink、H2O、SUBGEN等方法,都呈现出效果逐渐下降的趋势,这主要因为对于稀疏缓存策略,随着序列长度的增加,会保存一小部分 KV 对,因此会省略更多信息


LESS
LESS(Low-rank Embedding Sidekick with Sparse policy)来学习原始注意力输出和稀疏策略近似的注意力输出之间的残差,通过将稀疏策略丢弃的信息累积到恒定大小的低等级缓存或状态中来实现此目的,从而允许查询仍然访问信息以恢复注意力图中先前省略的区域。

Toy (top row) and Llama 2 7B (bottom row) example decoder attention maps with H2O as the underlying sparse policy, Sparse attention policies zero out many positive attention probabilities. Our method, LESS, ensures all previous tokens will have some contribution to the attention

LESS这种方法所具有的优点:

  • 性能改进:LESS 将稀疏 KV 策略与低秩状态综合起来,以弥补这些稀疏算法表现出弱点的各种任务的性能差距。事实上,LESS 比简单地将内存用于存储更多 KV 更能提高性能。
  • 恒定的低等级缓存大小:LESS 中的低等级缓存占用相对于序列长度恒定的内存。
  • 廉价的集成:对LLM架构的更改很小,并且不会扰乱原始权重。对 LLM 的唯一修改是在每个注意力层添加微小的多层感知(MLP)。例如,将 LESS 与 Llama 2 13B 结合使用,添加的参数总数不到 2%。

参考下图,Baseline+ (H2O参考)和 LESS (1% H (2 )O) 对于较短的序列似乎表现相似,但对于较长的序列则不同,LESS接近Full的水准。究其原因,低秩状态从一定程度上能解决部分重要attention被忽略的问题。

Relationship between Rouge-1 score and prompt length for Llama 2 7B with different cachemethods on CNN/DailyMail (left) and XSum (right).

Quantization--量化


当前主流推理框架都在逐步支持 KV cache 量化,例如exllama支持8位kv量化,LMDeploy支持到4位的kv量化。我们提一下刚出的WKVQuan,即一种低于4位的PTQ 框架。
WKVQuant,是一个专门为量化权重和 LLM 的键/值 (KV) 缓存而设计的 PTQ 框架。具体来说,我们结合了仅过去的量化来改进注意力的计算。此外,我们引入了二维量化策略来处理 KV 缓存的分布,以及用于参数优化的跨块重建正则化。实验表明WKVQuant实现了几乎与权重激活量化相当的内存节省,同时也接近仅权重量化的性能。

Longtext scores. Results of LLaMA-2-13B and LLaMA-13B can be found in A.5.

上图表明,WKVQuant的Weight-KV缓存方法具有卓越的量化准确性。


Allocator--显存分配


KV cache的内存分配策略,从一定程度上会影响模型推理的性能。
Page Attention 采用的是另外一种显存管理方式。允许生成过程中不断为用户追加显存。类似于操作系统中的页式存储或者内存分页。当一个请求来到之后,系统会为这个请求分配一小块显存,这一小块显存通常只够生成 8 个字符,当请求生成了 8 个字符之后,系统会追加一块显存,可以把结果再写到这块显存里面,同时系统会维护一个显存块和显存块之间的链表,从而使得算子可以正常地进行输出。当生成的长度不断变长时,会不断地给用户追加显存块的分配,并且可以动态维护显存块分配的列表,使系统不会存在大量浪费的资源,不需要为这个请求保留太多的显存空间。
除了Page Attention之外,PPL.LLM的Virtual Memory 和S3[7]均提出了一种预测器机制,即为每一个请求预测一个它所需的生成长度。每个请求进来之后,都会直接为其分配一个连续的空间,这个连续空间的长度是预测出来的。但理论上看可能难以实现,尤其到了线上推理阶段,不太可能清楚地知道每个请求究竟要生成多长的内容。因此我们推荐训练一个模型去做这件事情。因为即使我们采用了 Page Attention 这样的模式,依然会遇到问题。PageAttention 在运行的过程中,具体到一个特定的时间点,比如当前系统上已经有了四个请求,系统里面还剩余 6 块显存没有被分配。这时我们无法知道是否会有新的请求进来,能否为其继续提供服务,因为当前的四个请求还没有结束,可能未来还要继续为它们追加新的显存块。所以即使是 Page Attention 机制,还是需要预测每一个用户实际的生成长度。这样才知道在具体的一个时间点上能不能接受一个新的用户的输入。

Average batch size and number of iterations for different models

如上图所示,S3 生成类似的理想情况下的序列数量比 ORCA (一种高吞吐量的 Transformer 服务系统)多 1.13 倍至 6.49 倍。


Share--KV cache共享


MQA/GQA同属于基于attention变体的KV cache共享方法,相当于不同的注意力头或者同一组的注意力头共享一个K和V的集合,因为只单独保留了一份查询参数。因此K和V的矩阵仅有一份,这大幅度减少了显存占用,使其更高效。除了共享attention这种思路之外,还有其它的共享思路,我们当中重点说一下FlashInfer和Hydragen
FlashInfer(Cascade Inference)提出Recursive Attention的方法,也就是采用multi-query attention计算共享部分的doc kv,采用single-query attention计算分开部分的kv1和kv2。简单来说,我们有三个子任务,分别计算Attn([q1, q2], doc kv),Attn(q1, kv1),和Attn(q2, kv2)。最后,采取类似于FlashDecoding的方式将中间结果进行合并。

Workflow of Cascade Inference


Hydragen 分别计算对共享前缀和唯一后缀的注意力。这种分解通过跨序列批量查询来实现高效的前缀注意力,减少冗余内存读取并支持使用硬件友好的矩阵乘法。与竞争基线相比,我们的方法可以将端到端 LLM 吞吐量提高高达 32 倍,并且加速随着批量大小和共享前缀长度而增长。Hydragen 还支持使用非常长的共享上下文:在批量大小较高的情况下,将前缀长度从 1K 令牌增加到 16K 令牌会使 Hydragen 吞吐量降低不到 15%,而基线吞吐量则下降超过 90%。

Hydragen workflow

Hydragen对应的结构描述:

  • 左:LLM 推理场景示例,其中聊天机器人模型处理许多共享大共享前缀(系统提示)。
  • 中:Hydragen 的概述,其中整体注意力被分解为共享前缀上的注意力(批量批次中的所有查询)并注意其余后缀(独立于序列,如通常所做的那样)。
  • 右上:Hydragen 的注意力分解允许用更少的矩阵-矩阵乘积替换许多矩阵向量乘积。
  • 右下:使用矩阵-矩阵乘积尤其重要,因为 GPU 将其总 FLOP 中越来越大的比例分配给张量核心专门用于矩阵乘法。
     

总结


KV cache对应的优化方法,总结成下表:

由上表可以看出,KV cache是个值得投入精力去研究的一个重要方向,算法上有着许多未知的方法可以去探索,工程上相对滞后,至少在主流推理框架上对部分方向的优化策略相对保守,这就给了足够多的机会。


参考资料



[1]LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale
[2]LLM-FP4: 4-Bit Floating-Point Quantized Transformers
[3]GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers
[4]https://k.sina.com.cn/article_2674405451_9f68304b019014oo6.html
[5]Fast Distributed Inference Serving for Large Language Models
[6]H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models
[7]S3: Increasing GPU Utilization during GenerativeInference for Higher Throughput
[8]EFFICIENT STREAMING LANGUAGE MODELS WITH ATTENTION SINKS
[9]https://flashinfer.ai/2024/02/02/cascade-inference.html
[10]Flash-Decoding for long-context inference
[11]https://github.com/openppl-public/ppl.llm.kernel.cuda/blame/master/src/ppl/kernel/llm/cuda/pmx/multi_head_cache_attention.cu
[12]FlashDecoding++: Faster Large Language Model Inference on GPUs
[13]FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
[14]LM-Infinite: Simple On-the-Fly Length Generalization for Large Language Models
[15]WKVQuant: Quantizing Weight and Key/Value Cache for Large Language Models Gains More
[16]Get More with LESS: Synthesizing Recurrence with KV Cache Compression for Efficient LLM Inference
[17]SubGen: Token Generation in Sublinear Time and Memory
[18]Hydragen: High-Throughput LLM Inference with Shared Prefixes
[19]LoMA: Lossless Compressed Memory Attention
[20]CacheGen: Fast Context Loading for Language Model Applications
[21]Orca: A distributed serving system for Transformer-Based generative models

发布于 2024-03-07 22:06・IP 属地日本

​赞同 12​​添加评论

​分享

​收藏​喜欢

收起​

更多回答

Young

Young

HPC高性能计算

​ 关注

王焱、刘聪NLP 等 490 人赞同了该回答

0. 引言

做大模型性能优化的一定对KV Cache不陌生,那么我们对这个技术了解到什么程度呢?请尝试回答如下问题:

  1. KV Cache节省了Self-Attention层中哪部分的计算?
  2. KV Cache对MLP层的计算量有影响吗?
  3. KV Cache对block间的数据传输量有影响吗?

本文打算剖析该技术并给出上面问题的答案。

1. KV Cache是啥?

大模型推理性能优化的一个常用技术是KV Cache,该技术可以在不影响任何计算精度的前提下,通过空间换时间思想,提高推理性能。网上有一些关于该技术的分析博客,但读过后仍然会很迷糊,甚至可能会被带偏,认为这个Cache过程和数据库读取或CPU Cache加速类似的荒谬结论。刚开始我也有类似误解,直到逐行查阅并运行源码,才清楚了解到其Cache了啥,以及如何节省计算的。

2. 背景

生成式generative模型的推理过程很有特点,我们给一个输入文本,模型会输出一个回答(长度为N),其实该过程中执行了N次推理过程。即GPT类模型一次推理只输出一个token,输出token会与输入tokens 拼接在一起,然后作为下一次推理的输入,这样不断反复直到遇到终止符。

如上描述是我们通常认知的GPT推理过程。代码描述如下:

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizermodel = GPT2LMHeadModel.from_pretrained("/WORK/Test/gpt", torchscript=True).eval()# tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("/WORK/Test/gpt")
in_text = "Lionel Messi is a"
in_tokens = torch.tensor(tokenizer.encode(in_text))# inference
token_eos = torch.tensor([198]) # line break symbol
out_token = None
i = 0
with torch.no_grad():while out_token != token_eos:logits, _ = model(in_tokens)out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)in_tokens = torch.cat((in_tokens, out_token), 0)text = tokenizer.decode(in_tokens)print(f'step {i} input: {text}', flush=True)i += 1out_text = tokenizer.decode(in_tokens)
print(f' Input: {in_text}')
print(f'Output: {out_text}')

输出:

step 0 input: Lionel Messi is a player
step 1 input: Lionel Messi is a player who
step 2 input: Lionel Messi is a player who has
step 3 input: Lionel Messi is a player who has been
step 4 input: Lionel Messi is a player who has been a
step 5 input: Lionel Messi is a player who has been a key
step 6 input: Lionel Messi is a player who has been a key part
step 7 input: Lionel Messi is a player who has been a key part of
step 8 input: Lionel Messi is a player who has been a key part of the
step 9 input: Lionel Messi is a player who has been a key part of the team
step 10 input: Lionel Messi is a player who has been a key part of the team's
step 11 input: Lionel Messi is a player who has been a key part of the team's success
step 12 input: Lionel Messi is a player who has been a key part of the team's success.
step 13 input: Lionel Messi is a player who has been a key part of the team's success.Input: Lionel Messi is a
Output: Lionel Messi is a player who has been a key part of the team's success.

可以看出如上计算的问题吗?每次推理过程的输入tokens都变长了,导致推理FLOPs随之增大。有方法实现推理过程的FLOPs基本恒定不变或变小吗?(埋个伏笔,注意是基本恒定)。

3. 原理

在上面的推理过程中,每 step 内,输入一个 token序列,经过Embedding层将输入token序列变为一个三维张量[b, s, h],经过一通计算,最后经logits层将计算结果映射至词表空间,输出张量维度为[b, s, vocab_size]。

当前轮输出token与输入tokens拼接,并作为下一轮的输入tokens,反复多次。可以看出第�+1 轮输入数据只比第�轮输入数据新增了一个token,其他全部相同!因此第�+1轮推理时必然包含了第 � 轮的部分计算。KV Cache的出发点就在这里,缓存当前轮可重复利用的计算结果,下一轮计算时直接读取缓存结果,就是这么简单,不存在什么Cache miss问题。

4. 实现细节

目前各大模型推理都实现了KV Cache,下面就看如何使用了。我们可以在上面代码基础上修改,主要改动:

  • 在推理时新增了 past_key_values 参数,该参数就会以追加方式保存每一轮的K V值。kvcache变量内容为((k,v), (k,v), ..., (k,v)),即有 ������� 个 k,v 组成的一个元组,其中 k 和 v 的维度均为 [b, n_head, s, head_dims]。这里可以顺带计算出每轮推理对应的 cache 数据量为 2∗�∗�∗ℎ∗������� ,这里 � 值等于当前轮次值。以GPT3-175B为例,假设以 float16 来保存 KV cache,senquence长度为100,batchsize=1,则 KV cache占用显存为 2×100×12288×96×2 Byte= 472MB。
  • 推理输出的token直接作为下一轮的输入,不再拼接,因为上文信息已经在 kvcache 中。

代码示例:

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizermodel = GPT2LMHeadModel.from_pretrained("/WORK/Test/gpt", torchscript=True).eval()# tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("/WORK/Test/gpt")
in_text = "Lionel Messi is a"
in_tokens = torch.tensor(tokenizer.encode(in_text))# inference
token_eos = torch.tensor([198]) # line break symbol
out_token = None
kvcache = None
out_text = in_text
i = 0
with torch.no_grad():while out_token != token_eos:logits, kvcache = model(in_tokens, past_key_values=kvcache) # 增加了一个 past_key_values 的参数out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)in_tokens = out_token # 输出 token 直接作为下一轮的输入,不再拼接text = tokenizer.decode(in_tokens)print(f'step {i} input: {text}', flush=True)i += 1out_text += textprint(f' Input: {in_text}')
print(f'Output: {out_text}')

通过上面代码只能看到调用层面的变化,实现细节还需看各框架的底层实现,例如Hugging Face的transformers库代码实现就比较清爽,在modeling_gpt2.py中Attention部分相关代码如下:

        query = self._split_heads(query, self.num_heads, self.head_dim)key = self._split_heads(key, self.num_heads, self.head_dim)value = self._split_heads(value, self.num_heads, self.head_dim)if layer_past is not None: # 当输出第一个token后,layer_past就是非None了past_key, past_value = layer_past # 取出之前计算好的 key, valuekey = torch.cat((past_key, key), dim=-2) # past_key 与当前 token 对应的 key 拼接value = torch.cat((past_value, value), dim=-2) # past_value 与当前 token 对应的 value 拼接if use_cache is True:present = (key, value)else:present = None

在 block 层面也有相关代码,大家有空细品吧。还是那句话,说一千道一万不如阅读并运行源码一次。

其实,KV Cache 配置开启后,推理过程可以分为2个阶段:

  1. 预填充阶段:发生在计算第一个输出token过程中,这时Cache是空的,计算时需要为每个 transformer layer 计算并保存key cache和value cache,在输出token时Cache完成填充;FLOPs同KV Cache关闭一致,存在大量gemm操作,推理速度慢。
  2. 使用KV Cache阶段:发生在计算第二个输出token至最后一个token过程中,这时Cache是有值的,每轮推理只需读取Cache,同时将当前轮计算出的新的Key、Value追加写入至Cache;FLOPs降低,gemm变为gemv操作,推理速度相对第一阶段变快,这时属于Memory-bound类型计算。

这里用图可能更有助理解,下图是一个Decoder Block,含有Self-Attention和MLP,标红部分为KV Cache影响到的内容,即KV Cache开启后,标红的序列长度 � 变为 1,当batch_size=1时,Self-Attention中的2个dense全都变为gemv操作,MLP中的dense也全都变为gemv操作。看懂这个图就可以答对上面的3个问题啦。

如下链接也有这方面的定量分析,写的很棒,推荐大家看看。

回旋托马斯x:分析transformer模型的参数量、计算量、中间激活、KV cache

5. 总结

KV Cache是Transformer推理性能优化的一项重要工程化技术,各大推理框架都已实现并将其进行了封装(例如 transformers库 generate 函数已经将其封装,用户不需要手动传入past_key_values)并默认开启(config.json文件中use_cache=True)。本文尝试打开封装分析该技术内部实现,希望对大家有所帮助,文中如有纰漏,欢迎指正。

发布于 2023-05-22 23:37

​赞同 490​​25 条评论

​分享

​收藏​喜欢

收起​

小莲子

小莲子

网易 资深自然语言处理工程师

​ 关注

21 人赞同了该回答

通俗的的大白话来了,Key-Value Cache 的实质是,

把矩阵 Q 退化为当前时间步向量 q,把两个矩阵间的 QK 运算退化为向量和矩阵间的 qK 


为什么要只搞 Key-Value Cache?有没有听过把 Query 缓存起来的?

以单层 transformer 为例。不做 cache 时候,全序列并行计算。在时间步 �=3 时,输入序列�3=[�1,�2,�3],变换成 �3=[�1,�2,�3] 、 �3=[�1,�2,�3] 、 �3=[�1,�2,�3] 。self-attention 之后获得输出 �3=[�1,�2,�3],进而预测出 �4 。

放到 transformers 库的代码里是这样:

input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")for _ in range(4):next_logits = model(input_ids)["logits"][:, -1:]# 取最大概率token为当前时间步输出.next_token_id = torch.argmax(next_logits,dim=-1)# 将历史输出拼接在一起,作为下一步的输入.input_ids = torch.cat([input_ids, next_token_id], dim=-1)

在下一个时间步 �=4 中,输入序列 �4=[�1,�2,�3,�4]=�3+[�4],变换成 �4=�3+[�4] 、 �4=�3+[�4] 、 �4=�3+[�4] 。由于 decoder 结构的 self-attention 是单向的,新增加的 �4 / �4 / �4,无法往前去改变 �1 / �2 / �3 的值,只能影响到 �4。这就导致了 �4=�3+[�4]。

单向 self-attention

到了下一层,把上一层的输出 �4 看作是输入 �4,就又是一样的逻辑。

所有的 transormer 层,新增的输入只有一个 �4,也只需要输出一个新的 �4 即可。输入 �4 经过变换得到 �4 / �4 / �4。由于 �4∼softmax(�4��T)�4,Query 不需要缓存直接取 �4,Key 和 Value 取出缓存 �3 / �3 并补上新位置得到 �4=�3+[�4] 、 �4=�3+[�4] ,送入 self-attention 得到输出 �4。

Key-Value 在 transformers 库里有个开关参数 use_cache=True,用上之后代码如下:

past_key_values = None # 即 key-value cache, 列表结构为 [num_layers, 0 for k, 1 for v, batch_size, length, hidden_dim]
generated_tokens = [] # 记录生成结果.
next_token_id = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")for _ in range(4):# 只传入当前 token 作为输入.next_logits, past_key_values = model(next_token_id, past_key_values=past_key_values, use_cache=True).to_tuple()next_logits = next_logits[:, -1:]# 取最大概率token为当前时间步输出.next_token_id = torch.argmax(next_logits, dim=-1)# 只记录, 不参与后续生成.generated_tokens.append(next_token_id.item())

model() 的输入从矩阵 input_ids 退化为 向量 next_token_id。

Transformer推理性能优化技术很重要的一个就是K V cache,能否通俗分析,可以结合代码? - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/675.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

设计模式:简单工厂模式(Simple Factory)

设计模式:简单工厂模式(Simple Factory) 设计模式:简单工厂模式(Simple Factory)模式动机模式定义模式结构时序图模式实现测试模式分析实例:Qt 控件类优缺点适用环境模式应用 设计模式&#xff…

关基网络战时代,赛宁网安电力网络攻防靶场全面提升电网安全防护力

随着网络空间成为与陆地、海洋、天空、太空同等重要的人类活动新领域,自网络空间向物理电网发起攻击,破坏电力等国家关键基础设施成为当前大国博弈、大规模战争的重要手段和常态进攻形式。同时,新型电力系统建设发展驱动电力系统形态和控制方…

基于Springboot的社区待就业人员信息管理系统(有报告)。Javaee项目,springboot项目。

演示视频: 基于Springboot的社区待就业人员信息管理系统(有报告)。Javaee项目,springboot项目。 项目介绍: 采用M(model)V(view)C(controller)三…

TaskWeaver使用记录

TaskWeaver使用记录 1. 基本介绍2. 总体结构与流程3. 概念细节3.1 Project3.2 Session3.3 Memory3.4 Conversation3.5 Round3.6 Post3.7 Attachment3.8 Plugin3.9 Executor 4. 代码特点5. 使用过程5.1 api调用5.2 本地模型使用5.3 添加插件 6. 存在的问题与使用体验6.1 判别模型…

笔记本电脑坏了硬盘数据会丢失吗 笔记本电脑坏了如何取出硬盘的资料 数据恢复软件

笔记本电脑对我们真的非常重要了,是实现无纸化办公和学习的重要工具,但是如果笔记本电脑坏了我们存储在电脑里的资料该怎么办?笔记本电脑坏了硬盘数据会丢失吗?相信有许多朋友都会有这样的担忧。本文今天就为大家解决笔记本电脑坏…

【银角大王———Django学习DAY0——基础准备】

银角大王——Django学习前情提要 (1)在pycharm中下载Flask(2)使用Flask(3)下载BootStrap框架(4) 使用BootStrap框架 (1)在pycharm中下载Flask 在设置——项目…

【若依】代码生成详细教程(单表、主从表、树形表增删改查)

若依代码生成开发接口 修改代码生成配置一、单表实现增删改查1. 新建数据库表结构2. 新建模块,解决项目依赖3. 启动项目,新建菜单4. 导入数据表,自动生成代码5. 将生成代码粘贴到对应的模块,执行生成的sql(用于生成菜单…

GitHub/R3D3项目环境配置踩坑记录

1、前言 项目链接地址:SysCV/r3d3 (github.com) 按照安装步骤容易出现的问题,environment.yaml文件中安装相关包,其中还有两个pip install githttps://github.com/..........这两个建议注释掉,后面再来安装这两个。 2、问题及解…

【C++题解】1020. 算算和是多少

问题:1020. 算算和是多少 类型:基本运算、拆位求解 题目描述: 输入一个三位正整数,然后与它倒过来的数相加,输出和。 如:输入167 ,则和为167761928。 输入: 只有一行&#xff0c…

全开源小狐狸Ai系统 小狐狸ai付费创作系统 ChatGPT智能机器人2.7.6免授权版

内容目录 一、详细介绍二、效果展示1.部分代码2.效果图展示 三、学习资料下载 一、详细介绍 测试环境:Linux系统CentOS7.6、宝塔、PHP7.4、MySQL5.6,根目录public,伪静态thinkPHP,开启ssl证书 具有文章改写、广告营销文案、编程…

PostgreSql-Install

PostgreSql源码安装 一、源代码下载二、操作系统配置三、编译安装四、启动数据库五、相关命令 PostgreSQL是一个强大的 开源对象关系数据库系统,它使用并扩展了SQL语言,并结合了许多功能,可以安全地存储和扩展最复杂的数据工作负载。 一、源…

gin框架提高篇(四)

参数校验(一) uuid包:https://github.com/satori/go.uuid 因为作者更改了参数限制,导致会出问题 → 问题解决 package mainimport ("fmt""github.com/gin-gonic/gin""github.com/go-playground/validato…

盲人盲杖:科技革新,助力视障人士独立出行

在我们的社会中,盲人朋友们以其坚韧的精神风貌,生动诠释着生活的多样与可能。然而,当我们聚焦于他们的日常出行,那些普通人视为寻常的街道、路口,却成为他们必须面对的严峻挑战。如何切实提升盲人盲杖的功能&#xff0…

【Linux进阶之路】高级IO

一、 铺垫 I,即input为输入;O,即output为输出,IO,即input output为输入输出。IO一般是基于网卡,磁盘,光盘,U盘,磁盘,磁带等毫秒级别的外存,相较…

Python实现贪吃蛇

提供学习或者毕业设计使用,功能基本都有,不能和市场上正式游戏相提比论,请理性对待!通过购买专栏或者CSDN问答提问,采纳后,私信博主。提供源码! 说明:需要的话联系博主!谢谢。 代码: import pygame import random import tkinter as tk from tkinter import mess…

BetterZip 5 for Mac:轻松解压缩的得力助手

BetterZip 5 for Mac是一款专为苹果电脑用户设计的压缩与解压软件,以其强大的功能和便捷的操作赢得了广大用户的喜爱。 BetterZip 5 for Mac v5.3.4中文版下载 这款软件支持多种主流的压缩格式,如ZIP、RAR、7-Zip等,满足了用户多样化的需求。…

WordPress 主题选择与自定义配置

最近我在使用wordpress网站进行建站。 我是使用的hostease的主机产品进行wordpress建站,在选择wordpress主题时颇为头疼。后来咨询了hostease的客服人员,他们家的技术人员提供了诸多帮助。在WordPress网站建设时,主题选择对于建立各类网站至关…

【MIT6.824】lab2C-persistence, lab2D-log compaction 实现笔记

引言 lab2C的实验要求如下 Complete the functions persist() and readPersist() in raft.go by adding code to save and restore persistent state. You will need to encode (or “serialize”) the state as an array of bytes in order to pass it to the Persister. Us…

记录——FPGA的学习路线

文章目录 一、前言二、编程语言2.1 书籍2.2 刷题网站2.3 仿真工具 三、基础知识3.1 专业基础课3.2 fpga相关专业知识 四、开发工具五、动手实验 一、前言 也不是心血来潮想学习fpga了,而是祥哥还有我一个国科大的同学都在往fpga这个方向走 并且看过我之前文章的同…

合并有序表 (顺序存储 和 链式存储 方式实现)

代码详细解析: 合并有序表文章浏览阅读1.4k次,点赞6次,收藏7次。●假设有两个有序表 LA和LB , 将他们合并成一个有序表LC●要求不破坏原有的表 LA和 LB构思:把这两个表, 合成一个有序表 , 不是简简单单吗?就算是把他们先遍历不按顺序插入到表 C里面 , …