模型Alignment之RLHF与DPO

1. RLHF (Reinforcement Learning from Human Feedback)

RLHF 是一种通过人类反馈来强化学习的训练方法,它能够让语言模型更好地理解和执行人类指令。

RLHF 的三个阶段

RLHF 的训练过程一般分为三个阶段:

  1. 监督微调(Supervised Fine-Tuning, SFT)

    • 目的:让模型初步具备按照人类指令生成文本的能力。
    • 数据:使用大量的人工标注数据,包含输入的 prompt 和对应的期望输出。
    • 训练:将这些数据作为监督学习任务,对预训练的大语言模型进行微调。
    • 结果:得到一个初步具有指令跟随能力的模型。
  2. 奖励模型训练(Reward Model Training)

    • 目的:训练一个模型来评估模型生成的文本质量。
    • 数据:收集模型在 SFT 阶段生成的多个不同回复,由人类标注人员对这些回复进行排序,以表示它们相对于给定 prompt 的优劣。
    • 训练:将这些排序数据作为训练数据,训练一个奖励模型。奖励模型的输出是一个标量值,表示生成的文本的质量。
    • 结果:得到一个能够对文本质量进行打分的奖励模型。
  3. 强化学习微调(Reinforcement Learning Fine-Tuning)

    • 目的:使用奖励模型的反馈来进一步优化模型的生成能力。
    • 方法:采用强化学习算法(如 PPO),将语言模型作为策略,奖励模型作为价值函数。
    • 过程
      • 模型生成文本。
      • 奖励模型对生成的文本打分。
      • 根据奖励信号,更新模型的参数,使其生成更高质量的文本。
    • 结果:得到一个在人类反馈下表现更优的语言模型。

技术细节

  • 奖励模型:奖励模型通常是一个分类模型,它学习将不同的文本输出映射到一个连续的奖励值。常用的模型架构包括:
    • 基于 Transformer 的模型:与语言模型类似,具有强大的序列处理能力。
    • 对比学习模型:通过比较不同文本输出的相似性来学习奖励函数。
  • 强化学习算法:PPO(Proximal Policy Optimization)是一种常用的强化学习算法,它能够在保证策略稳定性的同时,高效地更新策略。
  • 数据收集:在 RLHF 的过程中,需要不断地收集新的数据来训练奖励模型和更新策略。这些数据可以来自以下几个方面:
    • 人工标注:由人类标注人员对模型生成的文本进行评估。
    • 用户反馈:收集用户在实际使用中的反馈。
    • 模型自生成:模型通过自生成的方式产生大量数据。

2. PPO在RLHF中的应用

PPO算法概述

PPO(Proximal Policy Optimization)是一种常用的强化学习算法,在RLHF中,它被用来优化语言模型,使其生成的文本能最大化人类反馈的奖励。

核心思想:

  • 策略更新: 通过不断调整模型的参数,使得模型生成的文本能获得更高的奖励。
  • 近端策略更新: 为了保证策略的稳定性,PPO限制了新旧策略之间的差异,避免模型发生剧烈变化。

PPO在RLHF中的具体步骤

  1. 采样数据:

    • 使用当前的语言模型生成多个文本样本。
    • 将这些样本输入到奖励模型中,获得对应的奖励分数。
  2. 计算优势函数:

    • 优势函数表示一个动作的好坏程度相对于平均动作的偏离程度。
    • 在RLHF中,优势函数可以表示为:
      • 优势函数 = 奖励 - 基线
    • 基线通常是所有样本奖励的平均值或一个估计值。
  3. 更新策略:

    • 概率比: 计算新旧策略下,生成相同文本的概率比。
    • 裁剪概率比: 为了防止策略更新过大,将概率比裁剪到一个合理范围内。
    • 计算损失函数:
      • 损失函数通常包含两项:
        • 策略损失: 鼓励模型生成高奖励的文本。
        • KL散度: 限制新旧策略之间的差异。
    • 更新模型参数: 使用梯度下降法来最小化损失函数,从而更新模型的参数。

损失函数的具体形式

PPO的损失函数可以写成如下形式:

L(θ) = 𝔼[min(r_t(θ) * A_t, clip(r_t(θ), 1 - ε, 1 + ε) * A_t)] - β * KL[π_θ, π_θ_old]
  • r_t(θ): 概率比,表示新旧策略下生成相同动作的概率比。
  • A_t: 优势函数。
  • clip: 裁剪操作,将概率比裁剪到[1-ε, 1+ε]的范围内。
  • β: KL散度的系数,用于控制新旧策略之间的差异。
  • KL[π_θ, π_θ_old]: 新旧策略之间的KL散度。

  • 第一项: 鼓励模型生成高奖励的文本。当优势函数为正时,希望概率比越大越好;当优势函数为负时,希望概率比越小越好。
  • 第二项: 限制新旧策略之间的差异,保证策略的稳定性。

3. DPO (Direct Preference Optimization)

DPO的工作原理

DPO的核心思想是:通过比较不同文本生成的优劣,直接优化模型参数。具体来说,DPO会收集大量的文本对,其中每一对文本代表着人类对两个文本的偏好。然后,DPO会训练模型,使得模型能够对新的文本对进行排序,并尽可能地与人类的偏好一致。

DPO与RLHF的区别

特点RLHFDPO
奖励模型需要训练奖励模型无需训练奖励模型
优化目标最大化奖励信号直接优化人类偏好
训练过程两阶段训练(预训练+强化学习)单阶段训练
  • 与RLHF相比,DPO旨在简化过程,直接针对用户偏好优化模型,而不需要复杂的奖励建模和策略优化
  • 换句话说,DPO专注于直接优化模型的输出,以符合人类的偏好或特定目标
  • 如下所示是DPO如何工作的概述

DPO没有再去训练一个奖励模型,使用奖励模型更新大模型,而是直接对LLM进行微调。
实现DPO损失的具体公式如下所示:

  • “期望值” E \mathbb{E} E是统计学术语,表示随机变量的平均值或平均值(括号内的表达式);优化 − E -\mathbb{E} E使模型更好地与用户偏好保持一致
  • π θ \pi_{\theta} πθ变量是所谓的策略(从强化学习借用的一个术语),表示我们想要优化的LLM; π r e f \pi_{ref} πref是一个参考LLM,这通常是优化前的原始LLM(在训练开始时, π θ \pi_{\theta} πθ π r e f \pi_{ref} πref通常是相同的)
  • β \beta β是一个超参数,用于控制 π θ \pi_{\theta} πθ和参考模型之间的分歧;增加 β \beta β增加差异的影响
    π θ \pi_{\theta} πθ π r e f \pi_{ref} πref在整体损失函数上的对数概率,从而增加了两个模型之间的分歧
  • logistic sigmoid函数 σ ( ⋅ ) \sigma(\centerdot) σ()将首选和拒绝响应的对数优势比(logistic sigmoid函数中的项)转换为概率分数

DPO需要两个LLMs,一个策略(policy)模型(我们想要优化的模型)还有一个参考(reference)模型(原始的模型,保持不变)。
我们得到两个模型的输出后,对其输出的结果计算softmax并取log,然后通过target取出预测目标对应的数值。(其实就是做了一个交叉熵,和交叉熵的计算过程一模一样)。通过这个过程我们可以得到每个模型在每个回答上的 π \pi π,于是代入公式计算结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/54621.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

TensorRT-LLM——优化大型语言模型推理以实现最大性能的综合指南

引言 随着对大型语言模型 (LLM) 的需求不断增长,确保快速、高效和可扩展的推理变得比以往任何时候都更加重要。NVIDIA 的 TensorRT-LLM 通过提供一套专为 LLM 推理设计的强大工具和优化,TensorRT-LLM 可以应对这一挑战。TensorRT-LLM 提供了一系列令人印…

.net core8 使用JWT鉴权(附当前源码)

说明 该文章是属于OverallAuth2.0系列文章,每周更新一篇该系列文章(从0到1完成系统开发)。 该系统文章,我会尽量说的非常详细,做到不管新手、老手都能看懂。 说明:OverallAuth2.0 是一个简单、易懂、功能强…

基于物联网技术的智能运动检测仪设计(微信小程序)(230)

文章目录 一、前言1.1 项目介绍【1】开发背景【2】项目实现的功能【3】项目硬件模块组成1.2 设计思路【1】整体设计思路【2】整体构架1.3 项目开发背景【1】选题的意义【2】可行性分析【3】参考文献【4】摘要【5】项目背景1.4 开发工具的选择【1】设备端开发【2】微信小程序开发…

9.22学习记录

进程间通信方式 管道、有名管道、共享内存、消息队列、信号、信号量、套接字 JVM内存模型 私有:程序计数器、本地方法栈、虚拟机栈 公有部分:堆、方法区 equals和hashcode有什么区别和联系? equals默认比较两个对象的引用,但…

YOLOv8——测量高速公路上汽车的速度

引言 在人工神经网络和计算机视觉领域,目标识别和跟踪是非常重要的技术,它们可以应用于无数的项目中,其中许多可能不是很明显,比如使用这些算法来测量距离或对象的速度。 测量汽车速度基本步骤如下: 视频采集&#x…

记录一次ubuntu /mysql/redis/nginx等 系统安装

没想到还会做一次系统安装配置类的工作,没办法,碰到问题了,总得解决。 安装 &网络配置 从网上下载了ubuntu 18.04.6的安装包,用UltraISO做安装盘,到服务器上修改了下启动顺序,ubuntu的安装非常简单&a…

算法打卡:第十一章 图论part05

今日收获:并查集理论基础,寻找存在的路径 1. 并查集理论基础(from代码随想录) (1)应用场景:判断两个元素是否在同一个集合中 (2)原理讲解:通过一个一维数组…

游戏如何应对云手机刷量问题

云手机的实现原理是依托公有云和 ARM 虚拟化技术,为用户在云端提供一个安卓实例,用户可以将手机上的应用上传至云端,再通过视频流的方式,远程实时控制云手机。 市面上常见的几款云手机 原本需要手机提供的计算、存储等能力都改由…

python文件读写知识简记

简单记录一下python文件读写相关知识 一、打开文件 python使用open函数打开文件,函数原型如下 open(file, moder, buffering-1, encodingNone, errorsNone, newline None, closefdTrue, openerNone) file 文件地址 mode 文件打开模式,可设定为如下的…

深度学习实战:UNet模型的训练与测试详解

🍑个人主页:Jupiter. 🚀 所属专栏:Linux从入门到进阶 欢迎大家点赞收藏评论😊 目录 1、云实例:配置选型与启动1.1 登录注册1.2 配置 SSH 密钥对1.3 创建实例1.4 登录云实例 2、云存储:数据集上传…

优选算法之 分治-快排

目录 一、颜色分类 1. 题目链接:75. 颜色分类 2. 题目描述: 3. 解法(快排思想 - 三指针法使数组分三块) 🌴算法思路: 🌴算法流程: 🌴算法代码: 二、快…

python写windows抓包工具, 直接使用windows api

主要使用python自带的ctypes和wintypes进行类型转换和交互 # python 3.11.7 import ctypes from ctypes import wintypes import inspect import socketdef log(data):print("----------------log start---------------")try:for attr, value in inspect.getmembers…

【鸿蒙】HarmonyOS NEXT开发快速入门教程之ArkTS语法装饰器(上)

系列文章目录 【鸿蒙】HarmonyOS NEXT开发快速入门教程之ArkTS语法装饰器(上) 【鸿蒙】HarmonyOS NEXT开发快速入门教程之ArkTS语法装饰器(下) 文章目录 系列文章目录前言一、ArkTS基本介绍1、 ArkTS组成2、组件参数和属性2.1、区…

YOLOV8 OpenCV + usb 相机 实时识别

1 OpenCV 读相机 import cv2cap cv2.VideoCapture(0) while (1):# get a frameret, frame cap.read()# show a framecv2.imshow("capture", frame)if cv2.waitKey(1) & 0xFF ord(q):# cv2.imwrite("/opt/code/image/fangjian2.jpeg", frame)#passb…

Linux基础知识-1

Linux和Windows最大差异:目录。在Windows中,磁盘是被分成了很多区的,比如C盘,D盘,不同的文件放在不同的盘下面。下图为Windows的磁盘管理,可以看到磁盘0被划分为了不同的区域,C盘,D盘等&#xf…

[深度学习]Pytorch框架

1 深度学习简介 应用领域:语音交互、文本处理、计算机视觉、深度学习、人机交互、知识图谱、分析处理、问题求解2 发展历史 1956年人工智能元年2016年国内开始关注深度学习2017年出现Transformer框架2018年Bert和GPT出现2022年,chatGPT出现,进入AIGC发展阶段3 PyTorch框架简…

Leetcode 第 139 场双周赛题解

Leetcode 第 139 场双周赛题解 Leetcode 第 139 场双周赛题解题目1:3285. 找到稳定山的下标思路代码复杂度分析 题目2:3286. 穿越网格图的安全路径思路代码复杂度分析 题目3:3287. 求出数组中最大序列值思路代码复杂度分析 题目4:…

2024 年 CSS 终于增加了垂直居中特性,效率翻倍!

在 2024 年的Chrome 123 版本中&#xff0c; CSS 原生可以使用 1 个 CSS 属性 align-content: center进行垂直居中。 有何魅力&#xff1f; 这个特性的魅力在哪儿呢&#xff1f;我举例给你看一下 <div style"align-content:center; height:200px; background: #614e…

计算机网络:物理层 --- 基本概念、编码与调制

目录 一. 物理层的基本概念 二. 数据通信系统的模型 三. 编码 3.1 基本概念 3.2 不归零制编码 3.3 归零制编码 3.4 曼切斯特编码 3.5 差分曼切斯特编码 ​编辑 四. 调制 4.1 调幅 4.2 调频 4.3 调相 4.4 混合调制 今天我们讲的是物理…

影刀RPA实战:网页爬虫之携程酒店数据

1.实战目标 大家对于携程并不陌生&#xff0c;我们出行定机票&#xff0c;住酒店&#xff0c;去旅游胜地游玩&#xff0c;都离不开这样一个综合性的网站为我们提供信息&#xff0c;同时&#xff0c;如果你也是做旅游的公司&#xff0c;那携程就是一个业界竞争对手&#xff0c;…