Monte Carlo Tree Search Boosts Reasoning via Iterative Preference Learning

Monte Carlo Tree Search Boosts Reasoning via Iterative Preference Learning

image.png
Github:https://github.com/YuxiXie/MCTS-DPO

一、动机

大语言模型在偏好对齐环节可以提高模型的性能。目前有诸多工作尝试将偏好对齐通过迭代的形式进行改进:

It involves a cycle that begins with the current policy, progresses through the collection and analysis of data to generate new preference data, and uses this data to update the policy

目前有一些工作尝试这种迭代式对齐:

  • Reinforced self-training (rest) for language modeling
  • Large language models can self-improve
  • Self-rewarding language models

在强化学习生态中,一个典型的工作AlphaZero就是采用这种迭代式的训练,其通过结合神经网络、强化学习以及蒙特卡洛树搜索(MCTS)实现这个迭代式过程。

然而在强化学习中,MCTS是一个N-step自举法,即对整个trajectory进行多步Reward计算。然而在大模型中,如何使用这种N-step自举法?目前很多preference pair都是站在样本(instance-lever)层面进行打标的,这会导致损失一些细节信息,对于MCTS这种需要以step层面进行学习的方式会不友好。

Conventionally, preference data is collected at the instance level. The instance-level approach employs sparse supervision, which can lose important information and may not optimally leverage the potential of MCTS in improving the LLMs

二、方法

本文提出一种迭代式的DPO算法,通过MCTS算法来抽取偏好数据,并用迭代式地训练Policy模型。整个流程大致如下所示:

  • 首先初始化一个policy模型 π θ ( 0 ) \pi_{\theta^{(0)}} πθ(0),以及一个prompt数据集 D P \mathcal{D}_{\mathcal{P}} DP
  • 在第 i i i次迭代时,先采样一组prompt,并使用上一轮的policy模型 π θ ( i − 1 ) \pi_{\theta^{(i-1)}} πθ(i1)为每个prompt生成若干个response;
  • 使用一个不断进化的Reward标准来抽取偏好数据 D i \mathcal{D}_i Di
  • 基于这个新的偏好数据,训练新的policy模型 π θ ( i ) \pi_{\theta^{(i)}} πθ(i)

这一过程比较类似于online版本的DPO偏好训练。
image.png
在抽取偏好数据时,采用MCTS算法,将instance-level的偏好转换为step-wise。

2.1 MCTS获得 Step-wise偏好数据

假设 x x x为prompt, s t s_t st表示大模型生成推理过程中reasoning chain中的前 t t t步, a a a表示从 s t s_t st进入下一个时刻 s t + 1 s_{t+1} st+1的动作,换句话说 a a a表示当前时刻要执行的推理步骤(动作)。
为此,当前所有可能的推理步骤(动作空间)可以表示为 π θ ( a ∣ x , s t ) \pi_{\theta}(a|x, s_t) πθ(ax,st)。MCTS会根据当前已知的状态预测未来N步骤之后的奖励情况,例如预测下一最佳推理状态可表示为:
image.png
其中 Q ( s t , a ) Q(s_t, a) Q(st,a)表示基于当前已有的reasoning chain,完成当前的推理 a a a后会得到的Reward奖励值。 N ( s t ) 1 + N ( s t + 1 ) \frac{\sqrt{N(s_t)}}{1+N(s_{t+1})} 1+N(st+1)N(st) 则用于平衡探索(exploring)与利用(exploiting)之间的关系,image.png

  • 探索:更多地试探其他可能的推理路径;
  • 利用:取奖励最大的动作 a a a作为下一步的推理。

为了确保在树搜索过程中,在搜索过程中,采用Self-evaluation。evaluation的模板 prompt eval \text{prompt}_{\text{eval}} prompteval如下所示:
image.png
基于这个evaluation prompt prompt eval \text{prompt}_{\text{eval}} prompteval,让当前的额policy模型充当一个evaluator,对当前前 t t t步骤的推理结果 s t s_t st进行预测,得到一个score:

C ( s t ) = π θ ( A ∣ prompt e v a l , x , s t ) \mathcal{C}(s_t)=\pi_{\theta}(\text{A}|\text{prompt}_{eval}, x, s_t) C(st)=πθ(Aprompteval,x,st)

Self-evaluation相关工作:Decomposition enhances reasoning via self-evaluation guided decoding.

另外,如果大模型生成的 s t s_t st在格式上完成了推理(即整个生成已经完成)且正确,那么记作 O ( s t ) = 1 \mathcal{O}(s_t)=1 O(st)=1,若未完成则为0,若推理结果错误则为-1。
为此,可以得到一个Reward打分函数:

R ( s t ) = O ( s t ) + C ( s t ) R(s_t) = \mathcal{O}(s_t) + \mathcal{C}(s_t) R(st)=O(st)+C(st)

当整个树搜索完成扩张(Expand)之后,后面需要进行回溯(Backup),更新公式如下:
image.png
其中 N ( s t ) N(s_t) N(st)是一个计数器。
假设整棵树的深度为 T T T,即reasoning chain最多有 T T T个步骤。在每个步骤时 t ∈ [ 1 , T ] t\in[1, T] t[1,T],都将会构建一个pair,其中正样本为具有最高 Q Q Q值的路径,负样本则为具有最低 Q Q Q值的路径。因此,最终可以获得 T T T个pair。

2.2 迭代式DPO

考虑到偏好数据中可能会存在噪声,此时采用conservation version DPO

参考文献:A note on dpo with noisy preferences & relationship to ipo.
conservation version是指对于一些可能是噪声的偏好数据,将 ( y w > y l ) (y_w>y_l) (yw>yl)逆转为 ( y l > y w ) (y_l>y_w) (yl>yw)

定义一个标签平滑系数:
image.png
DPO中两个偏好样本之间的Reward差为:

h π θ y w , y l = log ⁡ π θ ( y w ∣ x ) π r e f ( y w ∣ x ) − log ⁡ π θ ( y l ∣ x ) π r e f ( y l ∣ x ) h^{y_w, y_l}_{\pi_{\theta}}=\log{\frac{\pi_{\theta}(y_w|x)}{\pi_{ref}(y_w|x)}} - \log{\frac{\pi_{\theta}(y_l|x)}{\pi_{ref}(y_l|x)}} hπθyw,yl=logπref(ywx)πθ(ywx)logπref(ylx)πθ(ylx)

那么,loss定义为:
image.png
相当于一部分 ( y w , y l ) (y_w, y_l) (yw,yl)会被反转。

三、实验

基座模型:Mistral-7B,
基座模型进行SFT训练,训练数据为:https://huggingface.co/datasets/akjindal53244/Arithmo-Data
训练设备:4台A100(40G)
训练细节:
image.png
数据集:

  • GSM8K、MATH
  • ARC、CSQA、OpenBookQA、AI2Science

实验结果:
image.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/50757.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

装修行业办公家具销售公司网站带模版 附带完整的源代码包以及搭建部署教程

系统概述 这款网站源码系统是针对装修行业办公家具销售公司的特点定制的,它融合了现代化的设计理念和先进的技术架构,旨在为用户提供极佳的浏览和购物体验。系统采用了响应式设计,能够自适应不同设备的屏幕尺寸,确保用户在手机、…

【ai】 2005年 rule based expert system学习笔记1

PPT 是2005年的? Negnevitsky, Pearson Education 使用两种推理引擎的选择 backward chaining(逆向链接)推理过程 backward chaining(逆向链接)推理过程的GPT解释 这幅图展示了一个基于规则的专家系统如何通过backward chaining(逆向链接)推理过程来达到最终的推理目标…

大数据-55 Kafka sh脚本使用 与 JavaAPI使用 topics.sh producer.sh consumer.sh kafka-clients

点一下关注吧!!!非常感谢!!持续更新!!! 目前已经更新到了: Hadoop(已更完)HDFS(已更完)MapReduce(已更完&am…

禁毒教育展厅应如何创新展示方式,提升教育意义?

为了深刻揭示毒品的危害,促进禁毒知识的广泛传播,并显著提升公众的防范意识,禁毒教育展厅的推广举措正紧锣密鼓地展开。在这一关键进程中,展厅的空间布局与内容设计的合理性与针对性成为了至关重要的环节。接下来,我们…

angular入门基础教程(二)第一个angular组件

ng中的语法跟vue中是一样的插值语法,其实也是早期vue抄的ng的思路,使用{{variable}}形式,vue借鉴了ng和react,这个我们就不多了。 新建一个子组件 在项目根目录下面,执行 ng g component ./components/UserList这样…

【RL】强化学习入门:从基础到应用

本篇文章是博主强化学习RL领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章强化学习: 强化学习…

JDBC介绍及使用

目录 JDBC概述 JDBC概念 JDBC本质 JDBC好处 JDBC快速入门 JDBC API详解 DriverManager Connection Statement ResultSet PreparedStatement 数据库连接池 数据库连接池简介 数据库连接池实现 Driud使用 JDBC练习 JDBC概述 JDBC概念 JDBC 就是使用Java语言操作…

Docker Compose V2 安装 ClickHouse v20.6.8.5 经验分享

前言 ClickHouse 是一款开源的分布式列式数据库管理系统,专门设计用于高性能的大数据分析和查询。 目前项目中用到的一个场景是将mongo的数据同步到clickhouse,使用clickhouse做报表,后续也将分享同步和使用方案 使用 Docker Compose 部署单机版,小项目和自己测试够用了,生…

海外短剧CPS系统,平台短剧出海推广方案

随着国内短剧市场的蓬勃发展与国际化趋势的加速,海外观众对于高质量、富有创意的短剧内容需求日益增长。在此背景下,搭建一个高效、便捷的海外短剧CPS(Cost Per Sales,按销售分润)分销系统平台,能为内容创作…

实战内测-某内测项目站点FUZZ到Sql注入

0x1 前言 下面给师傅们分享的案例呢是前段时间实战的一个站点,也是我朋友前段时间让我测的一个站点。整体的测试流程也还算ok,然后里面有些细节要是对师傅们有帮助可以收藏下,后面主要是利用FUZZ打了一个sql注入漏洞上去。 0x2 fuzz和sql结…

Halcon Blob分析

斑点分析的思路:在图像中,相关对象的像素可以通过其灰度值来识别。例如下图的组织颗粒。这些颗粒是凉的,而液体是暗的,通过选择明亮像素(阈值),可以很容易地检测到颗粒。在需要应用中,这种简单的暗像素和亮…

HarmonyOS持久化存储数据Preference

Preference首选项 首选项:首选项为应用提供Key-Value键值型的数据处理能力,支持应用持久化轻量级数据,并对其修改和查询。数据存储形式为键值对,键的类型为字符串型,值的存储数据类型包括数字型、字符型、布尔型以及这…

【优秀python web设计】基于Python flask的猫眼电影可视化系统,可视化用echart,前端Layui,数据库用MySQL,包括爬虫

1 绪论 1.1 设计背景及目的 猫眼电影作为国内知名的电影信息网站,拥有海量的电影信息、票房数据和用户评价数据。这些数据对于电影市场的研究和分析具有重要意义。然而,由于数据的复杂性和数据来源的多样性,如何有效地采集、存储和展示这些数…

复现波恩大学的“LiDiff:基于扩散模型实现3D LiDAR场景补全!”(点云补全)项目

本文的主要工作就是复现下述论文中的算法。 该论文全称:Scaling Diffusion Models to Real-World 3D LiDAR Scene Completion 一、准备工作 首先通读readme.md文件的内容,了解所需要的相关依赖和数据等内容。 一定要多读几遍,不要扫一眼就…

[Linux安全运维] LAMP 环境搭建保姆级教学(Apache + MySQL + PHP) ~~

LAMP LAMP 是一种网站技术,可以实现动态的网站页面部署。 1. LAMP概述 1 .1构成 Linux: 简介: Linux 是一种开源的操作系统,以其稳定性和安全性而著称。在 LAMP 堆栈中,它作为服务器操作系统运行。作用: 为应用程序提供一个稳定、安全的运…

【linux】在多核CPU下,好像看到不同进程在不同CPU调度

在2353这行打印的情况来看,操作系统好像给不同的进程分配不同的CPU,从上图来看,同一个进程好像基本使用的相同的CPU: 其实摸索syscall文件系统操作,本意是想找到内核文件系统中文件的创建,写入,…

3DMAX神经网络插件Neuron使用方法详解

3DMAX神经网络插件Neuron使用方法 3DMAX神经网络插件Neuron,从一系列样条曲线创建具有分支结构的几何体。适用于如神经网络、血管、树枝等形状的3D建模。 【适用版本】 3dMax2016及更高(不仅限于此范围) 【安装方法】 Neuron插件无需安装&a…

windows 暂停更新

使用windows 系统的伙伴都深受其扰,动不动就要强制更新,并且无法长时间关闭更新。这里推荐一个工具来禁止更新。越来越多的工程师可能会逐渐放弃windows ,真的太冗杂了,linux 的桌面和命令行越来越好用。 下载地址 https://github.com/WereD…

Renesa Version Board开发RT-Thread 之I2C驱动应用(SHT20)

目录 概述 1 硬件接口介绍 1.1 Version Board上的I2C硬件接口 1.2 SHT20 1.2.1 SHT20简介 1.2.2 SHT-20模块电路 2 软件实现 2.1 软件版本信息 2.2 RT-Thread Studio创建项目 2.3 FSP配置I2C接口 2.4 使能Sensor驱动 3 RT-Thread驱动架构 3.1 接口函数 3.1.1 …

增量学习中Task incremental、Domain incremental、Class incremental 三种学习模式的概念及代表性数据集?

1 概念 在持续学习领域,Task incremental、Domain incremental、Class incremental 是三种主要的学习模式,它们分别关注不同类型的任务序列和数据分布变化。 1.1 Task Incremental Learning (Task-incremental) 任务增量学习,也称为任务增…