引言
上篇文章介绍了如何在多GPU上分布式训练,本文介绍大模型常用的一种推理加速技术——KV缓存。
KV Cache
KV缓存(KV Cache)是在大模型推理中常用的一种技巧。我们知道在推理阶段,Transformer也只能像RNN一样逐个进行预测,也称为自回归。KV cahce是用在注意力阶段缓存key和value状态,具体的我们可以看图示:
上图(灰色区域表示掩码)是在没有KV缓存的情况下,在每一步生成时,我们都在重新计算相同的之前的Token注意力,而实际上我们只想计算新Token的注意力。
比如在最后一步,即第4步时,我们再次计算了之前步骤已经算好的Token注意力Attention1到Attention3,实际上这是没有必要的。
如果我们可以缓存之前计算好的Key和Value,那么就可以不需要这么多重复计算,每次只关注最新Token的注意力:
上图(蓝色表示缓存起来的Key或Value)在有KV缓存的情况下,每次只需要传入新的Query,然后计算新的Key和Value,并且与之前的Key和Value缓存矩阵拼接在一起,最后计算出最新Token的注意力。这就是KV缓存的主要思想。可以看到这里不再需要掩码。
这里描述的是自注意力中的KV缓存,如果是交叉注意力那么更简单,因为编码器生成的memory不会改变,因此可以直接缓存memory计算出来的Key和Value矩阵,而不需要拼接。
为了让我们的Transformer能支持KV缓存技术,我们需要进行一些改造。首先对MultiHeadAttention
模块动刀,主要修改它的forward
方法:
def forward(self,query: Tensor,key_value: Tensor = None,mask: Tensor = None,past_key_value: Tuple[Tensor] = None,use_cache: bool = False,keep_attentions: bool = False,) -> Tuple[Tensor, Tensor]:"""Args:query (Tensor): (batch_size, q_len, d_model)key_value (Tensor, optional): (batch_size, k_len/v_len, d_model) key and value are same.mask (Tensor, optional): mask for padding or decoder. Defaults to None.past_key_value (Tuple[Tensor], optional): cached past key and value states. Defaults to None.use_cache (bool, optional): whether to use kv cache during inference. Defaults to False.keep_attentions (bool): whether to keep attention weigths or not. Defaults to False.Returns:output (Tensor): (batch_size, q_len, d_model) attention outputpresent_key_value (Tuple[Tensor], optional): Cached present key and value states"""if past_key_value is not None:assert self.is_decoder is True, "Encoder cannot cache past key value states"is_self_attention = key_value is None_query = queryquery = self._transform_and_split(self.q, query)if is_self_attention:# the 'self' attentionkey = self._transform_and_split(self.k, _query, is_key=True) # 即先进行Q/K/V转换,再拆分成多头value = self._transform_and_split(self.v, _query)key, value = self._concat_key_value(key, value, past_key_value) # 分情况拼接最新的key和valueelif past_key_value is None:# the cross attention, key_value is memorykey = self._transform_and_split(self.k, key_value, is_key=True)value = self._transform_and_split(self.v, key_value)else:# if is_self_attention == False and past_key_value is not None# key_value is memory and use cache(past_key_value not None) we do not need to calculate the key and value again because it was cached.# since memory will not change during inference.key, value = past_key_valueif self.is_decoder and use_cache:# cache newest key and valuepresent_key_value = (key, value)else:present_key_value = Noneattn_output = self.attenion(query, key, value, mask, keep_attentions)# Concatconcat_output = self.merge_heads(attn_output)# the final liear# output (batch_size, q_len, d_model)output = self.concat(concat_output)return output, present_key_value
其参数发生了一些变换,由原来的query,key,value
变成了query,key_value
。
首先,这里将key
和value
合并了起来,因为如果是自注意力query=key=value
,而如果是交叉注意力key=value=memory
,然后我们可以通过判断key_value
是否为空来分辨本次计算的是自注意力还是交叉注意力;
其次,增加了两个参数past_key_value
和use_cache
,use_cache
表示是否使用kv缓存,而past_key_value
代表缓存的kv,注意缓存的k和v是不同的,因为它们经过了Key和Value矩阵映射。
然后我们深入方法内部,注意只有在推理阶段的Decoder中才能使用kv cache。
这里要分两种情况:自注意力和交叉注意力。
如果是自注意力直接使用传入的query
就可以计算映射后的query,key,value,见代码行32到37。当使用缓存时,传入的query
的长度一定是1,因为我们只需要为最新的query
去计算注意力分数,算出一个预测的token。但还是需要当前query
对应K和V矩阵映射后的key
和value
,将它们与历史(缓存)的拼接起来去计算新的token。
如果是交叉注意力,即Decoder中第二个注意力模块,其query来自decoder,而key和value(即memory)来自encoder。显然这个memory在整个推理阶段都是一样的,因此只需要计算一次,然后存入past_key_value
缓存,后续就不再需要重复计算,对应上面的代码行47。
只有在使用缓存且为Decoder的时候才会缓存最新的key和value。
最后和之前一样计算注意力得分即可。
接下来修改DecoderBlock
中的forward
代码:
def forward(self,tgt: Tensor,memory: Tensor,tgt_mask: Tensor = None,memory_mask: Tensor = None,past_key_value: Tuple[Tensor] = None,use_cache: bool = True,keep_attentions: bool = False,) -> Tuple[Tensor, Tensor]:"""Args:tgt (Tensor): (batch_size, tgt_seq_len, d_model) the (target) sequence to the decoder block.memory (Tensor): (batch_size, src_seq_len, d_model) the sequence from the last layer of the encoder.tgt_mask (Tensor, optional): (batch_size, 1, tgt_seq_len, tgt_seq_len) the mask for the tgt sequence.memory_mask (Tensor, optional): (batch_size, 1, 1, src_seq_len) the mask for the memory sequence.past_key_values (Tuple[Tensor], optional): the cached key and value states. Defaults to None.use_cache (bool, optional): whether use kv cache during inference or not. Defaults to False.keep_attentions (bool): whether keep attention weigths or not. Defaults to False.Returns:tgt (Tensor): (batch_size, tgt_seq_len, d_model) output of decoder block"""if past_key_value is not None:# first two elements in the past_key_value tuple are self-attention# past_key_value是一个元组,其中前2个元素为自注意力层的key和value# 后2个元素为交叉注意力层的key和valueself_attn_past_key_value = past_key_value[:2]cross_attn_past_key_value = past_key_value[2:]else:self_attn_past_key_value = Nonecross_attn_past_key_value = Nonex = tgt# 自注意力self_attn_outputs = self._sa_sub_layer(x,tgt_mask,self_attn_past_key_value,use_cache,keep_attentions,)# self attention output and present key value state# x和之前的输出一样,多了一个保存key和value的present_key_value_statex, present_key_value_state = self_attn_outputs# 交叉注意力cross_attn_outputs = self._ca_sub_layer(x,memory,memory_mask,cross_attn_past_key_value,use_cache,keep_attentions,)x = cross_attn_outputs[0]if present_key_value_state is not None:# append the cross-attention key and value states to present key value states # 拼接注意力和交叉注意力中的key和value,得到元组的4个元素present_key_value_state = present_key_value_state + cross_attn_outputs[1]x = self._ff_sub_layer(x)# 别忘了返回return x, present_key_value_state
其中调用了两个子层对应的方法如下:
def _sa_sub_layer(self,x: Tensor,attn_mask: Tensor,past_key_value: Tensor,use_cache: bool,keep_attentions: bool,
) -> Tensor:residual = xx, present_key_value = self.masked_attention(query=self.norm1(x),past_key_value=past_key_value,use_cache=use_cache,mask=attn_mask,keep_attentions=keep_attentions,)x = self.dropout1(x) + residualreturn x, present_key_value# cross attention sub layer
def _ca_sub_layer(self,x: Tensor,mem: Tensor,attn_mask: Tensor,past_key_value: Tensor,use_cache: bool,keep_attentions: bool,
) -> Tensor:residual = xx, present_key_value = self.cross_attention(query=self.norm2(x),key_value=mem,mask=attn_mask,past_key_value=past_key_value,use_cache=use_cache,keep_attentions=keep_attentions,)x = self.dropout2(x) + residualreturn x, present_key_value
这里改成了默认Pre-LN的形式,即先计算层归一化,最后再进行残差连接。
还有一个非常重要的修改是PositionalEncoding
:
def forward(self, x: Tensor, position_ids: Union[int, list[int]] = None) -> Tensor:"""Args:x (Tensor): (batch_size, seq_len, d_model) embeddingsposition_ids (Union[int, list[int]]): singe position id or listReturns:Tensor: (batch_size, seq_len, d_model)"""if position_ids is None:position_ids = range(x.size(1))return self.dropout(x + self.pe[:, position_ids, :])
增加了一个参数表示位置id,我们知道如果使用缓存传入的seq_len
恒等于1,但实际上它对应的位置ID是不停增加的,若不修改此处,默认通过range(x.size(1))
永远只能获取索引等于0时的位置编码,导致表现大幅下降。因此我们要传入当前的位置。
由于缓存只对Decoder生效,因此我们可以直接修改Transformer
模块的decode
方法:
def decode(self,tgt: Tensor,memory: Tensor,tgt_mask: Tensor = None,memory_mask: Tensor = None,past_key_values: Tuple[Tensor] = None,use_cache: bool = False,keep_attentions: bool = False,
) -> Tensor:"""Args:tgt (Tensor): (batch_size, tgt_seq_len) the sequence to the decoder.memory (Tensor): (batch_size, src_seq_len, d_model) the sequence from the last layer of the encoder.tgt_mask (Tensor, optional): (batch_size, 1, 1, tgt_seq_len) the mask for the target sequence. Defaults to None.memory_mask (Tensor, optional): (batch_size, 1, 1, src_seq_len) the mask for the memory sequence. Defaults to None.past_key_values (Tuple[Tensor], optional): the cached key and value states. Defaults to None.use_cache (bool, optional): whether use kv cache during inference or not. Defaults to False.keep_attentions (bool, optional): whether keep attention weigths or not. Defaults to False.Returns:Tensor: output (batch_size, tgt_seq_len, tgt_vocab_size)"""if past_key_values is None:past_key_values = [None] * len(self.decoder.layers)# 未使用缓存则传Noneposition_ids = Noneelse:# when use_cache we only care about the current position# 否则传入当前位置对应的IDposition_ids = past_key_values[0][1].size(2)tgt_embed = self.dec_pos(self.tgt_embedding(tgt), position_ids)# logits (batch_size, tgt_seq_len, d_model)logits, past_key_values = self.decoder(tgt_embed,memory,tgt_mask,memory_mask,past_key_values,use_cache,keep_attentions,)return logits, past_key_values
代码增加了注释,大概意思是如果使用缓存,那么我们需要知道缓存的key或value对应的长度。而刚好seq_len
恒等于1,因此不需要增加这个seq_len
,past_key_values[0][1].size(2)
的值刚好就是我们想要的位置ID。
最后对贪心解码的实现进行一些小修改:
def _greedy_search(self,src: Tensor,src_mask: Tensor,max_gen_len: int,use_cache: bool,keep_attentions: bool,
):memory = self.transformer.encode(src, src_mask)batch_size = src.shape[0]device = src.device# keep track of which sequences are already finishedunfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)decoder_inputs = torch.LongTensor(batch_size, 1).fill_(self.bos_idx).to(device)input_ids = decoder_inputseos_idx_tensor = torch.tensor([self.eos_idx]).to(device)finished = Falsepast_key_values = Nonetgt_mask = None # 使用缓存的情况下可以传None,因为此时query可以看到所有的key。while True:if not use_cache:tgt_mask = self.generate_subsequent_mask(decoder_inputs.size(1), device)outputs = self.transformer.decode(input_ids,memory,tgt_mask=tgt_mask,memory_mask=src_mask,past_key_values=past_key_values,use_cache=use_cache,keep_attentions=keep_attentions,)logits = self.lm_head(outputs[0])past_key_values = outputs[1]next_tokens = torch.argmax(logits[:, -1, :], dim=-1)# finished sentences should have their next token be a pad tokennext_tokens = next_tokens * unfinished_sequences + self.pad_idx * (1 - unfinished_sequences)decoder_inputs = torch.cat([decoder_inputs, next_tokens[:, None]], dim=-1)# set sentence to finished if eos_idx was foundunfinished_sequences = unfinished_sequences.mul(next_tokens.tile(eos_idx_tensor.shape[0], 1).ne(eos_idx_tensor.unsqueeze(1)).prod(dim=0))if use_cache:# only need the last tokensinput_ids = next_tokens[:, None]else:input_ids = decoder_inputs# all sentences have eos_idxif unfinished_sequences.max() == 0:finished = Trueif decoder_inputs.shape[-1] >= max_gen_len:finished = Trueif finished:breakreturn decoder_inputs
在使用缓存的时候 input_ids = next_tokens[:, None]
,这样保证每次只传入最新预测的Token。
最后在测试集上进行推理来验证下加了kv cache速度提升了多少:
$ python train.py
source tokenizer size: 32000
target tokenizer size: 32000
Loads cached dataframes.
The model has 93255680 trainable parameters
begin train with arguments: {'d_model': 512, 'n_heads': 8, 'num_encoder_layers': 6, 'num_decoder_layers': 6, 'd_ff': 2048, 'dropout': 0.1, 'max_positions': 5000, 'source_vocab_size': 32000, 'target_vocab_size': 32000, 'attention_bias': False, 'pad_idx': 0, 'dataset_path': 'nlp-in-action/transformers/transformer/data/wmt', 'src_tokenizer_file': 'nlp-in-action/transformers/transformer/model_storage/source.model', 'tgt_tokenizer_path': 'nlp-in-action/transformers/transformer/model_storage/target.model', 'model_save_path': 'nlp-in-action/transformers/transformer/model_storage/best_transformer.pt', 'dataframe_file': 'dataframe.{}.pkl', 'use_dataframe_cache': True, 'cuda': True, 'num_epochs': 40, 'batch_size': 32, 'gradient_accumulation_steps': 1, 'grad_clipping': 0, 'betas': (0.9, 0.98), 'eps': 1e-09, 'label_smoothing': 0, 'warmup_steps': 4000, 'warmup_factor': 0.5, 'only_test': True, 'max_gen_len': 60, 'use_wandb': False, 'patient': 5, 'calc_bleu_during_train': True, 'use_kv_cache': False}
total train steps: 2212000%| | 0/1580 [00:00<?, ?it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1580/1580 [17:25<00:00, 1.51it/s]
TEST loss=0.0021 bleu score: 26.74$ python train.py
source tokenizer size: 32000
target tokenizer size: 32000
Loads cached dataframes.
The model has 93255680 trainable parameters
begin train with arguments: {'d_model': 512, 'n_heads': 8, 'num_encoder_layers': 6, 'num_decoder_layers': 6, 'd_ff': 2048, 'dropout': 0.1, 'max_positions': 5000, 'source_vocab_size': 32000, 'target_vocab_size': 32000, 'attention_bias': False, 'pad_idx': 0, 'dataset_path': 'transformers/transformer/data/wmt', 'src_tokenizer_file': 'transformers/transformer/model_storage/source.model', 'tgt_tokenizer_path': 'transformers/transformer/model_storage/target.model', 'model_save_path': 'transformers/transformer/model_storage/best_transformer.pt', 'dataframe_file': 'dataframe.{}.pkl', 'use_dataframe_cache': True, 'cuda': True, 'num_epochs': 40, 'batch_size': 32, 'gradient_accumulation_steps': 1, 'grad_clipping': 0, 'betas': (0.9, 0.98), 'eps': 1e-09, 'label_smoothing': 0, 'warmup_steps': 4000, 'warmup_factor': 0.5, 'only_test': True, 'max_gen_len': 60, 'use_wandb': False, 'patient': 5, 'calc_bleu_during_train': True, 'use_kv_cache': True}
total train steps: 2212000%| | 0/1580 [00:00<?, ?it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1580/1580 [13:37<00:00, 1.93it/s]
TEST loss=0.0021 bleu score: 26.74
这里加载之前训练效果最好的模型,可以看到计算出来的BLEU 分数都为26.74,使用kv cache耗时(单GPU推理)由17:25降到了13:37,快了接近4分钟。
kv cache实际上是一种空间换时间的技术,那么它会占多大的空间呢?
从上面代码可以看到,我们为每个Token都保存了4个向量,2个k和2个v,那么保存的字节数为:
4 ⋅ 4 ⋅ num_layers ⋅ num_heads ⋅ d_head 4 \cdot 4 \cdot \text{num\_layers} \cdot \text{num\_heads} \cdot \text{d\_head} 4⋅4⋅num_layers⋅num_heads⋅d_head
第一个4表示有4个向量;第二个4表示假设在float-32下需要4个字节;为每层都保存kv cahce;每个向量的大小为 num_heads × d_head \text{num\_heads} \times \text{d\_head} num_heads×d_head。
在base设定下(层数=6,d_model=512)批大小等于1,一个Token需要48kb的显存,假设最终生成512个长度的序列时,那么需要24M的显存。看起来不大,但对于大模型的参数量来说,显存占用就显著上升了。
我们这次结合多GPU和KV缓存进行训练:
$ sh train.sh
Number of GPUs used: 3
Running DDP on rank 2.0%| | 0/1844 [00:00<?, ?it/s]Running DDP on rank 1.0%| | 0/1844 [00:00<?, ?it/s]Running DDP on rank 0.
source tokenizer size: 32000
target tokenizer size: 32000
Loads cached train dataframe.
Loads cached dev dataframe.
The model has 93255680 trainable parameters
begin train with arguments: {'d_model': 512, 'n_heads': 8, 'num_encoder_layers': 6, 'num_decoder_layers': 6, 'd_ff': 2048, 'dropout': 0.1, 'max_positions': 5000, 'source_vocab_size': 32000, 'target_vocab_size': 32000, 'attention_bias': False, 'pad_idx': 0, 'dataset_path': 'nlp-in-action/transformers/transformer/data/wmt', 'src_tokenizer_file': 'nlp-in-action/transformers/transformer/model_storage/source.model', 'tgt_tokenizer_path': 'nlp-in-action/transformers/transformer/model_storage/target.model', 'model_save_path': 'nlp-in-action/transformers/transformer/model_storage/best_transformer.pt', 'dataframe_file': 'dataframe.{}.pkl', 'use_dataframe_cache': True, 'cuda': True, 'num_epochs': 40, 'batch_size': 32, 'gradient_accumulation_steps': 1, 'grad_clipping': 0, 'betas': (0.9, 0.98), 'eps': 1e-09, 'label_smoothing': 0, 'warmup_steps': 4000, 'warmup_factor': 0.5, 'only_test': False, 'max_gen_len': 60, 'use_wandb': False, 'patient': 5, 'calc_bleu_during_train': True, 'use_kv_cache': True}
total train steps: 73760
[GPU0] TRAIN loss=7.033506, learning rate=0.0001612: 100%|██████████| 1844/1844 [03:57<00:00, 7.76it/s]
[GPU1] TRAIN loss=7.085324, learning rate=0.0001612: 100%|██████████| 1844/1844 [03:57<00:00, 7.76it/s]
[GPU2] TRAIN loss=6.532835, learning rate=0.0001612: 100%|██████████| 1844/1844 [03:57<00:00, 7.76it/s]0%| | 0/264 [00:00<?, ?it/s]
| ID | GPU | MEM |
------------------
| 0 | 0% | 22% |
| 1 | 87% | 80% |
| 2 | 83% | 72% |
| 3 | 87% | 74% |
begin evaluate
100%|██████████| 264/264 [00:07<00:00, 36.57it/s]
100%|██████████| 264/264 [00:07<00:00, 36.18it/s]
calculate bleu score for dev dataset
100%|██████████| 264/264 [00:07<00:00, 35.56it/s]
100%|██████████| 264/264 [02:47<00:00, 1.57it/s]
100%|██████████| 264/264 [02:51<00:00, 1.54it/s]
100%|██████████| 264/264 [02:52<00:00, 1.53it/s]
[GPU1] end of epoch 1 [ 421s]| train loss: 8.0776 | valid loss: 7.1336 | valid bleu_score 0.42
[GPU0] end of epoch 1 [ 421s]| train loss: 8.0674 | valid loss: 7.1126 | valid bleu_score 0.41
Save model with best bleu score :0.41[GPU0] end of epoch 2 [ 403s]| train loss: 6.5031 | valid loss: 5.8428 | valid bleu_score 6.66
Save model with best bleu score :6.66[GPU0] end of epoch 3 [ 400s]| train loss: 5.2757 | valid loss: 4.6797 | valid bleu_score 16.64
Save model with best bleu score :16.64[GPU0] end of epoch 4 [ 400s]| train loss: 4.2989 | valid loss: 4.1087 | valid bleu_score 21.78
Save model with best bleu score :21.78[GPU0] end of epoch 5 [ 396s]| train loss: 3.7218 | valid loss: 3.8263 | valid bleu_score 23.51
Save model with best bleu score :23.51[GPU0] end of epoch 6 [ 396s]| train loss: 3.3296 | valid loss: 3.6755 | valid bleu_score 24.84
Save model with best bleu score :24.84[GPU0] end of epoch 8 [ 391s]| train loss: 2.8033 | valid loss: 3.5605 | valid bleu_score 25.86
Save model with best bleu score :25.86[GPU0] end of epoch 10 [ 386s]| train loss: 2.4323 | valid loss: 3.5600 | valid bleu_score 26.43
Save model with best bleu score :26.43[GPU0] end of epoch 11 [ 400s]| train loss: 2.2831 | valid loss: 3.5782 | valid bleu_score 26.91
Save model with best bleu score :26.91[GPU0] end of epoch 12 [ 390s]| train loss: 2.1463 | valid loss: 3.6085 | valid bleu_score 26.77[GPU0] end of epoch 13 [ 397s]| train loss: 2.0249 | valid loss: 3.6398 | valid bleu_score 26.61[GPU0] end of epoch 14 [ 389s]| train loss: 1.9126 | valid loss: 3.6763 | valid bleu_score 26.41[GPU0] end of epoch 15 [ 388s]| train loss: 1.8102 | valid loss: 3.7161 | valid bleu_score 26.15| ID | GPU | MEM |
------------------
| 0 | 1% | 22% |
| 1 | 81% | 81% |
| 2 | 80% | 75% |
| 3 | 89% | 89% |[GPU0] end of epoch 16 [ 399s]| train loss: 1.7163 | valid loss: 3.7508 | valid bleu_score 26.38
stop from early stopping.
基本上每个epoch快了个30秒左右,可以明显的看到第一个epoch训练大概用了3分57秒,但推理时只用了2分50秒左右,并且比上篇文章省了一个epoch。
注意,这里为了性能,虽然设置了随机种子,但并不是完全确定的,即每次结果可能稍微有点不同,如果想实现完全可复现,可参考 https://pytorch.org/docs/stable/notes/randomness.html 。