诸神缄默不语-个人CSDN博文目录
开放式文本生成会偏好采样方法。
由于我要下班了,所以本文主要就写了第五节。别的内容请大家参考第六节给出的参考资料。
文章目录
- 1. 贪心搜索
- 2. beam search
- 3. top-k sampling
- 4. top-p sampling
- 5. 代码实践:transformers的generate()函数
- 6. 参考资料
1. 贪心搜索
2. beam search
3. top-k sampling
4. top-p sampling
5. 代码实践:transformers的generate()函数
我用GPT-2写了一篇代码示例,介绍了不同解码策略的实现方案。可参考:https://github.com/PolarisRisingWar/all-notes-in-one/blob/main/decode_examples_in_GPT2.ipynb
(注意pad_token_id那个只有GPT-2需要,别的模型大多不需要)
这个代码我没怎么写注释,如果有人有需要的话我以后补上吧。
generate()
生成的结果经decode()
即可得到原始文本(output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
)
贪心搜索:greedy_output = model.generate(input_ids, max_length=50)
入参:
- max_length:包括context的总长度
- max_new_tokens
- num_beams:beam search的beam数
- early_stopping:设置num_beams时可用
- no_repeat_ngram_size:设置禁止ngram重复出现
- num_return_sequences:返回多少个序列
- do_sample:根据概率采样
- top_k:对概率前k的token的概率重新归一化
- top_p:对大于p的token的概率重新归一化
- temperature:降低temperature会极化token概率,导致抽样随机性减小(当temperature→0时,抽样策略趋近于贪心搜索)
- min_length 用于强制模型在达到 min_length 之前不生成 EOS。这在摘要场景中使用得比较多,但如果用户想要更长的文本输出,也会很有用。
- repetition_penalty 可用于对生成重复的单词这一行为进行惩罚。它首先由 Keskar 等人 (2019) 引入,在 Welleck 等人 (2019) 的工作中,它是训练目标的一部分。它可以非常有效地防止重复,但似乎对模型和用户场景非常敏感,其中一个例子见 Github 上的 讨论。
- attention_mask 可用于屏蔽填充符。
- pad_token_id、bos_token_id、eos_token_id: 如果模型默认没有这些 token,用户可以手动选择其他 token id 来表示它们。
6. 参考资料
- 如何生成文本:通过 Transformers 用不同的解码方法生成文本:我感觉这篇非常浅显易懂,建议直接看这篇。如果有读者有需要或者我认为有需要的话,我后期可能会用我的叙述思路再讲一遍这几种文本解码策略,但我觉得这篇已经很直白了。