了解PPO算法(Proximal Policy Optimization)

Proximal Policy Optimization (PPO) 是一种强化学习算法,由 OpenAI 提出,旨在解决传统策略梯度方法中策略更新过大的问题。PPO 通过引入限制策略更新范围的机制,在保证收敛性的同时提高了算法的稳定性和效率。

PPO算法原理

PPO 算法的核心思想是通过优化目标函数来更新策略,但在更新过程中限制策略变化的幅度。具体来说,PPO 引入了裁剪(Clipping)和信赖域(Trust Region)的思想,以确保策略不会发生过大的改变。

PPO算法公式

PPO 主要有两种变体:裁剪版(Clipped PPO)和信赖域版(Adaptive KL Penalty PPO)。本文重点介绍裁剪版的 PPO。

  • 旧策略:

    \pi_{\theta_{\text{old}}}(a|s)

    其中,\theta_{\text{old}}​ 是上一次更新后的策略参数。

  • 计算概率比率:

    r_t(\theta) = \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}
  • 裁剪后的目标函数:

    L^{\text{CLIP}}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \hat{A}_t \right) \right]

    其中,\hat{A}_t​ 是优势函数(Advantage Function),\epsilon 是裁剪范围的超参数,通常取值为0.2。

  • 更新策略参数:

    a_{\text{new}} = \arg\max_{\theta} L^{\text{CLIP}}(\theta)
PPO算法的实现

下面是用Python和TensorFlow实现 PPO 算法的代码示例:

import tensorflow as tf
import numpy as np
import gym# 定义策略网络
class PolicyNetwork(tf.keras.Model):def __init__(self, action_space):super(PolicyNetwork, self).__init__()self.dense1 = tf.keras.layers.Dense(128, activation='relu')self.dense2 = tf.keras.layers.Dense(128, activation='relu')self.logits = tf.keras.layers.Dense(action_space, activation=None)def call(self, inputs):x = self.dense1(inputs)x = self.dense2(x)return self.logits(x)# 定义值函数网络
class ValueNetwork(tf.keras.Model):def __init__(self):super(ValueNetwork, self).__init__()self.dense1 = tf.keras.layers.Dense(128, activation='relu')self.dense2 = tf.keras.layers.Dense(128, activation='relu')self.value = tf.keras.layers.Dense(1, activation=None)def call(self, inputs):x = self.dense1(inputs)x = self.dense2(x)return self.value(x)# 超参数
learning_rate = 0.0003
clip_ratio = 0.2
epochs = 10
batch_size = 64
gamma = 0.99# 创建环境
env = gym.make('CartPole-v1')
obs_dim = env.observation_space.shape[0]
n_actions = env.action_space.n# 创建策略和值函数网络
policy_net = PolicyNetwork(n_actions)
value_net = ValueNetwork()# 优化器
policy_optimizer = tf.keras.optimizers.Adam(learning_rate)
value_optimizer = tf.keras.optimizers.Adam(learning_rate)def get_action(observation):logits = policy_net(observation)action = tf.random.categorical(logits, 1)return action[0, 0]def compute_advantages(rewards, values, next_values, done):advantages = []gae = 0for i in reversed(range(len(rewards))):delta = rewards[i] + gamma * next_values[i] * (1 - done[i]) - values[i]gae = delta + gamma * gaeadvantages.insert(0, gae)return np.array(advantages)def ppo_update(observations, actions, advantages, returns):with tf.GradientTape() as tape:old_logits = policy_net(observations)old_log_probs = tf.nn.log_softmax(old_logits)old_action_log_probs = tf.reduce_sum(old_log_probs * tf.one_hot(actions, n_actions), axis=1)logits = policy_net(observations)log_probs = tf.nn.log_softmax(logits)action_log_probs = tf.reduce_sum(log_probs * tf.one_hot(actions, n_actions), axis=1)ratio = tf.exp(action_log_probs - old_action_log_probs)surr1 = ratio * advantagessurr2 = tf.clip_by_value(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio) * advantagespolicy_loss = -tf.reduce_mean(tf.minimum(surr1, surr2))policy_grads = tape.gradient(policy_loss, policy_net.trainable_variables)policy_optimizer.apply_gradients(zip(policy_grads, policy_net.trainable_variables))with tf.GradientTape() as tape:value_loss = tf.reduce_mean((returns - value_net(observations))**2)value_grads = tape.gradient(value_loss, value_net.trainable_variables)value_optimizer.apply_gradients(zip(value_grads, value_net.trainable_variables))# 训练循环
for epoch in range(epochs):observations = []actions = []rewards = []values = []next_values = []dones = []obs = env.reset()done = Falsewhile not done:obs = obs.reshape(1, -1)observations.append(obs)action = get_action(obs)actions.append(action)value = value_net(obs)values.append(value)obs, reward, done, _ = env.step(action.numpy())rewards.append(reward)dones.append(done)if done:next_values.append(0)else:next_value = value_net(obs.reshape(1, -1))next_values.append(next_value)returns = compute_advantages(rewards, values, next_values, dones)advantages = returns - valuesobservations = np.concatenate(observations, axis=0)actions = np.array(actions)returns = np.array(returns)advantages = np.array(advantages)ppo_update(observations, actions, advantages, returns)print(f'Epoch {epoch+1} completed')
总结

PPO 算法通过引入裁剪机制和信赖域约束,限制了策略更新的幅度,提高了训练过程的稳定性和效率。其简单而有效的特性使其成为目前强化学习中最流行的算法之一。通过理解并实现 PPO 算法,可以更好地应用于各种强化学习任务,提升模型的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/42633.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【docker 把系统盘空间耗没了!】windows11 更改 ubuntu 子系统存储位置

系统:win11 ubuntu 22 子系统,docker 出现问题:系统盘突然没空间了,一片红 经过排查,发现 AppData\Local\packages\CanonicalGroupLimited.Ubuntu22.04LTS_79rhkp1fndgsc\ 这个文件夹竟然有 90GB 下面提供解决办法 步…

Spring-AOP(二)

作者:月下山川 公众号:月下山川 1、什么是AOP AOP(Aspect Oriented Programming)是一种设计思想,是软件设计领域中的面向切面编程,它是面向对象编程的一种补充和完善,它以通过预编译方式和运行期…

【课程总结】Day13(下):人脸识别和MTCNN模型

前言 在上一章课程【课程总结】Day13(上):使用YOLO进行目标检测,我们了解到目标检测有两种策略,一种是以YOLO为代表的策略:特征提取→切片→分类回归;另外一种是以MTCNN为代表的策略:先图像切片→特征提取→分类和回归。因此,本章内容将深入了解MTCNN模型,包括:MTC…

使用jdk11运行javafx程序和jdk11打包jre包含javafx模块

我们都知道jdk11是移除了javafx的,如果需要使用javafx,需要单独下载。 这就导致我们使用javafx开发的桌面程序使用jdk11时提示缺少javafx依赖。但这是可以通过下面的方法解决。 一,使用jdk11运行javafx程序 我们可以通过设置vmOptions来使用jdk11运行javafx程序 1,添加j…

【RAG KG】GraphRAG开源:查询聚焦摘要的图RAG方法

前言 传统的 RAG 方法在处理针对整个文本语料库的全局性问题时存在不足,例如查询:“数据中的前 5 个主题是什么?” 对于此类问题,是因为这类问题本质上是查询聚焦的摘要(Query-Focused Summarization, QFS&#xff09…

嵌入式单片机,两者有什么关联又有什么区别?

在开始前刚好我有一些资料,是我根据网友给的问题精心整理了一份「嵌入式的资料从专业入门到高级教程」, 点个关注在评论区回复“666”之后私信回复“666”,全部无偿共享给大家!!!使用单片机是嵌入式系统的…

CurrentHashMap巧妙利用位运算获取数组指定下标元素

先来了解一下数组对象在堆中的存储形式【数组长度,数组元素类型信息等】 【存放元素对象的空间】 Ma 基础信息实例数据内存填充Mark Word,ClassPointer,数组长度第一个元素第二个元素固定的填充内容 所以我们想要获取某个下标的元素首先要获取这个元素的起始位置…

Sorted Set 类型命令(命令语法、操作演示、命令返回值、时间复杂度、注意事项)

Sorted Set 类型 文章目录 Sorted Set 类型zadd 命令zrange 命令zcard 命令zcount 命令zrevrange 命令zrangebyscore 命令zpopmax 命令bzpopmax 命令zpopmin 命令bzpopmin 命令zrank 命令zscore 命令zrem 命令zremrangebyrank 命令zremrangebyscore 命令zincrby 命令zinterstor…

线程池案例

秒杀 需求 10个礼物20个客户抢随机10个客户获取礼物&#xff0c;另外10无法获取礼物 任务类 记得给共享资源加锁 public class MyTask implements Runnable{// 礼物列表private ArrayList<String> gifts ;// 用户名private String username;public MyTask( String user…

android Dialog全屏沉浸式状态栏实现

在Android中&#xff0c;创建沉浸式状态栏通常意味着让状态栏背景与应用的主题颜色一致&#xff0c;并且让对话框在状态栏下面显示&#xff0c;而不是浮动。为了实现这一点&#xff0c;你可以使用以下代码片段&#xff1a; 1、实际效果图&#xff1a; 2、代码实现&#xff1a;…

揭秘GPT-4o:未来智能的曙光

引言 近年来&#xff0c;人工智能&#xff08;AI&#xff09;的发展突飞猛进&#xff0c;尤其是自然语言处理&#xff08;NLP&#xff09;领域的进步&#xff0c;更是引人注目。在这一背景下&#xff0c;OpenAI发布的GPT系列模型成为了焦点。本文将详细探讨最新的模型GPT-4o&a…

Unity海面效果——6、反射和高光

Unity引擎制作海面效果 大家好&#xff0c;我是阿赵。 上一篇的结束时&#xff0c;海面效果已经做成这样了&#xff1a; 这个Shader的复杂程度已经比较高了&#xff1a; 不过还有一些美中不足的地方。 1、 海平面没有反射到天空球 2、 在近岸边看得到水底的部分&#xff0c;水…

一些关于C++的基础知识

引言&#xff1a;C兼容C的大部分内容&#xff0c;但其中仍有许多小细节的东西需要大家注意 一.C的第一个程序 #include <iostream> using namespace std;int main() {cout << "hello world!" << endl;return 0; } 第一次看这个是否感觉一头雾水…

数据挖掘——matplotlib

matplotlib概述 Mat指的是Matlab&#xff0c;plot指的是画图&#xff0c;lib即library&#xff0c;顾名思义&#xff0c;matplotlib是python专门用于开发2D图表的第三方库&#xff0c;使用之前需要下载该库&#xff0c;使用pip命令即可下载。 pip install matplotlib1、matpl…

elasticsearch SQL:在Elasticsearch中启用和使用SQL功能

❃博主首页 &#xff1a; 「码到三十五」 &#xff0c;同名公众号 :「码到三十五」&#xff0c;wx号 : 「liwu0213」 ☠博主专栏 &#xff1a; <mysql高手> <elasticsearch高手> <源码解读> <java核心> <面试攻关> ♝博主的话 &#xff1a…

服务注册Eureka

目录 一、背景 1、概念 2、CAP 理论 3、常见的注册中心 二、Eureka 三、搭建 Eureka Server 1、搭建注册中心 四、服务注册 五、服务发现 六、Eureka 和 Zooper 的区别 一、背景 1、概念 远程调用就类似于一种通信 例如&#xff1a;当游客与景区之间进行通信&…

Xubuntu24.04之设置高性能模式两种方式(二百六十一)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒体系统工程师系列【原创干货持续更新中……】🚀 优质视频课程:AAOS车载系统+AOSP…

苍穹外卖--新增员工

代码开发 package com.sky.controller.admin;import com.sky.constant.JwtClaimsConstant; import com.sky.dto.EmployeeDTO; import com.sky.dto.EmployeeLoginDTO; import com.sky.entity.Employee; import com.sky.properties.JwtProperties; import com.sky.result.Result…

Springboot各个版本维护时间

Springboot各个版本维护时间

MQTT教程--服务器使用EMQX和客户端使用MQTTX

什么是MQTT MQTT&#xff08;Message Queuing Telemetry Transport&#xff09;是一种轻量级、基于发布-订阅模式的消息传输协议&#xff0c;适用于资源受限的设备和低带宽、高延迟或不稳定的网络环境。它在物联网应用中广受欢迎&#xff0c;能够实现传感器、执行器和其它设备…