模型Alignment之RLHF与DPO

1. RLHF (Reinforcement Learning from Human Feedback)

RLHF 是一种通过人类反馈来强化学习的训练方法,它能够让语言模型更好地理解和执行人类指令。

RLHF 的三个阶段

RLHF 的训练过程一般分为三个阶段:

  1. 监督微调(Supervised Fine-Tuning, SFT)

    • 目的:让模型初步具备按照人类指令生成文本的能力。
    • 数据:使用大量的人工标注数据,包含输入的 prompt 和对应的期望输出。
    • 训练:将这些数据作为监督学习任务,对预训练的大语言模型进行微调。
    • 结果:得到一个初步具有指令跟随能力的模型。
  2. 奖励模型训练(Reward Model Training)

    • 目的:训练一个模型来评估模型生成的文本质量。
    • 数据:收集模型在 SFT 阶段生成的多个不同回复,由人类标注人员对这些回复进行排序,以表示它们相对于给定 prompt 的优劣。
    • 训练:将这些排序数据作为训练数据,训练一个奖励模型。奖励模型的输出是一个标量值,表示生成的文本的质量。
    • 结果:得到一个能够对文本质量进行打分的奖励模型。
  3. 强化学习微调(Reinforcement Learning Fine-Tuning)

    • 目的:使用奖励模型的反馈来进一步优化模型的生成能力。
    • 方法:采用强化学习算法(如 PPO),将语言模型作为策略,奖励模型作为价值函数。
    • 过程
      • 模型生成文本。
      • 奖励模型对生成的文本打分。
      • 根据奖励信号,更新模型的参数,使其生成更高质量的文本。
    • 结果:得到一个在人类反馈下表现更优的语言模型。

技术细节

  • 奖励模型:奖励模型通常是一个分类模型,它学习将不同的文本输出映射到一个连续的奖励值。常用的模型架构包括:
    • 基于 Transformer 的模型:与语言模型类似,具有强大的序列处理能力。
    • 对比学习模型:通过比较不同文本输出的相似性来学习奖励函数。
  • 强化学习算法:PPO(Proximal Policy Optimization)是一种常用的强化学习算法,它能够在保证策略稳定性的同时,高效地更新策略。
  • 数据收集:在 RLHF 的过程中,需要不断地收集新的数据来训练奖励模型和更新策略。这些数据可以来自以下几个方面:
    • 人工标注:由人类标注人员对模型生成的文本进行评估。
    • 用户反馈:收集用户在实际使用中的反馈。
    • 模型自生成:模型通过自生成的方式产生大量数据。

2. PPO在RLHF中的应用

PPO算法概述

PPO(Proximal Policy Optimization)是一种常用的强化学习算法,在RLHF中,它被用来优化语言模型,使其生成的文本能最大化人类反馈的奖励。

核心思想:

  • 策略更新: 通过不断调整模型的参数,使得模型生成的文本能获得更高的奖励。
  • 近端策略更新: 为了保证策略的稳定性,PPO限制了新旧策略之间的差异,避免模型发生剧烈变化。

PPO在RLHF中的具体步骤

  1. 采样数据:

    • 使用当前的语言模型生成多个文本样本。
    • 将这些样本输入到奖励模型中,获得对应的奖励分数。
  2. 计算优势函数:

    • 优势函数表示一个动作的好坏程度相对于平均动作的偏离程度。
    • 在RLHF中,优势函数可以表示为:
      • 优势函数 = 奖励 - 基线
    • 基线通常是所有样本奖励的平均值或一个估计值。
  3. 更新策略:

    • 概率比: 计算新旧策略下,生成相同文本的概率比。
    • 裁剪概率比: 为了防止策略更新过大,将概率比裁剪到一个合理范围内。
    • 计算损失函数:
      • 损失函数通常包含两项:
        • 策略损失: 鼓励模型生成高奖励的文本。
        • KL散度: 限制新旧策略之间的差异。
    • 更新模型参数: 使用梯度下降法来最小化损失函数,从而更新模型的参数。

损失函数的具体形式

PPO的损失函数可以写成如下形式:

L(θ) = 𝔼[min(r_t(θ) * A_t, clip(r_t(θ), 1 - ε, 1 + ε) * A_t)] - β * KL[π_θ, π_θ_old]
  • r_t(θ): 概率比,表示新旧策略下生成相同动作的概率比。
  • A_t: 优势函数。
  • clip: 裁剪操作,将概率比裁剪到[1-ε, 1+ε]的范围内。
  • β: KL散度的系数,用于控制新旧策略之间的差异。
  • KL[π_θ, π_θ_old]: 新旧策略之间的KL散度。

  • 第一项: 鼓励模型生成高奖励的文本。当优势函数为正时,希望概率比越大越好;当优势函数为负时,希望概率比越小越好。
  • 第二项: 限制新旧策略之间的差异,保证策略的稳定性。

3. DPO (Direct Preference Optimization)

DPO的工作原理

DPO的核心思想是:通过比较不同文本生成的优劣,直接优化模型参数。具体来说,DPO会收集大量的文本对,其中每一对文本代表着人类对两个文本的偏好。然后,DPO会训练模型,使得模型能够对新的文本对进行排序,并尽可能地与人类的偏好一致。

DPO与RLHF的区别

特点RLHFDPO
奖励模型需要训练奖励模型无需训练奖励模型
优化目标最大化奖励信号直接优化人类偏好
训练过程两阶段训练(预训练+强化学习)单阶段训练
  • 与RLHF相比,DPO旨在简化过程,直接针对用户偏好优化模型,而不需要复杂的奖励建模和策略优化
  • 换句话说,DPO专注于直接优化模型的输出,以符合人类的偏好或特定目标
  • 如下所示是DPO如何工作的概述

DPO没有再去训练一个奖励模型,使用奖励模型更新大模型,而是直接对LLM进行微调。
实现DPO损失的具体公式如下所示:

  • “期望值” E \mathbb{E} E是统计学术语,表示随机变量的平均值或平均值(括号内的表达式);优化 − E -\mathbb{E} E使模型更好地与用户偏好保持一致
  • π θ \pi_{\theta} πθ变量是所谓的策略(从强化学习借用的一个术语),表示我们想要优化的LLM; π r e f \pi_{ref} πref是一个参考LLM,这通常是优化前的原始LLM(在训练开始时, π θ \pi_{\theta} πθ π r e f \pi_{ref} πref通常是相同的)
  • β \beta β是一个超参数,用于控制 π θ \pi_{\theta} πθ和参考模型之间的分歧;增加 β \beta β增加差异的影响
    π θ \pi_{\theta} πθ π r e f \pi_{ref} πref在整体损失函数上的对数概率,从而增加了两个模型之间的分歧
  • logistic sigmoid函数 σ ( ⋅ ) \sigma(\centerdot) σ()将首选和拒绝响应的对数优势比(logistic sigmoid函数中的项)转换为概率分数

DPO需要两个LLMs,一个策略(policy)模型(我们想要优化的模型)还有一个参考(reference)模型(原始的模型,保持不变)。
我们得到两个模型的输出后,对其输出的结果计算softmax并取log,然后通过target取出预测目标对应的数值。(其实就是做了一个交叉熵,和交叉熵的计算过程一模一样)。通过这个过程我们可以得到每个模型在每个回答上的 π \pi π,于是代入公式计算结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/54621.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

TensorRT-LLM——优化大型语言模型推理以实现最大性能的综合指南

引言 随着对大型语言模型 (LLM) 的需求不断增长,确保快速、高效和可扩展的推理变得比以往任何时候都更加重要。NVIDIA 的 TensorRT-LLM 通过提供一套专为 LLM 推理设计的强大工具和优化,TensorRT-LLM 可以应对这一挑战。TensorRT-LLM 提供了一系列令人印…

.net core8 使用JWT鉴权(附当前源码)

说明 该文章是属于OverallAuth2.0系列文章,每周更新一篇该系列文章(从0到1完成系统开发)。 该系统文章,我会尽量说的非常详细,做到不管新手、老手都能看懂。 说明:OverallAuth2.0 是一个简单、易懂、功能强…

YOLOv8——测量高速公路上汽车的速度

引言 在人工神经网络和计算机视觉领域,目标识别和跟踪是非常重要的技术,它们可以应用于无数的项目中,其中许多可能不是很明显,比如使用这些算法来测量距离或对象的速度。 测量汽车速度基本步骤如下: 视频采集&#x…

游戏如何应对云手机刷量问题

云手机的实现原理是依托公有云和 ARM 虚拟化技术,为用户在云端提供一个安卓实例,用户可以将手机上的应用上传至云端,再通过视频流的方式,远程实时控制云手机。 市面上常见的几款云手机 原本需要手机提供的计算、存储等能力都改由…

python文件读写知识简记

简单记录一下python文件读写相关知识 一、打开文件 python使用open函数打开文件,函数原型如下 open(file, moder, buffering-1, encodingNone, errorsNone, newline None, closefdTrue, openerNone) file 文件地址 mode 文件打开模式,可设定为如下的…

深度学习实战:UNet模型的训练与测试详解

🍑个人主页:Jupiter. 🚀 所属专栏:Linux从入门到进阶 欢迎大家点赞收藏评论😊 目录 1、云实例:配置选型与启动1.1 登录注册1.2 配置 SSH 密钥对1.3 创建实例1.4 登录云实例 2、云存储:数据集上传…

【鸿蒙】HarmonyOS NEXT开发快速入门教程之ArkTS语法装饰器(上)

系列文章目录 【鸿蒙】HarmonyOS NEXT开发快速入门教程之ArkTS语法装饰器(上) 【鸿蒙】HarmonyOS NEXT开发快速入门教程之ArkTS语法装饰器(下) 文章目录 系列文章目录前言一、ArkTS基本介绍1、 ArkTS组成2、组件参数和属性2.1、区…

YOLOV8 OpenCV + usb 相机 实时识别

1 OpenCV 读相机 import cv2cap cv2.VideoCapture(0) while (1):# get a frameret, frame cap.read()# show a framecv2.imshow("capture", frame)if cv2.waitKey(1) & 0xFF ord(q):# cv2.imwrite("/opt/code/image/fangjian2.jpeg", frame)#passb…

Linux基础知识-1

Linux和Windows最大差异:目录。在Windows中,磁盘是被分成了很多区的,比如C盘,D盘,不同的文件放在不同的盘下面。下图为Windows的磁盘管理,可以看到磁盘0被划分为了不同的区域,C盘,D盘等&#xf…

[深度学习]Pytorch框架

1 深度学习简介 应用领域:语音交互、文本处理、计算机视觉、深度学习、人机交互、知识图谱、分析处理、问题求解2 发展历史 1956年人工智能元年2016年国内开始关注深度学习2017年出现Transformer框架2018年Bert和GPT出现2022年,chatGPT出现,进入AIGC发展阶段3 PyTorch框架简…

2024 年 CSS 终于增加了垂直居中特性,效率翻倍!

在 2024 年的Chrome 123 版本中&#xff0c; CSS 原生可以使用 1 个 CSS 属性 align-content: center进行垂直居中。 有何魅力&#xff1f; 这个特性的魅力在哪儿呢&#xff1f;我举例给你看一下 <div style"align-content:center; height:200px; background: #614e…

计算机网络:物理层 --- 基本概念、编码与调制

目录 一. 物理层的基本概念 二. 数据通信系统的模型 三. 编码 3.1 基本概念 3.2 不归零制编码 3.3 归零制编码 3.4 曼切斯特编码 3.5 差分曼切斯特编码 ​编辑 四. 调制 4.1 调幅 4.2 调频 4.3 调相 4.4 混合调制 今天我们讲的是物理…

影刀RPA实战:网页爬虫之携程酒店数据

1.实战目标 大家对于携程并不陌生&#xff0c;我们出行定机票&#xff0c;住酒店&#xff0c;去旅游胜地游玩&#xff0c;都离不开这样一个综合性的网站为我们提供信息&#xff0c;同时&#xff0c;如果你也是做旅游的公司&#xff0c;那携程就是一个业界竞争对手&#xff0c;…

[Spring]Spring MVC 请求和响应及用到的注解

文章目录 一. Maven二. SpringBoot三. Spring MVC四. MVC注解1. RequestMapping2. RequestParam3. PathVariable4. RequestPart5. CookieValue6. SessionAttribute7. RequestHeader8. RestController9. ResponseBody 五. 请求六. 响应 一. Maven Maven是⼀个项⽬管理⼯具。基于…

Python | Leetcode Python题解之第421题数组中两个数的最大异或值

题目&#xff1a; 题解&#xff1a; class Trie:def __init__(self):# 左子树指向表示 0 的子节点self.left None# 右子树指向表示 1 的子节点self.right Noneclass Solution:def findMaximumXOR(self, nums: List[int]) -> int:# 字典树的根节点root Trie()# 最高位的二…

Java基础知识扫盲

目录 Arrays.sort的底层实现 BigDecimal(double)和BigDecimal(String)有什么区别 Char可以存储一个汉字吗 Java中的Timer定时调度任务是咋实现的 Java中的序列化机制是咋实现的 Java中的注解是干嘛的 Arrays.sort的底层实现 Arrays.sort是Java中提供的对数组进行排序的…

LabVIEW编程能力如何能突飞猛进

要想让LabVIEW编程能力实现突飞猛进&#xff0c;需要采取系统化的学习方法&#xff0c;并结合实际项目进行不断的实践。以下是一些提高LabVIEW编程能力的关键策略&#xff1a; 1. 扎实掌握基础 LabVIEW的编程本质与其他编程语言不同&#xff0c;它是基于图形化的编程方式&…

使用 UWA Gears 定位游戏内存问题

UWA Gears 是UWA最新发布的无SDK性能分析工具。针对移动平台&#xff0c;提供了实时监测和截帧分析功能&#xff0c;帮助您精准定位性能热点&#xff0c;提升应用的整体表现。 内存不足、内存泄漏和过度使用等问题&#xff0c;常常导致游戏出现卡顿、崩溃&#xff0c;甚至影响…

CSS | 如何来避免 FOUC(无样式内容闪烁)现象的发生?

一、什么是 FOUC(无样式内容闪烁)? ‌FOUC&#xff08;Flash of Unstyled Content&#xff09;是指网页在加载过程中&#xff0c;由于CSS样式加载延迟或加载顺序不当&#xff0c;导致页面出现闪烁或呈现出未样式化的内容的现象。‌ 这种现象通常发生在HTML文档已经加载&…

Redis数据结构之哈希表

这里的哈希表说的是value的类型是哈希表 一.相关命令 1.hset key field value 一次可以设置多个 返回值是设置成功的个数 注意&#xff0c;哈希表中的键值对&#xff0c;键是唯一的而值可以重复 所以有下面的结果&#xff1a; key中原来已经有了f1&#xff0c;所以再使用hse…