PPO和文本生成

策略梯度

策略梯度(Policy Gradient)方法梯度的计算如下:
E ( a t , s t ) ∈ π θ [ A ^ t ∇ θ log ⁡ π θ ( a t ∣ s t ) ] \mathbb E_{(a_t,s_t) \in \pi_\theta}[\hat A_t \nabla_ \theta \log \pi_\theta(a_t | s_t)] E(at,st)πθ[A^tθlogπθ(atst)] A ^ t \hat A_t A^t是优势函数(advantage function) A t A_t At的估计。
A t = Q ( s t , a t ) − V ( s t ) A_t=Q(s_t, a_t)-V(s_t) At=Q(st,at)V(st)优势函数计算的是,在该状态下采取这个行动的奖励与在该状态下的平均奖励的差值。
上面的导数可以通过对下面的目标求导获得:
L P G ( θ ) = E ( a t , s t ) ∈ π θ [ A ^ t log ⁡ π θ ( a t ∣ s t ) ] L^{PG}(\theta)=\mathbb E_{(a_t,s_t) \in \pi_\theta}[\hat A_t \log \pi_\theta(a_t | s_t)] LPG(θ)=E(at,st)πθ[A^tlogπθ(atst)]

PPO(Proximal Policy Optimization)

PPO有两个形式,其中一种形式PPO_CLIP的优化目标函数是:
L C L I P ( θ ) = E ( a t , s t ) ∈ π θ [ min ⁡ ( r t ( θ ) A ^ t , c l i p ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) A ^ t ) ] (1) L^{CLIP}(\theta)=\mathbb E_{(a_t,s_t) \in \pi_\theta}[\min(r_t(\theta)\hat A_t, clip(r_t(\theta), 1-\epsilon, 1+\epsilon)\hat A_t)] \tag{1} LCLIP(θ)=E(at,st)πθ[min(rt(θ)A^t,clip(rt(θ),1ϵ,1+ϵ)A^t)](1)其中 r t ( θ ) = π θ ( a t ∣ s t ) π θ o l d ( a t ∣ s t ) r_t(\theta)=\frac{\pi_\theta(a_t | s_t)}{\pi_{\theta_{old}}(a_t | s_t)} rt(θ)=πθold(atst)πθ(atst)
PPO算法中的advantage用下面的公式估计:
A ^ t = δ t + ( γ λ ) δ t + 1 + ⋯ + ( γ λ ) T − t + 1 δ T − 1 \hat A^t = \delta^t + (\gamma \lambda)\delta_{t+1} + \cdots+ (\gamma \lambda)^{T-t+1}\delta_{T-1} A^t=δt+(γλ)δt+1++(γλ)Tt+1δT1其中 δ t = r t + γ V ( s t + 1 ) − V ( s t ) \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) δt=rt+γV(st+1)V(st)
通常情况下,我们用一个网络学习策略和价值函数,这样策略和价值函数能共享参数,那么就需要结合策略代理和价值函数误差项的损失函数。再加上熵奖励(entropy bonus)来以确保足够的探索,优化目标变为:
L C L I P + V F + S ( θ ) = E ( a t , s t ) ∈ π θ [ L t C L I P ( θ ) − c 1 L t V F ( θ ) + c 2 S [ π θ ] ( s t ) ] L^{CLIP+VF+S}(\theta)=\mathbb E_{(a_t,s_t) \in \pi_\theta}[L_t^{CLIP}(\theta) - c_1 L_t^{VF}(\theta) + c_2 S[\pi_\theta](s_t)] LCLIP+VF+S(θ)=E(at,st)πθ[LtCLIP(θ)c1LtVF(θ)+c2S[πθ](st)]其中 L t V F ( θ ) = ( V θ ( s t ) − V t t a r g ) 2 L_t^{VF}(\theta)=(V_\theta(s_t)-V_t^{targ})^2 LtVF(θ)=(Vθ(st)Vttarg)2是价值函数的误差项,S是entropy bonus。

文本生成

在文本生成的情况下,给一个prompt,生成完整的response,是一个episode。动作空间是vocabulary。每生成一个词是一个时间步。

公式(1)需要advantage的估计,为了计算advantage,我们需要定义奖励(reward) r r r和估计状态价值函数 V ( s ) V(s) V(s)

用于强化学习的reward计算如下:
R ( x , y ) = r ( x , y ) − β log ⁡ π ( y ∣ x ) ρ ( y ∣ x ) R(x,y) = r(x,y) - \beta\log\frac{\pi(y|x)}{\rho(y|x)} R(x,y)=r(x,y)βlogρ(yx)π(yx)x是问题,y是回答, r ( x , y ) r(x,y) r(x,y)是reward model的输出,也就是下面代码中的score。注意这里reward model的输出称之为score,送入强化学习部分的才称为reward。 π ( y ∣ x ) \pi(y|x) π(yx)是要学习的生成模型, ρ ( y ∣ x ) \rho(y|x) ρ(yx)是参数固定的原始生成模型。
在trl库中reward的计算如下:

   def compute_rewards(self,scores: torch.FloatTensor,logprobs: torch.FloatTensor,ref_logprobs: torch.FloatTensor,masks: torch.LongTensor,):"""Compute per token rewards from scores and KL-penalty.Args:scores (`torch.FloatTensor`):Scores from the reward model, shape (`batch_size`)logprobs (`torch.FloatTensor`):Log probabilities of the model, shape (`batch_size`, `response_length`)ref_logprobs (`torch.FloatTensor`):Log probabilities of the reference model, shape (`batch_size`, `response_length`)"""rewards, non_score_rewards = [], []for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks):# compute KL penalty (from difference in logprobs)kl = self._kl_penalty(logprob, ref_logprob)non_score_reward = -self.kl_ctl.value * klnon_score_rewards.append(non_score_reward)reward = non_score_reward.clone()last_non_masked_index = mask.nonzero()[-1]# reward is preference model score + KL penaltyreward[last_non_masked_index] += scorerewards.append(reward)return torch.stack(rewards), torch.stack(non_score_rewards)

可以看到上面的实现中,只将reward model的score添加到最后一个token的reward上,其他token的reward来自当前模型和 原始生成模型之间KL散度。这么做是为了减轻奖励模型的过度优化问题。

在trl库中用一个网络AutoModelForCausalLMWithValueHead学习策略 π θ ( s ) \pi_\theta(s) πθ(s)和状态价值函数 V ( s ) V(s) V(s)。AutoModelForCausalLMWithValueHead在普通AutoModelForCausalLM模型上了一个线性层nn.Linear(hidden_size, 1),用于估计状态价值函数 V ( s ) V(s) V(s)
普通AutoModelForCausalLM模型估计token概率即可作为策略 π θ ( s ) \pi_\theta(s) πθ(s)

在trl库中advantage的计算如下:

    def compute_advantages(self: torch.FloatTensor,values: torch.FloatTensor, # AutoModelForCausalLMWithValueHead输出的状态价值估计Vrewards: torch.FloatTensor, # compute_rewards函数计算得到的rewardsmask: torch.FloatTensor,):lastgaelam = 0advantages_reversed = []gen_len = rewards.shape[-1]values = values * maskrewards = rewards * maskfor t in reversed(range(gen_len)):nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t]lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelamadvantages_reversed.append(lastgaelam)advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)returns = advantages + valuesadvantages = masked_whiten(advantages, mask)advantages = advantages.detach()return values, advantages, returns

完整的PPO算法如下:
在这里插入图片描述

Reference

Proximal Policy Optimization Algorithms
Fine-Tuning Language Models from Human Preferences
Training language models to follow instructions with human feedback
https://github.com/huggingface/trl

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/31075.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Nginx的重定向

URI&#xff1a;统一资源标识符&#xff0c;是一种字符串标识&#xff0c;主要是用于标识抽象的或者是物理资源&#xff08;主要是指一些文件视频等等&#xff09; 常用的Nginx正则表达式 ^ 匹配输入字符串的起始位置&#xff08;以......开头&#xff09; $ 匹配输入…

07 |「广播接收器 」

前言 实践是最好的学习方式&#xff0c;技术也如此。 文章目录 前言一、二、实践1、发送和接收系统广播2、发送和接收自定义广播 一、 广播是 Android 系统和 Android 应用程序在发生可能影响其他应用程序组件功能的事件时发送的消息&#xff1b;广播是Android系统中的一种进程…

FreeRTOS( 任务与中断优先级,临界保护)

资料来源于硬件家园&#xff1a;资料汇总 - FreeRTOS实时操作系统课程(多任务管理) 目录 一、中断优先级 1、NVIC基础知识 2、FreeRTOS配置NVIC 3、SVC、PendSV、Systick中断 4、不受FreeRTOS管理的中断 5、STM32CubeMX配置 二、任务优先级 1、任务优先级说明 2、任务…

【LeetCode】144. 二叉树的前序遍历、94. 二叉树的中序遍历、145. 二叉树的后序遍历

作者&#xff1a;小卢 专栏&#xff1a;《Leetcode》 喜欢的话&#xff1a;世间因为少年的挺身而出&#xff0c;而更加瑰丽。 ——《人民日报》 144. 二叉树的前序遍历 144. 二叉树的前序遍历 题目&#xff1a; 给你二叉树的根节点 root &…

保姆级Arcgis安装图文安装教程

参考视频&#xff1a;【钟老师arcGIS从放弃到入门】02软件下载与安装_哔哩哔哩_bilibili 安装包在视频简介中有 注释&#xff1a;安装过程中有犯错误&#xff0c;请耐心看完一遍再跟着操作 &#xff08;一&#xff09;安装包下载 下载视频中分享的压缩包(压缩包密码&#x…

window下部署Yapi接口管理系统部署总结

window下部署Yapi接口管理系统部署总结 YApi 是高效、易用、功能强大的 api 管理平台&#xff0c;旨在为开发、产品、测试人员提供更优雅的接口管理服务。可以帮助开发者轻松创建、发布、维护 API&#xff0c;YApi 还为用户提供了优秀的交互体验&#xff0c;开发人员只需利用平…

后端开发8.品牌模块

概述 简介 效果图 数据库设计 DROP TABLE IF EXISTS `goods_brand`;CREATE TABLE `goods_brand` ( `goodsBrandId` int(11) NOT NULL AUTO_IN

04-4_Qt 5.9 C++开发指南_时间日期与定时器

文章目录 1. 时间日期相关的类2. 源码2.1 可视化UI设计2.2 dialog.h2.3 dialog.cpp 1. 时间日期相关的类 时间日期是经常遇到的数据类型&#xff0c;Qt 中时间日期类型的类如下。 QTime:时间数据类型&#xff0c;仅表示时间&#xff0c;如 15:23:13。 QDate:日期数据类型&…

【资料分享】全志科技T507-H工业核心板规格书

1 核心板简介 创龙科技SOM-TLT507是一款基于全志科技T507-H处理器设计的4核ARM Cortex-A53全国产工业核心板&#xff0c;主频高达1.416GHz。核心板CPU、ROM、RAM、电源、晶振等所有元器件均采用国产工业级方案&#xff0c;国产化率100%。 核心板通过邮票孔连接方式引出MIPI C…

TCP通信——多线程并发回环服务器

思路 首先要考虑到服务器的流程&#xff0c;TCP服务器端程序流程&#xff1a; socketbind绑定listen监听accept等待连接 多线程并发服务器需要通过多个线程实现与多个客户端的连接&#xff0c;当每次有一个客户端连接来时&#xff0c;创建一个线程&#xff0c;用于与客户端的…

C语言判断文件是否存在之stat、fopen、access

一、stat 头文件 sys/stat.h unistd.h 函数原型 结构体struct stat说明 struct stat {dev_t st_dev; //device 文件的设备编号ino_t st_ino; //inode 文件的i-nodemode_t st_mode; //protection 文件的类型和存取的权限nlink_t st_nlink; //number of hard links 连到该文件…

ol问题总结二

一、加载坐标系是4326格式的&#xff0c;使用wfsServer发布的服务&#xff0c;图层加载失败&#xff1b;坐标系是3857格式的。图层加载正常 原因&#xff1a;4326格式的&#xff0c;发布出来的&#xff0c;经纬度是颠倒的 解决方案一&#xff1a;将经纬度进行反转 <templa…

QGIS开发五:使用UI文件

前面我们说了在创建项目时创建的是一个空项目&#xff0c;即不使用 Qt 提供的综合开发套件 Qt Creator&#xff0c;也不使用 Qt Visual Studio Tools 这类工具。 但是后面发现&#xff0c;如果我想要有更加满意的界面布局&#xff0c;还是要自己写一个UI文件&#xff0c;如果不…

深度对话|如何设计合适的网络经济激励措施

近日&#xff0c;我们与Mysten Labs的首席经济学家Alonso de Gortari进行了对话&#xff0c;讨论了如何在网络运营商和参与者之间找到激励措施的平衡&#xff0c;以及Sui的经济如何不断发展。 是什么让您选择将自己的经济学背景应用于区块链和Web3领域&#xff1f; 起初&…

微信个人小程序申请 (AppID 和 AppSecret)

1. 登录微信公众平台 https://mp.weixin.qq.com/cgi-bin/loginpage?url%2Fcgi-bin%2Fhome%3Ft%3Dhome%2Findex%26lang%3Dzh_CN%26token%3D47421820 2. 右上角立即注册 3. 注册类型选择小程序 4. 账号信息 5. 邮箱激活 6. 小程序发布流程 7. 小程序信息 (前往填写) 8. 获取小程…

【一】初步认识数据库

数据库概览数据库 缘起表(Table)的理解用表来定义数据库数据库系统的理解概念层次的理解实例层次的理解 数据库管理系统的理解从用户角度看从系统实现角度看典型的数据库管理系统 数据库语言数据库定义、操纵、控制语言数据库语言 VS 高级语言 内容回顾练习 数据库概览 走马观…

QT笔记——QT自定义事件

我们有时候想发送自定义事件 1&#xff1a;创建自定义事件&#xff0c;首先我们需要知道它的条件 1&#xff1a;自定义事件需要继承QEvent 2&#xff1a;事件的类型需要在 QEvent::User 和 QEvent::MaxUser 范围之间&#xff0c;在QEvent::User之前 是预留给系统的事件 3&#…

【PostgreSQL】几个提高性能的小特性

一、LOCALE 与 “operator class” 在PostgreSQL里&#xff0c;LOCALE默认使用C的本地化规则。LOCALE是一种文化偏好的区域设置&#xff0c;包括字母表、排序、数字格式等。 LOCALE里有一个比较重要的规则LC_COLLATE&#xff0c;即排序方式(Collation)&#xff0c;它会对数据…

前端先行模拟接口(mock+expres+json)

目录 mock模拟数据&#xff1a;data/static.js 路由&#xff1a;index.js 服务器&#xff1a;server.js yarn /node 启动服务器&#xff1a;yarn start 客户端&#xff1a;修改代理路径(修改设置后都要重启才生效) 示例 后端框架express构建服务器 前端发起请求 静态数…

smtplib.SMTPHeloError: (500, b‘Error: bad syntax‘)

如果你编写邮件收发工具的时候,有可能会遇到这个问题。这里直接给出解决办法。 目录 1、检查系统版本 2、点击右侧的更改适配器选项