论文地址:https://arxiv.org/pdf/2309.17453
github地址:https://github.com/mit-han-lab/streaming-llm
1. 研究背景与挑战
随着大语言模型(LLMs)在对话系统、文档摘要、代码补全和问答等领域的广泛应用,如何高效且准确地处理长序列生成成为亟待解决的问题。然而,现有LLMs在流式应用(如多轮对话)中面临两大主要挑战:
挑战一:解码阶段内存消耗巨大
-
问题描述: Transformer架构的LLMs在解码阶段会缓存所有先前token的键值对(KV),导致内存消耗随序列长度呈二次增长,显著增加了解码延迟。
- 图示说明: 如下图1(a)所示,密集注意力机制(密集注意力)需要存储所有token的KV,导致时间复杂度为 O ( T 2 ) O(T^2) O(T2),且缓存大小随文本长度增加而增加。当文本长度超过预训练长度时,性能会下降。
图1:StreamingLLM与现有方法的对比。
挑战二:模型对长文本的泛化能力有限
- 问题描述: 现有模型在处理超过预训练时设定的注意力窗口长度的序列时,性能会显著下降。
- 图示说明: 如下图1(b)所示,窗口注意力机制(窗口注意力)仅缓存最近一定数量的token的KV,虽然在推理过程中效率较高,但一旦序列长度超过缓存大小,性能会急剧下降。
2. 现有方法的局限性
- 窗口注意力:
- 优点: 内存使用恒定,解码速度稳定。
- 缺点: 一旦序列长度超过缓存大小,性能会崩溃。例如,移除第一个token的KV会导致模型性能大幅下降。
- 滑动窗口重计算:
- 优点: 在长文本上表现良好。
- 缺点: 由于在上下文重计算中需要进行二次注意力计算,其时间复杂度为 O ( T L 2 ) O(TL^2) O(TL2),导致速度非常慢,不适用于实际流式应用。
3. 注意力汇聚点现象的发现与解释
为了解决窗口注意力的局限性,研究人员观察到一个有趣的现象:注意力汇聚点(Attention Sink)。
3.1 注意力汇聚点现象
-
现象描述: 自回归LLMs会分配大量注意力分数给初始token,无论其与语言建模任务的语义相关性如何。
- 图示说明: 如下图2所示,在Llama-2-7B模型中,初始token在所有层和注意力头中均获得了较高的注意力分数。
图2:Llama-2-7B模型中平均注意力logits的可视化。 -
原因分析:
- SoftMax函数的特性: SoftMax函数要求所有上下文token的注意力分数之和为1。因此,即使当前查询在许多先前token中没有强匹配,模型仍然需要将这些不必要的注意力值分配到某些token上。
- 初始token的全局可见性: 由于自回归语言建模的性质,初始token对所有后续token都是可见的,这使得它们更容易被训练成为注意力汇聚点,捕获不必要的注意力。
3.2 注意力汇聚点对模型性能的影响
-
实验验证:
- 将前四个token替换为换行符,模型仍然会显著关注这些初始换行符。
- 重新引入这些初始token后,语言建模的困惑度(perplexity)恢复到与原始初始token相当的水平。
- 结论: 初始token的绝对位置比其语义价值更重要。
-
对窗口注意力的影响:
- 移除初始token的KV会导致SoftMax函数分母发生显著变化,导致注意力分数分布偏离正常推理设置,从而导致模型性能下降。
4. StreamingLLM的提出与创新
基于上述分析,研究人员提出了StreamingLLM,一种高效的框架,使LLMs能够处理无限长度的文本,而无需任何微调。
4.1 StreamingLLM的核心思想
-
利用注意力汇聚点: StreamingLLM利用注意力汇聚点具有高注意力值的特点,通过保留它们,可以将注意力分数分布保持在接近正常的水平。
-
具体方法:
- 保留注意力汇聚点token的KV: 仅保留少量初始token(例如4个)的KV作为注意力汇聚点。
- 结合滑动窗口KV: 将注意力汇聚点token的KV与滑动窗口的KV结合起来,以锚定注意力计算并稳定模型性能。
图4:StreamingLLM的KV缓存。
4.2 StreamingLLM的优势
-
高效性:
- 与滑动窗口重计算基线相比,StreamingLLM实现了高达 22.2 × 22.2\times 22.2×的速度提升。
-
稳定性:
- 如下图3所示,StreamingLLM在处理长达20K tokens的文本时,其困惑度与重计算基线几乎一致,证明了其稳定性能。
图3:StreamingLLM在20K tokens文本上的语言建模困惑度。 -
可扩展性:
- StreamingLLM使包括Llama-2、MosaicML、MPT、Falcon和Pythia在内的模型能够可靠地建模高达400万tokens的文本,甚至更多。
4.3 预训练中加入Sink Token的改进
-
问题: 现有模型通常使用多个初始token作为注意力汇聚点,而不是仅使用一个。
-
解决方案:
- 在所有训练样本的开头添加一个可学习的占位符token(Sink Token)作为专门的注意力汇聚点。
-
实验结果:
- 如下图7所示,添加Sink Token后,模型在所有层和注意力头中均会一致地关注Sink Token,有效收集冗余注意力。
- 如下图6所示,添加Sink Token不会对模型收敛和后续性能产生负面影响。
- 如下图5所示,使用Sink Token训练的模型在流式应用中的困惑度更低,仅需添加Sink Token即可实现稳定的流式性能,而无需其他初始token。
图7:有无Sink Token的模型在平均注意力logits上的可视化。
5. 实验结果与分析
5.1 长文本语言建模
- StreamingLLM在处理长达400万tokens的文本时,其困惑度保持稳定,证明了其在各种LLM家族和规模下的有效性。
5.2 预训练中加入Sink Token的效果
- 预训练中加入Sink Token不会损害模型性能,并且在流式应用中表现更佳。
5.3 流式问答任务
- StreamingLLM在模拟真实世界聊天设置的流式问答任务中表现良好,精度与单样本基线相当,而窗口注意力方法由于输入长度超过缓存大小而导致精度较低。
5.4 消融研究
- 初始token数量: 引入四个初始token作为注意力汇聚点足以,添加更多token对性能提升有限。
- 缓存大小: 增加缓存大小并不总是能降低语言建模困惑度,表明这些模型可能没有充分利用其接收到的上下文。
5.5 效率结果
- StreamingLLM的解码速度随缓存大小线性增长,而滑动窗口重计算基线的解码延迟呈二次增长。StreamingLLM实现了高达 22.2 × 22.2\times 22.2×的加速,同时保持了与重计算基线相似的内存占用。
6. 结论与展望
- 结论: StreamingLLM通过引入注意力汇聚点,使LLMs能够高效处理无限长度的文本,解决了现有方法在流式应用中的局限性。
- 未来方向:
- 进一步研究如何更好地利用缓存中的上下文信息。
- 探索如何将StreamingLLM与上下文扩展技术相结合,以进一步提高模型性能。
7. 总结图示
图5:StreamingLLM在超长文本上的语言建模困惑度。
图10:滑动窗口重计算基线与StreamingLLM的每token解码延迟和内存使用对比。
8. 附加说明
- 应用场景: StreamingLLM特别适用于流式应用,如多轮对话、实时助手等。
- 局限性: StreamingLLM不扩展模型的上下文窗口或增强其长期记忆能力,不适用于需要长期记忆和广泛数据依赖的任务。
- 社会影响: StreamingLLM提高了LLMs的效率,使其在教育、医疗和客户服务等领域更具可访问性,但同时也带来了生成错误信息和有偏见内容的风险,需要谨慎使用。