STaR(Self-Taught Reasoner)方法:让语言模型自学推理能力
在大型语言模型(LLM)的推理能力优化中,STaR(Self-Taught Reasoner) 是一种引人注目的技术,属于“修改提议分布(Modifying Proposal Distribution)”类别。与传统的基于结果验证(verifier)方法不同,STaR通过训练模型生成更好的推理步骤(input-focused),直接调整采样分布,使其倾向于选择“推理相关”的token。本文将详细介绍STaR的原理、工作流程,并提供一个可运行的Python代码实现,帮助你理解和实践这一方法。
参考:https://newsletter.maartengrootendorst.com/p/a-visual-guide-to-reasoning-llms
1. STaR的原理
背景
传统的LLM生成方法通常依赖贪婪解码(选择最高概率token)或随机采样,但这些方法可能无法生成逻辑严谨的推理步骤。STaR通过让模型自生成推理数据并进行监督微调(Supervised Fine-Tuning),优化其推理能力,调整token的提议分布,使其更倾向于推理过程。
核心思想
- 自生成推理数据:模型首先生成推理步骤和答案。
- 验证与修正:
- 如果答案正确,直接将推理步骤和答案加入训练数据集。
- 如果答案错误,提供正确答案作为“提示”,让模型重新推理并生成正确过程。
- 监督微调:用生成的数据集训练模型,强化其推理行为。
目标
- 输入聚焦:通过修改提议分布,使模型更擅长生成推理相关token,而非简单输出结果。
- 自增强:利用模型自身生成的数据,无需大量人工标注。
2. STaR的工作流程
STaR的核心是一个循环过程,包含以下步骤:
-
生成推理步骤和答案:
- 模型根据问题生成推理路径和最终答案。
-
验证答案:
- 正确(2a):推理和答案正确,进入步骤3b。
- 错误(2b):答案错误,进入步骤4b。
-
正确答案处理(3b):
- 将问题、推理步骤、答案组成三元组,加入训练数据集。
-
错误答案修正(4b):
- 提供正确答案作为提示,要求模型重新生成推理步骤。
- 将修正后的推理加入训练数据集。
-
监督微调(5):
- 使用生成的三元组数据集,对模型进行监督微调,优化推理能力。
关键特点
- 合成数据:STaR通过自生成数据创建训练样本,类似于数据蒸馏。
- 迭代改进:多次循环生成和微调,逐步提升模型性能。
3. 代码实现
以下是一个简化的STaR实现,基于PyTorch。我们模拟一个数学推理任务(如“2 + 3 = ?”),展示其工作流程。
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy# 超参数
vocab_size = 10 # 词汇表大小(0-9数字)
embed_size = 16
num_heads = 2
hidden_size = 32
num_layers = 2
max_steps = 3 # 最大推理步骤# 生成模型
class SimpleReasoner(nn.Module):def __init__(self):super(SimpleReasoner, self).__init__()self.embedding = nn.Embedding(vocab_size, embed_size)self.transformer = nn.TransformerDecoderLayer(embed_size, num_heads, hidden_size)self.output_layer = nn.Linear(embed_size, vocab_size)def forward(self, x):x = self.embedding(x)x = self.transformer(x, x)return self.output_layer(x)def generate(self, prompt, max_len=3, temperature=1.0):seq = prompt.copy()inputs = torch.tensor([seq], dtype=torch.long).to(device)for _ in range(max_len - len(seq)):logits = self.forward(inputs)[:, -1, :]probs = F.softmax(logits / temperature, dim=-1)next_token = torch.multinomial(probs, 1).item()seq.append(next_token)inputs = torch.tensor([seq], dtype=torch.long).to(device)return seqdef train_step(self, data, optimizer):self.train()optimizer.zero_grad()inputs = torch.tensor([d[0] + d[1][:-1] for d in data], dtype=torch.long).to(device)targets = torch.tensor([d[1] for d in data], dtype=torch.long).to(device)logits = self.forward(inputs)loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))loss.backward()optimizer.step()return loss.item()# STaR实现
class STaR:def __init__(self, model):self.model = modelself.device = next(model.parameters()).devicedef generate_reasoning(self, prompt, correct_answer=None):if correct_answer is None:return self.model.generate(prompt, max_steps)# 提供正确答案作为提示hint_prompt = prompt + [correct_answer]return self.model.generate(hint_prompt, max_steps)def verify_answer(self, sequence, correct_answer):return sequence[-1] == correct_answerdef star_iteration(self, prompts, correct_answers, iterations=3):training_data = []for _ in range(iterations):new_model = deepcopy(self.model) # 保存当前模型状态optimizer = torch.optim.Adam(new_model.parameters(), lr=0.001)for prompt, correct_answer in zip(prompts, correct_answers):# 步骤1:生成推理步骤和答案sequence = self.generate_reasoning(prompt)# 步骤2:验证答案if self.verify_answer(sequence, correct_answer):# 步骤3b:正确答案加入训练数据training_data.append((prompt, sequence))else:# 步骤4b:错误答案,提供提示重新生成corrected_sequence = self.generate_reasoning(prompt, correct_answer)training_data.append((prompt, corrected_sequence))# 步骤5:监督微调if training_data:loss = new_model.train_step(training_data, optimizer)print(f"Iteration {_+1}, Loss: {loss}")self.model = new_model # 更新模型return training_data# 初始化并运行
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleReasoner().to(device)
star = STaR(model)# 示例数据
prompts = [[2, 3]] # "2 + 3"
correct_answers = [5]# 执行STaR
training_data = star.star_iteration(prompts, correct_answers, iterations=3)
print("Generated training data:", training_data)# 测试优化后的模型
test_prompt = [2, 3]
result = model.generate(test_prompt)
print(f"Test prompt: {test_prompt}, Generated result: {result}")
4. 代码解析
生成模型(SimpleReasoner)
generate
:根据提示生成推理序列,模拟推理步骤和答案。train_step
:使用监督微调优化模型,输入为问题+推理步骤,目标为完整序列。
STaR实现
generate_reasoning
:- 无提示时:自由生成推理。
- 有提示时:基于正确答案生成推理。
verify_answer
:检查生成序列的最后一个token是否正确。star_iteration
:- 步骤1:生成推理和答案。
- 步骤2a/2b:验证答案,正确则直接记录,错误则用提示修正。
- 步骤3b/4b:收集三元组(问题、推理、答案)。
- 步骤5:用生成的数据微调模型。
运行逻辑
- 每次迭代生成数据,优化模型,逐步提高推理能力。
- 使用
deepcopy
保留模型状态,确保迭代独立。
5. 运行结果示例
运行代码可能得到:
Iteration 1, Loss: 2.305
Iteration 2, Loss: 2.287
Iteration 3, Loss: 2.251
Generated training data: [([2, 3], [2, 3, 5]), ([2, 3], [2, 3, 5]), ([2, 3], [2, 3, 5])]
Test prompt: [2, 3], Generated result: [2, 3, 5]
- 未训练模型初始生成随机,STaR通过微调逐渐倾向于正确答案
[2, 3, 5]
。 - 实际中需更多数据和迭代。
6. STaR的意义与改进
意义
- 自增强:无需大量人工数据,模型自生成训练样本。
- 推理优化:调整提议分布,强化推理token的选择。
- 数据蒸馏:生成合成数据,可用于其他模型训练。
改进方向
- 多样化提示:增加问题类型(如数学、自然语言问答)。
- 奖励函数:引入PRM评估推理步骤质量,而非仅验证答案。
- 迭代控制:动态调整迭代次数或数据筛选标准。
- 预训练模型:基于已有LLM(如GPT)实现,提升初始性能。
7. 总结
STaR通过自生成推理数据和监督微调,优化LLM的推理能力。其流程从生成到验证再到修正,利用合成数据调整token分布,是“修改提议分布”的典型方法。代码实现展示了从 [2, 3]
到 [2, 3, 5]
的优化过程,体现了其核心思想。运行这段代码,你可以体验STaR的自学过程。希望这篇博客对你理解和实践STaR有所帮助!如需进一步优化,欢迎讨论。
基于大型语言模型改进 STaR 方法:以 LLaMA 3 或 Qwen 2.5 为例
在之前的STaR(Self-Taught Reasoner)实现中,我们使用了一个简化的模型来展示其工作原理。然而,为了在实际任务中获得更好的推理能力,可以基于Hugging Face(HF)上的预训练大型语言模型(LLM)如 LLaMA 3 或 Qwen 2.5 进行改进。本文将以中文博客的形式,结合改进方向(多样化提示、奖励函数、迭代控制、预训练模型),详细说明如何基于这些HF模型优化STaR,并提供改进后的代码实现。
1. 改进背景与目标
原始实现局限
- 模型能力:
SimpleReasoner
未经过预训练,生成随机且缺乏推理能力。 - 提示单一:仅支持简单数学任务。
- 奖励简单:仅验证答案,未评估推理质量。
- 静态迭代:固定次数,缺乏灵活性。
改进目标
- 预训练模型:利用LLaMA 3或Qwen 2.5的强大语言理解能力。
- 多样化提示:支持数学和自然语言问答。
- 奖励函数:引入过程奖励模型(PRM)评估推理步骤。
- 迭代控制:动态调整迭代次数和数据筛选。
2. 改进方案
1. 基于预训练模型:LLaMA 3 或 Qwen 2.5
- 选择原因:
- LLaMA 3:高效、适合微调,广泛用于研究。
- Qwen 2.5:开源,支持多语言,推理能力强。
- 实现:使用Hugging Face的
transformers
库加载预训练模型,替换SimpleReasoner
。
2. 多样化提示
- 数学任务:如“2 + 3 = ?”。
- 自然语言问答:如“中国的首都是哪里?”。
- 方法:扩展输入格式,支持文本和符号混合。
3. 奖励函数:引入PRM
- 目的:评估推理步骤的逻辑性和完整性,而非仅答案。
- 实现:使用一个小型预训练模型(如BERT)作为PRM,评分推理质量。
4. 迭代控制
- 动态调整:根据数据质量或损失收敛动态停止迭代。
- 数据筛选:仅保留高质量推理样本。
3. 改进后的代码实现
以下基于 Qwen 2.5(也可替换为LLaMA 3)的STaR实现,展示改进后的完整流程。
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
from copy import deepcopy
import random# 超参数
max_steps = 50 # 最大生成长度
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 初始化生成模型(Qwen 2.5)
model_name = "Qwen/Qwen2.5-7B-Instruct" # 可替换为 "meta-llama/Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
generator = AutoModelForCausalLM.from_pretrained(model_name).to(device)# 初始化PRM(使用BERT评估推理质量)
prm_name = "bert-base-uncased"
prm_tokenizer = AutoTokenizer.from_pretrained(prm_name)
prm_model = AutoModelForSequenceClassification.from_pretrained(prm_name, num_labels=1).to(device)# STaR实现
class STaR:def __init__(self, generator, tokenizer, prm_model, prm_tokenizer):self.generator = generatorself.tokenizer = tokenizerself.prm_model = prm_modelself.prm_tokenizer = prm_tokenizerdef generate_reasoning(self, prompt, correct_answer=None, temperature=0.7):"""生成推理步骤和答案"""if correct_answer is None:input_text = f"问题: {prompt}\n推理步骤和答案:"else:input_text = f"问题: {prompt}\n正确答案: {correct_answer}\n请提供推理步骤:"inputs = self.tokenizer(input_text, return_tensors="pt").to(device)outputs = self.generator.generate(**inputs, max_length=max_steps, temperature=temperature,do_sample=True, pad_token_id=self.tokenizer.eos_token_id)return self.tokenizer.decode(outputs[0], skip_special_tokens=True)def verify_answer(self, response, correct_answer):"""验证答案是否正确"""answer_part = response.split("答案:")[-1].strip()return str(correct_answer) in answer_partdef evaluate_reasoning(self, response):"""使用PRM评估推理质量"""inputs = self.prm_tokenizer(response, return_tensors="pt", truncation=True, max_length=512).to(device)with torch.no_grad():score = self.prm_model(**inputs).logits.item()return score # 返回正值表示推理质量def star_iteration(self, prompts, correct_answers, max_iterations=5, min_loss=0.1):training_data = []model = deepcopy(self.generator)optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)for iteration in range(max_iterations):new_data = []total_loss = 0.0for prompt, correct_answer in zip(prompts, correct_answers):# 步骤1:生成推理和答案response = self.generate_reasoning(prompt)# 步骤2:验证答案if self.verify_answer(response, correct_answer):# 步骤3b:正确答案,检查推理质量score = self.evaluate_reasoning(response)if score > 0.5: # 筛选高质量推理new_data.append((prompt, response))else:# 步骤4b:错误答案,提供提示重新生成corrected_response = self.generate_reasoning(prompt, correct_answer)score = self.evaluate_reasoning(corrected_response)if score > 0.5:new_data.append((prompt, corrected_response))# 步骤5:监督微调if new_data:model.train()optimizer.zero_grad()inputs = self.tokenizer([d[0] + "\n" + d[1] for d in new_data], return_tensors="pt", padding=True, truncation=True, max_length=max_steps).to(device)labels = inputs["input_ids"].clone()outputs = model(**inputs, labels=labels)loss = outputs.lossloss.backward()optimizer.step()total_loss += loss.item()training_data.extend(new_data)print(f"Iteration {iteration+1}, Loss: {total_loss / len(new_data) if new_data else 0}")if total_loss / len(new_data) < min_loss and new_data:breakself.generator = modelreturn training_data# 示例数据
prompts = ["2 + 3等于多少?","中国的首都是哪里?"
]
correct_answers = ["5", "北京"]# 初始化STaR
star = STaR(generator, tokenizer, prm_model, prm_tokenizer)# 执行STaR
training_data = star.star_iteration(prompts, correct_answers)
print("Generated training data:", training_data)# 测试优化后的模型
for prompt in prompts:result = star.generate_reasoning(prompt)print(f"Prompt: {prompt}, Generated result: {result}")
4. 代码解析
1. 预训练模型:Qwen 2.5
- 加载:使用
AutoModelForCausalLM
加载Qwen 2.5,替换简化的SimpleReasoner
。 - 生成:
generate_reasoning
使用model.generate
支持多样化提示,生成文本而非token序列。 - 优势:Qwen 2.5 已具备语言理解能力,初始生成更接近推理。
2. 多样化提示
- 输入格式:
- 数学:
"2 + 3等于多少?\n推理步骤和答案:"
。 - 问答:
"中国的首都是哪里?\n推理步骤和答案:"
。
- 数学:
- 输出:支持自然语言,生成完整句子,如“推理:2加3等于5,答案:5”。
3. 奖励函数:PRM
- 实现:使用BERT作为PRM,评分推理文本的逻辑性。
- 筛选:
score > 0.5
保留高质量推理,避免噪声数据。 - 改进:可训练BERT区分正确推理(如“2+3=5”)和错误推理(如“2*3=5”)。
4. 迭代控制
- 动态停止:若损失低于
min_loss
(如0.1),提前终止。 - 数据筛选:结合PRM分数,确保训练数据质量。
5. 运行结果示例
运行代码可能得到:
Iteration 1, Loss: 0.85
Iteration 2, Loss: 0.62
Iteration 3, Loss: 0.09
Generated training data: [('2 + 3等于多少?', '问题: 2 + 3等于多少?\n推理步骤和答案: 首先,2加上3,等于5。\n答案: 5'),('中国的首都是哪里?', '问题: 中国的首都是哪里?\n推理步骤和答案: 中国是一个国家,其首都是北京。\n答案: 北京')
]
Prompt: 2 + 3等于多少?, Generated result: 问题: 2 + 3等于多少?\n推理步骤和答案: 首先,2加上3,等于5。\n答案: 5
Prompt: 中国的首都是哪里?, Generated result: 问题: 中国的首都是哪里?\n推理步骤和答案: 中国是一个国家,其首都是北京。\n答案: 北京
- 结果:Qwen 2.5初始生成已较合理,微调后更倾向推理。
6. 基于LLM的改进优势
预训练能力
- Qwen 2.5 或 LLaMA 3 自带语言理解和生成能力,初始推理质量高于随机模型。
- STaR在此基础上进一步强化推理分布。
多样化支持
- 处理文本输入,支持数学和问答任务,扩展性强。
PRM增强
- BERT作为PRM评估推理逻辑,确保生成数据不仅是正确答案,还包含合理步骤。
动态优化
- 损失收敛后停止,节省计算资源。
7. 进一步优化建议
- 更大模型:使用LLaMA 3-70B或Qwen 2.5-72B,提升推理深度。
- 混合奖励:结合PRM和答案正确性(ORM),综合评分。
- 数据蒸馏:将STaR生成的数据用于其他模型(如小规模LLM)的训练。
8. 总结
基于Qwen 2.5的STaR改进,利用预训练LLM的强大能力,支持多样化提示,通过PRM优化推理质量,并动态控制迭代。代码展示了从数学到问答的推理生成,体现了“修改提议分布”的核心思想。运行此代码,你可以体验基于HF模型的STaR优化过程。希望这篇博客对你有所帮助!如需调整或扩展,欢迎讨论。
解析 STaR 中 star_iteration
的逐迭代训练设计
提出疑问:为什么训练是每个iteration都要进行,而不是将所有数据处理好后再进行一次训练?下面详细解析这种逐迭代训练的设计动机,分析其优劣势,并探讨替代方案。
1. 逐迭代训练的背景
STaR的核心思想
STaR(Self-Taught Reasoner)是一种自监督方法,通过让模型生成推理数据并进行监督微调(Supervised Fine-Tuning),优化其推理能力。其流程本质上是一个迭代改进的过程:
- 模型基于当前参数生成推理和答案。
- 验证答案,收集正确或修正后的数据。
- 用生成的数据微调模型。
- 重复上述步骤。
代码中的训练位置
- 每次迭代内训练:在每个
for iteration in range(max_iterations)
循环中,生成new_data
后立即调用model.train_step
进行微调。 - 累计数据:
training_data.extend(new_data)
将每次迭代的数据加入总数据集,但训练发生在每次迭代结束时。
2. 为什么每个Iteration都要训练?
1. 动态优化模型分布
- 提议分布的修改:
- STaR的目标是调整模型的token提议分布,使其倾向于生成推理相关的内容。
- 每次迭代后,模型参数通过微调更新,下一次生成会基于更优的分布。
- 逐次改进:
- 如果不训练,模型在所有迭代中都使用初始参数,生成的推理质量可能持续较差。
- 每次训练后,模型更可能生成正确的推理步骤,逐步提升数据质量。
2. 自增强反馈循环
- 自生成数据:
- STaR依赖模型自身生成训练数据,每次迭代的
new_data
是当前模型能力的反映。 - 训练后,模型能力提升,下次生成的
new_data
更接近期望的推理模式。
- STaR依赖模型自身生成训练数据,每次迭代的
- 反馈效应:
- 类似强化学习,每次迭代强化模型的推理行为,形成正反馈。
3. 数据质量的逐步提高
- 初始数据可能较差:
- 未训练模型生成的推理可能随机或错误(如
[2, 3, 1]
)。 - 第一次训练后,模型学会部分正确模式(如
[2, 3, 5]
),后续数据更优质。
- 未训练模型生成的推理可能随机或错误(如
- 避免积累噪声:
- 若等到最后训练,可能积累大量低质量数据,影响微调效果。
4. 计算资源与时间优化
- 小批量训练:
- 每次迭代只处理当前生成的
new_data
(如2个样本),训练负担轻。 - 若积累所有数据再训练,可能需要更大批量或更多epoch,增加内存和时间成本。
- 每次迭代只处理当前生成的
- 提前终止:
if total_loss / len(new_data) < min_loss:
允许在损失收敛时停止,无需完成所有迭代。
代码中的体现
- 训练时机:
if new_data:model.train()optimizer.zero_grad()# ... 微调代码 ...optimizer.step()
- 每次迭代立即训练,确保模型实时更新。
3. 模拟过程
任务
prompts = ["2 + 3等于多少?"]
,correct_answers = ["5"]
。- ( max_iterations = 3 \text{max\_iterations} = 3 max_iterations=3 )。
第一次迭代
- 生成:
response = "问题: 2 + 3等于多少?\n推理和答案: 2 * 3 = 6\n答案: 6"
。 - 验证:错误。
- 修正:
corrected_response = "问题: 2 + 3等于多少?\n正确答案: 5\n推理: 2 + 3 = 5"
。 - 数据:
new_data = [("2 + 3等于多少?", corrected_response)]
。 - 训练:微调模型,更新参数。
第二次迭代
- 生成:
response = "问题: 2 + 3等于多少?\n推理和答案: 2 + 3 = 5\n答案: 5"
(因训练改进)。 - 验证:正确,
score > 0.5
。 - 数据:
new_data = [("2 + 3等于多少?", response)]
。 - 训练:进一步强化正确推理。
第三次迭代
- 生成:更稳定的正确推理。
- 数据:累计高质量样本。
- 训练:继续优化。
对比假设
- 若最后训练:
- 第一次:
[2, 3, 6]
。 - 第二次:
[2, 3, 1]
(仍随机)。 - 第三次:
[2, 3, 4]
。 - 最后训练可能因数据混杂,效果不佳。
- 第一次:
4. 逐迭代训练的优势与劣势
优势
- 实时反馈:每次迭代优化模型,提升后续生成质量。
- 数据质量递增:避免积累低质量数据。
- 灵活终止:损失收敛时停止,节省资源。
劣势
- 计算开销:频繁训练增加总计算时间。
- 模型不稳定:小批量训练可能导致参数波动。
- 实现复杂性:需管理每次迭代的模型副本(如
deepcopy
)。
5. 为何不等到所有数据处理好再训练?
替代方案的问题
假设修改为收集所有数据后一次性训练:
def star_iteration(self, prompts, correct_answers, max_iterations=5):training_data = []for _ in range(max_iterations):for prompt, correct_answer in zip(prompts, correct_answers):response = self.generate_reasoning(prompt)if self.verify_answer(response, correct_answer):if self.evaluate_reasoning(response) > 0.5:training_data.append((prompt, response))else:corrected_response = self.generate_reasoning(prompt, correct_answer)if self.evaluate_reasoning(corrected_response) > 0.5:training_data.append((prompt, corrected_response))# 一次性训练if training_data:model = deepcopy(self.generator)optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)loss = model.train_step(training_data, optimizer) # 假设支持多epochself.generator = modelreturn training_data
问题分析
-
数据质量不一致:
- 所有迭代使用初始模型,生成的
training_data
可能包含大量错误或低质量推理。 - 无法利用中间训练的改进。
- 所有迭代使用初始模型,生成的
-
缺乏反馈:
- 模型未在迭代中更新,每次生成无进步,可能浪费计算资源。
-
训练负担:
- 一次性处理大量数据需更多epoch或更高计算资源,可能超出现有硬件能力。
-
STaR目标偏离:
- STaR强调自增强循环,逐迭代训练是其核心机制,最后训练削弱了这一特性。
6. 改进建议
折中方案
- 批次训练:每隔几轮迭代训练一次,平衡反馈与效率:
if new_data and iteration % 2 == 0: # 每2轮训练一次model.train_step(new_data, optimizer)
动态调整
- 自适应迭代:根据数据质量(如PRM分数)调整训练频率。
- 增量数据:仅训练新增数据,避免重复计算。
7. 总结
STaR中逐迭代训练的设计是为了:
- 动态优化:实时更新模型,提升每次生成的质量。
- 自增强:形成反馈循环,逐步强化推理能力。
- 效率:小批量训练结合提前终止,适应资源限制。
相比之下,所有数据处理后再训练可能导致数据质量低、缺乏反馈,违背STaR的自适应优化目标。代码中的逐迭代训练是其核心优势的体现。
后记
2025年3月2日16点43分于上海,在grok3大模型辅助下完成。