基于Contiue来阅读open-r1中的GRPO训练代码

原创 快乐王子HP 快乐王子AI说 2025年04月03日 23:54 广东

前面安装了vscode[1]同时也安装了Coninue的相关插件[2],现在想用它们来阅读一下open-r1项目的代码[3]。

首先,从启动训练开始(以GRPO为例子)

第一步,使用TRL的vLLM后端

CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B

第二步,启动GRPO

CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 ACCELERATE_LOG_LEVEL=info \     accelerate launch --config_file recipes/accelerate_configs/zero2.yaml --num_processes 7 \     src/open_r1/grpo.py --config recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml

查看vllm的服务启动帮助文档

usage: trl vllm-serve [-h] --model MODEL [--revision REVISION] [--tensor_parallel_size TENSOR_PARALLEL_SIZE] [--host HOST] [--port PORT] [--gpu_memory_utilization GPU_MEMORY_UTILIZATION] [--dtype DTYPE]                       [--max_model_len MAX_MODEL_LEN] [--enable_prefix_caching ENABLE_PREFIX_CACHING]

关于zero2.yaml文件

(https://github.com/huggingface/open-r1/blob/main/recipes/accelerate_configs/zero2.yaml)

0

    1.核心配置:    - 使用 DeepSpeed 的 Zero Stage 2 优化 (zero_stage: 2)    - 混合精度训练采用 bf16 (mixed_precision: bf16)    - 单机 8 GPU 训练 (num_machines: 1, num_processes: 8)2.Zero Stage 2 特点:    - 优化器状态分区,减少内存占用    - 没有启用参数或优化器卸载 (offload_optimizer_device: none, offload_param_device: none)    - 比 Stage 3 内存效率稍低,但通信开销更小3.硬件配置:    - 纯 GPU 训练 (use_cpu: false)    - 不涉及 TPU (tpu_* 相关配置均为 false)    - 适合具有 8 个 GPU 的单个节点4.使用场景:    - 中等规模模型训练    - 当 GPU 内存足够容纳模型参数和激活值时    - 需要比 Zero Stage 1 更高的内存效率,但不想承受 Stage 3 的通信开销5.性能考虑:    - bf16 混合精度可以在支持它的硬件上提供良好的训练速度和内存效率    - 8 个 GPU 的配置适合大多数单节点服务器这个配置文件适合在单个多 GPU 节点上训练中等规模模型,在内存效率和通信开销之间取得平衡。

    recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml文件分析

    (https://github.com/huggingface/open-r1/blob/main/recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml)

    1.模型架构:  - 基于1.5B参数的蒸馏版Qwen模型  - 使用Flash Attention 2优化注意力计算  - bfloat16混合精度训练2.训练策略:  - 采用GRPO(可能是一种强化学习优化算法)训练方法  - 结合三种奖励函数:准确性、格式正确性和标签计数  - 使用vLLM加速推理过程3.数据处理:  - 专门设计的复杂对话模板  - 数学领域专用数据集(OpenR1-Math-220k)  - 要求模型以和标签分步输出4.资源利用:  - 梯度检查点和梯度累积优化显存使用  - 适中的batch size(16)和上下文长度(512/2048)5.监控与部署:  - 完整的训练日志记录(W&B)  - 模型自动推送至HuggingFace Hub  - 严格的模型保存策略

    grpo.py文件

    (https://github.com/huggingface/open-r1/blob/main/src/open_r1/grpo.py)

    ```mermaidgraph TD    A[开始] --> B[设置随机种子]    B --> C[配置日志系统]    C --> D[检查检查点]    D --> E[初始化WandB]    E --> F[加载数据集]    F --> G[加载tokenizer]    G --> H[获取奖励函数]    H --> I[格式化对话数据]    I --> J[初始化模型参数]    J --> K[创建GRPOTrainer]    K --> L{是否有检查点?}    L -- 是 --> M[从检查点恢复训练]    L -- 否 --> N[开始新训练]    M --> O[训练模型]    N --> O    O --> P[保存模型和指标]    P --> Q{是否评估?}    Q -- 是 --> R[执行评估]    Q -- 否 --> S    R --> S[保存评估结果]    S --> T{是否推送至Hub?}    T -- 是 --> U[推送模型]    T -- 否 --> V[结束]    U --> V```

    rewards.py

    (https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py)

    0

    结合医学场景来探索

    0

      def medical_accuracy_reward(response: str, golden_answer: str) -> float:    """评估医学准确性,需要与标准医学答案对比"""    # 这里可以集成医学知识库或NLP模型进行专业评估    medical_terms_score = calculate_medical_terms_match(response, golden_answer)    treatment_score = evaluate_treatment_correctness(response, golden_answer)    return 0.6 * medical_terms_score + 0.4 * treatment_scoredef safety_reward(response: str) -> float:    """安全性评估:检查是否有危险建议"""    dangerous_keywords = ["自行停药", "未经医生", "高剂量", "随意服用"]    for keyword in dangerous_keywords:        if keyword in response:            return 0.0  # 发现危险建议直接0分    return 1.0def citation_reward(response: str) -> float:    """参考文献引用评估"""    citation_formats = ["[1]", "(Smith et al., 2020)", "根据最新指南"]    return 1.0 if any(fmt in response for fmt in citation_formats) else 0.5def patient_language_reward(response: str) -> float:    """患者友好语言评估"""    complex_terms = ["病理学", "分子机制", "流行病学"]    simplified_explanations = ["简单说", "通俗理解", "换句话说"]        complex_count = sum(term in response for term in complex_terms)    simple_count = sum(term in response for term in simplified_explanations)        if complex_count == 0:         return 1.0    return simple_count / (complex_count + 1)  # 确保至少解释了部分复杂术语def empathy_reward(response: str) -> float:    """同理心评估"""    empathy_keywords = ["理解您", "不用担心", "建议咨询", "我们会帮助"]    return min(1.0, 0.2 * sum(kw in response for kw in empathy_keywords))

      0

      参考:

      [1]vscode安装:https://mp.weixin.qq.com/s/FvqSUrJFFXSVxFpZ6Q2-jg

      [2]vscode上安装Coninue的相关插件:

      https://mp.weixin.qq.com/s/cD-BHkCWQxfeedL3eboaBA

      [3]open-r1项目:https://mp.weixin.qq.com/s/BDDUe1RyIVutucUVA9Yuzg,https://github.com/huggingface/open-r1]

      本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/74368.shtml

      如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

      相关文章

      JVM深入原理(六)(二):双亲委派机制

      目录 6.5. 类加载器-双亲委派机制 6.5.1. 双亲委派机制-作用 6.5.2. 双亲委派机制-工作流程 6.5.3. 双亲委派机制-父加载器 6.5.4. 双亲委派机制-面试题 6.5.5. 双亲委派机制-代码主动加载一个类 6.6. 类加载器-打破双亲委派机制 6.6.1. 打破委派-ClassLoader原理 6.6.…

      Linux 文件系统超详解

      一.磁盘 磁盘是计算机的主要存储介质,它可以存储大量二进制数据,即使断电后也可以保证数据不会丢失。下面我们将了解磁盘的物理结构、存储结构以及逻辑结构。 磁盘的存储结构 1. 磁盘寻址的时候,基本单位既不是bit也不是byte,而…

      2025年大模型与Transformer架构:重塑AI未来的科技革命

      引言:一场关于智能的革命 想象一下,当你向一个虚拟助手提问时,它不仅能够准确理解你的需求,还能生成一段流畅且富有逻辑的回答;或者当你上传一张模糊的照片时,系统可以快速修复并生成高清版本——这一切的…

      GO语言学习(16)Gin后端框架

      目录 ☀️前言 1.什么是前端?什么是后端?🌀 2.Gin框架介绍 🌷 3.Gin框架的基本使用 -Hello,World例子🌷 🌿入门示例 - Hello,World 💻补充(一些常用的网…

      深入解析 Git Submodule:从基础到高级操作指南

      深入解析 Git Submodule:从基础到高级操作指南 一、Git Submodule 是什么? git submodule 是 Git 提供的一个强大功能,允许在一个 Git 仓库(主仓库)中嵌入另一个独立的 Git 仓库(子模块)。主仓…

      电子电气架构 --- EEA演进与芯片架构转移

      我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 周末洗了一个澡,换了一身衣服,出了门却不知道去哪儿,不知道去找谁,漫无目的走着,大概这就是成年人最深的孤独吧! 旧人不知我近况,新人不知我过…

      如何用deepseek生成流程图

      软件准备: 在线流程图【Flowchart Maker & Online Diagram Software】或【process on】 步骤: 1、用 【DeepSeek】生成 结构化内容(Mermaid文件) 1.1、向deepseek输入指令:【帮我用mermaind写出“某某”的具体…

      【华为OD技术面试真题 - 技术面】- Java面试题(17)

      华为OD面试真题精选 专栏:华为OD面试真题精选 目录: 2024华为OD面试手撕代码真题目录以及八股文真题目录 文章目录 华为OD面试真题精选虚拟机分区1. **虚拟磁盘分区**2. **虚拟机的内存分区**3. **CPU分配**4. **虚拟网络分区**5. **存储虚拟化和分区**6. **虚拟机分区管理**…

      Linux | I.MX6ULL内核及文件系统源码结构(7)

      01 类型 描述 备注 ARM 交叉编译器 版本:4.9.4 提供软件工具 Uboot 版本:2016.03 提供源码 支持LCD显示;支持网口; 支持 EMMC,NAND FLASH; 支持环境变量修改保存 Linux 内核 版本:4.1.15 提供…

      0基础入门scrapy 框架,获取豆瓣top250存入mysql

      一、基础教程 创建项目命令 scrapy startproject mySpider --项目名称 创建爬虫文件 scrapy genspider itcast "itcast.cn" --自动生成 itcast.py 文件 爬虫名称 爬虫网址 运行爬虫 scrapy crawl baidu(爬虫名) 使用终端运行太麻烦了,而且…

      鸿蒙NEXT小游戏开发:猜小球

      1. 引言 “猜小球”是一个经典的益智游戏,通常由一名表演者和多名参与者共同完成。表演者会将一个小球放在一个杯子下面,然后将三个杯子快速地交换位置,参与者则需要猜出最终哪个杯子下面有小球。本文将介绍如何使用HarmonyOS NEXT技术&…

      网络购物谨慎使用手机免密支付功能

      在数字经济蓬勃发展的当下,“免密支付”成为许多人消费时的首选支付方式。 “免密支付”的存在有其合理性。在快节奏的现代生活中,时间愈发珍贵,每节省一秒都可能带来更高的效率。以日常通勤为例,上班族乘坐交通工具时&#xff0c…

      记录 | Android getWindow().getDecorView().setSystemUiVisibility(...)设置状态栏属性

      纯纯的一边开发一边学习,是小白是菜鸟,单纯的记录和学习,大神勿喷,理解有错望指正~ getWindow().getDecorView().setSystemUiVisibility(…) 该方法用于控制系统 UI(如状态栏、导航栏)的可见性…

      java虚拟机---JVM

      JVM JVM,也就是 Java 虚拟机,它最主要的作用就是对编译后的 Java 字节码文件逐行解释,翻译成机器码指令,并交给对应的操作系统去执行。 JVM 的其他特性有: JVM 可以自动管理内存,通过垃圾回收器回收不再…

      VectorBT:使用PyTorch+LSTM训练和回测股票模型 进阶四

      VectorBT:使用PyTorchLSTM训练和回测股票模型 进阶四 本方案融合 LSTM 时序预测与动态风险控制。系统采用混合架构,离线训练构建多尺度特征工程和双均线策略,结合在线增量更新持续优化模型。技术要点包括三层特征筛选、波动率动态仓位管理、混…

      前端中rem,vh,vw

      1. rem&#xff08;Root EM&#xff09; 参照对象 基准&#xff1a;相对于 根元素&#xff08;<html>&#xff09;的 font-size 计算。 默认情况下&#xff0c;浏览器的根 font-size 为 16px&#xff08;即 1rem 16px&#xff09;&#xff0c;但可通过 CSS 修改&#…

      详解 MySQL 常见的存储引擎及它们之间的区别

      MySQL 支持多种存储引擎&#xff0c;每种引擎针对不同的应用场景提供了特定的特性和优化。下面是几种常见的存储引擎以及它们之间的主要区别&#xff1a; 常见存储引擎 1. InnoDB&#xff08;重点&#xff09; 事务支持&#xff1a; 完全支持 ACID 事务&#xff0c;确保数据一…

      html+css+js 实现一个贪吃蛇小游戏

      目录 游戏简介 游戏功能与特点 如何玩转贪吃蛇 游戏设计与实现 HTML结构 JavaScript核心实现 代码结构&#xff1a; 效果 关于“其他游戏” 游戏简介 贪吃蛇是一款经典的单人小游戏&#xff0c;玩家通过控制蛇的移动&#xff0c;吃掉食物来增加长度&#xff0c;避免撞…

      GLSL(OpenGL 着色器语言)基础语法

      GLSL&#xff08;OpenGL 着色器语言&#xff09;基础语法 GLSL&#xff08;OpenGL Shading Language&#xff09;是 OpenGL 计算着色器的语言&#xff0c;语法类似于 C 语言&#xff0c;但提供了针对 GPU 的特殊功能&#xff0c;如向量运算和矩阵运算。 着色器的开头总是要声明…

      ngx_http_core_merge_srv_conf

      定义在 src\http\ngx_http_core_module.c static char * ngx_http_core_merge_srv_conf(ngx_conf_t *cf, void *parent, void *child) {ngx_http_core_srv_conf_t *prev parent;ngx_http_core_srv_conf_t *conf child;ngx_str_t name;ngx_http_server_name_t…