目录
- 前言
- 1. 从MHA、MQA、GQA到MLA
- 1.1 MHA
- 1.2 瓶颈
- 1.3 MQA
- 1.4 GQA
- 1.5 MLA
- 1.5.1 Part 1
- 1.5.2 Part 2
- 1.5.3 Part 3
- 结语
- 参考
前言
学习 DeepSeek 中的 MLA 模块,究极缝合怪,东抄抄西抄抄,主要 copy 自苏神的文章,仅供自己参考😄
refer1:缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA
refer2: 博客分享:从MHA、MQA、GQA到MLA
1. 从MHA、MQA、GQA到MLA
以下内容均来自于苏神的文章:缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA
1.1 MHA
MHA(Multi-Head Attention),也就是多头注意力,是开山之作 《Attention is all you need》 所提出的一种 Attention 的形式
在数学上,多头注意力 MHA 等价于多个独立的单头注意力的拼接,假设输入的(行)向量序列为 x 1 , x 2 , ⋯ , x l \bm{x}_1,\bm{x}_2,\cdots,\bm{x}_l x1,x2,⋯,xl,其中 x i ∈ R d \bm{x}_i \in \mathbb{R}^d xi∈Rd,那么 MHA 可以形式地记为:
o t = [ o t ( 1 ) , o t ( 2 ) , ⋯ , o t ( h ) ] o t ( s ) = A t t e n t i o n ( q t ( s ) , k ≤ t ( s ) , v ≤ t ( s ) ) ≜ ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) v i ( s ) ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) q i ( s ) = x i W q ( s ) ∈ R d k , W q ( s ) ∈ R d × d k k i ( s ) = x i W k ( s ) ∈ R d k , W k ( s ) ∈ R d × d k v i ( s ) = x i W v ( s ) ∈ R d v , W v ( s ) ∈ R d × d v \bm{o_{t}}=\left[\bm{o_{t}^{(1)}},\bm{o_{t}^{(2)}},\cdots,\bm{o_{t}^{(h)}}\right] \\ \bm{o}_{t}^{(s)}=\bm{Attention}\left(\bm{q}_{t}^{(s)},\bm{k}_{\leq t}^{(s)}, \bm{v}_{\leq t}^{(s)}\right)\triangleq\frac{\sum_{i\leq t}\exp\Bigl(\bm{q}_ {t}^{(s)}\bm{k}_{i}^{(s)}{}^{\top}\Bigr)\bm{v}_{i}^{(s)}}{\sum_{i\leq t}\exp \Bigl(\bm{q}_{t}^{(s)}\bm{k}_{i}^{(s)}{}^{\top}\Bigr)} \\ \begin{array}{l} \bm{q_{i}^{(s)}=x_{i}W_{q}^{(s)}\in\mathbb{R}^{d_{k}}, \quad W_{q}^{(s)}\in\mathbb{R}^{d\times d_{k}}}\\ \bm{k_{i}^{(s)}=x_{i}W_{k}^{(s)}\in\mathbb{R}^{d_{k}},\quad W_{k}^{(s)}\in \mathbb{R}^{d\times d_{k}}}\\ \bm{v_{i}^{(s)}=x_{i}W_{v}^{(s)}\in\mathbb{R}^{d_{v}},\quad W_{v}^{(s)}\in \mathbb{R}^{d\times d_{v}}} \end{array} ot=[ot(1),ot(2),⋯,ot(h)]ot(s)=Attention(qt(s),k≤t(s),v≤t(s))≜∑i≤texp(qt(s)ki(s)⊤)∑i≤texp(qt(s)ki(s)⊤)vi(s)qi(s)=xiWq(s)∈Rdk,Wq(s)∈Rd×dkki(s)=xiWk(s)∈Rdk,Wk(s)∈Rd×dkvi(s)=xiWv(s)∈Rdv,Wv(s)∈Rd×dv
简单起见,这里省略了 Attention 矩阵的缩放因子 1 d k \frac{1}{\sqrt{d_k}} dk1
实践上,常见的设置是 d k = d v = d / h d_k=d_v=d/h dk=dv=d/h,例如对于 LLaMA2-7B 有 d = 4096 , h = 32 , d k = d v = 128 {d=4096,h=32,d_{k}=d_{v}=128} d=4096,h=32,dk=dv=128,LLaMa2-70B 则是 d = 8192 , h = 64 , d k = d v = 128 {d=8192,h=64,d_{k}=d_{v}=128} d=8192,h=64,dk=dv=128
这里只考虑主流自回归 LLM 所用的 Causal Attention,在 token by token 递归生成时,新预测出来的第 t + 1 t+1 t+1 个 token,并不会影响到已经算好的 k ≤ t ( s ) , v ≤ t ( s ) {k_{\leq t}^{(s)},v_{\leq t}^{(s)}} k≤t(s),v≤t(s),因此这部分结果我们可以缓存下来,供后续生成调用,避免不必要的重复计算,这就是所谓的 KV Cache
关于 KV Cache 大家感兴趣的可以看看:KV Cache的原理与实现
后面的 MQA、GQA、MLA 都是围绕“如何减少 KV Cache 同时尽可能地保证效果”这个主题发展而来的产物
上图展示了标准 MHA 下的 KV Cache 是多大,它和注意力头数、序列长度等相关,此时 KV Cache 的大小是 2 ∗ s e q _ l e n ∗ n u m _ h e a d ∗ h e a d _ d i m 2*seq\_len * num\_head * head\_dim 2∗seq_len∗num_head∗head_dim
Note:该图片来自于 https://github.com/preacher-1/MLA_tutorial
代码实现如下:
import math
import torch
import torch.nn as nn# Multi-Head Attention
class MHA(nn.Module):def __init__(self, d_model, num_heads):super().__init__()self.d_model = d_modelself.num_heads = num_headsassert d_model % num_heads == 0self.head_dim = d_model // num_headsself.q_linear = nn.ModuleList([nn.Linear(d_model, self.head_dim, bias=False) for _ in range(num_heads)])self.k_linear = nn.ModuleList([nn.Linear(d_model, self.head_dim, bias=False) for _ in range(num_heads)])self.v_linear = nn.ModuleList([nn.Linear(d_model, self.head_dim, bias=False) for _ in range(num_heads)])self.out_linear = nn.Linear(d_model, d_model, bias=False) def forward(self, x):bsz, seq_len, _ = x.shapeoutputs = []# Parallelfor i in range(self.num_heads):q = self.q_linear[i](x) # (bsz, seq_len, head_dim)k = self.k_linear[i](x) # (bsz, seq_len, head_dim)v = self.v_linear[i](x) # (bsz, seq_len, head_dim)# RoPE# TODO: Implement RoPE# Attentionattention = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (bsz, seq_len, seq_len)# Casual maskmask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()mask = mask.unsqueeze(0).to(x.device) # (1, seq_len, seq_len)attention = attention.masked_fill(mask, float('-inf'))attention = torch.softmax(attention, dim=-1)# Outputoutput = torch.matmul(attention, v) # (bsz, seq_len, seq_len)outputs.append(output)# Linear projectionoutput = torch.cat(outputs, dim=-1) # (bsz, seq_len, d_model)output = self.out_linear(output)return output# Another implement for Multi-Head Attention
class MHA2(nn.Module):def __init__(self, d_model, num_heads):super().__init__()self.d_model = d_modelself.num_heads = num_headsassert d_model % num_heads == 0self.head_dim = d_model // num_headsself.q_linear = nn.Linear(d_model, num_heads * self.head_dim, bias=False)self.k_linear = nn.Linear(d_model, num_heads * self.head_dim, bias=False)self.v_linear = nn.Linear(d_model, num_heads * self.head_dim, bias=False)self.out_linear = nn.Linear(d_model, d_model, bias=False)def forward(self, x):bsz, seq_len, _ = x.shapeq = self.q_linear(x) # (bsz, seq_len, num_heads * head_dim)k = self.k_linear(x) # (bsz, seq_len, num_heads * head_dim)v = self.v_linear(x) # (bsz, seq_len, num_heads * head_dim)# matmul 只能在最后两个维度相乘, 需要对 NxD 的矩阵相乘, 做 1,2 维度的交换# (bsz, num_heads, seq_len, head_dim)q = q.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)k = k.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)v = v.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)# RoPE# TODO: Implement RoPE# Attentionattention = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (bsz, num_heads, seq_len, seq_len)# Casual maskmask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()mask = mask.unsqueeze(0).to(x.device) # (1, seq_len, seq_len)attention = attention.masked_fill(mask, float('-inf'))attention = torch.softmax(attention, dim=-1)# Outputoutput = torch.matmul(attention, v) # (bsz, num_heads, seq_len, head_dim)output = output.transpose(1, 2) # (bsz, seq_len, num_heads, head_dim) output = output.contiguous().view(bsz, seq_len, -1) # (bsz, seq_len, d_model)output = self.out_linear(output)return output# Example usage
torch.manual_seed(10)
d_model = 512
num_heads = 8
mha = MHA2(d_model, num_heads)
x = torch.randn(10, 20, d_model) # (bsz, seq_len, d_mdeol)
output = mha(x)
print(output.shape) # (10, 20, 512)
Note:代码参考自:https://github.com/preacher-1/MLA_tutorial
此外 MHA 还有另外一种实现,一次性将 Q , K , V Q,K,V Q,K,V 投影,具体代码可以参考 https://github.com/karpathy/minGPT/tree/master/mingpt
1.2 瓶颈
为什么降低 KV Cache 的大小如此重要呢?🤔
众所周知,一般情况下 LLM 的推理都是在 GPU 上进行的,而单张 GPU 的显存是有限的,一部分我们要用来存放模型的参数和前向计算的激活值,这部分依赖于模型的体量,选定模型后它就是个常数;另外一部分我们要用来存放模型的 KV Cache,这部分不仅依赖于模型的体量,还依赖于模型的输入长度,也就是在推理过程中是动态增长的,当 Context 长度足够长时,它的大小就会占主导地位,可能超过一张卡甚至一台机(8张卡)的总显存量
在 GPU 上部署模型的原则是:能一张卡部署的,就不要跨多张卡;能一台机部署的,就不要跨多台机。这是因为“卡内通信带宽 > 卡间通信带宽 > 机间通信带宽”,由于“木桶效应”,模型部署时跨的设备越多,受设备间通信带宽的的“拖累”就越大,事实上即便是单卡 H100 内 SRAM 与 HBM 的带宽已经达到了 3TB/s,但对于 Short Context 来说这个速度依然还是推理的瓶颈,更不用说更慢的卡间、机间通信了
所以,减少 KV Cache 的目的就是要实现在更少的设备上推理更长的 Context,或者在相同的 Context 长度下让推理的 batch size 更大,从而实现更快的推理速度或者更大的吞吐量。当然,最终的目的都是为了实现更低的推理成本
1.3 MQA
MQA(Multi-Query Attention)是减少 KV Cache 的一次非常朴素的尝试,首次提出自 《Fast Transformer Decoding: One Write-Head is All You Need》
MQA 的思路很简单,直接让所有 Attention Head 共享同一个 K、V,用公式来说就是取消 MHA 所有的 k , v \bm{k},\bm{v} k,v 的上标 ( s ) ^{(s)} (s)
o t = [ o t ( 1 ) , o t ( 2 ) , ⋯ , o t ( h ) ] o t ( s ) = A t t e n t i o n ( q t ( s ) , k ≤ t ( s ) , v ≤ t ( s ) ) ≜ ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) v i ( s ) ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) q i ( s ) = x i W q ( s ) ∈ R d k , W q ( s ) ∈ R d × d k k i ( s ) = x i W k ( s ) ∈ R d k , W k ( s ) ∈ R d × d k v i ( s ) = x i W v ( s ) ∈ R d v , W v ( s ) ∈ R d × d v \bm{o_{t}}=\left[\bm{o_{t}^{(1)}},\bm{o_{t}^{(2)}},\cdots,\bm{o_{t}^{(h)}}\right] \\ \bm{o}_{t}^{(s)}=\bm{Attention}\left(\bm{q}_{t}^{(s)},\bm{k}_{\leq t}^{\color{red}{\bcancel{(s)}}}, \bm{v}_{\leq t}^{\color{red}{\bcancel{(s)}}}\right)\triangleq\frac{\sum_{i\leq t}\exp\Bigl(\bm{q}_ {t}^{(s)}\bm{k}_{i}^{\color{red}{\bcancel{(s)}}}{}^{\top}\Bigr)\bm{v}_{i}^{\color{red}{\bcancel{(s)}}}}{\sum_{i\leq t}\exp \Bigl(\bm{q}_{t}^{(s)}\bm{k}_{i}^{\color{red}{\bcancel{(s)}}}{}^{\top}\Bigr)} \\ \begin{array}{l} \bm{q_{i}^{(s)}=x_{i}W_{q}^{(s)}\in\mathbb{R}^{d_{k}}, \quad W_{q}^{(s)}\in\mathbb{R}^{d\times d_{k}}}\\ \bm{k_{i}^{\color{red}{\bcancel{(s)}}}=x_{i}W_{k}^{\color{red}{\bcancel{(s)}}}\in\mathbb{R}^{d_{k}},\quad W_{k}^{\color{red}{\bcancel{(s)}}}\in \mathbb{R}^{d\times d_{k}}}\\ \bm{v_{i}^{\color{red}{\bcancel{(s)}}}=x_{i}W_{v}^{\color{red}{\bcancel{(s)}}}\in\mathbb{R}^{d_{v}},\quad W_{v}^{\color{red}{\bcancel{(s)}}}\in \mathbb{R}^{d\times d_{v}}} \end{array} ot=[ot(1),ot(2),⋯,ot(h)]ot(s)=Attention(qt(s),k≤t(s) ,v≤t(s) )≜∑i≤texp(qt(s)ki(s) ⊤)∑i≤texp(qt(s)ki(s) ⊤)vi(s) qi(s)=xiWq(s)∈Rdk,Wq(s)∈Rd×dkki(s) =xiWk(s) ∈Rdk,Wk(s) ∈Rd×dkvi(s) =xiWv(s) ∈Rdv,Wv(s) ∈Rd×dv
使用 MQA 的模型包括 PaLM、StarCoder、Gemini 等。很明显,MQA 直接将 KV Cache 减少到了原来的 1 / h 1/h 1/h,这是非常可观的,单从节省显存角度看已经是天花板了
效果方面,目前看来大部分任务的损失都比较有限,且 MQA 的支持者相信这部分损失可以通过进一步训练来弥补回。此外,注意到 MQA 由于共享了 K、V,将会导致 Attention 的参数量减少了将近一半,而为了模型总参数量的不变,通常会相应地增大 FFN/GLU 的规模,这也能弥补一部分效果损失
Note:该图片来自于 https://github.com/preacher-1/MLA_tutorial
上图展示了标准 MQA 下的 KV Cache 是多大,和标准的 MHA 相比 Q Q Q 保持不变,但所有头的 K , V K,V K,V 共享,此时 KV Cache 的大小是 2 ∗ s e q _ l e n ∗ 1 ∗ h e a d _ d i m 2*seq\_len * 1 * head\_dim 2∗seq_len∗1∗head_dim
代码实现如下:
import math
import torch
import torch.nn as nn# Multi-Query Attention
class MQA(nn.Module):def __init__(self, d_model, num_heads):super().__init__()self.d_model = d_modelself.num_heads = num_headsassert d_model % num_heads == 0self.head_dim = d_model // num_headsself.q_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, self.head_dim)self.v_linear = nn.Linear(d_model, self.head_dim)self.out_linear = nn.Linear(d_model, d_model)def forward(self, x):bsz, _, _ = x.shape# Linear projections, all heads share the same K, Vq = self.q_linear(x) # (bsz, seq_len, d_model)k = self.k_linear(x) # (bsz, seq_len, head_dim)v = self.v_linear(x) # (bsz, seq_len, head_dim)# Reshape for multi-head attentionq = q.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)# (bsz, num_heads, seq_len, head_dim)k = torch.unsqueeze(k, 1) # (bsz, 1, seq_len, head_dim)v = torch.unsqueeze(v, 1) # (bsz, 1, seq_len, head_dim)# Attentionattention = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (bsz, num_heads, seq_len, seq_len)attention = torch.softmax(attention, dim=-1)# Outputoutput = torch.matmul(attention, v) # (bsz, num_heads, seq_len, head_dim)output = output.transpose(1, 2) # (bsz, seq_len, num_heads, head_dim) output = output.contiguous().view(bsz, -1, d_model)# Linear projectionoutput = self.out_linear(output)return output# Example usage
torch.manual_seed(10)
d_model = 512
num_heads = 8
mqa = MQA(d_model, num_heads)
x = torch.randn(10, 20, d_model) # (bsz, seq_len, d_mdeol)
output = mqa(x)
print(output.shape) # (10, 20, 512)
Note:代码参考自:https://github.com/preacher-1/MLA_tutorial
这个代码和 MHA 实现类似,不同的是由于 MQA 所有头共享同一个 K , V K,V K,V,因此这里的 W k , W v W_k,W_v Wk,Wv 投影矩阵的维度是 head_dim 而不再是 num_heads * head_dim,在 forward 时通过广播机制将 W k , W v W_k,W_v Wk,Wv 共享到其他头即可
1.4 GQA
然而,也有人担心 MQA 对 KV Cache 的压缩太严重,以至于会影响模型的学习效率以及最终结果。为此,一个 MHA 和 MQA 之间的过渡版本 GQA(Grouped-Query Attention)应运而生,出自论文 《GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints》,是 23 年的工作
GQA 的思想也很朴素,它就是将所有 Head 分为 g g g 个组( g g g 可以整除 h h h),每组共享同一对 K、V,用数学公式表示为:
o t = [ o t ( 1 ) , o t ( 2 ) , ⋯ , o t ( h ) ] o t ( s ) = A t t e n t i o n ( q t ( s ) , k ≤ t ( ⌈ s g / h ⌉ ) , v ≤ t ( ⌈ s g / h ⌉ ) ) ≜ ∑ i ≤ t exp ( q t ( s ) k i ( ⌈ s g / h ⌉ ) ⊤ ) v i ( ⌈ s g / h ⌉ ) ∑ i ≤ t exp ( q t ( s ) k i ( ⌈ s g / h ⌉ ) ⊤ ) q i ( s ) = x i W q ( s ) ∈ R d k , W q ( s ) ∈ R d × d k k i ( ⌈ s g / h ⌉ ) = x i W k ( ⌈ s g / h ⌉ ) ∈ R d k , W k ( ⌈ s g / h ⌉ ) ∈ R d × d k v i ( ⌈ s g / h ⌉ ) = x i W v ( ⌈ s g / h ⌉ ) ∈ R d v , W v ( ⌈ s g / h ⌉ ) ∈ R d × d v \bm{o_{t}}=\left[\bm{o_{t}^{(1)}},\bm{o_{t}^{(2)}},\cdots,\bm{o_{t}^{(h)}}\right] \\ \bm{o}_{t}^{(s)}=\bm{Attention}\left(\bm{q}_{t}^{(s)},\bm{k}_{\leq t}^{\color{red}{(\lceil sg/h \rceil)}}, \bm{v}_{\leq t}^{\color{red}{(\lceil sg/h \rceil)}}\right)\triangleq\frac{\sum_{i\leq t}\exp\Bigl(\bm{q}_ {t}^{(s)}\bm{k}_{i}^{\color{red}{(\lceil sg/h \rceil)}}{}^{\top}\Bigr)\bm{v}_{i}^{\color{red}{(\lceil sg/h \rceil)}}}{\sum_{i\leq t}\exp \Bigl(\bm{q}_{t}^{(s)}\bm{k}_{i}^{\color{red}{(\lceil sg/h \rceil)}}{}^{\top}\Bigr)} \\ \bm{q_{i}^{(s)}=x_{i}W_{q}^{(s)}\in\mathbb{R}^{d_{k}}, \quad W_{q}^{(s)}\in\mathbb{R}^{d\times d_{k}}}\\ \bm{k_{i}^{\color{red}{(\lceil sg/h \rceil)}}=x_{i}W_{k}^{\color{red}{(\lceil sg/h \rceil)}}\in\mathbb{R}^{d_{k}},\quad W_{k}^{\color{red}{(\lceil sg/h \rceil)}}\in \mathbb{R}^{d\times d_{k}}}\\ \bm{v_{i}^{\color{red}{(\lceil sg/h \rceil)}}=x_{i}W_{v}^{\color{red}{(\lceil sg/h \rceil)}}\in\mathbb{R}^{d_{v}},\quad W_{v}^{\color{red}{(\lceil sg/h \rceil)}}\in \mathbb{R}^{d\times d_{v}}} ot=[ot(1),ot(2),⋯,ot(h)]ot(s)=Attention(qt(s),k≤t(⌈sg/h⌉),v≤t(⌈sg/h⌉))≜∑i≤texp(qt(s)ki(⌈sg/h⌉)⊤)∑i≤texp(qt(s)ki(⌈sg/h⌉)⊤)vi(⌈sg/h⌉)qi(s)=xiWq(s)∈Rdk,Wq(s)∈Rd×dkki(⌈sg/h⌉)=xiWk(⌈sg/h⌉)∈Rdk,Wk(⌈sg/h⌉)∈Rd×dkvi(⌈sg/h⌉)=xiWv(⌈sg/h⌉)∈Rdv,Wv(⌈sg/h⌉)∈Rd×dv
其中 ⌈ ⋅ ⌉ \lceil\cdot\rceil ⌈⋅⌉ 是上取整符号
GQA 提供了从 MHA 到 MQA 的自然过渡,当 g = h g=h g=h 时就是 MHA;当 g = 1 g=1 g=1 时就是 MQA;当 1 < g < h 1<g<h 1<g<h 时,它只将 KV Cache 压缩到 g / h g/h g/h,压缩率不如 MQA,但同时也提供了更大的自由度,效果上更有保证。GQA 最知名的使用者,大概是 Meta 开源的 LLAMA2-70B,以及 LLAMA3 全系列,此外使用 GQA 的模型还有 TigerBot、DeepSeek-V1、StarCoder2、Yi、ChatGLM2、ChatGLM3 等,相比使用 MQA 的模型更多
Note:该图片来自于 https://github.com/preacher-1/MLA_tutorial
上图展示了标准 GQA 下的 KV Cache 是多大,GQA 是 MHA 和 MQA 的一种折中,它将 K , V K,V K,V 分成 group 组,每组共享同一个 K , V K,V K,V,此时 KV Cache 的大小是 2 ∗ s e q _ l e n ∗ n _ g r o u p s ∗ h e a d _ d i m 2*seq\_len * n\_groups * head\_dim 2∗seq_len∗n_groups∗head_dim
代码实现如下:
import math
import torch
import torch.nn as nn# Grouped-Query Attention
class GQA(torch.nn.Module):def __init__(self, d_model, num_heads, num_groups):super().__init__()self.d_model = d_modelself.num_heads = num_headsself.num_groups = num_groupsself.group_heads = num_heads // num_groupsself.head_dim = d_model // num_headsself.W_q = nn.Linear(d_model, d_model)self.W_k = nn.Linear(d_model, self.head_dim * num_groups)self.W_v = nn.Linear(d_model, self.head_dim * num_groups)self.out_linear = nn.Linear(d_model, d_model)def forward(self, x):bsz, seq_len, _ = x.shape# Linear projections, each group share the same K, Vq = self.W_q(x) # (bsz, seq_len, d_model)k = self.W_k(x) # (bsz, seq_len, head_dim * num_groups)v = self.W_v(x) # (bsz, seq_len, head_dim * num_groups)# Reshape for multi-head attention# (bsz, num_groups, gropus_head, seq_len, head_dim)q = q.view(bsz, seq_len, self.num_groups, self.group_heads, self.head_dim).permute(0, 2, 3, 1, 4)k = k.view(bsz, seq_len, self.num_groups, self.head_dim).transpose(1, 2) # (bsz, num_groups, seq_len, head_dim)v = v.view(bsz, seq_len, self.num_groups, self.head_dim).transpose(1, 2) # (bsz, num_groups, seq_len, head_dim)k = torch.unsqueeze(k, 2) # (bsz, num_groups, 1, seq_len, head_dim)v = torch.unsqueeze(v, 2) # (bsz, num_groups, 1, seq_len, head_dim)# Attentionattention = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (bsz, num_gropus, gropus_head, seq_len, seq_len)attention = torch.softmax(attention, dim=-1)# Outputoutput = torch.matmul(attention, v) # (bsz, num_groups, gropus_head, seq_len, head_dim)output = output.permute(0, 3, 1, 2, 4).contiguous().view(bsz, -1, self.d_model)# Linear projectionoutput = self.out_linear(output)return output# Example usage
torch.manual_seed(10)
d_model = 512
num_heads = 8
num_groups = 4
gqa = GQA(d_model, num_heads, num_groups)
x = torch.randn(32, 10, d_model) # (bsz, seq_len, d_mdeol)
output = gqa(x)
print(output.shape) # (32, 10, 512)
Note:代码参考自:https://github.com/preacher-1/MLA_tutorial
1.5 MLA
有了 MHA、MQA、GQA 的铺垫,我们理解起 MLA(Multi-head Latent Attention)就相对容易一些了。DeepSeek-V2 的技术报告里是从低秩投影(类似于 LoRA)的角度引入 MLA 的,但苏神认为低秩投影这个角度并不贴近本质,MLA 的本质是低秩投影之后的工作
1.5.1 Part 1
GQA 在投影之后做了什么呢?首先它将向量对半分两份分别作为 K、V,然后每一份又均分为 g g g 份,每一份复制 h / g h/g h/g 次,以此来“凑”够 h h h 个 Attention Head 所需要的 K、V
我们知道分割、复制都是简单的线性变换,所以 MLA 的第一个想法是将这些简单的线性变换换成一般的线性变换,以增强模型的能力:
o t = [ o t ( 1 ) , o t ( 2 ) , ⋯ , o t ( h ) ] o t ( s ) = A t t e n t i o n ( q t ( s ) , k ≤ t ( s ) , v ≤ t ( s ) ) ≜ ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) v i ( s ) ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) q i ( s ) = x i W q ( s ) ∈ R d k , W q ( s ) ∈ R d × d k k i ( s ) = c i W k ( s ) ∈ R d k , W k ( s ) ∈ R d c × d k v i ( s ) = c i W v ( s ) ∈ R d v , W v ( s ) ∈ R d c × d v c i = x i W c ∈ R d c , W c ∈ R d × d c \bm{o_{t}}=\left[\bm{o_{t}^{(1)}},\bm{o_{t}^{(2)}},\cdots,\bm{o_{t}^{(h)}}\right] \\ \bm{o}_{t}^{(s)}=\bm{Attention}\left(\bm{q}_{t}^{(s)},\bm{k}_{\leq t}^{(s)}, \bm{v}_{\leq t}^{(s)}\right)\triangleq\frac{\sum_{i\leq t}\exp\Bigl(\bm{q}_ {t}^{(s)}\bm{k}_{i}^{(s)}{}^{\top}\Bigr)\bm{v}_{i}^{(s)}}{\sum_{i\leq t}\exp \Bigl(\bm{q}_{t}^{(s)}\bm{k}_{i}^{(s)}{}^{\top}\Bigr)} \\ \begin{array}{l} \bm{q_{i}^{(s)}=x_{i}W_{q}^{(s)}\in\mathbb{R}^{d_{k}}, \quad W_{q}^{(s)}\in\mathbb{R}^{d\times d_{k}}}\\ \bm{k_{i}^{(s)}=c_{i}W_{k}^{(s)}\in\mathbb{R}^{d_{k}},\quad W_{k}^{(s)}\in \mathbb{R}^{d_c\times d_{k}}}\\ \bm{v_{i}^{(s)}=c_{i}W_{v}^{(s)}\in\mathbb{R}^{d_{v}},\quad W_{v}^{(s)}\in \mathbb{R}^{d_c\times d_{v}}} \end{array}\\ \bm{c_{i}=x_{i}W_{c}\in\mathbb{R}^{d_{c}},\quad W_{c}\in \mathbb{R}^{d\times d_{c}}} ot=[ot(1),ot(2),⋯,ot(h)]ot(s)=Attention(qt(s),k≤t(s),v≤t(s))≜∑i≤texp(qt(s)ki(s)⊤)∑i≤texp(qt(s)ki(s)⊤)vi(s)qi(s)=xiWq(s)∈Rdk,Wq(s)∈Rd×dkki(s)=ciWk(s)∈Rdk,Wk(s)∈Rdc×dkvi(s)=ciWv(s)∈Rdv,Wv(s)∈Rdc×dvci=xiWc∈Rdc,Wc∈Rd×dc
这里博主有些困惑,那看了 博客分享:从MHA、MQA、GQA到MLA 视频之后大概理解了 MLA 想要做的事情,前面我们提到 GQA 的实现中有一些分割、复制的变换,MLA 出于增强 GQA 性能的目的,想要将这些简单的线性变换加上一些可学习的参数让其变成一般的线性变换
以下分析内容来自于:博客分享:从MHA、MQA、GQA到MLA
MQA 和 GQA 的“升维”投影矩阵
原始的 GQA 先将输入 x i x_i xi 分别压缩到 g d k gd_k gdk 和 g d v g d_v gdv 维,再复制 g g g 份得到可以直接和 q i q_i qi 相乘的 k i k_i ki 和 v i v_i vi。将 GQA 的投影矩阵记为 W c \boldsymbol{W}_c Wc,那么有:
c i = x i W c ∈ R g ( d k + d v ) , W c ∈ R d × g ( d k + d v ) c i = [ k i ( 1 ) , ⋯ , k i ( g ) , v i ( 1 ) , ⋯ , v i ( g ) ] = [ c k i , c v i ] W s p l i t k = [ I g d k 0 g d k ] W s p l i t v = [ 0 g d v I g d v ] c k i = c i W s p l i t k ∈ R g d k , c v i = c i W s p l i t v ∈ R g d v \boldsymbol{c}_i = \boldsymbol{x}_i \boldsymbol{W}_c \in \mathbb{R}^{g(d_k+d_v)},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times g(d_k+d_v)}\\\boldsymbol{c}_i = [\boldsymbol{k}_i^{(1)}, \cdots, \boldsymbol{k}_i^{(g)}, \boldsymbol{v}_i^{(1)}, \cdots, \boldsymbol{v}_i^{(g)}] = [\boldsymbol{ck}_i, \boldsymbol{cv}_i]\\ \boldsymbol{W}_{split}^k = \begin{bmatrix}\boldsymbol{I}_{gd_k} \\\boldsymbol{0}_{gd_k}\end{bmatrix} \quad\boldsymbol{W}_{split}^v = \begin{bmatrix}\boldsymbol{0}_{gd_v} \\\boldsymbol{I}_{gd_v}\end{bmatrix} \\\boldsymbol{ck}_i = \boldsymbol{c}_i \boldsymbol{W}_{split}^k \in \mathbb{R}^{gd_k},\quad \boldsymbol{cv}_i = \boldsymbol{c}_i \boldsymbol{W}_{split}^v \in \mathbb{R}^{gd_v}\\ ci=xiWc∈Rg(dk+dv),Wc∈Rd×g(dk+dv)ci=[ki(1),⋯,ki(g),vi(1),⋯,vi(g)]=[cki,cvi]Wsplitk=[Igdk0gdk]Wsplitv=[0gdvIgdv]cki=ciWsplitk∈Rgdk,cvi=ciWsplitv∈Rgdv
这里 W s p l i t k , W s p l i t v \boldsymbol{W}_{split}^k,\boldsymbol{W}_{split}^v Wsplitk,Wsplitv 实现了形式上的分割操作,得到 c k i , c v i \boldsymbol{ck}_i, \boldsymbol{cv}_i cki,cvi。下面我们将构造“复制”操作的投影矩阵:
W k ∈ R g d k × h d k = [ I d k I d k ⋯ I d k 0 d k 0 d k ⋯ 0 d k ⋯ 0 d k 0 d k ⋯ 0 d k 0 d k 0 d k ⋯ 0 d k I d k I d k ⋯ I d k ⋯ 0 d k 0 d k ⋯ 0 d k ⋮ ⋮ ⋱ ⋮ ⋮ ⋮ ⋱ ⋮ ⋱ ⋮ ⋮ ⋱ ⋮ 0 d k 0 d k ⋯ 0 d k 0 d k 0 d k ⋯ 0 d k ⋯ I d k I d k ⋯ I d k ] \boldsymbol{W}_k \in \mathbb{R}^{g d_k\times h d_k}= \begin{bmatrix} \boldsymbol{I}_{d_k} & \boldsymbol{I}_{d_k} & \cdots & \boldsymbol{I}_{d_k} & \boldsymbol{0}_{d_k } & \boldsymbol{0}_{d_k } & \cdots & \boldsymbol{0}_{d_k } & \cdots & \boldsymbol{0}_{d_k } & \boldsymbol{0}_{d_k } & \cdots & \boldsymbol{0}_{d_k } \\ \boldsymbol{0}_{d_k } & \boldsymbol{0}_{d_k } & \cdots & \boldsymbol{0}_{d_k } & \boldsymbol{I}_{d_k} & \boldsymbol{I}_{d_k} & \cdots & \boldsymbol{I}_{d_k} & \cdots & \boldsymbol{0}_{d_k } & \boldsymbol{0}_{d_k } & \cdots & \boldsymbol{0}_{d_k } \\ \vdots & \vdots & \ddots & \vdots & \vdots & \vdots & \ddots & \vdots & \ddots & \vdots & \vdots & \ddots & \vdots \\ \boldsymbol{0}_{d_k } & \boldsymbol{0}_{d_k } & \cdots & \boldsymbol{0}_{d_k } & \boldsymbol{0}_{d_k } & \boldsymbol{0}_{d_k } & \cdots & \boldsymbol{0}_{d_k } & \cdots & \boldsymbol{I}_{d_k} & \boldsymbol{I}_{d_k} & \cdots & \boldsymbol{I}_{d_k} \end{bmatrix} Wk∈Rgdk×hdk= Idk0dk⋮0dkIdk0dk⋮0dk⋯⋯⋱⋯Idk0dk⋮0dk0dkIdk⋮0dk0dkIdk⋮0dk⋯⋯⋱⋯0dkIdk⋮0dk⋯⋯⋱⋯0dk0dk⋮Idk0dk0dk⋮Idk⋯⋯⋱⋯0dk0dk⋮Idk
其中每行 I d k \boldsymbol{I}_{d_k} Idk 重复 h / g h/g h/g 遍,代表从 groups 到 heads 的“放缩”倍数或“复制”次数;一共有 g g g 行,对应原来的 d i ( s ) d_i^{(s)} di(s),共 g g g个。 W v ∈ R g d v × h d v \boldsymbol{W}_v \in \mathbb{R}^{g d_v\times h d_v} Wv∈Rgdv×hdv 的形式与 W k \boldsymbol{W}_k Wk 相同,故不赘述。将前者左乘 c k i \boldsymbol{ck}_i cki,则有:c k i W k = [ k i ( 1 ) , ⋯ , k i ( g ) ] ⋅ [ I d k I d k ⋯ I d k 0 d k 0 d k ⋯ 0 d k ⋯ 0 d k 0 d k ⋯ 0 d k 0 d k 0 d k ⋯ 0 d k I d k I d k ⋯ I d k ⋯ 0 d k 0 d k ⋯ 0 d k ⋮ ⋮ ⋱ ⋮ ⋮ ⋮ ⋱ ⋮ ⋱ ⋮ ⋮ ⋱ ⋮ 0 d k 0 d k ⋯ 0 d k 0 d k 0 d k ⋯ 0 d k ⋯ I d k I d k ⋯ I d k ] = [ k i ( 1 ) I d k , k i ( 1 ) I d k , ⋯ , k i ( 1 ) I d k , k i ( 2 ) I d k , k i ( 2 ) I d k , ⋯ , k i ( 2 ) I d k , ⋯ , k i ( g ) I d k , k i ( g ) I d k , ⋯ , k i ( g ) I d k ] = [ k i ( 1 ) , ⋯ , k i ( 1 ) , k i ( 2 ) , ⋯ , k i ( 2 ) , ⋯ , k i ( g ) , ⋯ , k i ( g ) ] ∈ R h d k \begin{aligned}\boldsymbol{ck}_i \boldsymbol{W}_k &= [\boldsymbol{k}_i^{(1)}, \cdots, \boldsymbol{k}_i^{(g)}] \cdot \begin{bmatrix} \boldsymbol{I}_{d_k} & \boldsymbol{I}_{d_k} & \cdots & \boldsymbol{I}_{d_k} & \boldsymbol{0}_{d_k } & \boldsymbol{0}_{d_k } & \cdots & \boldsymbol{0}_{d_k } & \cdots & \boldsymbol{0}_{d_k } & \boldsymbol{0}_{d_k } & \cdots & \boldsymbol{0}_{d_k } \\ \boldsymbol{0}_{d_k } & \boldsymbol{0}_{d_k } & \cdots & \boldsymbol{0}_{d_k } & \boldsymbol{I}_{d_k} & \boldsymbol{I}_{d_k} & \cdots & \boldsymbol{I}_{d_k} & \cdots & \boldsymbol{0}_{d_k } & \boldsymbol{0}_{d_k } & \cdots & \boldsymbol{0}_{d_k } \\ \vdots & \vdots & \ddots & \vdots & \vdots & \vdots & \ddots & \vdots & \ddots & \vdots & \vdots & \ddots & \vdots \\ \boldsymbol{0}_{d_k } & \boldsymbol{0}_{d_k } & \cdots & \boldsymbol{0}_{d_k } & \boldsymbol{0}_{d_k } & \boldsymbol{0}_{d_k } & \cdots & \boldsymbol{0}_{d_k } & \cdots & \boldsymbol{I}_{d_k} & \boldsymbol{I}_{d_k} & \cdots & \boldsymbol{I}_{d_k} \end{bmatrix}\\&=[\boldsymbol{k}_i^{(1)} \boldsymbol{I}_{d_k}, \boldsymbol{k}_i^{(1)} \boldsymbol{I}_{d_k}, \cdots,\boldsymbol{k}_i^{(1)} \boldsymbol{I}_{d_k},\boldsymbol{k}_i^{(2)} \boldsymbol{I}_{d_k},\boldsymbol{k}_i^{(2)} \boldsymbol{I}_{d_k},\cdots,\boldsymbol{k}_i^{(2)} \boldsymbol{I}_{d_k},\cdots,\boldsymbol{k}_i^{(g)} \boldsymbol{I}_{d_k},\boldsymbol{k}_i^{(g)} \boldsymbol{I}_{d_k},\cdots,\boldsymbol{k}_i^{(g)} \boldsymbol{I}_{d_k}]\\ &=[\boldsymbol{k}_i^{(1)}, \cdots, \boldsymbol{k}_i^{(1)},\boldsymbol{k}_i^{(2)}, \cdots, \boldsymbol{k}_i^{(2)},\cdots,\boldsymbol{k}_i^{(g)}, \cdots, \boldsymbol{k}_i^{(g)}] \in \mathbb{R}^{h d_k} \end{aligned} ckiWk=[ki(1),⋯,ki(g)]⋅ Idk0dk⋮0dkIdk0dk⋮0dk⋯⋯⋱⋯Idk0dk⋮0dk0dkIdk⋮0dk0dkIdk⋮0dk⋯⋯⋱⋯0dkIdk⋮0dk⋯⋯⋱⋯0dk0dk⋮Idk0dk0dk⋮Idk⋯⋯⋱⋯0dk0dk⋮Idk =[ki(1)Idk,ki(1)Idk,⋯,ki(1)Idk,ki(2)Idk,ki(2)Idk,⋯,ki(2)Idk,⋯,ki(g)Idk,ki(g)Idk,⋯,ki(g)Idk]=[ki(1),⋯,ki(1),ki(2),⋯,ki(2),⋯,ki(g),⋯,ki(g)]∈Rhdk
于是我们就得到了维度为 h d k h d_k hdk 的 k i \boldsymbol{k}_i ki,其中每个 k i ( s ) \boldsymbol{k}_i^{(s)} ki(s) 都被复制了 h / g h/g h/g 次,实现了“复制”操作。同理, c v i W v \boldsymbol{cv}_i \boldsymbol{W}_v cviWv 得到 v i \boldsymbol{v}_i vi,维度为 h d v h d_v hdv。
这里我们讨论的都是单个 token 的行向量,而对于实际输入序列,其(最后两个)维度为 (seq_len, d),同样可以直接替换上面的单一向量。在上面所构造的所有矩阵中,最重要的是 W k \boldsymbol{W}_k Wk 和 W v \boldsymbol{W}_v Wv,可以看出两者都是由若干单位矩阵组成的稀疏矩阵,是 GQA 的分割、复制操作的矩阵形式描述,那么正如苏神文章中所述,我们可以将让这两个矩阵变成可学习的参数,比如在“复制”过程中给每个头一个不同的权重,这样理论上可以增强 GQA 的能力。
这就是 MLA 的思想,它可以看作是 GQA 的一种改进,在压缩到 c \boldsymbol{c} c 维之后又用一个上投影矩阵来恢复到更高的维度,在 DeepSeek-V2 的技术报告中是先利用下投影矩阵 W D K V W^{DKV} WDKV 将隐藏层输入 h t \mathbf{h}_t ht 投影得到 c t K V \mathbf{c}_t^{KV} ctKV,然后再用两个上投影矩阵 W U K , W U V W^{UK},W^{UV} WUK,WUV 将 c t K V \mathbf{c}_t^{KV} ctKV 还原得到 k t C , v t C \mathbf{k}_t^C,\mathbf{v}_t^C ktC,vtC
然而,理论上这样是能增加模型能力,但别忘了 GQA 的主要目的是减少 KV Cache,出于节省计算和通信成本的考虑,我们一般缓存的是投影后的 k i , v i \bm{k_{i}},\bm{v_{i}} ki,vi 而不是投影前的 c i \bm{c_{i}} ci 或 x i \bm{x_{i}} xi,而 MLA 的这个做法,通过不同的投影矩阵再次让所有的 K、V Head 都变得各不相同,那么 KV Cache 的大小就恢复成跟 MHA 一样大了,违背了 GQA 的初衷。
对此,MLA 发现,我们可以结合 Dot-Attention 的具体形式,通过一个简单但不失巧妙的恒等变换来规避这个问题。首先,在训练阶段还是照常进行,此时优化空间不大;然后,在推理阶段,我们利用:
q t ( s ) k i ( s ) ⊤ = ( x t W q ( s ) ) ( c i W k ( s ) ) ⊤ = x t ( W q ( s ) W k ( s ) ⊤ ) c i ⊤ \bm{q_{t}^{(s)}}\bm{k_{i}^{(s)\top}}=\left(\bm{x_{t}}\bm{W_{q}^{(s)}}\right) \left(\bm{c_{i}}\bm{W_{k}^{(s)}}\right)\bm{{}^{\top}}=\bm{x_{t}}\left(\bm{W_{q }^{(s)}}\bm{W_{k}^{(s)\top}}\right)\bm{c_{i}^{\top}} qt(s)ki(s)⊤=(xtWq(s))(ciWk(s))⊤=xt(Wq(s)Wk(s)⊤)ci⊤
这意味着推理阶段,我们可以将 W q ( s ) W k ( s ) ⊤ \bm{W_{q }^{(s)}}\bm{W_{k}^{(s)\top}} Wq(s)Wk(s)⊤ 合并起来作为 Q 的投影矩阵,那么 c i \bm{c_i} ci 则取代了原本的 k i \bm{k_{i}} ki
同理,在 o t \bm{o_{t}} ot 后面我们还有一个投影矩阵,于是 v i ( s ) = c i W v ( s ) \bm{v_{i}^{(s)}=c_{i}W_{v}^{(s)}} vi(s)=ciWv(s) 的 W v ( s ) \bm{W_{v}^{(s)}} Wv(s) 也可以吸收到后面的投影矩阵中去,于是等效地 v i \bm{v_{i}} vi 也可以用 c i \bm{c_i} ci 代替,也就是说此时 KV Cache 只需要存下所有的 c i \bm{c_i} ci 就行,而不至于存下所有的 k i ( s ) \bm{k_{i}^{(s)}} ki(s)、 v i ( s ) \bm{v_{i}^{(s)}} vi(s)。注意到 c i \bm{c_i} ci 跟 ( s ) ^{(s)} (s) 无关,也就是说所有的头共享的,即 MLA 在推理阶段它可以恒等变换为一个 MQA
再次强调,我们的主题一直都是减少 KV Cache,那到目前为止,MLA 做到了什么呢?答案是通过不同的投影矩阵来增强 GQA 的能力,并且推理时可以保持同样大小的 KV Cache。那么反过来,如果我们只需要跟 GQA 相近的能力,那么是不是就可以再次减少 KV Cache 了?换言之, d c d_c dc 没必要取 g ( d k + d v ) g(d_k+d_v) g(dk+dv),而是取更小的值(DeepSeek-V2 取了 512),从而进一步压缩 KV Cache,这就是 MLA 的核心思想
1.5.2 Part 2
一切似乎都很完美,但到目前为止的 MLA 有一个难以绕开的缺陷—不兼容 RoPE(旋转位置编码)
关于 RoPE 大家感兴趣的可以看看:RoPE旋转位置编码原理浅析
前面我们说了,MLA 之所以能保持跟 GQA 一样大小的 KV Cache,其关键一步是将 W q ( s ) W k ( s ) ⊤ \bm{W_{q }^{(s)}}\bm{W_{k}^{(s)\top}} Wq(s)Wk(s)⊤ 合并成一个(跟位置无关的)矩阵作为 Q 的投影矩阵,但如果加了 RoPE 的话,这一步就无法实现了。这是因为 RoPE 是一个跟位置相关的、 d k × d k d_k\times d_k dk×dk 的分块对角矩阵 R m \bm{\mathcal{R}_{m}} Rm,满足 R m R n ⊤ = R m − n \bm{\mathcal{R}_{m}}\bm{\mathcal{R}_{n}}^{\top}=\bm{\mathcal{R}_{m-n}} RmRn⊤=Rm−n,MLA 加入 RoPE 之后会让 W q ( s ) W k ( s ) ⊤ \bm{W_{q }^{(s)}}\bm{W_{k}^{(s)\top}} Wq(s)Wk(s)⊤ 之间多插入了一项 R t − i \bm{\mathcal{R}_{t-i}} Rt−i:
q i ( s ) = x i W q ( s ) R i , k i ( s ) = c i W k ( s ) R i q t ( s ) k i ( s ) ⊤ = ( x t W q ( s ) R i ) ( c i W k ( s ) R i ) ⊤ = x t ( W q ( s ) R t − i W k ( s ) ⊤ ) c i ⊤ \bm{q_{i}^{(s)}}=\bm{x_{i}W_{q}^{(s)}{\color{red}{\mathcal{R}_{i}}}}\quad,\quad\bm{k_{i}^{(s) }}=\bm{c_{i}W_{k}^{(s)}{\color{red}{\mathcal{R}_{i}}}} \\ \bm{q}_{t}^{(s)}\bm{k}_{i}^{(s)}{}^{\top}=\left(\bm{x}_{t}\bm{W}_{q}^{(s)} {\color{red}{\bm{\mathcal{R}_{i}}}}\right)\left(\bm{c}_{i}\bm{W}_{k}^{(s)} {\color{red}{\bm{\mathcal{R}_{i}}}}\right) {}^{\top}=\bm{x}_{t}\left(\bm{W}_{q}^{(s)} {\color{red}{\bm{\mathcal{R}_{t-i}}}}\bm{W}_{k}^{(s)}{ }^{\top}\right)\bm{c}_{i}^{\top} qi(s)=xiWq(s)Ri,ki(s)=ciWk(s)Riqt(s)ki(s)⊤=(xtWq(s)Ri)(ciWk(s)Ri)⊤=xt(Wq(s)Rt−iWk(s)⊤)ci⊤
这里的 W q ( s ) R t − i W k ( s ) ⊤ \bm{W}_{q}^{(s)} {\color{red}{\bm{\mathcal{R}_{t-i}}}}\bm{W}_{k}^{(s)}{ }^{\top} Wq(s)Rt−iWk(s)⊤ 就无法合并为一个固定的投影矩阵了(跟位置差 t − i t-i t−i 相关),从而 MLA 的想法无法结合 RoPE 实现
最后发布的 MLA 通过将每个 Attention Head 的 Q、K 新增 d r d_r dr 个维度用来添加 RoPE,其中 K 新增的维度每个 Head 共享:
o t = [ o t ( 1 ) , o t ( 2 ) , ⋯ , o t ( h ) ] o t ( s ) = A t t e n t i o n ( q t ( s ) , k ≤ t ( s ) , v ≤ t ( s ) ) ≜ ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) v i ( s ) ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) q i ( s ) = [ x i W q c ( s ) , x i W q r ( s ) R i ] ∈ R d k + d r , W q c ( s ) ∈ R d × d k , W q r ( s ) ∈ R d × d r k i ( s ) = [ c i W k c ( s ) , x i W k r ( s ) R i ] ∈ R d k + d r , W k c ( s ) ∈ R d c × d k , W k r ( s ) ∈ R d × d r v i ( s ) = c i W v ( s ) ∈ R d v , W v ( s ) ∈ R d c × d v c i = x i W c ∈ R d c , W c ∈ R d × d c \bm{o_{t}}=\left[\bm{o_{t}^{(1)}},\bm{o_{t}^{(2)}},\cdots,\bm{o_{t}^{(h)}}\right] \\ \bm{o}_{t}^{(s)}=\bm{Attention}\left(\bm{q}_{t}^{(s)},\bm{k}_{\leq t}^{(s)}, \bm{v}_{\leq t}^{(s)}\right)\triangleq\frac{\sum_{i\leq t}\exp\Bigl(\bm{q}_ {t}^{(s)}\bm{k}_{i}^{(s)}{}^{\top}\Bigr)\bm{v}_{i}^{(s)}}{\sum_{i\leq t}\exp \Bigl(\bm{q}_{t}^{(s)}\bm{k}_{i}^{(s)}{}^{\top}\Bigr)} \\ \bm{q}_{i}^{(s)}=\left[\bm{x_{i}}\bm{W}_{qc}^{(s)}\bm{,x_{i}}\bm{W}_{qr}^{(s)} {\color{red}\bm{\mathcal{R}}_{i}}\right]\in\mathbb{R}^{d_{k}+d_{r}}\bm{,}\quad\bm{W}_{qc}^ {(s)}\in\mathbb{R}^{d\times d_{k}}\bm{,W}_{qr}^{(s)}\in\mathbb{R}^{d\times d _{r}} \\ \bm{k}_{i}^{(s)}=\left[\bm{c_{i}}\bm{W}_{kc}^{(s)}\bm{,x_{i}}\bm{W}_{kr}^{\color{red}\bcancel{(s)}} {\color{red}\bm{\mathcal{R}}_{i}}\right]\in\mathbb{R}^{d_{k}+d_{r}}\bm{,}\quad\bm{W}_{kc}^ {(s)}\in\mathbb{R}^{d_c\times d_{k}}\bm{,W}_{kr}^{\color{red}\bcancel{(s)}}\in\mathbb{R}^{d\times d _{r}} \\ \bm{v_{i}^{(s)}=c_{i}W_{v}^{(s)}\in\mathbb{R}^{d_{v}},\quad W_{v}^{(s)}\in \mathbb{R}^{d_c\times d_{v}}} \\ \bm{c_{i}=x_{i}W_{c}\in\mathbb{R}^{d_{c}},\quad W_{c}\in \mathbb{R}^{d\times d_{c}}} ot=[ot(1),ot(2),⋯,ot(h)]ot(s)=Attention(qt(s),k≤t(s),v≤t(s))≜∑i≤texp(qt(s)ki(s)⊤)∑i≤texp(qt(s)ki(s)⊤)vi(s)qi(s)=[xiWqc(s),xiWqr(s)Ri]∈Rdk+dr,Wqc(s)∈Rd×dk,Wqr(s)∈Rd×drki(s)=[ciWkc(s),xiWkr(s) Ri]∈Rdk+dr,Wkc(s)∈Rdc×dk,Wkr(s) ∈Rd×drvi(s)=ciWv(s)∈Rdv,Wv(s)∈Rdc×dvci=xiWc∈Rdc,Wc∈Rd×dc
这样一来,没有 RoPE 的维度就可以重复 “Part 1” 的操作,在推理时 KV Cache 只需要存 c i \bm{c_i} ci,新增的带 RoPE 的维度就可以用来补充位置信息,并且由于所有 Head 共享,所以也就只有在 K Cache 这里增加了 d r d_r dr 个维度,原论文取了 d r = d k / 2 = 64 d_r=d_k/2=64 dr=dk/2=64,相比原本的 d c = 512 d_c=512 dc=512,增加的幅度不大
1.5.3 Part 3
最后有一个细节,就是 MLA 的最终版本,还将 Q 的输入也改为了低秩投影形式,这与减少 KV Cache 无关,主要是为了减少训练期间参数量和相应的梯度所占的显存:
o t = [ o t ( 1 ) , o t ( 2 ) , ⋯ , o t ( h ) ] o t ( s ) = A t t e n t i o n ( q t ( s ) , k ≤ t ( s ) , v ≤ t ( s ) ) ≜ ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) v i ( s ) ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) q i ( s ) = [ c i ′ W q c ( s ) , c i ′ W q r ( s ) R i ] ∈ R d k + d r , W q c ( s ) ∈ R d c ′ , W q r ( s ) ∈ R d c ′ × d r k i ( s ) = [ c i W k c ( s ) , x i W k r ( s ) R i ] ∈ R d k + d r , W k c ( s ) ∈ R d c , W k r ( s ) ∈ R d × d r v i ( s ) = c i W v ( s ) ∈ R d v , W v ( s ) ∈ R d c × d v c i ′ = x i W c ′ ∈ R d c ′ , W c ′ ∈ R d × d c ′ c i = x i W c ∈ R d c , W c ∈ R d × d c \begin{gathered} \boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt] \boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(s)} ,\boldsymbol{v}_{\leq t}^{(s)}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)} \\[15pt] \boldsymbol{q}_i^{(s)} = \left[\boldsymbol{c}_i'\boldsymbol{W}_{qc}^{(s)}, \boldsymbol{c}_i'\boldsymbol{W}_{qr}^{(s)}{\color{red}{\boldsymbol{\mathcal{R}}_i}}\right]\in\mathbb{R}^{d_k + d_r},\quad \boldsymbol{W}_{qc}^{(s)}\in\mathbb{R}^{d_c'},\boldsymbol{W}_{qr}^{(s)}\in\mathbb{R}^{d_c'\times d_r}\\ \boldsymbol{k}_i^{(s)} = \left[\boldsymbol{c}_i\boldsymbol{W}_{kc}^{(s)}, \boldsymbol{x}_i\boldsymbol{W}_{kr}^{\color{red}{\smash{\bcancel{(s)}}}}{\color{red}{\boldsymbol{\mathcal{R}}_i}}\right]\in\mathbb{R}^{d_k+d_r},\quad \boldsymbol{W}_{kc}^{(s)}\in\mathbb{R}^{d_c}, \boldsymbol{W}_{kr}^{\color{red}{\smash{\bcancel{(s)}}}}\in\mathbb{R}^{d\times d_r} \\ \boldsymbol{v}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_c\times d_v} \\[10pt] \boldsymbol{c}_i' = \boldsymbol{x}_i \boldsymbol{W}_c'\in\mathbb{R}^{d_c'},\quad \boldsymbol{W}_c'\in\mathbb{R}^{d\times d_c'} \\ \boldsymbol{c}_i = \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c} \\ \end{gathered} ot=[ot(1),ot(2),⋯,ot(h)]ot(s)=Attention(qt(s),k≤t(s),v≤t(s))≜∑i≤texp(qt(s)ki(s)⊤)∑i≤texp(qt(s)ki(s)⊤)vi(s)qi(s)=[ci′Wqc(s),ci′Wqr(s)Ri]∈Rdk+dr,Wqc(s)∈Rdc′,Wqr(s)∈Rdc′×drki(s)=[ciWkc(s),xiWkr(s) Ri]∈Rdk+dr,Wkc(s)∈Rdc,Wkr(s) ∈Rd×drvi(s)=ciWv(s)∈Rdv,Wv(s)∈Rdc×dvci′=xiWc′∈Rdc′,Wc′∈Rd×dc′ci=xiWc∈Rdc,Wc∈Rd×dc
注意 k i ( s ) \boldsymbol{k}_i^{(s)} ki(s) 中的第二项,带 RoPE 的部分,其输入还是 x i \boldsymbol{x}_i xi 而不是 c i \boldsymbol{c}_i ci,这里保持了原论文的设置,不是笔误, d c ′ d_c' dc′ 原论文的取值是 1536,跟 d c = 512 d_c=512 dc=512 不同。
同时,我们把带 RoPE 的 MHA 放在下面,方便大家对比:
o t = [ o t ( 1 ) , o t ( 2 ) , ⋯ , o t ( h ) ] o t ( s ) = A t t e n t i o n ( q t ( s ) , k ≤ t ( s ) , v ≤ t ( s ) ) ≜ ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) v i ( s ) ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) q i ( s ) = x i W q ( s ) R i ∈ R d k , W q ( s ) ∈ R d × d k k i ( s ) = x i W k ( s ) R i ∈ R d k , W k ( s ) ∈ R d × d k v i ( s ) = x i W v ( s ) ∈ R d v , W v ( s ) ∈ R d × d v \begin{gathered} \boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt] \boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(s)} ,\boldsymbol{v}_{\leq t}^{(s)}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)} \\[15pt] \boldsymbol{q}_i^{(s)}=\boldsymbol{x}_i\boldsymbol{W}_q^{(s)}{\color{red}\mathcal{R}_i}\in\mathbb{R}^{d_k},\quad\boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d\times d_k} \\\boldsymbol{k}_i^{(s)}=\boldsymbol{x}_i\boldsymbol{W}_k^{(s)}{\color{red}\mathcal{R}_i}\in\mathbb{R}^{d_k},\quad\boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d\times d_k}\\ \boldsymbol{v}_i^{(s)}=\boldsymbol{x}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d\times d_v} \end{gathered} ot=[ot(1),ot(2),⋯,ot(h)]ot(s)=Attention(qt(s),k≤t(s),v≤t(s))≜∑i≤texp(qt(s)ki(s)⊤)∑i≤texp(qt(s)ki(s)⊤)vi(s)qi(s)=xiWq(s)Ri∈Rdk,Wq(s)∈Rd×dkki(s)=xiWk(s)Ri∈Rdk,Wk(s)∈Rd×dkvi(s)=xiWv(s)∈Rdv,Wv(s)∈Rd×dv
可以发现,其实在训练阶段,除了多了一步低秩投影以及只在部分维度加 RoPE 外,MLA 与 Q、K 的 Head Size 由 d k d_k dk 换成 d k + d r d_k+d_r dk+dr 的 MHA 基本无异
推理阶段的 MLA 则改为:
o t = [ o t ( 1 ) W v ( 1 ) , o t ( 2 ) W v ( 2 ) , ⋯ , o t ( h ) W v ( h ) ] o t ( s ) = A t t e n t i o n ( q t ( s ) , k ≤ t ( s ) , c ≤ t ) ≜ ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) c i ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) q i ( s ) = [ c i ′ W q c ( s ) W k c ( s ) ⊤ , c i ′ W q r ( s ) R i ] ∈ R d c + d r k i ( s ) = [ c i , x i W k r ( s ) R i ] ∈ R d c + d r W q c ( s ) ∈ R d c ′ × d k , W k c ( s ) ∈ R d c × d k , W q r ( s ) ∈ R d c ′ × d r , W k r ( s ) ∈ R d × d r c i ′ = x i W c ′ ∈ R d c ′ , W c ′ ∈ R d × d c ′ c i = x i W c ∈ R d c , W c ∈ R d × d c \begin{gathered} \boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}\boldsymbol{W}_v^{(1)}, \boldsymbol{o}_t^{(2)}\boldsymbol{W}_v^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\boldsymbol{W}_v^{(h)}\right] \\[10pt] \boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{\color{red}{\smash{\bcancel{(s)}}}} ,\boldsymbol{c}_{\leq t}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\color{red}{\smash{\bcancel{(s)}}}}{}^{\top}\right)\boldsymbol{c}_i}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\color{red}{\smash{\bcancel{(s)}}}}{}^{\top}\right)} \\[15pt] \boldsymbol{q}_i^{(s)} = \left[\boldsymbol{c}_i'\boldsymbol{W}_{qc}^{(s)}\boldsymbol{W}_{kc}^{(s)}{}^{\top}, \boldsymbol{c}_i'\boldsymbol{W}_{qr}^{(s)}{\color{red}{\boldsymbol{\mathcal{R}}_i}}\right]\in\mathbb{R}^{d_c + d_r}\\ \boldsymbol{k}_i^{\color{red}{\smash{\bcancel{(s)}}}} = \left[\boldsymbol{c}_i, \boldsymbol{x}_i\boldsymbol{W}_{kr}^{\color{red}{\smash{\bcancel{(s)}}}}{\color{red}{\boldsymbol{\mathcal{R}}_i}}\right]\in\mathbb{R}^{d_c+d_r}\\ \boldsymbol{W}_{qc}^{(s)}\in\mathbb{R}^{d_c'\times d_k},\boldsymbol{W}_{kc}^{(s)}\in\mathbb{R}^{d_c\times d_k},\boldsymbol{W}_{qr}^{(s)}\in\mathbb{R}^{d_c'\times d_r},\boldsymbol{W}_{kr}^{\color{#ccc}{\smash{\bcancel{(s)}}}}\in\mathbb{R}^{d\times d_r} \\[10pt] \boldsymbol{c}_i' = \boldsymbol{x}_i \boldsymbol{W}_c'\in\mathbb{R}^{d_c'},\quad \boldsymbol{W}_c'\in\mathbb{R}^{d\times d_c'} \\ \boldsymbol{c}_i = \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c} \\ \end{gathered} ot=[ot(1)Wv(1),ot(2)Wv(2),⋯,ot(h)Wv(h)]ot(s)=Attention(qt(s),k≤t(s) ,c≤t)≜∑i≤texp(qt(s)ki(s) ⊤)∑i≤texp(qt(s)ki(s) ⊤)ciqi(s)=[ci′Wqc(s)Wkc(s)⊤,ci′Wqr(s)Ri]∈Rdc+drki(s) =[ci,xiWkr(s) Ri]∈Rdc+drWqc(s)∈Rdc′×dk,Wkc(s)∈Rdc×dk,Wqr(s)∈Rdc′×dr,Wkr(s) ∈Rd×drci′=xiWc′∈Rdc′,Wc′∈Rd×dc′ci=xiWc∈Rdc,Wc∈Rd×dc
此时 Q、K 的 Head Size 变成了 d c + d r d_c + d_r dc+dr,V 的 Head Size 则变成了 d c d_c dc,按照原论文的设置,这是 d k d_k dk、 d v d_v dv 的 4 倍。所以实际上 MLA 在推理阶段做的这个转换,虽然能有效减少KV Cache,但其推理的计算量是增加的。
那为什么还能提高推理效率呢?这又回到“瓶颈”一节所讨论的问题了,我们可以将LLM的推理分两部分:第一个 Token 的生成(Prefill)和后续每个 Token 的生成(Generation),Prefill 阶段涉及到对输入所有 Token 的并行计算,然后把对应的 KV Cache 存下来,这部分对于计算、带宽和显存都是瓶颈,MLA 虽然增大了计算量,但 KV Cache 的减少也降低了显存和带宽的压力,大家半斤八两;但是 Generation 阶段由于每步只计算一个 Token,实际上它更多的是带宽瓶颈和显存瓶颈,因此 MLA 的引入理论上能明显提高 Generation 的速度。
还有一个细节充分体现了这个特性。一般的 LLM 架构参数满足 h = d h=d h=d,即 num_heads * head_size = hidden_size,但 DeepSeek-V2 不一样,它 d k = 128 , d = 5120 d_k=128,d=5120 dk=128,d=5120,但 h = 128 h=128 h=128,是一般设置的 3 倍!这是因为 MLA 的 KV Cache 大小跟 h h h 无关,增大 h h h 只会增加计算量和提升模型能力,但不会增加 KV Cache,所以不会带来速度瓶颈。
由于篇幅原因(CSDN 对正文字数有限制,真服了😒),MLA 的代码实现我们放在下篇文章
结语
MLA 可以看作是 GQA 的优化,通过投影矩阵的方式替换 GQA 中的分割、复制等线性变换操作,并引入了一个恒等变换在推理阶段通过矩阵吸收来进一步压缩 KV Cache,同时采用了一种混合方法通过新增维度来兼容 RoPE 旋转位置编码,总的来说,MLA 算得上一种非常实用的注意力变体
大家可以多看看苏神的文章来加深理解🤗
参考
- 缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA
- 博客分享:从MHA、MQA、GQA到MLA
- https://github.com/preacher-1/MLA_tutorial
- https://github.com/deepseek-ai/DeepSeek-V3