强化学习系列--带基准线的REINFORCE算法

强化学习系列--带基准线的REINFORCE算法

  • 介绍
    • 示例代码

介绍

在强化学习中,带基准线的REINFORCE算法是一种用于求解策略梯度的方法。REINFORCE算法(也称为蒙特卡洛策略梯度算法)可以用于训练能够从环境中学习的策略。带基准线的REINFORCE算法是对经典REINFORCE算法的改进,通过引入一个基准线来减小方差,加速学习的过程

REINFORCE算法通过采样轨迹并利用蒙特卡洛方法来估计策略梯度。该算法的目标是最大化期望回报,即最大化累积奖励的期望值。具体来说,REINFORCE算法使用以下更新规则来更新策略参数:

Δ θ = α ∑ t = 0 T ∇ θ log ⁡ ( π ( a t ∣ s t ) ) G t \Delta\theta = \alpha \sum_{t=0}^{T} \nabla\theta \log(\pi(a_t|s_t)) G_t Δθ=αt=0Tθlog(π(atst))Gt

其中, Δ θ \Delta\theta Δθ是策略参数的更新量, α \alpha α是学习率, ∇ θ \nabla\theta θ是对策略参数的梯度, π ( a t ∣ s t ) \pi(a_t|s_t) π(atst)是在状态 s t s_t st下选择动作 a t a_t at的概率, G t G_t Gt是从时间步 t t t开始的累积奖励。

带基准线的REINFORCE算法则在更新规则中引入了一个基准线 b ( s t ) b(s_t) b(st),用来减小方差。基准线可以是任何函数,通常选择一个与状态有关的函数,如状态值函数 V ( s t ) V(s_t) V(st)。更新规则变为:

Δ θ = α ∑ t = 0 T ∇ θ log ⁡ ( π ( a t ∣ s t ) ) ( G t − b ( s t ) ) \Delta\theta = \alpha \sum_{t=0}^{T} \nabla\theta \log(\pi(a_t|s_t)) (G_t - b(s_t)) Δθ=αt=0Tθlog(π(atst))(Gtb(st))

通过减去基准线的值,可以降低更新的方差,加速学习过程。基准线可以看作是估计的奖励期望的偏差,通过减去这个偏差,可以更准确地估计策略梯度。

带基准线的REINFORCE算法的步骤如下:

  1. 初始化策略参数 θ \theta θ和基准线 b ( s ) b(s) b(s)
  2. 与环境交互,采样轨迹,记录奖励和状态序列。
  3. 对于每个时间步 t t t,计算状态 s t s_t st下选择动作 a t a_t at的概率 π ( a t ∣ s t ) \pi(a_t|s_t) π(atst)和基准线 b ( s t ) b(s_t) b(st)的值。
  4. 计算策略梯度 ∇ θ log ⁡ ( π ( a t ∣ s t ) ) ( G t − b ( s t ) ) \nabla\theta \log(\pi(a_t|s_t)) (G_t - b(s_t)) θlog(π(atst))(Gtb(st))
  5. 更新策略参数 θ \theta θ θ = θ + α ∇ θ log ⁡ ( π ( a t ∣ s t ) ) ( G t − b ( s t ) ) \theta = \theta + \alpha \nabla\theta \log(\pi(a_t|s_t)) (G_t - b(s_t)) θ=θ+αθlog(π(atst))(Gtb(st))
  6. 重复步骤2-5,直到达到停止条件。

带基准线的REINFORCE算法通过减小方差来加速学习过程,提高了算法的稳定性和收敛性。选择合适的基准线函数对算法的表现有重要影响,通常需要通过实验来确定。

示例代码

下面是一个简单的Python代码实现带基准线的REINFORCE算法,其中包含了策略网络和基线网络:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import gym# 定义策略网络
class PolicyNetwork(nn.Module):def __init__(self, input_dim, output_dim):super(PolicyNetwork, self).__init__()self.fc1 = nn.Linear(input_dim, 32)self.fc2 = nn.Linear(32, output_dim)def forward(self, x):x = F.relu(self.fc1(x))x = F.softmax(self.fc2(x), dim=-1)return x# 定义基线网络
class BaselineNetwork(nn.Module):def __init__(self, input_dim):super(BaselineNetwork, self).__init__()self.fc1 = nn.Linear(input_dim, 32)self.fc2 = nn.Linear(32, 1)def forward(self, x):x = F.relu(self.fc1(x))x = self.fc2(x)return x# 定义带基准线的REINFORCE算法
def reinforce(env, policy_net, baseline_net, num_episodes, gamma=0.99):optimizer_policy = optim.Adam(policy_net.parameters(), lr=0.01)optimizer_baseline = optim.Adam(baseline_net.parameters(), lr=0.01)for episode in range(num_episodes):state = env.reset()log_probs = []rewards = []while True:state_tensor = torch.FloatTensor(state)action_probs = policy_net(state_tensor)action_dist = torch.distributions.Categorical(action_probs)action = action_dist.sample()log_prob = action_dist.log_prob(action)log_probs.append(log_prob)next_state, reward, done, _ = env.step(action.item())rewards.append(reward)if done:breakstate = next_statereturns = []cumulative_return = 0for r in reversed(rewards):cumulative_return = r + gamma * cumulative_returnreturns.insert(0, cumulative_return)returns = torch.FloatTensor(returns)log_probs = torch.stack(log_probs)# 计算基准线的值values = baseline_net(torch.FloatTensor(state))advantages = returns - values# 更新策略网络policy_loss = -(log_probs * advantages.detach()).mean()optimizer_policy.zero_grad()policy_loss.backward()optimizer_policy.step()# 更新基线网络baseline_loss = F.mse_loss(returns, values)optimizer_baseline.zero_grad()baseline_loss.backward()optimizer_baseline.step()if episode % 10 == 0:print(f"Episode {episode}: policy_loss = {policy_loss}, baseline_loss = {baseline_loss}")# 创建环境和网络
env = gym.make('CartPole-v1')
input_dim = env.observation_space.shape[0]
output_dim = env.action_space.npolicy_net = PolicyNetwork(input_dim, output_dim)
baseline_net = BaselineNetwork(input_dim)# 运行带基准线的REINFORCE算法
reinforce(env, policy_net, baseline_net, num_episodes=1000)

这段代码使用PyTorch实现了带基准线的REINFORCE算法,并应用于OpenAI Gym中的CartPole环境。首先定义了策略网络和基线网络的结构,然后在reinforce函数中进行算法的训练过程。在每个回合中,通过策略网络选择动作,并计算动作的概率和对数概率。同时,也记录下每个动作的奖励。当回合结束后,根据累积奖励计算优势函数,并使用优势函数更新策略网络和基线网络的参数。算法通过反复迭代多个回合来提高策略的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/53040.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Pytorch建立MyDataLoader过程详解

简介 torch.utils.data.DataLoader(dataset, batch_size1, shuffleNone, samplerNone, batch_samplerNone, num_workers0, collate_fnNone, pin_memoryFalse, drop_lastFalse, timeout0, worker_init_fnNone, multiprocessing_contextNone, generatorNone, *, prefetch_factorN…

使用Python爬虫定制化开发自己需要的数据集

在数据驱动的时代,获取准确、丰富的数据对于许多项目和业务至关重要。本文将介绍如何使用Python爬虫进行定制化开发,以满足个性化的数据需求,帮助你构建自己需要的数据集,为数据分析和应用提供有力支持。 1.确定数据需求和采集目…

flutter ios webview不能打开http地址

参考 1、iOS添加信任 webview_flutter 在使用过程中会iOS出现无法加载HTTP请求的情况&#xff0c; 但是Flutter 却可以加载HTTP请求。这就与两个的框架有关了&#xff0c;Flutter是独立于UIKit框架的。 解决方案就是在iOS 的info.plist中添加对HTTP的信任。 <key>NSApp…

拼多多淘宝大量缓存商品数据用什么格式提供比较好?

众所周知&#xff0c;淘宝拼多多是我国主流的电商平台&#xff0c;其上有大量的商品数据。很多商家会通过API来访问他们的商品数据&#xff0c;根据API的调用次数收费。第三方数据公司提供电商数据接口API&#xff0c;采集实时数据。但是&#xff0c;在他们的服务器上有大量的缓…

【2023钉钉杯复赛】A题 智能手机用户监测数据分析 Python代码分析

【2023钉钉杯复赛】A题 智能手机用户监测数据分析 Python代码分析 1 题目 一、问题背景 近年来&#xff0c;随着智能手机的产生&#xff0c;发展到爆炸式的普及增长&#xff0c;不仅推动了中 国智能手机市场的发展和扩大&#xff0c;还快速的促进手机软件的开发。近年中国智能…

【教程】Java 集成Mongodb

【教程】Java 集成Mongodb 依赖 <dependency><groupId>org.mongodb</groupId><artifactId>mongo-java-driver</artifactId><version>3.12.14</version></dependency> <dependency><groupId>cn.hutool</groupId…

网络安全应急响应预案培训

应急响应预案的培训是为了更好地应对网络突发状况&#xff0c;实施演 练计划所做的每一项工作&#xff0c;其培训过程主要针对应急预案涉及的相 关内容进行培训学习。做好应急预案的培训工作能使各级人员明确 自身职责&#xff0c;是做好应急响应工作的基础与前提。应急响应…

CleanMyMac2024永久版Mac清理工具

Mac电脑作为相对封闭的一个系统&#xff0c;它会中毒吗&#xff1f;如果有一天Mac电脑产生了疑似中毒或者遭到恶意不知名攻击的现象&#xff0c;那又应该如何从容应对呢&#xff1f;这些问题都是小编使用Mac系统一段时间后产生的疑惑&#xff0c;通过一番搜索研究&#xff0c;小…

人机识别:走近智能时代的大门

在当今数字化快速发展的时代&#xff0c;人机识别技术正成为引领人工智能革命的重要一环。人机识别&#xff0c;即通过计算机视觉和模式识别技术&#xff0c;使机器能够自动识别、分析、理解和处理人类的信息&#xff0c;逐渐渗透到我们的生活和工作中。从简单的人脸识别到更复…

Redis 7 教程 数据类型 基础篇

🌹 引导 Commands | Redishttps://redis.io/commands/Redis命令中心(Redis commands) -- Redis中国用户组(CRUG)Redis命令大全,显示全部已知的redis命令,redis集群相关命令,近期也会翻译过来,Redis命令参考,也可以直接输入命令进行命令检索。

图为科技_边缘计算在智能安防领域的作用

边缘计算在智能安防领域发挥着重要的作用。智能安防系统通常需要处理大量的图像、视频和传感器数据&#xff0c;并对其进行实时分析和处理。边缘计算可以将计算和数据处理功能移动到离数据源更接近的地方&#xff0c;例如摄像头、传感器设备或安防终端。 以下是边缘计算在智能…

网络爬虫到底是个啥?

网络爬虫到底是个啥&#xff1f; 当涉及到网络爬虫技术时&#xff0c;需要考虑多个方面&#xff0c;从网页获取到最终的数据处理和分析&#xff0c;每个阶段都有不同的算法和策略。以下是这些方面的详细解释&#xff1a; 网页获取&#xff08;Web Crawling&#xff09;&#x…

10 - 网络通信优化之通信协议:如何优化RPC网络通信?

微服务框架中 SpringCloud 和 Dubbo 的使用最为广泛&#xff0c;行业内也一直存在着对两者的比较&#xff0c;很多技术人会为这两个框架哪个更好而争辩。 我记得我们部门在搭建微服务框架时&#xff0c;也在技术选型上纠结良久&#xff0c;还曾一度有过激烈的讨论。当前 Sprin…

URI、URL、URIBuilder、UriBuilder、UriComponentsBuilder说明及基本使用

之前想过直接获取url通过拼接字符串的方式实现,但是这种只是暂时的,后续地址如果有变化或参数很多,去岂不是要拼接很长,由于这些等等原因,所以找了一些方法实现 java.net.URI URI全称是Uniform Resource Identifier,也就是统一资源标识符,它是一种采用特定的语法标识一…

强化学习时序差分学习方法--SARSA算法

强化学习时序差分学习方法--SARSA算法 介绍示例代码 介绍 SARSA&#xff08;State-Action-Reward-State-Action&#xff09;是一种强化学习算法&#xff0c;用于解决马尔可夫决策过程&#xff08;MDP&#xff09;中的问题。SARSA算法属于基于值的强化学习算法&#xff0c;用于…

Redis添加LocalDateTime时间序列化/反序列化Java 8报‘jackson-datatype-jsr310’问题

错误信息&#xff1a; com.fasterxml.jackson.databind.exc.InvalidDefinitionException: Java 8 date/time type java.time.LocalDateTime not supported by default: add Module "com.fasterxml.jackson.datatype:jackson-datatype-jsr310" to enable handling (t…

Navicat 连接 mysql 问题

需要将mysql配置文件设置为远程任意ip可登陆&#xff0c;注释掉一下两行配置 # bind-address>->--- 127.0.0.1 # mysqlx-bind-address>-- 127.0.0.1Cant connect to MySQL server on "192.168.137.139 (10013 "Unknown error") 检查Navicat是否联网H…

OSCS开源安全周报第 56 期:Apache Airflow Spark Provider 任意文件读取漏洞

本周安全态势综述 OSCS 社区共收录安全漏洞 3 个&#xff0c;公开漏洞值得关注的是 Apache NiFi 连接 URL 验证绕过漏洞(CVE-2023-40037)、PowerJob 未授权访问漏洞(CVE-2023-36106)、Apache Airflow Spark Provider 任意文件读取漏洞(CVE-2023-40272)。 针对 NPM 、PyPI 仓库…

stm32之点亮LED

今天&#xff0c;记录一下stm32如何点亮一个LED,程序本身十分简单&#xff0c;但主要是学习编程的格式。 led.h #ifndef _led_H #define _led_H#include "system.h"/* LED时钟端口、引脚定义 */ #define LED1_PORT GPIOB #define LED1_PIN GPIO_Pin_5 #d…

开发一款AR导览导航小程序多少钱?ar地图微信小程序 ar导航 源码

随着科技的不断发展&#xff0c;增强现实&#xff08;AR&#xff09;技术在不同领域展现出了巨大的潜力。AR导览小程序作为其中的一种应用形式&#xff0c;为用户提供了全新的观赏和学习体验。然而&#xff0c;开发一款高质量的AR导览小程序需要投入大量的时间、人力和技术资源…