Step-DPO 论文——数学大语言模型理解

论文题目:STEP-DPO: STEP-WISE PREFERENCE OPTIMIZATION FOR LONG-CHAIN REASONING OF LLMS

翻译为中文就是:“LLMs长链推理的逐步偏好优化”

论文由港中文贾佳亚团队推出,基于推理步骤的大模型优化策略,能够像老师教学生一样优化大模型。

Qwen2-72B-Instruct模型作为基础模型进行微调优化后,其数学成绩超越了GPT-4、Gemini1.5-Pro、Claude3-Opus等闭源模型。

论文链接:https://arxiv.org/pdf/2406.18629

 代码仓库:https://github.com/dvlab-research/Step-DPO

1. 介绍

大语言模型(LLMs)在数学推理上具有重大挑战,这是由于数学需要精确的推理链。然而,直接偏好优化(DPO)对长链数学推理的益处有限,因为采用DPO的模型难以识别错误答案中的详细错误。

所以作者提出了Step-DPO方法,它将整个答案划分多个步骤作答(Step1, Step2, Step3, ...),大大提高的模型的精度。

在MATH数据集上,在Qwen2-7B-Instruct上准确率从53.0% 提升到58.6%,GSM8K数据集,准确率从85.5%提升到87.9% 。使用 Qwen2-72B-Instruct模型,在MATH和GSM8K上分别取得 70.8%94.0%的准确率。

1.1 像教育学生一样训练大模型

数学推理被认为是大语言模型(LLMs)中一种关键的长链推理能力。由于需要广泛的思维链(CoT),这项任务尤其具有挑战性,其中可能包括许多推理步骤,这些步骤中的任何错误都可能导致最终得不到正确答案。 

(1)首先,最常用的方法就是监督微调(SFT),使用各种数据增强对齐来微调模型。然而,当SFT数据达到一定数量时,模型经常出现幻觉,性能也随之趋于饱和。一个潜在的原因是,随着首选输出的概率增加,不理想输出的概率也会增加。这种现象使得模型在长链推理中更容易出错。

(2)最近,直接偏好优化(DPO)(Rafailov et al., 2024)被提出用于使用偏好对数据进行对齐(每个偏好对都包含一个输入提示、偏好输出及非偏好输出),因其简单性而广受欢迎。尽管DPO在Chat聊天任务中很有效,但它对长链(long-chain)数学任务效果不明显。如下图2所示。

(3)于是作者提出了Step-DPO,基于推理步骤的直接偏好优化。Step-DPO 逐步检查每个步骤的答案是否正确,这使得模型能够轻松定位错误Step,以进行有效的优化,显著增强了长链推理

2.  STEP-DPO 公式

2.1 DPO

我们先看到DPO的优化目标函数:

\begin{aligned} L_{DPO}(\theta)=-E_{(x,y_{win},y_{lose})\backsim D}[log \sigma (\beta log \frac {\pi_{\theta} (y_{win} \mid x)}{\pi_{ref}(y_{win \mid x})} - \beta log \frac{\pi_{\theta}(y_{lose} \mid x)}{\pi_{ref}(y_{lose} \mid x)})] \end{aligned}

其中,\ x 是输入提示 ,\ y_{win}, y_{lose} 分别表示正确的回答、错误的回答, \ D 是偏好对数据集。 \sigma 表示 sigmoid 函数, \pi_{\theta}\pi_{ref} 分别表示当前要优化的微调模型 以及训练过程中保存不变的参照模型, \beta 是一个超参数用来控制距离。

2.2 Step-DPO

我们再看到Step-DPO,它不再像DPO从整体对比答案,而是将每个推理步骤视为一个基本单元,对比单个推理步骤,更精细地提升模型的推理能力。目标优化公式:

\begin{aligned} L(\theta)=-E_{(x,s_{1 \backsim k-1},s_{win}, s_{lose})\backsim D}[log \sigma (\beta log \frac {\pi_{\theta} (s_{win} \mid x; s_{1 \backsim k-1})}{\pi_{ref}(s_{win} \mid x; s_{1 \backsim k-1})} - \beta log \frac{\pi_{\theta}(s_{lose} \mid x; s_{1 \backsim k-1})}{\pi_{ref}(s_{lose} \mid x; s_{1 \backsim k-1})})] \end{aligned}

回答 \ y 可以分解为多个步骤 \ y=s_{1}, ..., s_n\ x 表示输入提示。Step-DPO 优化目标就是最大化正确的下一个推理步骤 \ s_{win} 的概率,最小化错误步骤 \ s_{lose} 的概率,如图3所示。

3. 分布式数据构建

Step-DPO 的训练数据集是怎样的呢?每个数据样本中应该包含下面4项:

1)prompt \ x

2)初始推理步骤 \ s_{1 \backsim k-1}

3)首选推理步骤  \ s_{win}

4)不需要(错误)的推理步骤 \ s_{lose}

如下图所示:

(1)错误收集

首先,我们收集数学问题问答的数据集 \ D_0 = \{ (x, \hat{y}) \} ,x 是数学问题,\ \hat{y} 是真实答案。

然后,使用初始(参照)模型 \ \pi_{ref} 来得到每个数学问题 x 的答案。

在进行模型推理之前,添加思维链(CoT)前缀作为提示,比如:“Let‘s think step by step. Step 1:”,以确保模型的推理结果被结构化为多个推理步骤。

模型推理完成之后可得到每个数学问题x的推理结果y,然后选择与真实答案 \ \hat{y} 不一致的那些结果,汇总得到数据集 \ D_1

\begin{aligned} D_1 = \{ (x, \hat{y}, y) \mid x \in D_0 \} \end{aligned}

(2)错误步骤定位

假设每个错误的推理结果都被明确地表示为 推理步骤序列 \ y = s_1, s_2, ..., s_n ,随后需要人工或利用GPT-4验证每个推理步骤的正确性,直到找到第一个错误步骤 \ s_k ,选择 \ s_k 作为错误的推理步骤 \ s_{loss} 。这样得到一个包含错误步骤的数据集 \ D_2

D_2 = \{ (x, \hat{y}, s_{1 \backsim k-1}, s_{loss}) \mid x \in D_1 \}

(3)步骤修正

为了获得 \ D_2 中每个样本的相应正确推理步骤,需要通过用 提示x 和前面的正确推理步骤 \ s_{1 \backsim k-1} 通过模型 \pi_{ref} 来采样多个输出 \ y_{cont} ,该过程被表述为:

y_{cont} \backsim \pi_{ref} (y \mid x; s_{1 \backsim k-1})

随后,保留那些最终答案与实际情况相匹配的输出。我们选择 \ y_{cont} 中的第一个推理步骤作为 \ s_{win} ,从而得到最终的数据集D:

D = \{ (x, s_{1 \backsim k-1}, s_{lose}, s_{win} \mid x \in D_2 ) \}

数据样本示例如 Figure 5 所示。

4. 实验结果

Step-DPO 可以在SFT模型或现有的开源 Instruct 模型上进行微调,仅通过 10K 数据以及数百个训练步数,即可去得大幅度数学能力提升。

其中 Qwen2-72B-Instruct + Step-DPO 取得了 70.8%94.0% 准确率在 MATH 和 GSM8K 数据集上。

在难度较高的包含数学竞赛题 Odyssey-MATH 榜单上也有显著提升。

突出了 Step-DPO 强大泛化能力,模型更加鲁棒,减少幻觉的产生。

如下三个例子:

1. 假设h(x)=f-1(x),如果h(2)=10,h(10)=1,h(1)=2,求f(f(10))

2. t的平方根大于2且小于3.5,满足这一条件的整数t有多少个?

下面比较难的数学竞赛题也能做对

3. 在所有非增函数f:{1,2,…,10}→{1,2,…,10}中,有些函数有固定点,另一些没有,这两种函数的数量相差多少?


参考:

https://github.com/dvlab-research/Step-DPO

贾佳亚团队新作:10k数据让大模型数学能力超GPT-4

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/48904.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

String 和StringBuilder字符串操作快慢的举例比较

System.currentTimeMillis(); //当前时间与1970年1月1日午夜UTC之间的毫秒差。public class HelloWorld {public static void main(String[] args) {String s1 "";StringBuilder s2 new StringBuilder("");long time System.currentTimeMillis();long s…

git命令学习分享

分布式版本控制系统,本地仓库和远程仓库相互独立。 使用repository仓库进行控制,可以对里面的文件进行跟踪,复原。 git config --global --list:查看git配置列表 cd ** :进入** cd .. :退回上一级 echo…

AI Agent项目探索与实践记录

AI Agent项目探索与实践记录 1. 概述2. 总体结构2.1 记忆模块2.2 模型服务模块2.2.1 LLM服务2.2.2 retrieval服务2.2.3 rerank服务 2.3 Agent系统2.3.1 Planner2.3.2 Code/SQL Generator2.3.3 Code Executor2.3.4 Responser2.3.5 Round Compressor2.3.6 New Turn Discriminator…

基于Llama Index构建RAG应用(Datawhale AI 夏令营)

前言 Hello,大家好,我是GISer Liu😁,一名热爱AI技术的GIS开发者,本文参与活动是2024 DataWhale AI夏令营;😲 在本文中作者将通过: Gradio、Streamlit和LlamaIndex介绍 LlamaIndex 构…

linux文本查看命令

在Linux中,查找文件通常使用几个不同的命令,具体取决于你的需求和上下文。以下是一些最常用的命令: find 命令: find 是最强大和灵活的命令之一,用于在目录树中搜索文件,并执行对找到的文件执行指定的操作…

全局 loading

好久不见! 做项目中一直想用一个统一的 loading 状态控制全部的接口加载,但是一直不知道怎么处理,最近脑子突然灵光了一下想到了一个办法。 首先设置一个全局的 loading 状态,优先想到的就是 Pinia 然后因为页面会有很多接口会…

数据结构——栈(链式结构)

一、栈的链式存储结构 如果一个栈存在频繁进栈和出栈操作,可以考虑链式结构。 栈的链式存储结构是指使用链表来实现栈这种数据结构。在链式存储结构中,栈的每个元素被封装成一个节点,节点之间通过指针相连,形成一个链表。栈顶元…

Linux下开放指定端口

比如需要开放82端口: #查询是否开通 firewall-cmd --query-port82/tcp#开放端口82 firewall-cmd --zonepublic --add-port82/tcp --permanent#重新加载防火墙 firewall-cmd --reload

java学习--代码块

package com.block.test01; class Main{public static void main(String[] args) {Block block new Block("你好,李焕英");new Block("你好",12,24);} } public class Block {String name;int begin_time;int end_time; //如果在调用构造器时都…

华盈生物-20K人类蛋白组芯片的超凡应用:揭秘蛋白质的神奇世界

各位科研小伙伴们,欢迎再次来到我们的科学探险之旅!今天,我们要深入探讨一项超级实用的科研工具——20K人类蛋白组芯片。通过这款芯片,你可以揭开蛋白质世界的神秘面纱,探索各种有趣的应用方向。准备好了吗&#xff1f…

在python中使用正则表达式

正则表达式是什么?就是要寻找的数据的规律,使用正则表达式的步骤有三 第一,寻找规律,第二使用正则符号表示规律,第三,提取信息 看下面的代码 import re wenzhang (小草偷偷地从土里钻出来,嫩…

Leetcode 3228. Maximum Number of Operations to Move Ones to the End

Leetcode 3228. Maximum Number of Operations to Move Ones to the End 1. 解题思路2. 代码实现 题目链接:3228. Maximum Number of Operations to Move Ones to the End 1. 解题思路 这一题不难分析得到,要获得最多的操作次数,只需要从左…

数据结构---散列表(哈希表)

什么是哈希表 1、哈希表(Hash Table):也叫做散列表。是根据关键码值(Key Value)直接进行访问的数据结构。 2、哈希表通过「键 key 」和「映射函数 Hash(key) 」计算出对应的「值 value」,把关键码值映射到…

SwiftUI 5.0(iOS 17)滚动视图的滚动目标行为(Target Behavior)解惑和实战

概览 在 SwiftUI 的开发过程中我们常说:“屏幕不够,滚动来凑”。可见滚动视图对于超长内容的呈现有着多么秉轴持钧的重要作用。 这不,从 SwiftUI 5.0(iOS 17)开始苹果又为滚动视图增加了全新的功能。但是官方的示例可…

【Node.js】调试 Node.js 程序

调试 Node.js 程序可以使⽤以下⽅法: console.log():使⽤ console.log() 打印变量或者调试信息,可以在控制台中查看输出的结果。debugger:在代码中使⽤ debugger 命令设置断点,当程序执⾏到该点时会暂停,可…

Linux----Mplayer音视频库的移植

想要播放视频音乐就得移植相关库到板子上 Mplayer移植需要依赖以下源文件:(从官网获取或者网上) 1、zlib-1.2.3.tar.gz :通用的内存空间的压缩库。 2、libpng-1.2.57.tar.gz :png格式图片的压缩或解压库 3、Jpegsrc.v9b.tar.gz : jpeg格式图片的压…

Unity3D 如何自动点击UIElement.Button类型的按钮详解

前言 在Unity3D开发中,自动点击UI界面上的按钮是一个常见的需求,特别是在自动化测试、演示脚本或游戏AI控制等场景中。Unity的UI系统(UGUI)提供了灵活的接口来实现这一功能。下面将详细介绍如何在Unity中自动点击UIElement.Butto…

数据结构day3

一、思维导图 二、顺序表实现学生管理系统 //头文件 #ifndef TEST_H #define TEST_H #define MAX_SIZE 100//定义学生类型 typedef struct {char name[20]; //姓名int age; //年龄double score; //分数 }datatype;//定义班级类型 typedef struct {datatype student[MAX…

CDGA数据治理:突破卡点堵点,解决确权难、流通交易难问题

随着大数据时代的来临,数据已成为推动社会进步和经济发展的重要力量。然而,数据治理中的卡点堵点问题,特别是确权难、流通交易难,正成为制约数据要素市场健康发展的瓶颈。本文将探讨这些问题,并提出相应的解决方案。 确…

uniapp写登陆|微信小程序登录和微信h5登录使用同一个页面

文章目录 导文微信小程序登录先写一个样式代码实现详细解释: 微信h5登录先写一个样式代码实现1. checkWeChatCode()2. getWeChatCode()页面获取登陆后的code 导文 微信小程序登录怎么实现? 微信h5登录怎么实现? 用uniapp写同一个页面&#xf…