PPO和文本生成

策略梯度

策略梯度(Policy Gradient)方法梯度的计算如下:
E ( a t , s t ) ∈ π θ [ A ^ t ∇ θ log ⁡ π θ ( a t ∣ s t ) ] \mathbb E_{(a_t,s_t) \in \pi_\theta}[\hat A_t \nabla_ \theta \log \pi_\theta(a_t | s_t)] E(at,st)πθ[A^tθlogπθ(atst)] A ^ t \hat A_t A^t是优势函数(advantage function) A t A_t At的估计。
A t = Q ( s t , a t ) − V ( s t ) A_t=Q(s_t, a_t)-V(s_t) At=Q(st,at)V(st)优势函数计算的是,在该状态下采取这个行动的奖励与在该状态下的平均奖励的差值。
上面的导数可以通过对下面的目标求导获得:
L P G ( θ ) = E ( a t , s t ) ∈ π θ [ A ^ t log ⁡ π θ ( a t ∣ s t ) ] L^{PG}(\theta)=\mathbb E_{(a_t,s_t) \in \pi_\theta}[\hat A_t \log \pi_\theta(a_t | s_t)] LPG(θ)=E(at,st)πθ[A^tlogπθ(atst)]

PPO(Proximal Policy Optimization)

PPO有两个形式,其中一种形式PPO_CLIP的优化目标函数是:
L C L I P ( θ ) = E ( a t , s t ) ∈ π θ [ min ⁡ ( r t ( θ ) A ^ t , c l i p ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) A ^ t ) ] (1) L^{CLIP}(\theta)=\mathbb E_{(a_t,s_t) \in \pi_\theta}[\min(r_t(\theta)\hat A_t, clip(r_t(\theta), 1-\epsilon, 1+\epsilon)\hat A_t)] \tag{1} LCLIP(θ)=E(at,st)πθ[min(rt(θ)A^t,clip(rt(θ),1ϵ,1+ϵ)A^t)](1)其中 r t ( θ ) = π θ ( a t ∣ s t ) π θ o l d ( a t ∣ s t ) r_t(\theta)=\frac{\pi_\theta(a_t | s_t)}{\pi_{\theta_{old}}(a_t | s_t)} rt(θ)=πθold(atst)πθ(atst)
PPO算法中的advantage用下面的公式估计:
A ^ t = δ t + ( γ λ ) δ t + 1 + ⋯ + ( γ λ ) T − t + 1 δ T − 1 \hat A^t = \delta^t + (\gamma \lambda)\delta_{t+1} + \cdots+ (\gamma \lambda)^{T-t+1}\delta_{T-1} A^t=δt+(γλ)δt+1++(γλ)Tt+1δT1其中 δ t = r t + γ V ( s t + 1 ) − V ( s t ) \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) δt=rt+γV(st+1)V(st)
通常情况下,我们用一个网络学习策略和价值函数,这样策略和价值函数能共享参数,那么就需要结合策略代理和价值函数误差项的损失函数。再加上熵奖励(entropy bonus)来以确保足够的探索,优化目标变为:
L C L I P + V F + S ( θ ) = E ( a t , s t ) ∈ π θ [ L t C L I P ( θ ) − c 1 L t V F ( θ ) + c 2 S [ π θ ] ( s t ) ] L^{CLIP+VF+S}(\theta)=\mathbb E_{(a_t,s_t) \in \pi_\theta}[L_t^{CLIP}(\theta) - c_1 L_t^{VF}(\theta) + c_2 S[\pi_\theta](s_t)] LCLIP+VF+S(θ)=E(at,st)πθ[LtCLIP(θ)c1LtVF(θ)+c2S[πθ](st)]其中 L t V F ( θ ) = ( V θ ( s t ) − V t t a r g ) 2 L_t^{VF}(\theta)=(V_\theta(s_t)-V_t^{targ})^2 LtVF(θ)=(Vθ(st)Vttarg)2是价值函数的误差项,S是entropy bonus。

文本生成

在文本生成的情况下,给一个prompt,生成完整的response,是一个episode。动作空间是vocabulary。每生成一个词是一个时间步。

公式(1)需要advantage的估计,为了计算advantage,我们需要定义奖励(reward) r r r和估计状态价值函数 V ( s ) V(s) V(s)

用于强化学习的reward计算如下:
R ( x , y ) = r ( x , y ) − β log ⁡ π ( y ∣ x ) ρ ( y ∣ x ) R(x,y) = r(x,y) - \beta\log\frac{\pi(y|x)}{\rho(y|x)} R(x,y)=r(x,y)βlogρ(yx)π(yx)x是问题,y是回答, r ( x , y ) r(x,y) r(x,y)是reward model的输出,也就是下面代码中的score。注意这里reward model的输出称之为score,送入强化学习部分的才称为reward。 π ( y ∣ x ) \pi(y|x) π(yx)是要学习的生成模型, ρ ( y ∣ x ) \rho(y|x) ρ(yx)是参数固定的原始生成模型。
在trl库中reward的计算如下:

   def compute_rewards(self,scores: torch.FloatTensor,logprobs: torch.FloatTensor,ref_logprobs: torch.FloatTensor,masks: torch.LongTensor,):"""Compute per token rewards from scores and KL-penalty.Args:scores (`torch.FloatTensor`):Scores from the reward model, shape (`batch_size`)logprobs (`torch.FloatTensor`):Log probabilities of the model, shape (`batch_size`, `response_length`)ref_logprobs (`torch.FloatTensor`):Log probabilities of the reference model, shape (`batch_size`, `response_length`)"""rewards, non_score_rewards = [], []for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks):# compute KL penalty (from difference in logprobs)kl = self._kl_penalty(logprob, ref_logprob)non_score_reward = -self.kl_ctl.value * klnon_score_rewards.append(non_score_reward)reward = non_score_reward.clone()last_non_masked_index = mask.nonzero()[-1]# reward is preference model score + KL penaltyreward[last_non_masked_index] += scorerewards.append(reward)return torch.stack(rewards), torch.stack(non_score_rewards)

可以看到上面的实现中,只将reward model的score添加到最后一个token的reward上,其他token的reward来自当前模型和 原始生成模型之间KL散度。这么做是为了减轻奖励模型的过度优化问题。

在trl库中用一个网络AutoModelForCausalLMWithValueHead学习策略 π θ ( s ) \pi_\theta(s) πθ(s)和状态价值函数 V ( s ) V(s) V(s)。AutoModelForCausalLMWithValueHead在普通AutoModelForCausalLM模型上了一个线性层nn.Linear(hidden_size, 1),用于估计状态价值函数 V ( s ) V(s) V(s)
普通AutoModelForCausalLM模型估计token概率即可作为策略 π θ ( s ) \pi_\theta(s) πθ(s)

在trl库中advantage的计算如下:

    def compute_advantages(self: torch.FloatTensor,values: torch.FloatTensor, # AutoModelForCausalLMWithValueHead输出的状态价值估计Vrewards: torch.FloatTensor, # compute_rewards函数计算得到的rewardsmask: torch.FloatTensor,):lastgaelam = 0advantages_reversed = []gen_len = rewards.shape[-1]values = values * maskrewards = rewards * maskfor t in reversed(range(gen_len)):nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t]lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelamadvantages_reversed.append(lastgaelam)advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)returns = advantages + valuesadvantages = masked_whiten(advantages, mask)advantages = advantages.detach()return values, advantages, returns

完整的PPO算法如下:
在这里插入图片描述

Reference

Proximal Policy Optimization Algorithms
Fine-Tuning Language Models from Human Preferences
Training language models to follow instructions with human feedback
https://github.com/huggingface/trl

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/31075.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Nginx的重定向

URI&#xff1a;统一资源标识符&#xff0c;是一种字符串标识&#xff0c;主要是用于标识抽象的或者是物理资源&#xff08;主要是指一些文件视频等等&#xff09; 常用的Nginx正则表达式 ^ 匹配输入字符串的起始位置&#xff08;以......开头&#xff09; $ 匹配输入…

FreeRTOS( 任务与中断优先级,临界保护)

资料来源于硬件家园&#xff1a;资料汇总 - FreeRTOS实时操作系统课程(多任务管理) 目录 一、中断优先级 1、NVIC基础知识 2、FreeRTOS配置NVIC 3、SVC、PendSV、Systick中断 4、不受FreeRTOS管理的中断 5、STM32CubeMX配置 二、任务优先级 1、任务优先级说明 2、任务…

【LeetCode】144. 二叉树的前序遍历、94. 二叉树的中序遍历、145. 二叉树的后序遍历

作者&#xff1a;小卢 专栏&#xff1a;《Leetcode》 喜欢的话&#xff1a;世间因为少年的挺身而出&#xff0c;而更加瑰丽。 ——《人民日报》 144. 二叉树的前序遍历 144. 二叉树的前序遍历 题目&#xff1a; 给你二叉树的根节点 root &…

保姆级Arcgis安装图文安装教程

参考视频&#xff1a;【钟老师arcGIS从放弃到入门】02软件下载与安装_哔哩哔哩_bilibili 安装包在视频简介中有 注释&#xff1a;安装过程中有犯错误&#xff0c;请耐心看完一遍再跟着操作 &#xff08;一&#xff09;安装包下载 下载视频中分享的压缩包(压缩包密码&#x…

window下部署Yapi接口管理系统部署总结

window下部署Yapi接口管理系统部署总结 YApi 是高效、易用、功能强大的 api 管理平台&#xff0c;旨在为开发、产品、测试人员提供更优雅的接口管理服务。可以帮助开发者轻松创建、发布、维护 API&#xff0c;YApi 还为用户提供了优秀的交互体验&#xff0c;开发人员只需利用平…

后端开发8.品牌模块

概述 简介 效果图 数据库设计 DROP TABLE IF EXISTS `goods_brand`;CREATE TABLE `goods_brand` ( `goodsBrandId` int(11) NOT NULL AUTO_IN

04-4_Qt 5.9 C++开发指南_时间日期与定时器

文章目录 1. 时间日期相关的类2. 源码2.1 可视化UI设计2.2 dialog.h2.3 dialog.cpp 1. 时间日期相关的类 时间日期是经常遇到的数据类型&#xff0c;Qt 中时间日期类型的类如下。 QTime:时间数据类型&#xff0c;仅表示时间&#xff0c;如 15:23:13。 QDate:日期数据类型&…

【资料分享】全志科技T507-H工业核心板规格书

1 核心板简介 创龙科技SOM-TLT507是一款基于全志科技T507-H处理器设计的4核ARM Cortex-A53全国产工业核心板&#xff0c;主频高达1.416GHz。核心板CPU、ROM、RAM、电源、晶振等所有元器件均采用国产工业级方案&#xff0c;国产化率100%。 核心板通过邮票孔连接方式引出MIPI C…

QGIS开发五:使用UI文件

前面我们说了在创建项目时创建的是一个空项目&#xff0c;即不使用 Qt 提供的综合开发套件 Qt Creator&#xff0c;也不使用 Qt Visual Studio Tools 这类工具。 但是后面发现&#xff0c;如果我想要有更加满意的界面布局&#xff0c;还是要自己写一个UI文件&#xff0c;如果不…

深度对话|如何设计合适的网络经济激励措施

近日&#xff0c;我们与Mysten Labs的首席经济学家Alonso de Gortari进行了对话&#xff0c;讨论了如何在网络运营商和参与者之间找到激励措施的平衡&#xff0c;以及Sui的经济如何不断发展。 是什么让您选择将自己的经济学背景应用于区块链和Web3领域&#xff1f; 起初&…

微信个人小程序申请 (AppID 和 AppSecret)

1. 登录微信公众平台 https://mp.weixin.qq.com/cgi-bin/loginpage?url%2Fcgi-bin%2Fhome%3Ft%3Dhome%2Findex%26lang%3Dzh_CN%26token%3D47421820 2. 右上角立即注册 3. 注册类型选择小程序 4. 账号信息 5. 邮箱激活 6. 小程序发布流程 7. 小程序信息 (前往填写) 8. 获取小程…

【一】初步认识数据库

数据库概览数据库 缘起表(Table)的理解用表来定义数据库数据库系统的理解概念层次的理解实例层次的理解 数据库管理系统的理解从用户角度看从系统实现角度看典型的数据库管理系统 数据库语言数据库定义、操纵、控制语言数据库语言 VS 高级语言 内容回顾练习 数据库概览 走马观…

QT笔记——QT自定义事件

我们有时候想发送自定义事件 1&#xff1a;创建自定义事件&#xff0c;首先我们需要知道它的条件 1&#xff1a;自定义事件需要继承QEvent 2&#xff1a;事件的类型需要在 QEvent::User 和 QEvent::MaxUser 范围之间&#xff0c;在QEvent::User之前 是预留给系统的事件 3&#…

前端先行模拟接口(mock+expres+json)

目录 mock模拟数据&#xff1a;data/static.js 路由&#xff1a;index.js 服务器&#xff1a;server.js yarn /node 启动服务器&#xff1a;yarn start 客户端&#xff1a;修改代理路径(修改设置后都要重启才生效) 示例 后端框架express构建服务器 前端发起请求 静态数…

smtplib.SMTPHeloError: (500, b‘Error: bad syntax‘)

如果你编写邮件收发工具的时候,有可能会遇到这个问题。这里直接给出解决办法。 目录 1、检查系统版本 2、点击右侧的更改适配器选项

0基础学C#笔记08:插入排序法

文章目录 前言一、过程简单描述&#xff1a;二、代码总结 前言 我们在玩打牌的时候&#xff0c;你是怎么整理那些牌的呢&#xff1f;一种简单的方法就是一张一张的来&#xff0c;将每一张牌插入到其他已经有序的牌中的适当位置。当我们给无序数组做排序的时候&#xff0c;为了…

十九、docker学习-Dockerfile

Dockerfile 官网地址 https://docs.docker.com/engine/reference/builder/Dockerfile其实就是我们用来构建Docker镜像的源码&#xff0c;当然这不是所谓的编程源码&#xff0c;而是一些命令的集合&#xff0c;只要理解它的逻辑和语法格式&#xff0c;就可以很容易的编写Docke…

为什么DNS协议运行在UDP之上?

DNS (Domain Name System) 运行在 UDP (User Datagram Protocol) 上主要是出于以下原因&#xff1a; 简单性和效率&#xff1a;UDP 是无连接的&#xff0c;这意味着与建立和维护 TCP 连接相比&#xff0c;UDP 有更少的开销。当 DNS 查询被发送时&#xff0c;它只需要一个小的请…

[ MySQL ] — 数据库环境安装、概念和基本使用

目录 安装MySQL 获取mysql官⽅yum源 安装mysql yum 源 安装mysql服务 启动服务 登录 方法1&#xff1a;获取临时root密码 方法2&#xff1a;无密码 方法3&#xff1a;跳过密码认证 配置my.cnf 卸载环境 设置开机启动(可以不设) 常见问题 安装遇到秘钥过期的问题&…

创建型设计模式:4、建造者模式(Builder Pattern)

目录 1、建造者模式含义 2、建造者模式的讲解 3、使用C实现建造者模式的实例 4、建造者模式的优缺点 5、建造者模式VS工厂模式 1、建造者模式含义 The intent of the Builder design pattern is to separate the construction of a complex object from its representatio…