深入理解策略梯度算法

策略梯度(Policy Gradient)算法是强化学习中的一种重要方法,通过优化策略以获得最大回报。本文将详细介绍策略梯度算法的基本原理,推导其数学公式,并提供具体的例子来指导其实现。

策略梯度算法的基本概念

在强化学习中,智能体通过与环境交互来学习一种策略(policy),该策略定义了在每个状态下采取哪种行动的概率分布。策略可以是确定性的或随机的。在策略梯度方法中,策略通常表示为参数化的概率分布,即 $\pi_\theta(a|s)$,其中$\theta$ 是策略的参数,$s$ 是状态,$a$ 是行动。

目标是找到最佳的策略参数 $\theta$ 使得智能体在环境中获得的期望回报最大。为此,我们需要定义一个目标函数$J(\theta)$,表示期望回报。然后,通过梯度上升法(或下降法)来优化该目标函数。

策略梯度的数学推导

假设我们的目标函数 $J(\theta)$ 定义为:

J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} [R(\tau)]

其中$\tau$ 表示一个完整的轨迹(从初始状态到终止状态的状态-动作序列),$R(\tau)$ 是该轨迹的总回报。根据策略的定义,我们有:

\pi_\theta(\tau) = p(s_0) \prod_{t=0}^{T-1} \pi_\theta(a_t|s_t) p(s_{t+1}|s_t, a_t)

因此,目标函数可以重写为:

J(\theta) = \sum_{\tau} \pi_\theta(\tau) R(\tau)

为了最大化$J(\theta)$,我们需要计算其梯度 $\nabla_\theta J(\theta)$

\nabla_\theta J(\theta) = \nabla_\theta \sum_{\tau} \pi_\theta(\tau) R(\tau) = \sum_{\tau} \nabla_\theta \pi_\theta(\tau) R(\tau)

使用概率分布的梯度性质,我们有:

\nabla_\theta \pi_\theta(\tau) = \pi_\theta(\tau) \nabla_\theta \log \pi_\theta(\tau)

因此,梯度可以表示为:

\nabla_\theta J(\theta) = \sum_{\tau} \pi_\theta(\tau) \nabla_\theta \log \pi_\theta(\tau) R(\tau) = \mathbb{E}_{\tau \sim \pi_\theta} [\nabla_\theta \log \pi_\theta(\tau) R(\tau)]

这个公式被称为策略梯度定理。为了估计这个期望值,我们通常使用蒙特卡洛方法,从策略 $\pi_\theta$ 中采样多个轨迹 $\tau$,然后计算平均值。

策略梯度算法的实现

我们以一个简单的环境为例,展示如何实现策略梯度算法。假设我们有一个离散动作空间的环境,我们使用一个神经网络来参数化策略$\pi_\theta(a|s)$

步骤 1:环境设置

首先,设置环境和参数:

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optimenv = gym.make('CartPole-v1')
n_actions = env.action_space.n
state_dim = env.observation_space.shape[0]
步骤 2:策略网络定义

定义一个简单的策略网络:

class PolicyNetwork(nn.Module):def __init__(self, state_dim, n_actions):super(PolicyNetwork, self).__init__()self.fc1 = nn.Linear(state_dim, 128)self.fc2 = nn.Linear(128, n_actions)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return torch.softmax(x, dim=-1)policy = PolicyNetwork(state_dim, n_actions)
optimizer = optim.Adam(policy.parameters(), lr=0.01)
步骤 3:采样轨迹

编写函数来从策略中采样轨迹:

def sample_trajectory(env, policy, max_steps=1000):state = env.reset()states, actions, rewards = [], [], []for _ in range(max_steps):state = torch.FloatTensor(state).unsqueeze(0)probs = policy(state)action = np.random.choice(n_actions, p=probs.detach().numpy()[0])next_state, reward, done, _ = env.step(action)states.append(state)actions.append(action)rewards.append(reward)if done:breakstate = next_statereturn states, actions, rewards
步骤 4:计算回报和梯度

计算每个状态的回报,并使用策略梯度定理更新策略:

def compute_returns(rewards, gamma=0.99):returns = []G = 0for r in reversed(rewards):G = r + gamma * Greturns.insert(0, G)return returnsdef update_policy(policy, optimizer, states, actions, returns):returns = torch.FloatTensor(returns)loss = 0for state, action, G in zip(states, actions, returns):state = state.squeeze(0)probs = policy(state)log_prob = torch.log(probs[action])loss += -log_prob * Goptimizer.zero_grad()loss.backward()optimizer.step()
步骤 5:训练策略

将上述步骤组合在一起,训练策略网络:

num_episodes = 1000
for episode in range(num_episodes):states, actions, rewards = sample_trajectory(env, policy)returns = compute_returns(rewards)update_policy(policy, optimizer, states, actions, returns)if episode % 100 == 0:print(f"Episode {episode}, total reward: {sum(rewards)}")
总结

通过以上步骤,我们实现了一个基本的策略梯度算法。策略梯度方法通过直接优化策略来最大化智能体的期望回报,具有理论上的简洁性和实用性。本文详细推导了策略梯度的数学公式,并提供了具体的实现步骤,希望能够帮助读者更好地理解和应用这一重要的强化学习算法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/38004.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Python3的内置函数和使用方法】

目录 Python 特点 Python 中文编码 Python 变量类型 Python列表 Python 元组 元组是另一个数据类型,类似于 List(列表) Python 字典 Python数据类型转换 Python 运算符 Python算术运算符 Python比较运算符 Python赋值运算符 Pyt…

一篇就够了,为你答疑解惑:锂电池一阶模型-离线参数辨识(附代码)

锂电池一阶模型-参数离线辨识 背景模型简介数据收集1. 最大可用容量实验2. 开路电压实验3. 混合动力脉冲特性实验离线辨识对应模型对应代码总结下期预告文章字数有点多,耐心不够的谨慎点击阅读。 下期继续讲解在线参数辨识方法。 背景 最近又在开始重新梳理锂电池建模仿真与S…

使用stat()函数的例子

代码&#xff1a; #include <sys/types.h> #include <sys/stat.h> #include <unistd.h> #include <stdio.h>int main(void) {struct stat st;if(-1stat("test.txt",&st)){printf("获得文件状态失败\n");return -1;}printf(&q…

Unidbg调用-补环境V2

1.B站 内部依赖自定义的SignedQuery对象,需要找到apk中的类并补充环境。 package com.nb.demo;import com.github.unidbg.AndroidEmulator

llama3模型部署时遇到的问题及解决方案

在llama3模型部署时&#xff0c;会遇到一系列问题&#xff0c;这里就作者所遇到的问题与解决方法分享一下。 注意&#xff1a;这里是从llama3 github主页上给的方法一步步做的&#xff0c;不适用于其他部署大模型的方法。 文章目录 ERROR 403&#xff1a;Forbidden安装依赖时出…

洛谷 P1548 [NOIP1997 普及组] 棋盘问题

题目 洛谷 P1548 [NOIP1997 普及组] 棋盘问题 [NOIP1997 普及组] 棋盘问题 题目背景 NOIP1997 普及组第一题 题目描述 设有一个 N M N \times M NM 方格的棋盘 ( 1 ≤ N ≤ 100 , 1 ≤ M ≤ 100 ) (1≤N≤100,1≤M≤100) (1≤N≤100,1≤M≤100) 求出该棋盘中包含有多少个正…

MySQL高级-MVCC-undo log 版本链

文章目录 1、undo log2、undo log 版本链2.1、然后&#xff0c;有四个并发事务同时在访问这张表。2.1.1、修改id为30记录&#xff0c;age改为32.1.2、修改id为30记录&#xff0c;name改为A32.1.3、修改id为30记录&#xff0c;age改为10 2.2、总结 1、undo log 回滚日志&#xf…

文件系统(操作系统实验)

实验内容 &#xff08;1&#xff09;在内存中开辟一个虚拟磁盘空间作为文件存储器&#xff0c; 在其上实现一个简单单用户文件系统。 在退出这个文件系统时&#xff0c;应将改虚拟文件系统保存到磁盘上&#xff0c; 以便下次可以将其恢复到内存的虚拟空间中。 &#xff08;2&…

数字孪生煤矿智能化综合管控平台

煤矿可视化通过图扑 HT 实现实时数据集成和三维建模仿真&#xff0c;呈现井下环境、设备状态和生产状况等多维度数据&#xff0c;帮助管理人员进行直观监控和精准分析。该技术提升了运营效率和安全水平&#xff0c;为煤矿作业提供了智能化的管理解决方案&#xff0c;有助于减少…

黑马点评DAY1|Redis入门、Redis安装

什么是Redis&#xff1f; redis是一种键值型数据库&#xff0c;内部所存的数据都是键值对的形式&#xff0c;例如&#xff0c;我们可以把一个用户数据存储为如下格式&#xff1a; 键值id$1600name张三age21 但是这样的存储方式&#xff0c;数据会显得非常松散&#xff0c;因…

云计算HCIE+RHCE学员的学习分享

大一下学期&#xff0c;我从学长嘴里了解到誉天教育&#xff0c;当时准备考RHCE&#xff0c;我也了解了很多培训机构&#xff0c;然后学长强烈给我推荐誉天&#xff0c;我就在誉天报名了RHCE的课程。 通过杨峰老师的教学&#xff0c;我学到了许多Linux知识&#xff0c;也了解了…

笔记本电脑部署VMware ESXi 6.0系统

正文共&#xff1a;888 字 18 图&#xff0c;预估阅读时间&#xff1a;1 分钟 前面我们介绍了在笔记本上安装Windows 11操作系统&#xff08;Windows 11升级不了&#xff1f;但Win10就要停服了啊&#xff01;来&#xff0c;我教你&#xff01;&#xff09;&#xff0c;也介绍了…

【单片机毕业设计选题24037】-基于STM32的电力系统电力参数无线监控系统

系统功能: 系统上电后&#xff0c;OLED显示“欢迎使用电力监控系统请稍后”&#xff0c;两秒后显示“Waiting..”等待ESP8266初始化完成&#xff0c; ESP8266初始化成功后进入正常页面显示&#xff0c; 第一行显示电压值&#xff08;单位V&#xff09; 第二行显示电流值&am…

互联网大厂核心知识总结PDF资料

我们要敢于追求卓越&#xff0c;也能承认自己平庸&#xff0c;不要低估3&#xff0c;5&#xff0c;10年沉淀的威力 hi 大家好&#xff0c;我是大师兄&#xff0c;大厂工作特点是需要多方面的知识和技能。这种学习和积累一般人需要一段的时间&#xff0c;不太可能一蹴而就&…

VMware虚拟机迁移:兼用性踩坑和复盘

文章目录 方法失败情况分析&#xff1a;参考文档 方法 虚拟机关机&#xff0c;整个文件夹压缩后拷贝到新机器中&#xff0c;开机启用即可 成功的情况&#xff1a; Mac (intel i5) -> Mac (intel i7)Mac (intel, MacOS - VMware Fusion) -> DELL (intel, Windows - VMw…

Zynq7000系列FPGA中的DMA控制器简介(二)

AXI互连上的DMA传输 所有DMA事务都使用AXI接口在PL中的片上存储器、DDR存储器和从外设之间传递数据。PL中的从设备通过DMAC的外部请求接口与DMAC通信&#xff0c;以控制数据流。这意味着从设备可以请求DMA交易&#xff0c;以便将数据从源地址传输到目标地址。 虽然DMAC在技术…

mysql5.7安装使用

mysql5.7安装包&#xff1a;百度网盘 提取码: 0000 一、 安装步骤 双击安装文件 选择我接受许可条款–Next 选择自定义安装&#xff0c;下一步 选择电脑对应的系统版本后(我的系统是64位)&#xff0c;点击中间的右箭头&#xff0c;选择Next 选择安装路径–Next 执行…

matlab可以把图像数据转换为小波分析吗

&#x1f3c6;本文收录于《CSDN问答解答》专栏&#xff0c;主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案&#xff0c;希望能够助你一臂之力&#xff0c;帮你早日登顶实现财富自由&#x1f680;&#xff1b;同时&#xff0c;欢迎大家关注&&收藏&…

【后端面试题】【中间件】【NoSQL】ElasticSearch 节点角色、写入数据过程、Translog和索引与分片

中间件的常考方向&#xff1a; 中间件如何做到高可用和高性能的&#xff1f; 你在实践中怎么做的高可用和高性能的&#xff1f; Elasticsearch节点角色 Elasticsearch的节点可以分为很多种角色&#xff0c;并且一个节点可以扮演多种角色&#xff0c;下面列举几种主要的&…

【软件测试】白盒测试(知识点 + 习题 + 答案)

《 软件测试基础持续更新中》 最近大家总是催更……&#xff0c;我也是百忙之中给大家详细总结了白盒测试的重点内容&#xff01; 知识点题型答案&#xff0c;让你用最短的时间&#xff0c;学到最高效的知识&#xff01; 整理不易&#xff0c;求个三连 ₍ᐢ..ᐢ₎ ♡ 目录 一、…