【复现DeepSeek-R1之Open R1实战】系列6:GRPO源码逐行深度解析(上)

目录

  • 4 GRPO源码分析
    • 4.1 数据类 `GRPOScriptArguments`
    • 4.2 系统提示字符串 `SYSTEM_PROMPT`
    • 4.3 奖励函数
      • 4.3.1 accuracy_reward函数
      • 4.3.2 verify函数
      • 4.3.3 format_reward函数
    • 4.4 将数据集格式化为对话形式
    • 4.5 初始化GRPO Trainer


【复现DeepSeek-R1之Open R1实战】系列3:SFT和GRPO源码逐行深度解析(上)
【复现DeepSeek-R1之Open R1实战】系列5:SFT和GRPO源码逐行深度解析(中)

4 GRPO源码分析

前面两篇博文已经详细介绍了一些基础知识和SFT源码,本文继续解读GRPO源码。与SFT源码差不多的部分,我们就不展开细说了,这里只解析GRPO独特的部分。

4.1 数据类 GRPOScriptArguments

该类使用了 Python 的 dataclass 装饰器,这是一种简化类定义的方式,特别是对于那些主要用来存储数据的类。它继承自 ScriptArguments 类。

  • reward_funcs: 这是一个列表,包含了一系列可能的奖励函数名称,默认值为 ["accuracy", "format"]。这些奖励函数可能是用于评估模型性能的不同标准。

    reward_funcs: list[str] = field(default_factory=lambda: ["accuracy", "format"],metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'"},
    )
    
  • cosine_min_value_wrongcosine_max_value_wrong: 分别表示错误答案在余弦相似度尺度上的最小和最大奖励值,默认分别为 0.0-0.5

  • cosine_min_value_correctcosine_max_value_correct: 分别表示正确答案在余弦相似度尺度上的最小和最大奖励值,默认分别为 0.51.0

  • cosine_max_len: 表示余弦相似度尺度的最大长度,默认值为 1000

  • repetition_n_grams: 表示用于重复惩罚奖励的n-gram数量,默认值为 3

  • repetition_max_penalty: 表示重复惩罚奖励的最大负值,默认值为 -1.0

每个字段都使用了 field() 函数来定义其默认值和元数据(如帮助信息)。这有助于工具和库更好地理解和处理这些字段,例如生成命令行解析器时。

4.2 系统提示字符串 SYSTEM_PROMPT

SYSTEM_PROMPT = ("A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant ""first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning ""process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., ""<think> reasoning process here </think><answer> answer here </answer>"
)

字符串描述了一个对话场景,用户先提问,助手首先思考推理过程,然后提供答案。推理过程和答案分别用 <think><answer> 标签包裹,这种格式化有助于区分和识别不同的部分,和DeepSeek-R1的思考过程格式一致。

4.3 奖励函数

奖励函数的定义如下,GRPO默认用到了accuracy_reward和format_reward这两个函数。

# Get reward functionsREWARD_FUNCS_REGISTRY = {"accuracy": accuracy_reward,"format": format_reward,"reasoning_steps": reasoning_steps_reward,"cosine": get_cosine_scaled_reward(min_value_wrong=script_args.cosine_min_value_wrong,max_value_wrong=script_args.cosine_max_value_wrong,min_value_correct=script_args.cosine_min_value_correct,max_value_correct=script_args.cosine_max_value_correct,max_len=script_args.cosine_max_len,),"repetition_penalty": get_repetition_penalty_reward(ngram_size=script_args.repetition_n_grams,max_penalty=script_args.repetition_max_penalty,),"length": len_reward,}reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

这段代码定义了一个奖励函数注册表 REWARD_FUNCS_REGISTRY,并根据用户提供的配置动态生成一个奖励函数列表 reward_funcs。每个奖励函数用于评估模型输出的不同方面,如准确性、格式、推理步骤等。

  1. 注册表定义
  • accuracy: 使用 accuracy_reward 函数评估模型输出的准确性。
  • format: 使用 format_reward 函数评估模型输出的格式。
  • reasoning_steps: 使用 reasoning_steps_reward 函数评估模型输出的推理步骤。
  • cosine: 使用 get_cosine_scaled_reward 函数计算余弦相似度奖励,参数包括:
    • min_value_wrong: 错误情况下的最小值。
    • max_value_wrong: 错误情况下的最大值。
    • min_value_correct: 正确情况下的最小值。
    • max_value_correct: 正确情况下的最大值。
    • max_len: 最大长度。
  • repetition_penalty: 使用 get_repetition_penalty_reward 函数计算重复惩罚奖励,参数包括:
    • ngram_size: n-gram 的大小。
    • max_penalty: 最大惩罚值。
  • length: 使用 len_reward 函数评估模型输出的长度。
  1. 动态生成奖励函数列表
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
  • 根据 script_args.reward_funcs 中指定的奖励函数名称,从 REWARD_FUNCS_REGISTRY 中获取相应的奖励函数,并生成一个列表 reward_funcs

4.3.1 accuracy_reward函数

该函数用于计算模型生成的补全与真实答案之间的准确性奖励。它通过解析和验证生成的内容与真实答案来确定奖励值。

def accuracy_reward(completions, solution, **kwargs):"""Reward function that checks if the completion is the same as the ground truth."""contents = [completion[0]["content"] for completion in completions]rewards = []for content, sol in zip(contents, solution):gold_parsed = parse(sol,extraction_mode="first_match",extraction_config=[LatexExtractionConfig()],)if len(gold_parsed) != 0:# We require the answer to be provided in correct latex (no malformed operators)answer_parsed = parse(content,extraction_config=[LatexExtractionConfig(normalization_config=NormalizationConfig(nits=False,malformed_operators=False,basic_latex=True,equations=True,boxed="all",units=True,),# Ensures that boxed is tried firstboxed_match_priority=0,try_extract_without_anchor=False,)],extraction_mode="first_match",)# Reward 1 if the content is the same as the ground truth, 0 otherwisereward = float(verify(answer_parsed, gold_parsed))else:# If the gold solution is not parseable, we reward 1 to skip this examplereward = 1.0print("Failed to parse gold solution: ", sol)rewards.append(reward)return rewards
  • completions (list): 包含多个补全结果的列表,每个补全结果是一个包含内容的字典列表。
  • solution (list): 真实答案的列表。
  • kwargs: 其他可选参数(在本函数中未使用)。
  1. 提取补全内容

    contents = [completion[0]["content"] for completion in completions]
    
    • completions 列表中提取每个补全的第一个内容(假设每个补全是单个元素的列表),形成一个新的 contents 列表。
  2. 初始化奖励列表

    rewards = []
    
  3. 遍历每个补全和对应的真实答案

    for content, sol in zip(contents, solution):gold_parsed = parse(sol,extraction_mode="first_match",extraction_config=[LatexExtractionConfig()],)
    
    • 使用 zip 函数将 contentssolution 配对。
    • 对于每一对补全内容和真实答案,首先解析真实答案 sol,使用 parse 函数提取其中的信息。
  4. 处理解析结果

    if len(gold_parsed) != 0:answer_parsed = parse(content,extraction_config=[LatexExtractionConfig(normalization_config=NormalizationConfig(nits=False,malformed_operators=False,basic_latex=True,equations=True,boxed="all",units=True,),# Ensures that boxed is tried firstboxed_match_priority=0,try_extract_without_anchor=False,)],extraction_mode="first_match",)
    
    • 如果解析得到的真实答案 gold_parsed 非空,则继续解析生成的补全内容 content
    • 使用 LatexExtractionConfigNormalizationConfig 进行详细配置,确保解析过程中考虑了各种格式要求(如方程、单位等)。
  5. 计算奖励

    reward = float(verify(answer_parsed, gold_parsed))
    
    • 使用 verify 函数比较生成的补全解析结果和真实答案的解析结果。
    • 如果两者匹配,则返回 1.0,否则返回 0.0
  6. 处理无法解析的情况

    else:reward = 1.0print("Failed to parse gold solution: ", sol)
    
    • 如果真实答案无法解析,则默认给予奖励 1.0 并打印警告信息。
  7. 添加奖励到列表

    rewards.append(reward)
    
  8. 返回所有奖励

    return rewards
    

4.3.2 verify函数

该函数用于验证目标表达式是否与参考表达式匹配,它通过多种比较策略来处理不同的数学对象(如数字、表达式、集合、矩阵等),并提供灵活的配置选项以适应不同的需求。

def verify(gold: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, target: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, float_rounding: int=6,numeric_precision: int=15,strict: bool=True,timeout_seconds: int=3
) -> bool:
  • gold: 参考或正确的表达式,可以是单个 SymPy 表达式(BasicMatrixBase)、字符串或这些类型的列表。
  • target: 需要验证的表达式,类型同 gold
  • float_rounding: 浮点数舍入的小数位数,默认为 6。
  • numeric_precision: 数值比较时考虑的小数位数,默认为 15。
  • strict: 是否启用严格比较模式,默认为 True
    • 在严格模式下:变量很重要,集合不可与元组比较。
    • 在非严格模式下:变量按位置匹配,集合可与元组比较。
  • timeout_seconds: 单次比较操作的最大超时时间(秒),默认为 3 秒。
  1. 定义内部比较函数 compare_single_extraction

    @timeout(timeout_seconds=timeout_seconds)
    def compare_single_extraction(gold: Basic | MatrixBase | str, target: Basic | MatrixBase | str) -> bool:...
    
    • 使用装饰器 @timeout 设置超时保护,默认超时时间为 timeout_seconds
    • 比较两个表达式:
      • 如果两者都是 SymPy 表达式(BasicMatrixBase),则调用 sympy_expr_eq 进行比较。
      • 如果两者都是字符串,则进行简单的字符串比较。
  2. 定义包装函数 compare_single_extraction_wrapper

    def compare_single_extraction_wrapper(g, t):try:return compare_single_extraction(g, t)except Exception as e:logger.exception(f"Error comparing {g} and {t}")return False
    
    • 包装 compare_single_extraction,捕获并记录任何异常,返回 False 以避免程序中断。
  3. 处理输入列表

    if not isinstance(gold, list):gold = [gold]
    if not isinstance(target, list):target = [target]
    
    • 如果 goldtarget 不是列表,则将其转换为单元素列表,以便统一处理。
  4. 组合所有可能的比较

    return any(compare_single_extraction_wrapper(g, t) for g, t in product(gold, target))
    
    • 使用 itertools.product 生成所有可能的 goldtarget 组合。
    • 对每个组合调用 compare_single_extraction_wrapper,如果任意一对匹配成功,则返回 True

4.3.3 format_reward函数

函数用于检查给定的完成文本是否符合特定的格式,它验证完成文本是否包含 <think><answer> 标签,并且这两个标签的内容是非空的。

def format_reward(completions, **kwargs):"""Reward function that checks if the completion has a specific format."""pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"completion_contents = [completion[0]["content"] for completion in completions]matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]return [1.0 if match else 0.0 for match in matches]
  • completions: 这是一个列表,其中每个元素都是一个包含完成内容的对象(通常是字典)。假设每个完成对象的第一个元素包含一个键 "content",其值是需要检查的文本。
  • kwargs: 其他关键字参数,这里没有使用,但可以为未来的扩展提供灵活性。
  1. 正则表达式模式定义

    pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
    
    • 这个正则表达式用于匹配字符串是否以 <think> 开始,紧接着是任意字符(非贪婪匹配),然后是 </think>,接着可能有任意数量的空白字符(包括换行符),最后是以 <answer> 开始并以 </answer> 结束。
    • .*? 是非贪婪匹配,确保尽可能少地匹配字符。
    • \s* 匹配零个或多个空白字符(包括换行符)。
    • re.DOTALL | re.MULTILINE 标志允许点号 . 匹配所有字符(包括换行符),并且使多行文本中的每一行都可以独立匹配。
  2. 提取完成内容

    completion_contents = [completion[0]["content"] for completion in completions]
    
    • 这里通过列表推导式从 completions 列表中提取每个完成对象的第一个元素的 "content" 字段,形成一个新的列表 completion_contents
  3. 匹配正则表达式

    matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
    
    • 使用 re.match 函数对 completion_contents 中的每个内容应用正则表达式模式。
    • matches 列表将包含 re.Match 对象(如果匹配成功)或 None(如果匹配失败)。
  4. 生成奖励分数

    return [1.0 if match else 0.0 for match in matches]
    
    • 最后一步是根据匹配结果生成奖励分数。如果匹配成功(即 match 不是 None),则返回 1.0;否则返回 0.0

示例代码:

completions = [[{"content": "<think>This is reasoning.</think><answer>This is answer.</answer>"}],[{"content": "<think>This is reasoning.</think>"}],[{"content": "<answer>This is answer.</answer>"}],[{"content": "This does not match."}]
]reward_scores = format_reward(completions)
print(reward_scores)  # 输出: [1.0, 0.0, 0.0, 0.0]

在这个例子中:

  • 第一个完成内容完全匹配正则表达式,因此得分为 1.0
  • 后三个完成内容不符合要求,因此得分均为 0.0

4.4 将数据集格式化为对话形式

# Format into conversationdef make_conversation(example):return {"prompt": [{"role": "system", "content": SYSTEM_PROMPT},{"role": "user", "content": example["problem"]},],}dataset = dataset.map(make_conversation)for split in dataset:if "messages" in dataset[split].column_names:dataset[split] = dataset[split].remove_columns("messages")

将一个数据集中的每个示例转换为对话格式,并确保数据集中没有多余的列(如 messages)。

  • 输入example 是一个字典,包含单个数据样本的信息,其中 "problem" 键对应的值是用户的问题或任务描述。
  • 输出:返回一个新的字典,包含一个 "prompt" 键,其值是一个对话列表:
    • 第一条消息是系统消息,内容由 SYSTEM_PROMPT 定义。
    • 第二条消息是用户消息,内容是 example["problem"]
  • dataset.map(make_conversation):使用 map 方法将 make_conversation 函数应用到数据集的每个示例上,生成新的对话格式。
  • 移除多余列:遍历数据集的每个拆分(split),如果存在 "messages" 列,则将其移除。

4.5 初始化GRPO Trainer

trainer = GRPOTrainer(model=model_args.model_name_or_path,reward_funcs=reward_funcs,args=training_args,train_dataset=dataset[script_args.dataset_train_split],eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,peft_config=get_peft_config(model_args),callbacks=get_callbacks(training_args, model_args),)

篇幅有限,训练部分的代码我们放到下一篇博文详细解读!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/895813.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【杂谈】加油!!!!

为了在三月底前系统准备Java后端开发的面试和笔试&#xff0c;以下是分阶段的高效学习计划&#xff1a; 一、知识体系构建&#xff08;第1-2周&#xff09; 核心基础强化 Java基础&#xff08;每日1.5小时&#xff09;&#xff1a; 重点掌握&#xff1a;JVM内存模型&#xff0…

python旅游推荐系统+爬虫+可视化(协同过滤算法)

✅️基于用户的协同过滤算法 ✅️有后台管理 ✅️2w多数据集 这个旅游数据分析推荐系统采用了Python语言、Django框架、MySQL数据库、requests库进行网络爬虫开发、机器学习中的协同过滤算法、ECharts数据可视化技术&#xff0c;以实现从网站抓取旅游数据、个性化推荐和直观展…

HarmonyNext上传用户相册图片到服务器

图片选择就不用说了&#xff0c;直接用 无须申请权限 。 上传图片&#xff0c;步骤和android对比稍微有点复杂&#xff0c;可能是为了安全性考虑&#xff0c;需要将图片先拷贝到缓存目录下面&#xff0c;然后再上传&#xff0c;当然你也可以转成Base64&#xff0c;然后和服务…

同为科技智能PDU助力Deepseek人工智能和数据交互的快速发展

1 2025开年&#xff0c;人工智能领域迎来了一场前所未有的变革。Deepseek成为代表“东方力量”的开年王炸&#xff0c;不仅在国内掀起了技术热潮&#xff0c;并且在全球范围内引起了高度关注。Deepseek以颠覆性技术突破和现象级应用场景席卷全球&#xff0c;这不仅重塑了产业格…

二、QEMU NFS 环境搭建

​ 在上一章节中&#xff0c;我们已经成功完成了内核和 busybox 环境的配置。为了进一步提高开发效率&#xff0c;我们可以使用 NFS&#xff08;Network File System&#xff09;来挂载根目录。NFS 允许我们将本地文件系统通过网络共享给虚拟机使用&#xff0c;这样在开发过程中…

.NET 9.0 的 Blazor Web App 项目中 EF Core 【事务】使用备忘

一、DbContext.Database.BeginTransactionAsync() 模式 1. 注意事项&#xff1a;连接字符串中启用了 MARS&#xff08;Multiple Active Result Sets&#xff1a;MultipleActiveResultSetsTrue &#xff09;后&#xff0c;无法创建 保存点&#xff08;保存点与 SQL Server 的多…

记一次 Git Fetch 后切换分支为空的情况

Git Fetch 后切换分支为空的情况 在使用 Git 时&#xff0c;我遇到这样的情况&#xff1a;执行 git fetch 后切换分支&#xff0c;发现工作目录是空的&#xff0c;没有任何文件&#xff0c;所以插眼记录一下。 原因分析 git fetch 的作用&#xff1a;git fetch 只会从远程仓库…

UMLS数据下载及访问

UMLS数据申请 这个直接在官网上申请即可&#xff0c;记得把地址填全&#xff0c;基本都会拿到lisence。 UMLS数据访问 UMLS的数据访问分为网页访问&#xff0c;API访问以及数据下载后的本地访问&#xff0c;网页访问&#xff0c;API访问按照官网的指示即可&#xff0c;这里主…

使用 Docker 部署 Apache Spark 集群教程

简介 Apache Spark 是一个强大的统一分析引擎&#xff0c;用于大规模数据处理。本文将详细介绍如何使用 Docker 和 Docker Compose 快速部署一个包含一个 Master 节点和两个 Worker 节点的 Spark 集群。这种方法不仅简化了集群的搭建过程&#xff0c;还提供了资源隔离、易于扩…

瑞萨RA-T系列芯片ADCGPT功能模块的配合使用

在马达或电源工程中&#xff0c;往往需要采集多路AD信号&#xff0c;且这些信号的优先级和采样时机不相同。本篇介绍在使用RA-T系列芯片建立马达或电源工程时&#xff0c;如何根据需求来设置主要功能模块ADC&GPT&#xff0c;包括采样通道打包和分组&#xff0c;GPT触发启动…

20250217 随笔 redis非原子性操作简述

从你提供的文本来看&#xff0c;核心是 Redis 作为缓存的检查机制&#xff0c;以及非原子性操作导致的不一致性问题。 我们可以拆解为两个部分来理解&#xff1a; &#x1f4cc; 1. 逻辑&#xff1a;先查 Redis&#xff0c;再决定是否注册 逻辑流程 先查询 Redis 是否有某个 …

git-提交时间和作者时间的区别

1.介绍 定义介绍 提交时间&#xff08;Committer Date&#xff09;&#xff1a;决定了提交在 Git 历史中的位置&#xff0c;通常影响 GitHub 上提交显示的顺序。 作者时间&#xff08;Author Date&#xff09;&#xff1a;虽然不影响提交的排序&#xff0c;但在每个提交详情页…

PHP框架入门指南:从零构建现代Web应用

一、为什么需要PHP框架? 1.1 传统PHP开发的痛点 重复造轮子:用户认证、表单验证等基础功能需要反复开发代码混乱:缺乏统一结构导致维护困难安全漏洞:手动处理SQL注入/XSS攻击效率低下扩展性差:耦合代码难以适应业务增长1.2 框架的核心价值 标准化架构:MVC模式强制代码分…

Leetcode 146 LRU缓存 的三种解法

146. LRU 缓存 请你设计并实现一个满足 LRU (最近最少使用) 缓存 约束的数据结构。 实现 LRUCache 类&#xff1a; LRUCache(int capacity) 以 正整数 作为容量 capacity 初始化 LRU 缓存int get(int key) 如果关键字 key 存在于缓存中&#xff0c;则返回关键字的值&#xff0…

尚硅谷 java 学习Day19 抽象类与抽象方法、接口、内部类

6-5 抽象类(abstract)与抽象方法&#xff08;important&#xff09; 一、什么叫抽象类&#xff1a; 有时候将一个父类设计的非常抽象&#xff0c;以至于它没有具体的实例&#xff0c;这样的类称为抽象类 abstract关键字的使用&#xff1a; ​ 1、abstract:抽象的 ​ 2、abs…

【LeetCode Hot100 链表(上)】相交链表、反转链表、回文链表、环形链表、合并两个有序链表、两数相加

链表 1. 相交链表问题描述解决思路代码实现 2. 反转链表问题描述解决思路代码实现 3. 回文链表问题描述解决思路代码实现 4. 环形链表问题描述解决思路代码实现 5. 环形链表II问题描述解决思路代码实现 6. 合并两个有序链表问题描述解决思路代码实现 7. 两数相加问题描述解决思…

【Python pro】基本数据类型

一、数字类型 1.1 数字类型的组成 1.1.1 整数 &#xff08;1&#xff09;十进制&#xff0c;二进制0b&#xff0c;八进制0o&#xff0c;十六进制0x print(16 0b10000 0o20 0x10) # 输出&#xff1a;True&#xff08;2&#xff09;十进制转其他进制 a bin(16) b oct(1…

拯救者电脑在重装系统之后电源计划丢失Fn+Q切换不了模式怎么恢复?

参考联想知识库的一下链接&#xff1a; https://iknow.lenovo.com.cn/detail/196192 其中下载的解压文件后的文件需要复制粘贴到D盘的根目录下&#xff0c;再来运行文件。若在生成的log文件中看到导入成功以及控制面板中看到已添加的电源计划即可 如果还是无效可是试试以下的…

ubuntu 执行 sudo apt-get update 报错

记录一下&#xff0c;遇到这个问题了&#xff0c;网络上看到的解决办法&#xff0c;亲测有效 执行sudo apt-get update ,却报以下错误&#xff0c;“SECURITY: URL redirect target contains control characters rejecting ” 经检查发现&#xff0c;/etc/apt/source.list 下的…

深度集成DeepSeek大模型:WebSocket流式聊天实现

目录 5分钟快速接入DeepSeek大模型&#xff1a;WebSocket实时聊天指南创建应用开发后端代码 (Python/Node.js)结语 5分钟快速接入DeepSeek大模型&#xff1a;WebSocket实时聊天指南 创建应用 访问DeepSeek官网 前往 DeepSeek官网。如果还没有账号&#xff0c;需要先注册一个。…