根据输入的prompt,生成一段指定长度的文字。Llama跑起来太慢了,这里用GPT-2作为列子。
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torchtokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)prompt_text = "This is a nice story that makes me"
max_gen_len = 9
input_ids = tokenizer.encode(prompt_text, return_tensors="pt")
prompt_len = input_ids.shape[-1]
print(f'length of prompt: {prompt_len}, length of generation: {max_gen_len}')print('>>> Way 1: Use `model.generate()` to generate tokens with KV cache')
generated_ids = model.generate(input_ids, max_length=prompt_len+max_gen_len, pad_token_id=tokenizer.eos_token_id)
print('generated_ids:', generated_ids)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print('generated_text:', generated_text)print('>>> Way 2: Use `for loop` to generate tokens with KV cache')
past_key_values = None
print('Prefill Stage..')
outputs = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True)
past_key_values = outputs.past_key_values
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
generated_ids = [pred_token_idx.item()]
print('Decoding/Generating Stage..')
for _ in range(max_gen_len - 1):outputs = model(input_ids=pred_token_idx, past_key_values=past_key_values, use_cache=True)past_key_values = outputs.past_key_values # if use_cache=False, past_key_values=Nonepred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)generated_ids.append(pred_token_idx.item())
print('generated_ids:', generated_ids)
generated_text = tokenizer.decode(torch.Tensor(generated_ids), skip_special_tokens=True)
print('generated_text:', prompt_text + generated_text)
这里提供了两种方法实现文本生成:
- model.generate():给模型输入prompt,一次性得到所有输出的token,最方便的写法
- for loop:这是StreamingLLM中给的代码例子,也揭示了自回归生成的原理。首先是prefill阶段,输入prompt,得到KV cache和生成的第一个token;然后是decoding/generating阶段,开始自回归生成token,每次生成的模型输入是当前新token和KV cache,每生成一个token都会自动更新KV cache
最终,可以看到两种方法生成的文本是一模一样的:
进一步探究自回归过程中维度的变化:
这个就是标准的自回归生成任务了,不管是GPT还是Llama,都是如此(至少PyTorch版本都是这样的,Flax版本的KV cache有点奇怪,用的lax.dynamic_update_slice(cached_key.value, key, indices),KV cache的维度并没有随着token的生成而增加…不太明白)。