【QMIX】一种基于Value-Based多智能体算法

文章目录

      • 1. QMIX 解决了什么问题(Motivation)
      • 2. QMIX 怎样解决团队收益最大化问题(Method)
        • 2.1 算法大框架 —— 基于 AC 框架的 CTDE(Centralized Training Distributed Execution) 模式
        • 2.2 Agent RNN Network
        • 2.3 Mixing Network
        • 2.4 模型更新流程
      • 3. QMIX 效果

QMIX 是一种基于 Value-Based 的多智能体强化学习算法(MARL),其基本思想来源于 Actor-Critic 与 DQN 的结合。使用中心式学习(Centralized Learning)分布式执行(Distributed Execution)的方法,利用中心式 Critic 网络接受全局状态用于指导 Actor 进行更新。QMIX 中 Critic 网络的更新方式和 DQN 相似,使用 TD-Error 进行网络自更新。除此之外,QMIX 中为 Critic 网络设立了 evaluate net 和 target net, 这和 DQN 中的设计思想完全相符。

1. QMIX 解决了什么问题(Motivation)

QMIX 是一种解决多智能体强化学习问题的算法,对于大多数多智能体强化学习问题(MARL)都面临着同样一个问题:信度分配(也叫回报分配)

这是指,当多个 Agent 在同时执行任务时,我们应该怎样合理的去评价每一个 Agent 的行为效用,举个例子:

假设我们现在正在训练一个算法模型,使用该算法模型去玩 MOBA 类游戏(DOTA 或者 LOL),算法模型需要同时操控 5 个英雄。在训练过程中遇到了这样一个情况:我方 3 个英雄迎面撞上了敌方 1 个英雄。此时,算法模型控制 1 号英雄和 2 号英雄对敌方英雄发起进攻,但却让 3 号英雄撤退。那么最终,因为 2 打 1 的局面,我方成功击败对方英雄,获得了 10 分的奖励分(Reward),那么我们该怎样为我方的这 3 个英雄进行奖励分配?

在上面案例中,我们很明显能看出,在人数占优势的情况下,算法选择让 1 号和 2 号英雄一起发起进攻是一次正确的尝试,而让 3 号英雄尝试撤退显然就不那么明智了。由于对 1 号和 2 号的正确决策,使得整个指挥策略得到了正向的奖励分(Positive Reward),但显然我们不能直接将这个正向奖励分同时应用到这 3 个英雄上。

我们希望被正确决策的英雄(1 号和 2 号)获得较高的奖励分,而被错误决策的英雄(3 号)获得负的惩罚分,即最后的期望得分可能为:1 号(8分),2 号(8分),3 号(-6分)。

三个英雄的得分总和加起来还是 10 分,只是每个英雄能够按照自己的实际情况获得对应的合理奖励分。

这就是 回报分配 的概念。

回报分配通常分为两种类型: 自下而上类型 和 自上而下类型。

  • 自上而下类型:这种类型通常指我们只能拿到一个团队的最终得分,而无法获得每一个 Agent 的独立得分,因此我们需要把团队回报(Team Reward)合理的分配给每一个独立的 Agent(Individual Reward),这个过程通常也叫 “独立回报分配”(Individual Reward Assign)。上述例子就属于这种类型,典型的代表算法为 COMA算法。

  • 自下而上类型:另外一种类型恰恰相反,指当我们只能获得每个 Agent 的独立回报(Individual)时,如何使得整个团队的团队得分(Team Reward)最大化。

QMIX 算法解决的是上述第二种类型的问题,即,在获得各 Agent 的独立回报的情况下,如何使得整个团队的团队收益最大化问题


2. QMIX 怎样解决团队收益最大化问题(Method)

2.1 算法大框架 —— 基于 AC 框架的 CTDE(Centralized Training Distributed Execution) 模式

多智能体强化学习(MARL)训练中面临的最大问题是:训练阶段和执行阶段获取的信息可能存在不对等问题。即,在训练的时候我们可以获得大量的全局信息(事实证明,只有获取足够的信息模型才能被有效训练)。

但在最终应用模型的时候,我们是无法获取到训练时那么多的全局信息的,因此,人们提出两个训练网络:一个为中心式训练网络(Critic),该网络只在训练阶段存在,获取全局信息作为输入并指导 Agent 行为控制网络(Actor)进行更新;另一个为行为控制网络(Actor),该网络也是最终被应用的网络,在训练和应用阶段都保持着相同的数据输入。

AC 算法的应用非常广泛,QMIX 在设计时同样借鉴了 AC 的 “中心式网络” 和 “分布式执行器” 的想法,整个网络包含了 Mixing Network(类比 Critic 网络)和 Agent RNN Network(类比 Actor 网络),整个网络架构图如下所示:

下面我们分别来看看 Mixing Network 和 RNN Network 的详细设计。

2.2 Agent RNN Network

QMIX 中每一个 Agent 都由 RNN 网络控制,在训练时你可以为每一个 Agent 个体都训练一个独立的 RNN 网络,同样也可以所有 Agent 复用同一个 RNN 网络,这取决于你自己的设计。

RNN 网络一共包含 3 层,输入层(MLP)→ 中间层(GRU)→ 输出层(MLP),实现代码如下:

class RNN(nn.Module):# 所有 Agent 共享同一网络, 因此 input_shape = obs_shape + n_actions + n_agents(one_hot_code)def __init__(self, input_shape, args):super().__init__()self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)     # GRUCell(input_size, hidden_size)self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions)def forward(self, obs, hidden_state):x = F.relu(self.fc1(obs))h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim)h = self.rnn(x, h_in)                # GRUCell 的输入要求(current_input, last_hidden_state)q = self.fc2(h)                      # h 是这一时刻的隐状态,用于输到下一时刻的RNN网络中去,q 是真实行为Q值输出return q, h

2.3 Mixing Network

Mixing 网络相当于 Critic 网络,同时接收 Agent RNN Network 的 Q 值和当前全局状态 sts_tst ,输出在当前状态下所有 Agent 联合行为 uuu 的行为效用值 QtotQ_{tot}Qtot

Mixing 同样使用神经网络结构,不同的是,上图中蓝色部分(中间层神经元)的权重(weights)和偏差(bias)均由右边红色的神经网络产生。即,Mixing 网络中实际包含两个神经网络,红色参数生成网络 & 蓝色推理网络。

  • 参数生成网络: 接收全局状态 sts_tst,生成蓝色网络中的神经元权重(weights)和偏差(bias)。
  • 推理网络:接收所有 Agent 的行为效用值 QQQ,并将参数生成网络生成的权重和偏差赋值到网络自身,从而推理出全局效用 QtotQ_{tot}Qtot

下图是推理网络示意图,只含有一个隐层,与隐层相连接的 weights 和 bias 均由参数生成网络生成,每一层需要的 weights 和 bias 维度如下图所示:

结合上图,我们来看看 Mixing 网络实现代码:

class QMixNet(nn.Module):def __init__(self, arglist):super().__init__()self.arglist = arglist# 因为生成的 hyper_w1 需要是一个矩阵,而 pytorch 神经网络只能输出一个向量,# 所以就先输出长度为需要的 矩阵行*矩阵列 的向量,然后再转化成矩阵# hyper_w1 网络用于输出推理网络中的第一层神经元所需的 weights,# 推理网络第一层需要 qmix_hidden * n_agents 个偏差值,因此 hyper_w1 网络输出维度为 qmix_hidden * n_agentsself.hyper_w1 = nn.Sequential(nn.Linear(arglist.state_shape, arglist.hyper_hidden_dim),nn.ReLU(),nn.Linear(arglist.hyper_hidden_dim, arglist.n_agents * arglist.qmix_hidden_dim))# hyper_w2 生成推理网络需要的从隐层到输出 Q 值的所有 weights,共 qmix_hidden 个self.hyper_w2 = nn.Sequential(nn.Linear(arglist.state_shape, arglist.hyper_hidden_dim),nn.ReLU(),nn.Linear(arglist.hyper_hidden_dim, arglist.qmix_hidden_dim))# hyper_b1 生成第一层网络对应维度的偏差 biasself.hyper_b1 = nn.Linear(arglist.state_shape, arglist.qmix_hidden_dim)# hyper_b2 生成对应从隐层到输出 Q 值层的 biasself.hyper_b2 =nn.Sequential(nn.Linear(arglist.state_shape, arglist.qmix_hidden_dim),nn.ReLU(),nn.Linear(arglist.qmix_hidden_dim, 1))def forward(self, q_values, states):  # states的shape为(episode_num, max_episode_len, state_shape)# 传入的q_values是三维的,shape为(episode_num, max_episode_len, n_agents)episode_num = q_values.size(0)q_values = q_values.view(-1, 1, self.arglist.n_agents)  # (episode_num * max_episode_len, 1, n_agents)states = states.reshape(-1, self.arglist.state_shape)  # (episode_num * max_episode_len, state_shape)w1 = torch.abs(self.hyper_w1(states))b1 = self.hyper_b1(states)w1 = w1.view(-1, self.arglist.n_agents, self.arglist.qmix_hidden_dim)b1 = b1.view(-1, 1, self.arglist.qmix_hidden_dim)hidden = F.elu(torch.bmm(q_values, w1) + b1)	# torch.bmm(a, b) 计算矩阵 a 和矩阵 b 相乘w2 = torch.abs(self.hyper_w2(states))b2 = self.hyper_b2(states)w2 = w2.view(-1, self.arglist.qmix_hidden_dim, 1)b2 = b2.view(-1, 1, 1)q_total = torch.bmm(hidden, w2) + b2q_total = q_total.view(episode_num, -1, 1)return q_total

2.4 模型更新流程

至此,我们已经了解了 QMIX 中主要网络的结构了,现在我们来看看训练过程中这些神经网络是如何进行参数更新的吧。

QMIX 的更新方式和 DQN 非常类似,设定 evaluate Net 和 target Net,并利用 TD-Error 完成参数更新:

loss=TDError=Qtot(evalutate)−(r+γQtot(target))loss = TDError = Q_{tot}(evalutate) - (r + \gamma Q_{tot}(target)) loss=TDError=Qtot(evalutate)(r+γQtot(target))

由上述公式可以看出,一共存在两个 Mixing 网络(evaluate & target),两个网络分别用于产生 Qtot(evaluate)Q_{tot}(evaluate)Qtot(evaluate)Qtot(target)Q_{tot}(target)Qtot(target),两个网络接收不同的输入:

  • eval 网络: 接收在状态 sss 下每个 Agent RNN Network 所选行为的 QQQ作为输入,输出 Qtot(evaluate)Q_{tot}(evaluate)Qtot(evaluate)
  • target 网络:接收在状态 snexts_{next}snext 下每个 Agent RNN Network 所有行为中最大的 QQQ作为输入,输出 Qtot(target)Q_{tot}(target)Qtot(target)

实现代码如下:

    def learn(self, batch):episode_num = batch['o'].shape[0]self.init_hidden(episode_num)# 把 batch 里的数据转化成 tensorfor key in batch.keys():if key == 'u':batch[key] = torch.tensor(batch[key], dtype=torch.long)else:batch[key] = torch.tensor(batch[key], dtype=torch.float32)s, s_next, u, r, avail_u, avail_u_next, terminated = batch['s'], batch['s_next'], batch['u'], \batch['r'],  batch['avail_u'], batch['avail_u_next'],\batch['terminated']# 得到每个 agent 对应的 Q 值列表q_evals, q_targets = self.get_q_values(batch)# 取出每个 agent 所选择动作的对应 Q 值q_evals = torch.gather(q_evals, dim=3, index=u).squeeze(3)# 得到target_q,取所有行为中最大的 Q 值q_targets[avail_u_next == 0.0] = - 9999999      # 如果该行为不可选,则把该行为的Q值设为极小值,保证不会被选到q_targets = q_targets.max(dim=3)[0]# qmix更新过程,evaluate网络输入的是每个agent选出来的行为的q值,target网络输入的是每个agent最大的q值,和DQN更新方式一样q_total_eval = self.eval_qmix_net(q_evals, s)q_total_target = self.target_qmix_net(q_targets, s_next)targets = r + self.arglist.gamma * q_total_target * (1 - terminated)td_error = (q_total_eval - targets.detach())# 不能直接用mean,因为还有许多经验是没用的,所以要求和再比真实的经验数,才是真正的均值loss = (masked_td_error ** 2).sum() / mask.sum()self.optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(self.eval_parameters, self.arglist.grad_norm_clip)self.optimizer.step()# 在指定周期更新 target network 的参数if train_step > 0 and train_step % self.arglist.target_update_cycle == 0:self.target_rnn.load_state_dict(self.eval_rnn.state_dict())self.target_qmix_net.load_state_dict(self.eval_qmix_net.state_dict())

3. QMIX 效果

下图是 QMIX 论文中给出的 QMIX 与其他算法之间的效果对比图:

可以看出,QMIX 相较于 IQL 有明显大幅度的提升,并且比 VDN 具有更优的效果。QMIX 实质上是 VDN 的一个改进版本,在 VDN 中直接将每个 Agent 的 QQQ 值相加得到 QtotQ_{tot}Qtot,而在 QMIX 中,利用两个神经网络,结合每个 Agent 的 QQQ 值与全局状态 sts_tst 共同推理出全局效用 QtotQ_{tot}Qtot,从结果来看确实比 VDN 在效果上有一定的提升。



QMIX 论文链接: https://arxiv.org/pdf/1803.11485.pdf
QMIX 实现代码:https://github.com/oxwhirl/smac

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/290317.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

增强型的for循环linkedlist_LinkedList的复习

先摘选一段Testpublic void test_LinkedList() { // 初始化100万数据 List list new LinkedList(1000000);// 遍历求和int sum 0;for (int i 0; i sum list.get(i); }}乍一看可能觉得没什么问题,但是这个遍历求和会非常慢。主要因为链表的数据结构…

linux之用echo输入数据到文本末尾以及用open ssl命令在证书文件里面获取公钥

1、用echo输入数据到文本末尾 我们知道清空一个文本快速的方法如下 echo "" > file 我们可以用echo输入数字到文本末尾,记住是 >> echo "hello word" >> file 2、如果用open ssl命令在证书文件里面提取公钥 证书文件内容要记…

3月更新来了!Windows 11正式版22000.556发布

面向 Windows 11 正式版用户,微软现已发布累积更新 KB5011493,更新后版本号升级至 Build 22000.556。主要变化1.微软正在改变 Windows 11 "开始"菜单中推荐模块有关 Office 文件的打开方式。如果文件被同步到 OneDrive,“开始”菜单…

[C/C++]重读《The C Programming Language》

第一次读这本书的时候是大三初,现在打算重读一遍!。 第一章 导言 1. 学习一门新程序设计语言的唯一途径就是用它来写程序。 2. 每个程序都从main函数的起点开始执行。 3. 在C语言中,所有变量必须先声明后使用。 4. C语言中的基本数据类型的大…

Max Points on a Line

Given n points on a 2D plane, find the maximum number of points that lie on the same straight line. 参考&#xff1a;http://blog.csdn.net/doc_sgl/article/details/17103427 AC的代码&#xff1a; #include<iostream> #include<map> #include<vector&g…

115怎么利用sha1下载东西_618“甩”度娘,拥抱115,体验和价格才是王道

网盘价钱​前天618&#xff0c;圈子里的朋友几乎都“甩”了度娘一巴掌&#xff0c;我才知道115搞活动&#xff0c;由原来500元1年的钻石会员&#xff0c;变成500元3年&#xff0c;算起来每天不到0.5元&#xff0c;确实比度娘实惠了很多&#xff0c;而且活动持续到6月底。自从发…

安装宝塔面板

安装宝塔面板&#xff1a; 1. 宝塔面板网站&#xff1a; https://www.bt.cn/ 2.安装教程 https://www.bt.cn/bbs/thread-1186-1-1.html 3.1 使用远程工具连接执行以下命令 yum install -y wget && wget -O install.sh http://download.bt.cn/install/install.sh &&…

C和指针之字符串编程练习1

1、问题 //编写一个程序,从标准的输入读取一些字符,并统计下各类字符所占的百分比 //控制字符、空白字符、数字、小写字母、大写字母、标点符号、不可打印的字符 2、代码实现 #include <stdio.h> #include <ctype.h>//编写一个程序,从标准的输入读取一些字符,…

【COMA】一种将团队回报拆分为独立回报的多智能体算法

文章目录1. COMA 解决了什么问题&#xff08;Motivation&#xff09;2. COMA 怎么解决独立回报分配问题&#xff08;Method&#xff09;2.1 核心思想 counterfactual baseline 的提出2.2 算法大框架 —— 基于 AC 框架的 CTDE&#xff08;Centralized Training Distributed Exe…

C#解析Markdown文档,实现替换图片链接操作

前言又是好久没写博客了其实也不是没写&#xff0c;是最近在「做一个博客」&#xff0c;从2月21日开始&#xff0c;大概一个多星期的时间&#xff0c;疯狂刷进度&#xff0c;边写代码边写了一整系列的博客开发笔记&#xff0c;目前为止已经写了16篇了&#xff0c;然后上3月之后…

LoadRunner测试下载功能点脚本(方法一)

性能需求&#xff1a;对系统某页面中&#xff0c;点击下载功能做并发测试&#xff0c;以获取在并发下载文件的情况下系统的性能指标。 备注&#xff1a;页面上点击下载时的文件可以是word、excel、pdf等。 问题1&#xff1a;录制完下载的场景后&#xff0c;发现脚本里面并没有包…

海南橡胶机器人成本_「图说」海垦看点:海南橡胶联合北京理工华汇智能科技首创我国林间智能割胶机器人...

1 海垦南繁产业集团长期以来高度重视改善职工居住条件&#xff0c;于去年启动了海燕队保障性住房项目&#xff0c;项目建成后将有效解决职工住房问题。图为近日正在加紧施工的建设工地。 蒙胜国 摄2 海南橡胶联合北京理工华汇智能科技有限公司&#xff0c;研发出来的最新一代林…

C和指针之字符串编程练习10(判断字符串是否是回文数)

1、问题 //如果参数字符串是个回文,函数就返回真,否则返回假。回文就是指一个字符串从左向右和从右向左读是一样的。函数应该忽略所有的非字母字符,而且在进行字符比较时不用区分大小写。 2、代码实现 #include <stdio.h> #include <ctype.h>//如果参数字符串是…

数据挖掘在轨迹信息上的应用实验

文章目录1. 实验概览2. 数据集下载3. 数据预处理3.1 异常点去除3.2 停留点检测与环绕点检测3.3 轨迹分段4. 基于轨迹信息的数据挖掘4.1 路口检测4.1.1 地图分割与轨迹点速度计算4.2 偏好学习通常&#xff0c;我们将一个连续的GPS信号点序列称为一个轨迹&#xff08;Trajectory&…

Python中如何把一个UTC时间转换为本地时间

需求&#xff1a; 将20141126010101格式UTC时间转换为本地时间。 在网上搜了好长时间都没有找到完美的解决方案。有的引用了第三方库&#xff0c;这就需要在现网安装第三方的软件。这个是万万不可的。因为真实环境不一定允许你随便使用root用户安装Python模块。最终找到了一个不…

Avalonia跨平台入门第二十三篇之滚动字幕

在前面分享的几篇中咱已经玩耍了Popup、ListBox多选、Grid动态分、RadioButton模板、控件的拖放效果、控件的置顶和置底、控件的锁定、自定义Window样式、动画效果、Expander控件、ListBox折叠列表、聊天窗口、ListBox图片消息、窗口抖动、语音发送、语音播放、语音播放问题、玩…

oracle dba 手动创建数据实例

2019独角兽企业重金招聘Python工程师标准>>> 1.手动建库大致步骤 设置环境变量.bash_profile创建目录结构创建参数文件(位置:$ORACLE_HOME/dbs)生成密码文件执行建库脚本创建数据字典其他设置2.DBCA 脚本创建 2.1设置系统环境变量 ORACLE_HOME/app/oracle/11g/11.2.…

解决 ubuntu 14.04.1 下一个sublime text3 3065 中国输入的问题

你看今天 sublime text3 我以前有没有3059 的 它有支持3065该。 因此&#xff0c;为了支持subl 对中国输入法的实现 &#xff0c;下面的操作步骤把我的记录供大家使用 有一个完整的教程&#xff1a; http://www.360doc.com/content/14/0329/08/13087748_364608018.shtml# 可…

C和指针之字符串实现my_strrchr(char *str, int ch)的函数

1、问题 编写一个叫my_strrchr(char *str, int ch)的函数&#xff0c;这个函数类似strchr函数&#xff0c;知识它返回的是一个指向ch字符在&#xff0c;str字符串中最后一次出现(最右边)的位置的指针 2、代码实现 #include <stdio.h> #include <string.h>/** 编写…

asp 强制转换浮点数值_C/C++中浮点数的编码存储

浮点数也称做实型数据(实数)&#xff0c;形式上就是数学中的小数。浮点型数据有两种表达方式&#xff1a; 一种是用数字和小数点表示的&#xff0c;如123.456&#xff1b; 另一种是用指数方式表示&#xff0c;如1.2e-6 或1.2E-6(1.2*10-6)。在计算机中实数是如何存储的呢&#…