强化学习--多维动作状态空间的设计

目录

  • 一、离散动作
  • 二、连续动作
    • 1、例子1
    • 2、知乎给出的示例
    • 2、github里面的代码

免责声明:以下代码部分来自网络,部分来自ChatGPT,部分来自个人的理解。如有其他观点,欢迎讨论!

一、离散动作

注意:本文均以PPO算法为例。

# time: 2023/11/22 21:04
# author: YanJPimport torch
import torch
import torch.nn as nn
from torch.distributions import Categoricalclass MultiDimensionalActor(nn.Module):def __init__(self, input_dim, output_dims):super(MultiDimensionalActor, self).__init__()# Define a shared feature extraction networkself.feature_extractor = nn.Sequential(nn.Linear(input_dim, 128),nn.ReLU(),nn.Linear(128, 64),nn.ReLU())# Define individual output layers for each action dimensionself.output_layers = nn.ModuleList([nn.Linear(64, num_actions) for num_actions in output_dims])def forward(self, state):# Feature extractionfeatures = self.feature_extractor(state)# Generate Categorical objects for each action dimensioncategorical_objects = [Categorical(logits=output_layer(features)) for output_layer in self.output_layers]return categorical_objects# 定义主函数
def main():# 定义输入状态维度和每个动作维度的动作数input_dim = 10output_dims = [5, 8]  # 两个动作维度,分别有 3 和 4 个可能的动作# 创建 MultiDimensionalActor 实例actor_network = MultiDimensionalActor(input_dim, output_dims)# 生成输入状态(这里使用随机数据作为示例)state = torch.randn(1, input_dim)# 调用 actor 网络categorical_objects = actor_network(state)# 输出每个动作维度的采样动作和对应的对数概率for i, categorical in enumerate(categorical_objects):sampled_action = categorical.sample()log_prob = categorical.log_prob(sampled_action)print(f"Sampled action for dimension {i+1}: {sampled_action.item()}, Log probability: {log_prob.item()}")if __name__ == "__main__":main()#Sampled action for dimension 1: 1, Log probability: -1.4930928945541382
#Sampled action for dimension 2: 3, Log probability: -2.1875085830688477

注意代码中categorical函数的两个不同传入参数的区别:参考链接
简单来说,logits是计算softmax的,probs直接就是已知概率的时候传进去就行。

二、连续动作

参考链接:github、知乎
为什么取对数概率?参考回答
在这里插入图片描述

1、例子1

先看如下的代码:

# time: 2023/11/21 21:33
# author: YanJP
#这是对应多维连续变量的例子:
# 参考链接:https://github.com/XinJingHao/PPO-Continuous-Pytorch/blob/main/utils.py
# https://www.zhihu.com/question/417161289
import torch.nn as nn
import torch
class Policy(nn.Module):def __init__(self, in_dim, n_hidden_1, n_hidden_2, num_outputs):super(Policy, self).__init__()self.layer = nn.Sequential(nn.Linear(in_dim, n_hidden_1),nn.ReLU(True),nn.Linear(n_hidden_1, n_hidden_2),nn.ReLU(True),nn.Linear(n_hidden_2, num_outputs))class Normal(nn.Module):def __init__(self, num_outputs):super().__init__()self.stds = nn.Parameter(torch.zeros(num_outputs))  #创建一个可学习的参数 def forward(self, x):dist = torch.distributions.Normal(loc=x, scale=self.stds.exp())action = dist.sample((every_dimention_output,))  #这里我觉得是最重要的,不填sample的参数的话,默认每个分布只采样一个值!!!!!!!!return actionif __name__ == '__main__':policy = Policy(4,20,20,5)normal = Normal(5) #设置5个维度every_dimention_output=10  #每个维度10个输出observation = torch.Tensor(4)action = normal.forward(policy.layer( observation))print("action: ",action)
  • self.stds.exp(),表示求指数,因为正态分布的标准差都是正数。
  • action = dist.sample((every_dimention_output,))这里最重要!!!

2、知乎给出的示例


class Agent(nn.Module):def __init__(self, envs):super(Agent, self).__init__()self.actor_mean = nn.Sequential(layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),nn.Tanh(),layer_init(nn.Linear(64, 64)),nn.Tanh(),layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01),)self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))def get_action_and_value(self, x, action=None):action_mean = self.actor_mean(x)action_logstd = self.actor_logstd.expand_as(action_mean)action_std = torch.exp(action_logstd)probs = Normal(action_mean, action_std)if action is None:action = probs.sample()return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)

这里的np.prod(envs.single_action_space.shape),表示每个维度的动作数相乘,然后初始化这么多个actor网络的标准差和均值,最后action里面的sample就是采样这么多个数据。(感觉还是拉成了一维计算)

2、github里面的代码

github

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Beta,Normalclass GaussianActor_musigma(nn.Module):def __init__(self, state_dim, action_dim, net_width):super(GaussianActor_musigma, self).__init__()self.l1 = nn.Linear(state_dim, net_width)self.l2 = nn.Linear(net_width, net_width)self.mu_head = nn.Linear(net_width, action_dim)self.sigma_head = nn.Linear(net_width, action_dim)def forward(self, state):a = torch.tanh(self.l1(state))a = torch.tanh(self.l2(a))mu = torch.sigmoid(self.mu_head(a))sigma = F.softplus( self.sigma_head(a) )return mu,sigmadef get_dist(self, state):mu,sigma = self.forward(state)dist = Normal(mu,sigma)return distdef deterministic_act(self, state):mu, sigma = self.forward(state)return mu

上述代码主要是通过设置mu_head 和sigma_head的个数,来实现多维动作。

class GaussianActor_mu(nn.Module):def __init__(self, state_dim, action_dim, net_width, log_std=0):super(GaussianActor_mu, self).__init__()self.l1 = nn.Linear(state_dim, net_width)self.l2 = nn.Linear(net_width, net_width)self.mu_head = nn.Linear(net_width, action_dim)self.mu_head.weight.data.mul_(0.1)self.mu_head.bias.data.mul_(0.0)self.action_log_std = nn.Parameter(torch.ones(1, action_dim) * log_std)def forward(self, state):a = torch.relu(self.l1(state))a = torch.relu(self.l2(a))mu = torch.sigmoid(self.mu_head(a))return mudef get_dist(self,state):mu = self.forward(state)action_log_std = self.action_log_std.expand_as(mu)action_std = torch.exp(action_log_std)dist = Normal(mu, action_std)return distdef deterministic_act(self, state):return self.forward(state)
class Critic(nn.Module):def __init__(self, state_dim,net_width):super(Critic, self).__init__()self.C1 = nn.Linear(state_dim, net_width)self.C2 = nn.Linear(net_width, net_width)self.C3 = nn.Linear(net_width, 1)def forward(self, state):v = torch.tanh(self.C1(state))v = torch.tanh(self.C2(v))v = self.C3(v)return v

上述代码只定义了mu的个数与维度数一样,std作为可学习的参数之一。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/160810.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何使用SD-WAN提升物流供应链网络效率

案例背景 本次分享的物流供应链企业是一家国际性的大型企业,专注于提供全球范围内的物流和供应链解决方案。案例用户在不同国家和地区均设有多个分支机构和办公地点,以支持客户需求和业务运营。 在过去,该企业用户使用传统的MPLS网络来连接各…

OceanBase:04-单机在线转分布式部署

目录 1.当前部署情况 2.单Zone多OBServer模式 3.多Zone多OBServer模式 3.1 集群规划 3.2 安装OBServer程序 3.3 新增Zone 3.4 启动Zone 3.5 向Zone新增OBserver节点 3.6重复3.2~3.5新增其他Zone 4.扩充资源 OceanBase 数据库为单机分布式一体化架构,支持单…

纯干货丨电脑监控软件有哪些(三款电脑监控软件大盘点)

电脑监控软件在日常生活和工作中的应用越来越广泛。这些软件可以帮助我们监控电脑的使用情况,保护电脑的安全,提高工作效率。本文将介绍一些高人气的电脑监控软件,并分享一些纯干货。 1、 域之盾软件----电脑监控系统 是一款功能强大的电脑监…

Linux输入设备应用编程(触摸屏获取坐标信息)

上一章学习了开发板外接键盘并获取键盘的的输入 Linux输入设备应用编程(键盘,按键)-CSDN博客 本章编写触摸屏应用程序,获取触摸屏的坐标信息并将其打印出来 目录 一 触摸屏数据分析(触摸,点击&#xff…

采用connector-c++ 8.0操作数据库

1.下载最新的Connector https://dev.mysql.com/downloads/connector/cpp/,下载带debug的库。 解压缩到本地,本次使用的是带debug模式的connector库: 注:其中mysqlcppconn与mysqlcppconn8的区别是: 2.在cmakelist…

请简要说明 Mysql 中 MyISAM 和 InnoDB 引擎的区别

“请简要说明 Mysql 中 MyISAM 和 InnoDB 引擎的区别”。 屏幕前有多少同学在面试过程与遇到过类似问题, 可以在评论区留言:遇到过。 考察目的 对于 xxxx 技术的区别,在面试中是很常见的一个问题 一般情况下,面试官会通过这类…

SpringBoot监听器解析

监听器模式介绍 监听器模式的要素 事件监听器广播器触发机制 SpringBoot监听器实现 系统事件 事件发送顺序 监听器注册 监听器注册和初始化器注册流程类似 监听器触发机制 获取监听器列表核心流程: 通用触发条件: 自定义监听器实现 实现方式1 实现监听器接口: Order(1) …

[操作系统]进程和线程

目录 1.什么是进程 1.1进程控制块抽象 1.2 CPU 分配 —— 进程调度(Process Scheduling) 1.3内存分配 —— 内存管理(Memory Manage) 1.4进程间通信(Inter Process Communication) 2.线程 2.1概念 2.2为什么要有线程 2.3线…

论文阅读 Forecasting at Scale (二)

最近在看时间序列的文章,回顾下经典 论文地址 项目地址 Forecasting at Scale 3.2、季节性 3.3、假日和活动事件3.4、模型拟合3.5、分析师参与的循环建模4、自动化预测评估4.1、使用基线预测4.2、建模预测准确性4.3、模拟历史预测4.4、识别大的预测误差 5、结论6、致…

【Python】重磅!这本30w人都在看的Python数据分析畅销书更新了!

Python 语言极具吸引力。自从 1991 年诞生以来,Python 如今已经成为最受欢迎的解释型编程语言。 【文末送书】今天推荐一本Python领域优质数据分析书籍,这本30w人都在看的书,值得入手。 目录 作译者简介主要变动导读视频购书链接文末送书 pan…

【计算机方向】通信、算法、自动化、机器人、电子电气、计算机工程、控制工程、计算机视觉~~~~~合集!!!

◆本文为大家梳理了近期可投的EI国际会议,涵盖计算机各个学科方向,均可EI检索 本期EI会议汇总合集涵盖领域:计算机视觉、物联网、算法、通信、智能技术、人工智能、人机交互、机器人、电子电气等众多领域! 本期所推荐的EI会议有…

ros2不同机器通讯时IP设置

看到这就是不同机器的IP地址,为了避免在路由器为不同的机器使用DHCP分配到上面的地址,可以设置DHCP分配的范围:(我的路由器是如下设置的,一般路由器型号都不一样,自己找一下) 防火墙设置-----&…

Leetcode—13.罗马数字转整数【简单】

2023每日刷题(三十七) Leetcode—13.罗马数字转整数 算法思想 当前位置的元素比下个位置的元素小,就减去当前值,否则加上当前值 实现代码 int getValue(char c) {switch(c) {case I:return 1;case V:return 5;case X:return 1…

wpf使用CefSharp.OffScreen模拟网页登录,并获取身份cookie

目录 框架信息&#xff1a;MainWindow.xamlMainWindow.xaml.cs爬取逻辑模拟登录拦截请求Cookie获取 CookieVisitorHandle 框架信息&#xff1a; CefSharp.OffScreen.NETCore 119.1.20 MainWindow.xaml <Window x:Class"Wpf_CHZC_Img_Identy_ApiDataGet.MainWindow&qu…

22LLMSecEval数据集及其在评估大模型代码安全中的应用:GPT3和Codex根据LLMSecEval的提示生成代码和代码补全,CodeQL进行安全评估

LLMSecEval: A Dataset of Natural Language Prompts for Security Evaluations 写在最前面主要工作 课堂讨论大模型和密码方向&#xff08;没做&#xff0c;只是一个idea&#xff09; 相关研究提示集目标NL提示的建立NL提示的建立流程 数据集数据集分析 存在的问题 写在最前面…

使用Python画一棵树

&#x1f38a;专栏【不单调的代码】 &#x1f354;喜欢的诗句&#xff1a;更喜岷山千里雪 三军过后尽开颜。 &#x1f386;音乐分享【如愿】 &#x1f970;欢迎并且感谢大家指出我的问题 文章目录 &#x1f339;Turtle模块&#x1f384;效果&#x1f33a;代码&#x1f6f8;代码…

【tomcat】java.lang.Exception: Socket bind failed: [730048

项目中一些旧工程运行情况处理 问题 1、启动端口占用 2、打印编码乱码 ʮһ&#xfffd;&#xfffd; 13, 2023 9:33:26 &#xfffd;&#xfffd;&#xfffd;&#xfffd; org.apache.coyote.AbstractProtocol init &#xfffd;&#xfffd;&#xfffd;&#xfffd;: Fa…

oracle面试相关的,Oracle基本操作的SQL命令

文章目录 数据库-Oracle〇、Oracle用户管理一、Oracle数据库操作二、Oracle表操作1、创建表2、删除表3、重命名表4、增加字段5、修改字段6、重名字段7、删除字段8、添加主键9、删除主键10、创建索引11、删除索引12、创建视图13、删除视图 三、Oracle操作数据1、数据查询2、插入…

Connect-The-Dots_2

Connect-The-Dots_2 一、主机发现和端口扫描 主机发现&#xff0c;靶机地址192.168.80.148 arp-scan -l端口扫描 nmap -A -p- -sV 192.168.80.148开放端口 21/tcp open ftp vsftpd 2.0.8 or later 80/tcp open http Apache httpd 2.4.38 ((Debian)) 111/tcp …

循环队列详解!!c 语言版本(两种方法)双向链表和数组法!!

目录 1.什么是循环队列 2.循环队列的实现&#xff08;两种方法&#xff09; 第一种方法 数组法 1.源代码 2.源代码详解&#xff01;&#xff01; 1.创造队列空间和struct变量 2.队列判空 3.队列判满&#xff08;重点&#xff09; 4.队列的元素插入 5.队列的元素删除 …