深入理解策略梯度算法

策略梯度(Policy Gradient)算法是强化学习中的一种重要方法,通过优化策略以获得最大回报。本文将详细介绍策略梯度算法的基本原理,推导其数学公式,并提供具体的例子来指导其实现。

策略梯度算法的基本概念

在强化学习中,智能体通过与环境交互来学习一种策略(policy),该策略定义了在每个状态下采取哪种行动的概率分布。策略可以是确定性的或随机的。在策略梯度方法中,策略通常表示为参数化的概率分布,即 $\pi_\theta(a|s)$,其中$\theta$ 是策略的参数,$s$ 是状态,$a$ 是行动。

目标是找到最佳的策略参数 $\theta$ 使得智能体在环境中获得的期望回报最大。为此,我们需要定义一个目标函数$J(\theta)$,表示期望回报。然后,通过梯度上升法(或下降法)来优化该目标函数。

策略梯度的数学推导

假设我们的目标函数 $J(\theta)$ 定义为:

J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} [R(\tau)]

其中$\tau$ 表示一个完整的轨迹(从初始状态到终止状态的状态-动作序列),$R(\tau)$ 是该轨迹的总回报。根据策略的定义,我们有:

\pi_\theta(\tau) = p(s_0) \prod_{t=0}^{T-1} \pi_\theta(a_t|s_t) p(s_{t+1}|s_t, a_t)

因此,目标函数可以重写为:

J(\theta) = \sum_{\tau} \pi_\theta(\tau) R(\tau)

为了最大化$J(\theta)$,我们需要计算其梯度 $\nabla_\theta J(\theta)$

\nabla_\theta J(\theta) = \nabla_\theta \sum_{\tau} \pi_\theta(\tau) R(\tau) = \sum_{\tau} \nabla_\theta \pi_\theta(\tau) R(\tau)

使用概率分布的梯度性质,我们有:

\nabla_\theta \pi_\theta(\tau) = \pi_\theta(\tau) \nabla_\theta \log \pi_\theta(\tau)

因此,梯度可以表示为:

\nabla_\theta J(\theta) = \sum_{\tau} \pi_\theta(\tau) \nabla_\theta \log \pi_\theta(\tau) R(\tau) = \mathbb{E}_{\tau \sim \pi_\theta} [\nabla_\theta \log \pi_\theta(\tau) R(\tau)]

这个公式被称为策略梯度定理。为了估计这个期望值,我们通常使用蒙特卡洛方法,从策略 $\pi_\theta$ 中采样多个轨迹 $\tau$,然后计算平均值。

策略梯度算法的实现

我们以一个简单的环境为例,展示如何实现策略梯度算法。假设我们有一个离散动作空间的环境,我们使用一个神经网络来参数化策略$\pi_\theta(a|s)$

步骤 1:环境设置

首先,设置环境和参数:

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optimenv = gym.make('CartPole-v1')
n_actions = env.action_space.n
state_dim = env.observation_space.shape[0]
步骤 2:策略网络定义

定义一个简单的策略网络:

class PolicyNetwork(nn.Module):def __init__(self, state_dim, n_actions):super(PolicyNetwork, self).__init__()self.fc1 = nn.Linear(state_dim, 128)self.fc2 = nn.Linear(128, n_actions)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return torch.softmax(x, dim=-1)policy = PolicyNetwork(state_dim, n_actions)
optimizer = optim.Adam(policy.parameters(), lr=0.01)
步骤 3:采样轨迹

编写函数来从策略中采样轨迹:

def sample_trajectory(env, policy, max_steps=1000):state = env.reset()states, actions, rewards = [], [], []for _ in range(max_steps):state = torch.FloatTensor(state).unsqueeze(0)probs = policy(state)action = np.random.choice(n_actions, p=probs.detach().numpy()[0])next_state, reward, done, _ = env.step(action)states.append(state)actions.append(action)rewards.append(reward)if done:breakstate = next_statereturn states, actions, rewards
步骤 4:计算回报和梯度

计算每个状态的回报,并使用策略梯度定理更新策略:

def compute_returns(rewards, gamma=0.99):returns = []G = 0for r in reversed(rewards):G = r + gamma * Greturns.insert(0, G)return returnsdef update_policy(policy, optimizer, states, actions, returns):returns = torch.FloatTensor(returns)loss = 0for state, action, G in zip(states, actions, returns):state = state.squeeze(0)probs = policy(state)log_prob = torch.log(probs[action])loss += -log_prob * Goptimizer.zero_grad()loss.backward()optimizer.step()
步骤 5:训练策略

将上述步骤组合在一起,训练策略网络:

num_episodes = 1000
for episode in range(num_episodes):states, actions, rewards = sample_trajectory(env, policy)returns = compute_returns(rewards)update_policy(policy, optimizer, states, actions, returns)if episode % 100 == 0:print(f"Episode {episode}, total reward: {sum(rewards)}")
总结

通过以上步骤,我们实现了一个基本的策略梯度算法。策略梯度方法通过直接优化策略来最大化智能体的期望回报,具有理论上的简洁性和实用性。本文详细推导了策略梯度的数学公式,并提供了具体的实现步骤,希望能够帮助读者更好地理解和应用这一重要的强化学习算法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/38004.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Python3的内置函数和使用方法】

目录 Python 特点 Python 中文编码 Python 变量类型 Python列表 Python 元组 元组是另一个数据类型,类似于 List(列表) Python 字典 Python数据类型转换 Python 运算符 Python算术运算符 Python比较运算符 Python赋值运算符 Pyt…

(笔记)CentOS7上安装neovim

sudo yum install epel-release sudo yum install snapd sudo systemctl enable --now snapd.socket sudo ln -s /var/lib/snapd/snap /snap sudo snap install nvim --classic nvim ok,搞定 如果之前用yum安装了旧版本的neovim往下看(之前没有安装…

一篇就够了,为你答疑解惑:锂电池一阶模型-离线参数辨识(附代码)

锂电池一阶模型-参数离线辨识 背景模型简介数据收集1. 最大可用容量实验2. 开路电压实验3. 混合动力脉冲特性实验离线辨识对应模型对应代码总结下期预告文章字数有点多,耐心不够的谨慎点击阅读。 下期继续讲解在线参数辨识方法。 背景 最近又在开始重新梳理锂电池建模仿真与S…

使用stat()函数的例子

代码&#xff1a; #include <sys/types.h> #include <sys/stat.h> #include <unistd.h> #include <stdio.h>int main(void) {struct stat st;if(-1stat("test.txt",&st)){printf("获得文件状态失败\n");return -1;}printf(&q…

【Rust】——所有的模式语法

&#x1f4bb;博主现有专栏&#xff1a; C51单片机&#xff08;STC89C516&#xff09;&#xff0c;c语言&#xff0c;c&#xff0c;离散数学&#xff0c;算法设计与分析&#xff0c;数据结构&#xff0c;Python&#xff0c;Java基础&#xff0c;MySQL&#xff0c;linux&#xf…

Unidbg调用-补环境V2

1.B站 内部依赖自定义的SignedQuery对象,需要找到apk中的类并补充环境。 package com.nb.demo;import com.github.unidbg.AndroidEmulator

llama3模型部署时遇到的问题及解决方案

在llama3模型部署时&#xff0c;会遇到一系列问题&#xff0c;这里就作者所遇到的问题与解决方法分享一下。 注意&#xff1a;这里是从llama3 github主页上给的方法一步步做的&#xff0c;不适用于其他部署大模型的方法。 文章目录 ERROR 403&#xff1a;Forbidden安装依赖时出…

洛谷 P1548 [NOIP1997 普及组] 棋盘问题

题目 洛谷 P1548 [NOIP1997 普及组] 棋盘问题 [NOIP1997 普及组] 棋盘问题 题目背景 NOIP1997 普及组第一题 题目描述 设有一个 N M N \times M NM 方格的棋盘 ( 1 ≤ N ≤ 100 , 1 ≤ M ≤ 100 ) (1≤N≤100,1≤M≤100) (1≤N≤100,1≤M≤100) 求出该棋盘中包含有多少个正…

牛客C++刷题记录

C 运算符优先级 运算符优先级顺口溜&#xff1a;淡云一笔&#xff0c;鞍落三服。 淡&#xff1a;单目运算符&#xff1b; 云&#xff1a;算数运算符&#xff1b; 一&#xff1a;移位运算符&#xff1b; 笔&#xff1a;比较运算符&#xff1b; 鞍&#xff1a;按位运算符&a…

MySQL高级-MVCC-undo log 版本链

文章目录 1、undo log2、undo log 版本链2.1、然后&#xff0c;有四个并发事务同时在访问这张表。2.1.1、修改id为30记录&#xff0c;age改为32.1.2、修改id为30记录&#xff0c;name改为A32.1.3、修改id为30记录&#xff0c;age改为10 2.2、总结 1、undo log 回滚日志&#xf…

文件系统(操作系统实验)

实验内容 &#xff08;1&#xff09;在内存中开辟一个虚拟磁盘空间作为文件存储器&#xff0c; 在其上实现一个简单单用户文件系统。 在退出这个文件系统时&#xff0c;应将改虚拟文件系统保存到磁盘上&#xff0c; 以便下次可以将其恢复到内存的虚拟空间中。 &#xff08;2&…

算法训练(leetcode)第二十一天 | 93. 复原 IP 地址、78. 子集、90. 子集 II

刷题记录 93. 复原 IP 地址78. 子集90. 子集 II 93. 复原 IP 地址 leetcode题目地址 题目有一个很重要的要求&#xff1a;你 不能 重新排序或删除 s 中的任何数字。你可以按 任何 顺序返回答案。 也就是说ip地址中需要包含整个字符串中的字符且顺序不可变。 ip地址的每一个数…

数字孪生煤矿智能化综合管控平台

煤矿可视化通过图扑 HT 实现实时数据集成和三维建模仿真&#xff0c;呈现井下环境、设备状态和生产状况等多维度数据&#xff0c;帮助管理人员进行直观监控和精准分析。该技术提升了运营效率和安全水平&#xff0c;为煤矿作业提供了智能化的管理解决方案&#xff0c;有助于减少…

黑马点评DAY1|Redis入门、Redis安装

什么是Redis&#xff1f; redis是一种键值型数据库&#xff0c;内部所存的数据都是键值对的形式&#xff0c;例如&#xff0c;我们可以把一个用户数据存储为如下格式&#xff1a; 键值id$1600name张三age21 但是这样的存储方式&#xff0c;数据会显得非常松散&#xff0c;因…

云计算HCIE+RHCE学员的学习分享

大一下学期&#xff0c;我从学长嘴里了解到誉天教育&#xff0c;当时准备考RHCE&#xff0c;我也了解了很多培训机构&#xff0c;然后学长强烈给我推荐誉天&#xff0c;我就在誉天报名了RHCE的课程。 通过杨峰老师的教学&#xff0c;我学到了许多Linux知识&#xff0c;也了解了…

笔记本电脑部署VMware ESXi 6.0系统

正文共&#xff1a;888 字 18 图&#xff0c;预估阅读时间&#xff1a;1 分钟 前面我们介绍了在笔记本上安装Windows 11操作系统&#xff08;Windows 11升级不了&#xff1f;但Win10就要停服了啊&#xff01;来&#xff0c;我教你&#xff01;&#xff09;&#xff0c;也介绍了…

【单片机毕业设计选题24037】-基于STM32的电力系统电力参数无线监控系统

系统功能: 系统上电后&#xff0c;OLED显示“欢迎使用电力监控系统请稍后”&#xff0c;两秒后显示“Waiting..”等待ESP8266初始化完成&#xff0c; ESP8266初始化成功后进入正常页面显示&#xff0c; 第一行显示电压值&#xff08;单位V&#xff09; 第二行显示电流值&am…

Java 使用Objects equals 、 != 、equals 比较对象之间的区别?

在Java中&#xff0c;比较对象是否相等的方法主要有三种&#xff1a;Objects.equals() 方法、! 操作符和 equals() 方法。它们之间的区别如下&#xff1a; Objects.equals() 方法&#xff1a; Objects.equals(a, b) 是一个静态方法&#xff0c;用于安全地比较两个对象是否相等。…

FastAPI中的Lifespan和异步上下文管理器:深入理解和实践

FastAPI中的Lifespan和异步上下文管理器&#xff1a;深入理解和实践 FastAPI中的Lifespan和异步上下文管理器&#xff1a;深入理解和实践1. 代码解析2. 异步上下文管理器2.1 什么是异步上下文管理器&#xff1f;2.2 asynccontextmanager装饰器2.3 代码示例 3. FastAPI的Lifespa…

现代信息检索笔记(一)

目录 什么是信息检索 应用一&#xff1a;做搜索引擎 应用二&#xff1a;信息推荐系统 应用三&#xff1a;婚恋网站 信息检索的具体应用 从信息规模上分类 为什么要学习信息检索技术&#xff1f; 市场发展需求大 应用需求多&#xff1a; 课程情况 课程宗旨 国际著名…