多智能体连续行为空间问题求解——MADDPG

目录

      • 1. 问题出现:连续行为空间出现
      • 2. DDPG 算法
        • 2.1 DDPG 算法原理
        • 2.2 DDPG 算法实现代码
          • 2.2.1 Actor & Critic
          • 2.2.2 Target Network
          • 2.2.3 Memory Pool
          • 2.2.4 Update Parameters(evaluate network)
          • 2.2.5 Update Parameters(target network)
      • 3. MADDPG 算法
        • 3.1 Actor 网络定义
        • 3.2 Critic 网络定义
        • 3.3 Update Parameters 过程

MADDPG 是一种针对多智能体、连续行为空间设计的算法。MADDPG 的前身是DDPG,DDPG 算法旨在解决连续性行为空间的强化学习问题,而 MADDPG 是在 DDPG 的基础上做了改进,使其能够适用于多智能体之间的合作任务学习。本文先从 DDPG 引入,接着再介绍如何在 DDPG 算法上进行修改使其变成 MADDPG 算法。

1. 问题出现:连续行为空间出现

Q-Learning 算法是强化学习中一种常用的方法,但传统的 Q-Learning 需要枚举所有的状态空间并建立 Q-Table,为了解决庞大不可枚举的状态空间问题,DQN 被人们设计出来,利用神经网络近似拟合的方法来避免了穷举所有可能的状态空间。但 DQN 算法有一个问题,那就是在计算当前 Q 值的时候需要求出下一个状态中每一个动作的值函数,选择最大的动作值函数值来进行计算。

Qπ(st,at)=R(st,at)+γmaxaQπ(st+1,at+1)Q^{\pi}(s_t, a_t) = R(s_t, a_t) + \gamma max_aQ^{\pi}(s_{t+1}, a_{t+1}) Qπ(st,at)=R(st,at)+γmaxaQπ(st+1,at+1)

在 Actor-Critic 算法中同样会面临这个问题,更新 critic 网络时候需要计算下一个状态下所有行为的Q值并取其平均值,计算公式如下:

Qπ(st,at)=R(st,at)+γEπ[Qπ(st+1,at+1)]Q^{\pi}(s_t, a_t) = R(s_t, a_t) + \gamma E_{\pi}[Q^{\pi}(s_{t+1}, a_{t+1})] Qπ(st,at)=R(st,at)+γEπ[Qπ(st+1,at+1)]

其中 Eπ[Qπ(st+1,at+1)]E_{\pi}[Q^{\pi}(s_{t+1}, a_{t+1})]Eπ[Qπ(st+1,at+1)] 是枚举所有动作的得分效用并乘上对应动作的选取概率(当然在 AC 中可以直接通过拟合一个 V(s)V(s)V(s) 来近似替代枚举结果)。那么不管是 DQN 还是 AC 算法,都涉及到需要计算整个行为空间中所有行为的效用值,一旦行为空间演变为连续型的就无法使用以上算法,因为无法穷举所有的行为并计算所有行为的值之和了。为此,在解决连续行为空间问题的时候,我们需要一种新的算法,能够不用穷举所有行为的值就能完成算法更新,DDPG 的出现很好的解决了这个问题。

2. DDPG 算法

2.1 DDPG 算法原理

DPG(Deterministic Policy Gradient)算法是一种 “确定性行为策略” 算法,我们之前问题的难点在于对于连续的庞大行为空间,我们无法一一枚举所有可能的行为。因此,DPG 认为,在求取下一个状态的状态值时,我们没有必要去计算所有可能的行为值并跟据每个行为被采取的概率做加权平均,我们只需要认为在一个状态下只有可能采取某一个确定的行为 aaa即该行为 aaa 被采取的概率为百分之百,这样就行了,于是整个 Q 值计算函数就变成了:

Qμ(st,at)=R(st,at)+γQμ(st+1,μ(st+1))]Q^{\mu}(s_t, a_t) = R(s_t, a_t) + \gamma Q^{\mu}(s_{t+1}, \mu{(s_{t+1})})] Qμ(st,at)=R(st,at)+γQμ(st+1,μ(st+1))]

即,原本的行为 aaa 是由随机策略 π\piπ 进行概率选择,而现在这个行为由一个确定性策略 μ\muμ 来选择,确定性策略是指只要输入一个状态就一定能得到唯一一个确定的输出行为,而随机性策略指的是输入一个状态,输出的是整个行为空间的所有行为概率分布。DDPG 是 DPG 算法上融合进神经网络技术,变成了 Deep Deterministic Policy Gradient,其整体思路和 DPG 是一致的。

2.2 DDPG 算法实现代码

DDPG 沿用了 Actor-Critic 算法结构,在代码中也存在一个 Actor 和一个 Critic,Actor 负责做行为决策,而 Critic 负责做行为效用评估,这里使用 DDPG 学习玩 gym 中一个倒立摆的游戏,游戏中的 action 为顺时针或逆时针的旋转力度,旋转力度是一个连续行为,力的大小是一个连续的随机变量,最终期望能够通过不断学习后算法能够学会如何让杆子倒立在上面静止不动,如下图所示:

2.2.1 Actor & Critic

我们先来看看在 DDPG 中 Actor 和 Critic 分别是怎么实现的的,Actor 和 Critic 的定义如下(代码参考自这里):

class Actor(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(Actor, self).__init__()self.linear1 = nn.Linear(input_size, hidden_size)self.linear2 = nn.Linear(hidden_size, hidden_size)self.linear3 = nn.Linear(hidden_size, output_size)def forward(self, s):x = F.relu(self.linear1(s))x = F.relu(self.linear2(x))x = torch.tanh(self.linear3(x))return xclass Critic(nn.Module):def __init__(self, input_size, hidden_size, output_size):super().__init__()self.linear1 = nn.Linear(input_size, hidden_size)self.linear2 = nn.Linear(hidden_size, hidden_size)self.linear3 = nn.Linear(hidden_size, output_size)def forward(self, s, a):x = torch.cat([s, a], 1)    # DDPG与普通AC算法的不同之处x = F.relu(self.linear1(x))x = F.relu(self.linear2(x))x = self.linear3(x)return x

Actor 的设计和以往相同,没什么太大变化。
Critic 的实现有了一些改变,在 forward 函数中,原始的 critic 只用传入状态sss,输出所有动作的效用值,但由于这是连续动作空间,无法输出每一个行为的值,因此 critic 网络改为接收一个状态 sss 和一个具体行为 aaa 作为输入,输出的是具体行为 aaa 在当前状态 sss 下的效用值,即 critic 网络输出维度为1

2.2.2 Target Network

除了在 critic 网络上有了改变之外,DDPG 在整个算法层面上也做了修改。DDPG 参照了 DQN 的方式,为了算法添加了 target network,即固定住一个 target 网络产生样本,另一个 evaluate 网络不断更新迭代的思想,因此整个算法包含 4 个网络:

actor = Actor(s_dim, 256, a_dim)
actor_target = Actor(s_dim, 256, a_dim)
critic = Critic(s_dim+a_dim, 256, a_dim)    # 输入维度是 状态空间 + 行为空间
critic_target = Critic(s_dim+a_dim, 256, a_dim)

值得注意的是,在上述 critic 网络中输入的是 s_dim + a_dim,为什么是加 a_dim 呢?因为在 DDPG 算法中,critic 网络评判的是一组行为的效用值,即如果有(油门、方向盘)这两个行为的话,那么传入的应该是(油门大小、方向盘转动度数)这一组行为,critic 网络对这一组动作行为做一个效用评判

2.2.3 Memory Pool

之前提到 DDPG 算法借用了 DQN 思想,除了加入了 Target 网络之外还引入了 Memory Pool 机制,将收集到的历史经验存放到记忆库中,在更新的时候取一个 batch 的数据来计算均值,memory pool 代码如下:

# 经验池
buffer = []# 往经验池存放经验数据
def put(self, *transition): if len(self.buffer)== self.capacity:self.buffer.pop(0)self.buffer.append(transition)
2.2.4 Update Parameters(evaluate network)

在定义好了这些结构之后,我们就开始看看如何进行梯度更新吧。所需要更新参数的网络一共有 4 个,2 个 target network 和 2 个 evaluate network,target network 的更新是在训练迭代了若干轮后将 evaluate network 当前的参数值复制过去即可,只不过这里并不是直接复制,会做一些处理,这里我们先来看 evaluate network 是如何进行参数更新的,actor 和 critic 的更新代码如下 :

def critic_learn():a1 = self.actor_target(s1).detach()y_true = r1 + self.gamma * self.critic_target(s1, a1).detach()	# 下一个状态的目标状态值y_pred = self.critic(s0, a0)		# 下一个状态的预测状态值loss_fn = nn.MSELoss()loss = loss_fn(y_pred, y_true)self.critic_optim.zero_grad()loss.backward()self.critic_optim.step()def actor_learn():loss = -torch.mean( self.critic(s0, self.actor(s0)) )self.actor_optim.zero_grad()loss.backward()self.actor_optim.step()

我们先来看 critic 的 learn 函数,loss 函数比较的是 用当前网络预测当前状态的Q值利用回报R与下一状态的状态值之和 之间的 error 值,现在问题在于下一个状态的状态值如何计算,在 DDPG 算法中由于确定了在一种状态下只会以100%的概率去选择一个确定的动作,因此在计算下一个状态的状态值的时候,直接根据 actor 网络输出一个在下一个状态会采取的行为,把这个行为当作100%概率的确定行为,并根据这个行为和下一刻的状态输入 critic 网络得到下一个状态的状态值,最后通过计算这两个值的差来进行反向梯度更新(TD-ERROR)。

再来看看 actor 的 learn 函数,actor 还是普通的更新思路 —— actor 选择一个可能的行为,通过 reward 来决定增加选取这个 action 的概率还是降低选择这个 action 的概率。而增加/减少概率的多少由 critic 网络来决定,若 critic 网络评判出来当前状态下采取当前行为会得到一个非常高的正效用值,那么梯度更新后 actor 下次采取这个行为的概率就会大幅度增加。而传统的 actor 在进行行为选择时神经网络会输出每一个行为的被采取概率,按照这些概率来随机选择一个行为,但在 DDPG 算法中,所有行为都是被确定性选择的,不会存在随机性,因此在代码中传入的是经过 actor 后得到的输出行为,认为该行为就是100%被确定性选择的,没有之前的按概率选择行为这一个环节了。 选好行为后和当前状态一起传给 critic 网络做效用值评估。

2.2.5 Update Parameters(target network)

Target Network 在 DDPG 算法中沿用了 DQN 的思路,在迭代一定的轮数后,会从 evaluate network 中 copy 参数到自身网络中去。但是不同的是,DDPG 在进行参数复制的时候选择的是 soft update 的方式,即,在进行参数复制的时候不是进行直接复制值,而是将 target net 和 evaluate net 的参数值以一定的权重值加起来,融合成新的网络参数,代码如下:

def soft_update(net_target, net, tau):for target_param, param  in zip(net_target.parameters(), net.parameters()):target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

参数 tau 是保留程度参数,tau 值越大则保留的原网络的参数的程度越大。

3. MADDPG 算法

在理解了 DDPG 算法后,理解 MADDPG 就比较容易了。MADDPG 是 Multi-Agent 下的 DDPG 算法,主要针对于多智能体之间连续行为进行求解。MADDPG 同样沿用了 AC 算法的架构,和 DDPG 相比只是在 Critic 网络上的输入做了一些额外信息的添加,下面结合实际代码来分析:

3.1 Actor 网络定义

class Actor(nn.Module):def __init__(self, args, agent_id):""" 网络层定义部分 """super(Actor, self).__init__()self.fc1 = nn.Linear(args.obs_shape[agent_id], 64)	# 定义输入维度self.fc2 = nn.Linear(64, 64)self.fc3 = nn.Linear(64, 64)self.action_out = nn.Linear(64, args.action_shape[agent_id])	# 定义输出维度def forward(self, x):""" 网络前向传播过程定义 """x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = F.relu(self.fc3(x))actions = torch.tanh(self.action_out(x))return actions

上面是 MADDPG 中 actor 网络的定义代码,由于一个场景中可能存在多种不同的智能体,其观测空间维度与行为空间维度都不尽相同,因此在进行 actor 定义时需传入每个智能体自身所符合的维度信息,如上述代码一样,通过 agent_id 来获取具体的智能体信息,前向传播过程与 DDPG 相同,没有什么特殊之处。

3.2 Critic 网络定义

class Critic(nn.Module):def __init__(self, args):super(Critic, self).__init__()self.max_action = args.high_actionself.fc1 = nn.Linear(sum(args.obs_shape) + sum(args.action_shape), 64)	# 定义输入层维度(联合观测+联合行为)self.fc2 = nn.Linear(64, 64)self.fc3 = nn.Linear(64, 64)self.q_out = nn.Linear(64, 1)def forward(self, state, action):state = torch.cat(state, dim=1)		# 联合观测action = torch.cat(action, dim=1)	# 联合行为x = torch.cat([state, action], dim=1)	# 联合观测 + 联合行为x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = F.relu(self.fc3(x))q_value = self.q_out(x)return q_value

Critic 的代码如上,可见 MADDPG 中的 Critic 是一个中心化网络,即传入的不只是当前 Agent 的(s,a)信息,还加入了其他 Agent 的(s,a)信息。这种做法在多智能体算法中不算新奇了,在训练学习阶段利用中心化的评价网络来指导 Actor 的更新在许多多智能体算法当中都用到了这个技巧。值得一提的是,由于 Critic 需要指导 Actor 的更新,所以理论上需要让 Critic 比 Actor 更快的收敛,因此通常 Critic 的 learning rate 需要设置的比 Actor 要稍大些

3.3 Update Parameters 过程

下面我们来看看 Actor 和 Critic 的更新过程:

  • Critic 更新
index = 0
for agent_id in range(self.args.n_agents):""" 获取下一时刻所有智能体的联合行为决策u_next """if agent_id == self.agent_id:u_next.append(self.actor_target_network(o_next[agent_id]))else:u_next.append(other_agents[index].policy.actor_target_network(o_next[agent_id]))index += 1""" 下一时刻的q值以及target q值 """
q_next = critic_target_network(o_next, u_next)		# 联合观测、联合行为
target_q = r + gamma * q_next""" 当前状态的q值 """
q_value = critic_network(o, u)critic_loss = (target_q - q_value).pow(2).mean()    # TD-Error 更新法self.critic_optim.zero_grad()
critic_loss.backward()
self.critic_optim.step()

上面是 Critic 的更新过程,Critic 的更新很好理解,利用联合观测来确定联合行为(DPG中一个观测就对应一个具体的行为),输入到 Critic 网络中进行计算,最后利用 TD-Error 进行梯度更新。

  • Actor 更新
""" 重新选择联合动作中当前agent的动作,其他agent的动作不变 """
u[self.agent_id] = self.actor_network(o[self.agent_id])
actor_loss = - self.critic_network(o, u).mean()""" 网络更新 """
self.actor_optim.zero_grad()
actor_loss.backward()
self.actor_optim.step()

Actor 在进行更新的时候,首先把当前 Agent 的当前行为替换成了另外一个行为,再用新的联合行为去预估 Critic 的值,新的联合行为中其他 Agent 的行为是保持不变的。那么这里为什么要单独改变自身 Agent 的行为呢?这是因为 MADDPG 是一种 off-policy 的算法,我们所取的更新样本是来自 Memory Pool 中的,是以往的历史经验,但我们现在自身的 Policy 已经和之前的不一样了(已经进化过了),因此需要按照现在的 Policy 重新选择一个行为进行计算。这和 PPO 算法中的 Importance Sampling 的思想一样,PPO 是采用概率修正的方式来解决行为不一致问题,而 MADDPG 中干脆直接就舍弃历史旧行为,按照当前策略重采样一次行为来进行计算。

  • Target 网络更新

和 DDPG 一样,MADDPG 中针对 Actor 和 Critic 的 target 网络也是采用 soft update 的,具体内容参见 2.2.5 小节。




以上就是 MADDPG 的全部内容。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/290330.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在.NET 6 中如何创建和使用 HTTP 客户端 SDK

如今,基于云、微服务或物联网的应用程序通常依赖于通过网络与其他系统通信。每个服务都在自己的进程中运行,并解决一组有限的问题。服务之间的通信是基于一种轻量级的机制,通常是一个 HTTP 资源 API。从.NET 开发人员的角度来看,我…

ttl接地是高电平还是低电平_功放技术参数1——高电平

在汽车音响中的功放或者DSP再或者是DSP功放中我们都会遇到高电平信号或者低电平信号输入,我们该如何判断主机输出的到底是高电平信号还是低电平信号呢?我们可以用一个很简单的方法来鉴定,那就是主机输出能够直接驱动喇叭的为高电平信号输出&a…

MultiProcessing中主进程与子进程之间通过管道(Pipe)通信

Python 中 Multiprocessing 实现进程通信1. 如何建立主进程与子进程之间的通信管道?2. 为什么一定要将Pipe中的某些端close()?本文参考自:python 学习笔记 - Queue & Pipes,进程间通讯 1. 如何建立主进程与子进程之间的通信管道&#xf…

如何为 .NET 项目自定义强制代码样式规则

前言每个人都有自己的代码样式习惯:命名约定、大括号、空格、换行等。但是,作为一个团队来说,应该使用同样的代码样式规则。这样可以有效减少编译器的警告/建议,保证阅读代码的人员理解一致。今天我们介绍一种为单独的 .NET 项目定义代码样式…

我是如何帮助创业公司改进企业工作的

前段时间在一家创业公司实习,几十个人的团队,正处在规模逐渐扩大的阶段,但是整个公司的协作工作和日常管理却越来越麻烦,鉴于我以前对Saas和协作平台都有过一点研究,于是leader叫我去找一个“简单,好用&…

PHP单例模式(精讲)

2019独角兽企业重金招聘Python工程师标准>>> 首先我们要明确单例模式这个概念,那么什么是单例模式呢? 单例模式顾名思义,就是只有一个实例。作为对象的创建模式,单例模式确保某一个类只有一个实例,而且自行…

【QMIX】一种基于Value-Based多智能体算法

文章目录1. QMIX 解决了什么问题(Motivation)2. QMIX 怎样解决团队收益最大化问题(Method)2.1 算法大框架 —— 基于 AC 框架的 CTDE(Centralized Training Distributed Execution) 模式2.2 Agent RNN Netw…

增强型的for循环linkedlist_LinkedList的复习

先摘选一段Testpublic void test_LinkedList() { // 初始化100万数据 List list new LinkedList(1000000);// 遍历求和int sum 0;for (int i 0; i sum list.get(i); }}乍一看可能觉得没什么问题,但是这个遍历求和会非常慢。主要因为链表的数据结构…

3月更新来了!Windows 11正式版22000.556发布

面向 Windows 11 正式版用户,微软现已发布累积更新 KB5011493,更新后版本号升级至 Build 22000.556。主要变化1.微软正在改变 Windows 11 "开始"菜单中推荐模块有关 Office 文件的打开方式。如果文件被同步到 OneDrive,“开始”菜单…

[C/C++]重读《The C Programming Language》

第一次读这本书的时候是大三初,现在打算重读一遍!。 第一章 导言 1. 学习一门新程序设计语言的唯一途径就是用它来写程序。 2. 每个程序都从main函数的起点开始执行。 3. 在C语言中,所有变量必须先声明后使用。 4. C语言中的基本数据类型的大…

115怎么利用sha1下载东西_618“甩”度娘,拥抱115,体验和价格才是王道

网盘价钱​前天618,圈子里的朋友几乎都“甩”了度娘一巴掌,我才知道115搞活动,由原来500元1年的钻石会员,变成500元3年,算起来每天不到0.5元,确实比度娘实惠了很多,而且活动持续到6月底。自从发…

安装宝塔面板

安装宝塔面板: 1. 宝塔面板网站: https://www.bt.cn/ 2.安装教程 https://www.bt.cn/bbs/thread-1186-1-1.html 3.1 使用远程工具连接执行以下命令 yum install -y wget && wget -O install.sh http://download.bt.cn/install/install.sh &&…

【COMA】一种将团队回报拆分为独立回报的多智能体算法

文章目录1. COMA 解决了什么问题(Motivation)2. COMA 怎么解决独立回报分配问题(Method)2.1 核心思想 counterfactual baseline 的提出2.2 算法大框架 —— 基于 AC 框架的 CTDE(Centralized Training Distributed Exe…

C#解析Markdown文档,实现替换图片链接操作

前言又是好久没写博客了其实也不是没写,是最近在「做一个博客」,从2月21日开始,大概一个多星期的时间,疯狂刷进度,边写代码边写了一整系列的博客开发笔记,目前为止已经写了16篇了,然后上3月之后…

LoadRunner测试下载功能点脚本(方法一)

性能需求:对系统某页面中,点击下载功能做并发测试,以获取在并发下载文件的情况下系统的性能指标。 备注:页面上点击下载时的文件可以是word、excel、pdf等。 问题1:录制完下载的场景后,发现脚本里面并没有包…

海南橡胶机器人成本_「图说」海垦看点:海南橡胶联合北京理工华汇智能科技首创我国林间智能割胶机器人...

1 海垦南繁产业集团长期以来高度重视改善职工居住条件,于去年启动了海燕队保障性住房项目,项目建成后将有效解决职工住房问题。图为近日正在加紧施工的建设工地。 蒙胜国 摄2 海南橡胶联合北京理工华汇智能科技有限公司,研发出来的最新一代林…

数据挖掘在轨迹信息上的应用实验

文章目录1. 实验概览2. 数据集下载3. 数据预处理3.1 异常点去除3.2 停留点检测与环绕点检测3.3 轨迹分段4. 基于轨迹信息的数据挖掘4.1 路口检测4.1.1 地图分割与轨迹点速度计算4.2 偏好学习通常,我们将一个连续的GPS信号点序列称为一个轨迹(Trajectory&…

Avalonia跨平台入门第二十三篇之滚动字幕

在前面分享的几篇中咱已经玩耍了Popup、ListBox多选、Grid动态分、RadioButton模板、控件的拖放效果、控件的置顶和置底、控件的锁定、自定义Window样式、动画效果、Expander控件、ListBox折叠列表、聊天窗口、ListBox图片消息、窗口抖动、语音发送、语音播放、语音播放问题、玩…

oracle dba 手动创建数据实例

2019独角兽企业重金招聘Python工程师标准>>> 1.手动建库大致步骤 设置环境变量.bash_profile创建目录结构创建参数文件(位置:$ORACLE_HOME/dbs)生成密码文件执行建库脚本创建数据字典其他设置2.DBCA 脚本创建 2.1设置系统环境变量 ORACLE_HOME/app/oracle/11g/11.2.…

asp 强制转换浮点数值_C/C++中浮点数的编码存储

浮点数也称做实型数据(实数),形式上就是数学中的小数。浮点型数据有两种表达方式: 一种是用数字和小数点表示的,如123.456; 另一种是用指数方式表示,如1.2e-6 或1.2E-6(1.2*10-6)。在计算机中实数是如何存储的呢&#…