原理
KV Cache的本质就是避免重复计算,把需要重复计算的结果进行缓存,生成式模型的新的token的产生需要用到之前的所有token的 K , V K,V K,V,在计算注意力的时候是当前的 Q Q Q和所有的 K , V K,V K,V来进行计算,所以是缓存 K , V K,V K,V。
由于Causal Mask的存在,前面已经生成的token不需要与后面的token产生attention,也就是用不到前面token的 Q Q Q,用的上前面token的 K , V K,V K,V,具体的公式如下:
a t t 1 ( Q , K , V ) = s o f t m a x ( Q 1 K 1 T D ) V 1 att_1(Q,K,V)=softmax(\frac{Q_1K_1^T}{\sqrt{D}})V_1 att1(Q,K,V)=softmax(DQ1K1T)V1
a t t 2 ( Q , K , V ) = s o f t m a x ( Q 2 K 1 T D ) V 1 + s o f t m a x ( Q 2 K 2 T D ) V 2 att_2(Q,K,V)=softmax(\frac{Q_2K_1^T}{\sqrt{D}})V_1+softmax(\frac{Q_2K_2^T}{\sqrt{D}})V_2 att2(Q,K,V)=softmax(DQ2K1T)V1+softmax(DQ2K2T)V2
a t t 3 ( Q , K , V ) = s o f t m a x ( Q 3 K 1 T D ) V 1 + s o f t m a x ( Q 3 K 2 T D ) V 2 + s o f t m a x ( Q 3 K 3 T D ) V 3 att_3(Q,K,V)=softmax(\frac{Q_3K_1^T}{\sqrt{D}})V_1+softmax(\frac{Q_3K_2^T}{\sqrt{D}})V_2+softmax(\frac{Q_3K_3^T}{\sqrt{D}})V_3 att3(Q,K,V)=softmax(DQ3K1T)V1+softmax(DQ3K2T)V2+softmax(DQ3K3T)V3
可以看出, K , V K,V K,V存在重复计算的情况,因此可以进行Cache。
KV Cache只适用于Decoder架构,因为有Causal Mask的存在,如果是Encoder,处理的是输入序列,是一次性完成整个序列attention的计算,并不像Decoder一样有自左向右的重复性的计算,Encoder由于其一次性和并行性,用不上KV-Cache,而解码器由于其自回归性,KV Cache是很有用的。