【MARL】MADDPG + attention 实现(+论文解读)

文章目录

  • 前言
  • 注意力机制
  • 论文里的attention
    • 回顾知识-MADDPG
    • 讲解
      • 1.Q的定义
      • 2.Q的恒等式
      • 3.论文里的attention
      • 4.好处
  • 实现 和 修改
    • 结果展示
    • 原论文代码 翻改版
    • 修改后
    • 原maddpg代码


前言

导师让在MADDPG上加一个注意力机制,试了很多种,下面的参考的论文的效果最好,先把其思路记录下来。
之后有时间再试试自注意力机制。

参考论文:
Modelling the Dynamic Joint Policy of Teammates with Attention Multi-agent DDPG
论文代码:github

注意力机制

注意力机制是什么?最初是在NLP领域中为了解决lstm中序列的前后文本因为位置跨度大而不能理解正确输出后文的问题。
著名文章:Attention is all your need
讲的比较好的文章:
1.Transformer:注意力机制(attention)和自注意力机制(self-attention)的学习总结
2.详解Transformer中Self-Attention以及Multi-Head Attention

我这里的个人理解:

注意力机制:Attention(Q, K, V) = softmax(Q*K^T/sqrt(d_k)) * V (以缩放点积注意力机制举例)


Q :查询(自主性提示,有主观意识成分) K : 键(非自主性提示,客观存在的特征) V: 值 (K中实际信息,客观存在)
针对两条序列,Q 和 K (假设Q:5x1x6 ,K:5x3x6 ,V:5x3x6)
注:batch_size x len x hidden_dim
具象理解:其中Q中的1可以具象化为我要找的一件东西(这个东西有6个特征),K中的3为我已有的三件物品的线索(也有6个特征),我拿我心中这个东西的特征和线索特征去匹配。
5可以表示为我找了5次东西。


注意力分数 (打分器 ) attention_score:Q*K^T/sqrt(d_k) 对Q和K做一个加权求和,输出一个值,这个值可以说是Q和K的相似程度,两者相似程度越高,这个值就越大,并且其他不相似的元素也对其有贡献(即软注意力,硬注意力是0,1 关系)。
注:这里使用了缩放点积的方法:缩放:sqrt(d_k) (防止后续softmax过大以丧失梯度)点积:Q*K^T
关键疑问:为什么这里明明是矩阵乘法(行列元素对应相乘并相加:实际就是加权求和),确说成是点积(对应元素相乘)?
回答:矩阵乘法当的第一步是行列元素相乘,可以看作是对应行元素与对应列元素的点积。
此时:得到5x1x3
具象理解:其中1x3 可以看作我得到了这3个线索的相似度分数,假设分数[3,4,5]。


注意力权重 attention_weight :softmax(Q*K^T/sqrt(d_k)) 对分数进行一个softmax函数,实际上就是把各个分数转换成(0,1)之间的概率分布,且其概率分布和为1。
此时:得到5x1x3
具象理解:其中1x3为输出的是我每个线索我要给予多少关注程度。假设是[0.2,0.3,0.5]。


注意力值 atten_value :softmax(Q*K^T/sqrt(d_k)) * V 将关注度与V还是进行一个加权求和(矩阵乘法)
此时:5x1x6
具象理解:我第一个线索中对应的东西的6个特征均给予0.2程度的关心,第2个线索对应的东西的6个特征均给予0.3程度的关心,第3个是均给予0.5程度的关心。最后找到了这个‘东西’。
最后找到的东西和原来的东西不一样但相似度高,比方说:我想找青苹果,最后找到了红苹果(即本质一样)。
在翻译领域上的具体实现是:翻译:apple 为 苹果。


总结
输入:Q,K序列,返回:在不同K‘地址’的V‘物品’里‘合成的像Q的’物品’
该模型可以动态地决定在不同位置上分配多少注意力,从而关注更需要关注的特征。
K和Q的匹配过程决定了“在哪里看”,而V则决定了“看到什么”。


补充:K-head注意力机制:只是在注意力机制基础上,用到的hidden层变成了K个
比如:原来Q:5x1x6 -> Kx5x1x6 同理K,V也是。
最后输出前再堆叠(stack)一下->5x1x (6xK)
其他常见的做法:未来避免隐层变多造成的计算过慢,将原来的hidden切成(view)K个,比如说三个,5x1x6 -> 5x1x(2x3) 。
好处是让注意力不只是注意一个敌方,注意多个地方。


注:由于注意力机制的第一步和最后一步其实都是矩阵相乘(加权求和),也就是matmul()或者mul().sum(-1),所以会见到这两种代码的排列组合,不要怀疑,这两种写法都是对的。

自注意力机制:(施工中)
简单理解:
与注意力机制的区别在于:只输入一条序列,输入的Q,K一样,关注自己相似的地方。
具体:The animal didn’t cross the street,because it was too tired。这里animal 和 it 是同一个。

论文里的attention

参考论文:
Modelling the Dynamic Joint Policy of Teammates with Attention Multi-agent DDPG

回顾知识-MADDPG

1.MADDPG采用了集中式训练,分布式执行(centralized training with decentralized execution,CTDE)的框架 集中式含义:在训练时使用所有智能体的状态和动作集合,并不是指训练出的critic网络一样。
疑问:为何训练出的critic网络不一样? 答:每个智能体的奖励函数不一样,或者有可能done也不一样。
分布式含义:在执行时只使用当前智能体的状态,所以训练出的actor肯定不一样的。

2.采用Actor-Critic框架,如何训练critic和actor?
critic网络:用深度神经网络拟合动作价值函数。
与DDPG一样,使用TD(0)算法来迭代实现贝尔曼状态动作价值函数。
在这里插入图片描述
这里的TD目标实际上代替了总回报值Gt。
于是自然推出损失函数为当前动作价值和TD目标的均方差->以实现估计状态动作价值函数
在这里插入图片描述
具体理解:拟合的Q(s,a)为:当前长期回报总值,y为当前奖励值和下一步的未来长期回报值。
本质上是对未来回报的预测,告诉智能体在某个状态下采取某个动作的潜在好处。
TD目标值并不总是大于当前Q值估计,大于时,则说明某个状态好,小于时,则说明某个状态差。
TD目标的计算是为了提供一个更接近真实长期回报的估计,目标是正确认识到可以达到的Q值。
具象理解:相当于是训练一个评论家,告诉你,我最大可以拿到多少奖励。


actor网络:用深度神经网络拟合动作值/动作概率分布/动作均值和方差。
在这里插入图片描述
一般来说:即使用当前策略下的动作(替换经验池中抽取当前智能体的动作为当前状态下的动作),损失函数为-Q(s,a)->加上负号以使用梯度下降。实际实现Q值最大化,即回报值最大化。

具象理解:训练一个玩家,使奖励达到最大。

其他:使用A-C算法,取代了传统DQN算法中显示使用max a′Q(s t+1,a ′) 选取最大Q值的策略,而是学习了一个策略。
注:先更新critic,后更新actor有助于学习到更好的策略

讲解

没按文章顺序解读,按自己的理解解读

1.Q的定义

和实际上的attention 不一样,论文里并不是直接加进去,而是巧妙利用了注意力机制里的一些特性,重新定义了Q函数,达到了神奇的效果。

MADDPG论文中定义的Q 为
Q = Q i u ( s , a ∣ a i = u i ( o i ) ) Q = Q_i^u(s,a|a_i=u_i(o_i)) Q=Qiu(s,aai=ui(oi))
其中u为actor的策略,s为所有智能体的状态,a为所有智能体的动作,u_i为当前智能体的策略。
这里的Q即在更新actor时要最大化的Q。
更新critic时的Q为 Q = Q i u ( s , a ) Q = Q_i^u(s,a) Q=Qiu(s,a) ,即没有后续的条件( a i = u i ( o i ) a_i=u_i(o_i) ai=ui(oi))。

论文里定义的Q,将动作价值函数Q(s,a)定义为:(更新critic时的Q)
Q = Q i u i ∣ u − i ( s , a i ) Q = Q_i^{u_i|u_{-i}}(s,a_i) Q=Qiuiui(s,ai)
其中 u i u_i ui为当前actor的策略, u − i u_{-i} ui表示其他智能体策略。
更新actor时和上述MADDPG一样,使用 a i = u i ( o i ) a_i=u_i(o_i) ai=ui(oi)

论文是这样阐述的:
在这里插入图片描述
回想一下智能体的环境,都是被a(所有智能体的动作)来影响的,
那么,从智能体i的角度来看,即就是当前智能体的角度来看,我在s(所有智能体的状态)下做出自己动作(a_i)的结果,取决于其他智能体的动作。

论文作者意思就是说,这个环境都是被动作影响的,那么我的动作是在其他智能体的动作影响下的环境下做出的动作,那么我的动作实际上取决于其他智能体的动作。

因此将Q定义为如上形式。

2.Q的恒等式

此时,我们的目标是也就是最大化这个定义的Q值,也就是论文(下图)这个argmax的形式。
在这里插入图片描述
从数学上,因为在其他智能体采取其他动作的条件下 ∣ u − i |u_{-i} ui,意味着我们要考虑所有可能的动作组合的概率分布,故可以显式的写成上述(6)、(7)式。
Σ a ⃗ − i ∈ A ⃗ − i [ π ⃗ − i ( a ⃗ − i ∣ s ) Q i π i ( s , a i , a ⃗ − i ) ] \Sigma_{\vec{a}_{-i}\in\vec{A}_{-i}}[\vec{\pi}_{-i}(\vec{a}_{-i}|s)Q_{i}^{\pi_{i}}(s,a_{i},\vec{a}_{-i})] Σa iA i[π i(a is)Qiπi(s,ai,a i)]
A为动作空间。

由于要估计每个其他智能体的动作(即实现 Σ a ⃗ − i ∈ A ⃗ − i \Sigma_{\vec{a}_{-i}\in\vec{A}_{-i}} Σa iA i) 论文作者运用了K-head模块来估计。

即: Σ a ⃗ − i ∈ A ⃗ − i [ Q i π i ( s , a i , a ⃗ − i ) ] \Sigma_{\vec{a}_{-i}\in\vec{A}_{-i}}[Q_{i}^{\pi_{i}}(s,a_{i},\vec{a}_{-i})] Σa iA i[Qiπi(s,ai,a i)]
可以写成(约等于)
∑ k = 1 K Q i k ( s , a i ∣ a ⃗ − i ; w i ) \sum_{k=1}^KQ_i^k(s,a_i|\vec{a}_{-i};w_i) k=1KQik(s,aia i;wi)
wi为critic网络的参数。

作者这里说
这里由于在生成 Q i k ( s , a i ∣ a ⃗ − i ; w i ) Q_i^k(s,a_i|\vec{a}_{-i};w_i) Qik(s,aia i;wi)时输入只有s和a_i,而a_-i,用一个额外的隐层h来实现,所以并没有写成 Q i k ( s , a i , a ⃗ − i ; w i ) Q_i^k(s,a_i,\vec{a}_{-i};w_i) Qik(s,ai,a i;wi)

然而我查看了代码,感觉写成后者也没问题,因为在生成时,两者都输入了,且分别用了H和h的隐层。不过我支持作者这样的写法,原因稍后说。

而关于 Σ a ⃗ − i ∈ A ⃗ − i [ π ⃗ − i ( a ⃗ − i ∣ s ) ] \Sigma_{\vec{a}_{-i}\in\vec{A}_{-i}}[\vec{\pi}_{-i}(\vec{a}_{-i}|s)] Σa iA i[π i(a is)]的估计
则是作者富有想象力的体现了。

在这里插入图片描述
但是这里作者并没写为什么,所以一直很难搞懂。

作者使用 W i ( w i ) W_i(w_i) Wi(wi)来近似所有其他智能体的动作概率分布,即 π ⃗ − i ( A ⃗ − i ∣ s ) \vec{\pi}_{-i}(\vec{A}_{-i}|s) π i(A is),此时问题变为如何近似 W i ( w i ) W_i(w_i) Wi(wi)。而注意力机制天然适合生成概率分布。

最后Q估计为如下的形式:

在这里插入图片描述
等式左端的Q其就是在Q的定义里描述的Q。

3.论文里的attention

回顾注意力机制中,注意力权重的部分,即找出Q和K序列的相似的权重的序列。
其概率分布和为1,且能动态调整。

假设3个同质智能体,状态维度均为12,动作维度均为1。
此时作者让Q的查询的部分为其余智能体的动作组合(2),K的键的部分为所有智能体状态和当前智能体动作(37)( Q i k ( s , a i ∣ a ⃗ − i ; w i ) Q_i^k(s,a_i|\vec{a}_{-i};w_i) Qik(s,aia i;wi))。(这里键是客观实在,最后输出时也是只用到了键的s,a_i,所以我支持作者上述写法。)

按照本文上述注意力机制的描述,此时得到了一个2x37(即 W i ( w i ) W_i(w_i) Wi(wi))。即两个1x37,每个(1)其他智能体动作有37个线索,每个线索应该给多少关注[0.11,…,0.2]。

所以此时得到的权重,正好可以看成在维度(2x37)和后面 Q i k Q_i^k Qik(37xhidden)维度相同的 其他智能体动作的概率值。
在这里插入图片描述
注:这里的K-head 和k-head注意力机制不同,这里仅仅是为了拟合多种不同的动作,所以只在键上做了K-head。

至于为什么这里可以拟合成 其他智能体动作的概率(比较难理解,作者也没解释,只是说正好可以这样生成,生成了可以动态调整。)
个人理解:
当前智能体的动作 取决于其他智能体的动作(和环境状态)(1.Q的定义上讲到)。
反过来,
所以其他智能体的动作 取决于当前智能体的动作和所有状态,
也就是说和当前智能体的动作和所有状态相关,
而其他智能体的动作的概率值本来也是和当前智能体的动作和所有状态相关。
两者均相关于同一个东西,那么就可以理解成: A~C B~C => A~B
那么注意力权重就是给出的概率分布,就可以用来近似。
再进一步解释,这里作者在传入Q和K前,分别将数据先进行一个全连接,再进行一个激活函数(hi)处理,可能就是让神经网络尽可能的学习到这个近似。

最后一步,将两者加权求和,这一步和注意力机制的最后一步竟也是一样。
会有一种让人觉得这就是简单在用注意力机制的感觉,不过这里有理有据。

还有一点与attention不一样,注意力机制最后是权重与V加权求和,这里还是和键K相乘,因为这里的K的含义不一样,这里的论文里的K是 Q i k ( s , a i ∣ a ⃗ − i ; w i ) Q_i^k(s,a_i|\vec{a}_{-i};w_i) Qik(s,aia i;wi)

在这里插入图片描述

4.好处

至此,关键创新点,解读完了,其他基本和MADDPG一致,除了代码那里在 更新critic那块还用了和MAAC一样的方法,(此论文发表比MAAC早,应该是早已有的方法)即求损失函数之和来更新critic。

此外,论文还解释了K在动作是离散空间下,不必是|A_-i|,因为只有一小部分的动作是重要的。
在这里插入图片描述

还研究了在连续空间下,K-head也是可以有效作用的。
在这里插入图片描述
最终的好处。
1.关注了其他智能体的动作来更新critic,缓解了环境非稳态的问题。
理由是在原来的算法下,不管其他队友的情况下,总有确定性的概率会导致相同的奖励和相同的下一个状态。

在这里插入图片描述
2.可以动态调节,动作的概率分布,更容易适用于不同策略(由于attention的性质),
即使其他智能体的策略,已经改变,当前的动作价值也不需要改变(因为已经学到了),提供了一个稳定的良好的Q值。
在这里插入图片描述
3.这种方法的近似,相当于以往Q= V+A的近似,比单单的全连接要好。
在这里插入图片描述
之后就是训练曲线展示,在合作导航和捕食者两个基准环境测试(maddpg论文里用的环境)下,证明了比MADDPG好。

实现 和 修改

当然,肯定是好的,因为我也用过了。
两者除了critic网络架构上的区别,其他参数均一致的情况下,且都调整为个人认为理想的网络层数的情况下,效果如下:(论文里的critic网路结构被我修改后的结果,修改前实验效果稍许不如)

结果展示

在这里插入图片描述
黄色为maddpg+attention
红色为maddpg。
可以看出黄色明显优于红色。

原论文代码:github
为方便理解代码,我将其先修改为一般的attention作为block参数共享的代码。
我这里的critic是单个单个更新的。环境是参照prttingzoo搭建的。

注意力机制 基本上是可以即插即用的,x是cat(s,a),只要把自己的network替换成如下形式就行。

一些得根据自己修改的地方
智能体数目:3
状态维度:12
动作维度:1
且这里的三个都是同样的状态维度和动作维度

论文版本的代码,
至于agent_id,agents,只是为了得到当前智能体是第几个智能体agent_id_index,
我这里是agent_id 是智能体的名字’Red-0’
agents是keys是agent_id,values是agent类的字典。

原论文代码 翻改版

class Attention2(nn.Module):def __init__(self, encoder_input_dim, decoder_input_dim, hidden_dim, head_count):super(Attention2, self).__init__()self.fc_encoder_input = nn.Linear(encoder_input_dim, hidden_dim)self.fc_encoder_heads = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(head_count)])self.fc_decoder_input = nn.Linear(decoder_input_dim, hidden_dim)def forward(self, encoder_input, decoder_input):''' encoder_input 由所有智能体的状态和当前智能体动作组成,decoder_input 由其余智能体的动作组成'''# encoder_input shape: (batch_size, input_dim)encoder_h = F.relu(self.fc_encoder_input(encoder_input))# encoder_h shape: (batch_size, hidden_dim)encoder_heads = torch.stack([F.relu(head(encoder_h)) for head in self.fc_encoder_heads], dim=0)# encoder_heads shape: (head_count, batch_size, hidden_dim)# decoder_input shape: (batch_size, input_dim)decoder_H = F.relu(self.fc_decoder_input(decoder_input))# decoder_H shape: (batch_size, hidden_dim)''' enocde_heads 用作键值对 decoder_H 用作查询 '''scores = torch.sum(torch.mul(encoder_heads, decoder_H), dim=2)# scores shape: (head_count, batch_size)attention_weights = F.softmax(scores.permute(1, 0), dim=1).unsqueeze(2)# attention_weights shape: (batch_size, head_count, 1)contextual_vector = torch.matmul(encoder_heads.permute(1, 2, 0), attention_weights).squeeze()# contextual_vector shape: (batch_size, hidden_dim)return contextual_vectorclass MLPNetworkWithAttention2(nn.Module):def __init__(self, in_dim, out_dim,hidden_dim = 128 ,head_count = 8 ):super(MLPNetworkWithAttention2, self).__init__()#self.args = args # 3为智能体个数 12为状态维度 1为动作维度 self.fc_obs = nn.Linear(12, hidden_dim) self.fc_action = nn.Linear(1, hidden_dim)self.attention_modules = Attention2(hidden_dim * (3 + 1), hidden_dim * (3 - 1),hidden_dim, head_count)  #3为智能体数量self.fc_qvalue = nn.Linear(hidden_dim, out_dim) def forward(self, x, agent_id, agents):agent_id_list = list(agents.keys())agent_id_index = agent_id_list.index(agent_id) #获取agent_id在agents中的索引 按照顺序排agent_n = len(agent_id_list) #智能体数量3 #12为state_dim #3*12=36out_obs_list = [F.relu(self.fc_obs(x[:,:12])) , F.relu(self.fc_obs(x[:,12:24])) , F.relu(self.fc_obs(x[:,24:36]))]               # out_obs_list shape: [(batch_size, hidden_dim), ...] #即 batch_size * hidden_dim * agent_countout_action_list = [F.relu(self.fc_action(x[:,36:37])) , F.relu(self.fc_action(x[:,37:38])) , F.relu(self.fc_action(x[:,38:39]))]# out_action_list shape: [(batch_size, hidden_dim), ...]encoder_input = torch.cat(out_obs_list + [out_action_list[agent_id_index]], dim=1)# encoder_input shape: (batch_size, hidden_dim * (agent_count + 1))decoder_input = torch.cat(out_action_list[:agent_id_index] + out_action_list[agent_id_index+1:], dim=1)# decoder_input shape: (batch_size, hidden_dim * (agent_count - 1))contextual_vector = self.attention_modules(encoder_input, decoder_input)# contextual_vector shape: (batch_size, hidden_dim)qvalue = self.fc_qvalue(contextual_vector)# qvalue shape: (batch_size, 1)return qvalue

修改后

我修改后的,
由于,我发现这里注意力机制中的隐层维度会随着 智能体数量的提高 而变高,可能会造成过拟合的现象,以及认为传入Q,K的数据,不需要进行relu的操作,因为在attention机制里已有一层relu,故修改如下:

## 注意力机制改2_ --Modelling the Dynamic Joint Policy of Teammates with Attention Multi-agent DDPG 论文 改版
class Attention2_(nn.Module):def __init__(self, encoder_input_dim, decoder_input_dim, hidden_dim, head_count):super(Attention2_, self).__init__()self.fc_encoder_input = nn.Linear(encoder_input_dim, hidden_dim)self.fc_encoder_heads = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(head_count)]) ##self.fc_decoder_input = nn.Linear(decoder_input_dim, hidden_dim)def forward(self, encoder_input, decoder_input):''' encoder_input 由所有智能体的状态和当前智能体动作组成,decoder_input 由其余智能体的动作组成'''# encoder_input shape: (batch_size, input_dim)encoder_h = F.relu(self.fc_encoder_input(encoder_input))# encoder_h shape: (batch_size, hidden_dim)encoder_heads = torch.stack([F.relu(head(encoder_h)) for head in self.fc_encoder_heads], dim=0)# encoder_heads shape: (head_count, batch_size, hidden_dim)# decoder_input shape: (batch_size, input_dim)decoder_H = F.relu(self.fc_decoder_input(decoder_input))# decoder_H shape: (batch_size, hidden_dim)''' enocde_heads 用作键值对 decoder_H 用作查询 '''scores = torch.sum(torch.mul(encoder_heads, decoder_H), dim=2)# scores shape: (head_count, batch_size) <- before sum (head_count, batch_size, hidden_dim) attention_weights = F.softmax(scores.permute(1, 0), dim=1).unsqueeze(2)# attention_weights shape: (batch_size, head_count, 1)contextual_vector = torch.matmul(encoder_heads.permute(1, 2, 0), attention_weights).squeeze()# contextual_vector shape: (batch_size, hidden_dim)return contextual_vectorclass MLPNetworkWithAttention2_(nn.Module):def __init__(self, in_dim, out_dim , hidden_dim = 128 ,head_count = 8 ):'''在Attention2中 hidden_dim = 128 ,head_count = 8  效果最好'''super(MLPNetworkWithAttention2_, self).__init__()'''#self.args = args # 3为智能体个数 12为状态维度 1为动作维度 self.fc_obs = nn.Linear(12, hidden_dim) self.fc_action = nn.Linear(1, hidden_dim)'''self.attention_modules = Attention2_(hidden_dim , hidden_dim ,hidden_dim, head_count) self.fc_qvalue = nn.Linear(hidden_dim, out_dim) #所有智能体的状态和当前智能体动作 维度self.fc1 = torch.nn.Linear(37, hidden_dim)#其余智能体的动作 维度self.fc2 = torch.nn.Linear(2, hidden_dim)self.fc3 = torch.nn.Linear(hidden_dim, hidden_dim)def forward(self, x,agent_id,agents):agent_id_list = list(agents.keys())agent_id_index = agent_id_list.index(agent_id) #获取agent_id在agents中的索引 按照顺序排agent_n = len(agent_id_list) #智能体数量 #12为state_dim #3*12=36'''改out_obs_list = [F.relu(self.fc_obs(x[:,:12])) , F.relu(self.fc_obs(x[:,12:24])) , F.relu(self.fc_obs(x[:,24:36]))]               # out_obs_list shape: [(batch_size, hidden_dim), ...] #即 batch_size * hidden_dim * agent_countout_action_list = [F.relu(self.fc_action(x[:,36:37])) , F.relu(self.fc_action(x[:,37:38])) , F.relu(self.fc_action(x[:,38:39]))]# out_action_list shape: [(batch_size, hidden_dim), ...]encoder_input = torch.cat(out_obs_list + [out_action_list[agent_id_index]], dim=1)# encoder_input shape: (batch_size, hidden_dim * (agent_count + 1))decoder_input = torch.cat(out_action_list[:agent_id_index] + out_action_list[agent_id_index+1:], dim=1)# decoder_input shape: (batch_size, hidden_dim * (agent_count - 1))'''# 所有智能体的动作对应列action_list = [x[:,36:37],x[:,37:38],x[:,38:39]]     encoder_input = self.fc1(torch.cat((x[:,:self.all_obs_dim],action_list[agent_id_index]),1)) #batch_size * 37 -> batch_size * hidden_dimdecoder_input = self.fc2(torch.cat((action_list[:agent_id_index]+action_list[agent_id_index+1:]),1)) #batch_size * 2 -> batch_size * hidden_dim# 要满足 encoder_input shape: (batch_size, hidden_dim) decoder_input shape: (batch_size, hidden_dim) contextual_vector = self.attention_modules(encoder_input, decoder_input)# contextual_vector shape: (batch_size, hidden_dim)t1 = F.relu(self.fc3(contextual_vector))qvalue = self.fc_qvalue(t1)# qvalue shape: (batch_size, 1)return qvalue

原maddpg代码

class MLPNetwork(nn.Module):def __init__(self, in_dim, out_dim,hidden_dim_1=256, hidden_dim_2=128,non_linear=nn.ReLU()):super(MLPNetwork, self).__init__()self.net = nn.Sequential(nn.Linear(in_dim, hidden_dim_1),non_linear,nn.Linear(hidden_dim_1, hidden_dim_2),non_linear,nn.Linear(hidden_dim_2, out_dim),).apply(self.init) #apply(self.init)是在初始化模块的权重和偏置时调用init方法@staticmethoddef init(m):"""init parameter of the module"""gain = nn.init.calculate_gain('relu') #zh-cn:计算增益if isinstance(m, nn.Linear):torch.nn.init.xavier_uniform_(m.weight, gain=gain)#这行代码使用 Xavier 均匀分布初始化方法来初始化模块的权重(m.weight)。Xavier 初始化方法旨在使得网络各层的激活值和梯度的方差在传播过程中保持一致,有助于加速网络的收敛。gain 参数是根据 ReLU 激活函数的特性调整的。m.bias.data.fill_(0.01) #zh-cn:这行代码使用常数 0.01 来初始化模块的偏置(m.bias)。def forward(self, x):return self.net(x)

我是把上述代码替换为上述修改版,运行代码得到的结果展示。

效果确实是有的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/50170.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++——保持原有库头文件不变的情况下,成功编译运行工程

问&#xff1a;想要保持原来库方式&#xff0c;应该怎么操作呢&#xff1f; 答&#xff1a;如果想保持原来的方式&#xff0c;则只需要将 库所在路径 tracker/detector/rknn_model_zoo/utils 加入到 工程库包含中即可。

基于jeecgboot-vue3的Flowable流程-自定义业务表单流程历史信息显示

因为这个项目license问题无法开源&#xff0c;更多技术支持与服务请加入我的知识星球。 1、对于自定义业务表单的流程历史记录信息做了调整&#xff0c;增加显示自定义业务表单 <el-tab-pane label"表单信息" name"form"><div v-if"customF…

ESP32开发进阶:OLED屏幕显示旋转的3D模型

一、硬件接线 我选择的是最常见的一块板子&#xff1a;ESP-WROOM-32&#xff0c;硬件接线如下&#xff1a; 21 - SDA 22 - SCL 二、Arduino端代码 我们使用Arduino和Adafruit SSD1306库在OLED显示屏上绘制和旋转一个3D立方体。 首先&#xff0c;定义立方体顶点和…

CSS(七)——CSS 列表和CSS Table(表格)

目录 CSS 列表 列表 作为列表项标记的图像 列表 - 简写属性 移除默认设置 所有的CSS列表属性 CSS 表格 表格边框 折叠边框&#xff08;border-collapse&#xff09; 表格宽度和高度 表格文字对齐 表格填充 表格颜色 CSS 列表 CSS 列表属性作用如下&#xff1a; 设…

C#开发的全屏图片切换效果应用 - 开源研究系列文章 - 个人小作品

这天无聊&#xff0c;想到上次开发的图片显示软件《 PhotoNet看图软件 》&#xff0c;然后想到开发一个全屏图片切换效果的应用&#xff0c;类似于屏幕保护程序&#xff0c;于是就写了此博文。这个应用比较简单&#xff0c;主要是全屏切换换图片效果的问题。 1、 项目目录&…

【Vue3】watch 监视 ref 定义的数据

【Vue3】watch 监视 ref 定义的数据 背景简介开发环境开发步骤及源码参数说明 背景 随着年龄的增长&#xff0c;很多曾经烂熟于心的技术原理已被岁月摩擦得愈发模糊起来&#xff0c;技术出身的人总是很难放下一些执念&#xff0c;遂将这些知识整理成文&#xff0c;以纪念曾经努…

【C++进阶学习】第八弹——红黑树的原理与实现——探讨树形结构存储的最优解

二叉搜索树&#xff1a;【C进阶学习】第五弹——二叉搜索树——二叉树进阶及set和map的铺垫-CSDN博客 AVL树&#xff1a; ​​​​​​【C进阶学习】第七弹——AVL树——树形结构存储数据的经典模块-CSDN博客 前言&#xff1a; 在前面&#xff0c;我们已经学习了二叉搜索树和…

PCIe 6.0为什么需要14-bit tag

1.TLP中的tag是什么 在PCIe TLP&#xff08;Transaction Layer Packet&#xff09;中&#xff0c;tag是分配给特定Non-Posted Request的编号&#xff0c;协议要求CPL/CPLD中的tag 与对应non-post request TLP中的tag保持一致&#xff0c;因此Requester可以使用tag来识别CPL…

免费【2024】springboot 趵突泉景区的智慧导游小程序

博主介绍&#xff1a;✌CSDN新星计划导师、Java领域优质创作者、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和学生毕业项目实战,高校老师/讲师/同行前辈交流✌ 技术范围&#xff1a;SpringBoot、Vue、SSM、HTML、Jsp、PHP、Nodejs、Python、爬虫、数据可视化…

十、SpringBoot 统⼀功能处理【拦截器、统一数据返回格式、统一异常处理】

十、SpringBoot 统⼀功能处理 1. 拦截器【HandlerInterceptor、WebMvcConfig】1.1 拦截器快速⼊⻔⾃定义拦截器&#xff1a;实现HandlerInterceptor接⼝&#xff0c;并重写其所有⽅法注册配置拦截器&#xff1a;实现WebMvcConfigurer接⼝&#xff0c;并重写addInterceptors⽅法…

堆(c++)

堆是计算机科学中一类特殊的数据结构的统称。堆通常是一个可以被看做一棵树的数组对象。 堆总是满足下列性质&#xff1a; 堆中某个节点的值总是不大于或不小于其父节点的值&#xff1b;堆总是一棵完全二叉树。 常见的堆有二叉堆、斐波那契堆等。 堆是非线性数据结构&#…

初识C++ · map和set的使用

目录 前言&#xff1a; 1 set 2 map 前言&#xff1a; 在前面阶段&#xff0c;我们已经学习了stl里面的部分容器&#xff0c;比如vector,list,deque等&#xff0c;这些容器都被称为序列式容器&#xff0c;也就是每个值之间式没有关联的&#xff0c;那么今天介绍的容器&…

IGV.js | 载入自己下载的gtf文件

1.安装 htslib-1.20 https://www.htslib.org/doc/tabix.html J3$ cd ~/Downloads/ $ wget https://github.com/samtools/htslib/releases/download/1.20/htslib-1.20.tar.bz2 $ tar jxvf htslib-1.20.tar.bz2编译安装&#xff1a; $ cd htslib-1.20/ $ ./configure --prefix/…

vue的三大核心知识点

响应式&#xff1a; 监听data属性getter setter(包括数组)模板编译&#xff1a; 模板到render函数再到vnodevdom&#xff1a; patch(elem, vnode)和patch(vnode, newVnode) vue组件初次渲染过程 解析模板为render函数&#xff08;或在开发环境已完成&#xff0c;vue-loader&a…

WIX Toolset 3.11 对本地化的支持方案

1.准备主题文件和本地化文件 WIX Toolset种主题文件为xml文件&#xff0c;负责配置控件的布局&#xff0c; 本地化文件为wxl文件&#xff0c;负责配置待加载的字符串&#xff0c;主题文件根据ID加载需要显示的文字内容。考虑到英文和中文字符长度大小不一&#xff0c;所以这里…

渗透测试——prime1靶场实战演练{常用工具}端口转发

文章目录 概要信息搜集 概要 靶机地址&#xff1a;https://www.vulnhub.com/entry/prime-1,358 信息搜集 nmap 扫网段存活ip及端口 找到除了网关外的ip&#xff0c;开放了80端口&#xff0c;登上去看看 是一个网站&#xff0c;直接上科技扫一扫目录 python dirsearch.py -u …

尝试带你理解 - 进程地址空间,写时拷贝

序言 在上一篇文章 进程概念以及进程状态&#xff0c;我们提到了 fork 函数&#xff0c;该函数可以帮我们创建一个子进程。在使用 fork 函数时&#xff0c;我们会发现一些奇怪的现象&#xff0c;举个栗子&#xff1a; 1 #include <stdio.h>2 #include <unistd.h>3 …

跟《经济学人》学英文:2024年07月20日这期 The Russell 2000 puts in a historic performance

Why investors have fallen in love with small American firms The Russell 2000 puts in a historic performance 罗素2000指数&#xff1a; 罗素2000指数&#xff08;英语&#xff1a;Russell 2000 Index&#xff09;为罗素3000指数中收录市值最小的2000家&#xff08;排序…

学习笔记 韩顺平 零基础30天学会Java(2024.7.25)

P425 枚举类引出 举了一个例子&#xff0c;季节类创建对象&#xff0c;但是根据Java的规则&#xff0c;可以设置春夏秋冬以外的对象&#xff0c;而且可以修改&#xff0c;这样就会不符合实际&#xff0c;因此引出枚举 P426 自定义枚举类 1.构造器私有化&#xff0c;使外面没有办…

深入解析 GPT-4o mini:强大功能与创新应用

&#x1f4e2;博客主页&#xff1a;https://blog.csdn.net/2301_779549673 &#x1f4e2;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; 如有错误敬请指正&#xff01; &#x1f4e2;本文由 JohnKi 原创&#xff0c;首发于 CSDN&#x1f649; &#x1f4e2;未来很长&#…