paper:Understanding the effects of RLHF on LLM generalisation and diversity
0 背景知识
标准的RLHF
finetuning pipeline一般包含3个阶段:
- supervised fine-tuning (
SFT
)。对预训练的模型进行用language modeling的方式进行微调。 - reward modeling (
RM
)。对预训练模型用predict human preferences的方式微调。 - reinforcement learning (
RL
)。结合上述两个模型用on-policy RL算法(如PPO
)进行微调。
1 Motivation
虽然目前RLHF(reinforcement learning from human feedback)成为LLM
训练不可缺少的一部分。但目前并没有详细研究RLHF
到底对LLM
哪一方面有益 or 有害。为了提升对RLHF
不同阶段收益的认知,本文从实验上系统探究了RLHF
的三个阶段supervised fine-tuning (SFT
), reward modeling(RW
), RLHF
对LLM
泛化性(generalisation)和生成多样性(diversity)的影响。
2 实验设置
实验的基座模型:LLaMA7B和OPT
不同stage的训练方法:
stage | 训练方法 |
---|---|
SFT | 给定input-ouput pair用language modeling的方式进行微调。input是context,cross-entry loss作用做output。 |
RW | 用(Stiennon et al., 2022).的方法进行训练。 |
RLHF | 用PPO算法训练 |
由于reward-model(RW)的输出是一个score,作者用best-of-N(BoN
)来作为RW的输出。简单来说BoN
就是:先从SFT
模型生成N个结果,然后用RW模型对这N个结果进行打分,取分数最大的一个。作者的实验设置 N = 16 N=16 N=16,采样的temperature为 0.7 0.7 0.7。
数据集
作者在两类任务评估generalisation和diversity。其中in-domain数据都用来训练。
in-domain | out of domain | |
---|---|---|
summary | TL;DR dataset | CNN/DailyMail |
Instruction Following | AlpacaFarm | Alpaca Eval/0 Sequential Instructions |
3 Result
3.1 generalisation的评测结果
作者用GPT4来评估summary的好坏。用到的prompt如下所示,简单来说就是同时将输入、参考答案、不同模型的summary放入到prompt中,让GPT4对结果进行排序。
<|im_start|>system You are a helpful assistant, that ranks models by the quality of their answers.
<|im_end|>
<|im_start|>user
Which of the following summaries does a better job of summarizing the most important points in the given news article, without including unimportant or irrelevant details? A good summary is both precise and concise.
Post: """{instruction}"""
Summary A: { "model": "model_1", "summary": """{output_1}""" }
Summary B: { "model": "model_2", "summary": """{output_2}""" }
Now please rank the models by the quality of their summaries, so that the model with rank 1 has the best summary. Then return a list of the model names and ranks, i.e., produce the following output:
[ {’model’: <model-name>, ’rank’: <model-rank>}, {’model’: <model-name>, ’rank’: <model-rank>}
]
Your response must be a valid Python dictionary and should contain nothing else because we will directly execute it in Python. Please provide the ranking that the majority of humans would give.
<|im_end|>
3.1.1 summary任务的generalisation能力评估
从结果可见summary任务的generalisation的排序是:
in domain dataset(TL;DR): BoN > RLHF > SFT
out of domain dataset: BoN > RLHF > SFT
从中可以看出做了RLHF能有效提升SFT模型的泛化能力,但还比不上基于RW的BoN。其实挺好奇RLHF的BoN指标,作者并没有给出。
3.1.2 Instruction Following任务的generalisation能力评估
从结果可见,不论是in-domain还是out-of-domain数据集generalisation的排序都是RLHF>BoN>SFT,明显感受到RLHF对指令的理解更具优势。
3.2 diversity的评测结果
作者用distinct N-grams、Sentence-BERT embedding cosine similarity和NLI diversity作为diversity的评估指标。简单介绍一下:
指标名称 | 计算方式 |
---|---|
distinct N-grams | 计算中不重复n-gram与所有n-gram的比值。文中作者分别计算n=1,2,…,5再对结果取平均 |
Sentence-BERT embedding cosine similarity | 分别将同一个prompt的多个输出,送入到sentence-BERT中提取embedding。两两计算embedding的相似性,随后用1减去作为diversity的评估。 |
NLI(natural language inference) diversity | 分别将同一个prompt的多个输出组成pair送入到NLI模型中来预测是entailments还是contradictions。contradiction越多说明多样性越强 |
作者只在summary任务中评估diversity能力。这是因为作者在Instruction Following任务中没有看到meaningful difference。作者猜测,这是由于设置的diversity指标偏好相对短的模型输出,但Instruction Following的输出都相对较长,并且SFT和RLHF的输出长度偏好也不同(RLHF倾向更长的输出),导致计算的diversity不够准确。
作者原话
- We ran some initial experiments evaluating diversity for the instruction-following models, but we did not see any meaningful differences. We hypothesise this is due to the diversity metrics we use being designed for settings where the model output is relatively short (e.g. a single sentence), whereas in the instruction-following setting outputs are generally much longer.
3.2.1 summary任务的diversity能力评估
实验用了N个prompt,每个prompt生成K条数据。(N=500, K=16)
指标进一步定义:
- EAD称之为syntactic diversity,其综合考虑上述3个指标
- Sent BERT称之为semantic diversity,只考虑Sentence-BERT embedding cosine similarity
- NLI称之为logical diversity,只考虑NLI指标。
统计方法:
- Per-input diversity:所有的样本的指标取平均。
- cross-input diversity: 每个N取最大的值进行统计。
从上述结果可知:
经过RLHF微调后,多样性会降低。
小结
- SFT后用RW+BoN进行推理多样性和泛化性都较好。但需要推理多次,开销大。
- RLHF的泛化性能比SFT好,但多样性会有所降低。