【动手学强化学习】番外8-IPPO应用框架学习与复现

文章目录

  • 一、待解决问题
    • 1.1 问题描述
    • 1.2 解决方法
  • 二、方法详述
    • 2.1 必要说明
      • (1)MAPPO 与 IPPO 算法的区别在于什么地方?
      • (2)IPPO 算法应用框架主要参考来源
    • 2.2 应用步骤
      • 2.2.1 搭建基础环境
      • 2.2.2 IPPO 算法实例复现
        • (1)源码
        • (2)Combat环境补充
        • (3)代码结果
      • 2.2.3 代码框架理解
  • 三、疑问
  • 四、总结


一、待解决问题

1.1 问题描述

在Combat环境中应用了MAPPO算法,在同样环境中学习并复现IPPO算法。

1.2 解决方法

(1)搭建基础环境。
(2)IPPO 算法实例复现。
(3)代码框架理解

二、方法详述

2.1 必要说明

(1)MAPPO 与 IPPO 算法的区别在于什么地方?

源文献链接:The Surprising Effectiveness of PPO in Cooperative, Multi-Agent Games
源文献原文如下:

为清楚起见,我们将具有 集中价值函数 输入的 PPO 称为 MAPPO (Multi-Agent PPO)。
将 策略和价值函数均具有本地输入 的 PPO 称为 IPPO (Independent PPO)。

总结而言,就是critic网络不再是集中式的了。
因此,IPPO 相对于 MAPPO 可能会更加占用计算、存储资源,毕竟每个agent都会拥有各自的critic网络

(2)IPPO 算法应用框架主要参考来源

其一,《动手学强化学习》-chapter 20
其二,深度强化学习(7)多智能体强化学习IPPO、MADDPG

✅非常感谢大佬的分享!!!

2.2 应用步骤

2.2.1 搭建基础环境

这一步骤直接参考上一篇博客,【动手学强化学习】番外7-MAPPO应用框架2学习与复现

2.2.2 IPPO 算法实例复现

(1)源码

智能体对于policy的使用分为separated policyshared policy,即每个agent拥有单独的policy net,所有agent共用一个policy net,二者在源码中都能够使用,对应位置取消注释即可。
MAPPO的不同就在于,每个agent拥有单独的value net。

import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import sys
from ma_gym.envs.combat.combat import Combat# PPO算法class PolicyNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)self.fc3 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc2(F.relu(self.fc1(x))))return F.softmax(self.fc3(x), dim=1)class ValueNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim):super(ValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)self.fc3 = torch.nn.Linear(hidden_dim, 1)def forward(self, x):x = F.relu(self.fc2(F.relu(self.fc1(x))))return self.fc3(x)def compute_advantage(gamma, lmbda, td_delta):td_delta = td_delta.detach().numpy()advantage_list = []advantage = 0.0for delta in td_delta[::-1]:advantage = gamma * lmbda * advantage + deltaadvantage_list.append(advantage)advantage_list.reverse()return torch.tensor(advantage_list, dtype=torch.float)# PPO,采用截断方式
class PPO:def __init__(self, state_dim, hidden_dim, action_dim,actor_lr, critic_lr, lmbda, eps, gamma, device):self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.critic = ValueNet(state_dim, hidden_dim).to(device)self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), actor_lr)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), critic_lr)self.gamma = gammaself.lmbda = lmbdaself.eps = eps  # PPO中截断范围的参数self.device = devicedef take_action(self, state):state = torch.tensor([state], dtype=torch.float).to(self.device)probs = self.actor(state)action_dict = torch.distributions.Categorical(probs)action = action_dict.sample()return action.item()def update(self, transition_dict):states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(self.device)td_target = rewards + self.gamma * \self.critic(next_states) * (1 - dones)td_delta = td_target - self.critic(states)advantage = compute_advantage(self.gamma, self.lmbda, td_delta.cpu()).to(self.device)old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()log_probs = torch.log(self.actor(states).gather(1, actions))ratio = torch.exp(log_probs - old_log_probs)surr1 = ratio * advantagesurr2 = torch.clamp(ratio, 1 - self.eps, 1 +self.eps) * advantage  # 截断action_loss = torch.mean(-torch.min(surr1, surr2))  # PPO损失函数critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))self.actor_optimizer.zero_grad()self.critic_optimizer.zero_grad()action_loss.backward()critic_loss.backward()self.actor_optimizer.step()self.critic_optimizer.step()def show_lineplot(data, name):# 生成 x 轴的索引x = list(range(100))# 创建图形和坐标轴plt.figure(figsize=(20, 6))# 绘制折线图plt.plot(x, data, label=name,marker='o', linestyle='-', linewidth=2)# 添加标题和标签plt.title(name)plt.xlabel('Index')plt.ylabel('Value')plt.legend()# 显示图形plt.grid(True)plt.show()actor_lr = 3e-4
critic_lr = 1e-3
epochs = 10
episode_per_epoch = 1000
hidden_dim = 64
gamma = 0.99
lmbda = 0.97
eps = 0.2
team_size = 2  # 每个team里agent的数量
grid_size = (15, 15)  # 二维空间的大小
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")# 创建环境
env = Combat(grid_shape=grid_size, n_agents=team_size, n_opponents=team_size)
state_dim = env.observation_space[0].shape[0]
action_dim = env.action_space[0].n# =============================================================================
# # 创建智能体(不参数共享:separated policy)
# agent1 = PPO(
#     state_dim, hidden_dim, action_dim,
#     actor_lr, critic_lr, lmbda, eps, gamma, device
# )
# agent2 = PPO(
#     state_dim, hidden_dim, action_dim,
#     actor_lr, critic_lr, lmbda, eps, gamma, device
# )
# =============================================================================# 创建智能体(参数共享:shared policy)
agent = PPO(state_dim, hidden_dim, action_dim,actor_lr, critic_lr, lmbda, eps, gamma, device
)win_list = []for e in range(epochs):with tqdm(total=episode_per_epoch, desc='Epoch %d' % e) as pbar:for episode in range(episode_per_epoch):# Replay buffer for agent1buffer_agent1 = {'states': [],'actions': [],'next_states': [],'rewards': [],'dones': []}# Replay buffer for agent2buffer_agent2 = {'states': [],'actions': [],'next_states': [],'rewards': [],'dones': []}# 重置环境s = env.reset()terminal = Falsewhile not terminal:# 采取动作(不进行参数共享)# a1 = agent1.take_action(s[0])# a2 = agent2.take_action(s[1])# 采取动作(进行参数共享)a1 = agent.take_action(s[0])a2 = agent.take_action(s[1])next_s, r, done, info = env.step([a1, a2])buffer_agent1['states'].append(s[0])buffer_agent1['actions'].append(a1)buffer_agent1['next_states'].append(next_s[0])# 如果获胜,获得100的奖励,否则获得0.1惩罚buffer_agent1['rewards'].append(r[0] + 100 if info['win'] else r[0] - 0.1)buffer_agent1['dones'].append(False)buffer_agent2['states'].append(s[1])buffer_agent2['actions'].append(a2)buffer_agent2['next_states'].append(next_s[1])buffer_agent2['rewards'].append(r[1] + 100 if info['win'] else r[1] - 0.1)buffer_agent2['dones'].append(False)s = next_s  # 转移到下一个状态terminal = all(done)# 更新策略(不进行参数共享)# agent1.update(buffer_agent1)# agent2.update(buffer_agent2)# 更新策略(进行参数共享)agent.update(buffer_agent1)agent.update(buffer_agent2)win_list.append(1 if info['win'] else 0)if (episode + 1) % 100 == 0:pbar.set_postfix({'episode': '%d' % (episode_per_epoch * e + episode + 1),'winner prob': '%.3f' % np.mean(win_list[-100:]),'win count': '%d' % win_list[-100:].count(1)})pbar.update(1)win_array = np.array(win_list)
# 每100条轨迹取一次平均
win_array = np.mean(win_array.reshape(-1, 100), axis=1)# 创建 episode_list,每组 100 个回合的累计回合数
episode_list = np.arange(1, len(win_array) + 1) * 100
plt.plot(episode_list, win_array)
plt.xlabel('Episodes')
plt.ylabel('win rate')
plt.title('IPPO on Combat(shared policy)')
plt.show()
(2)Combat环境补充

这里还需要说明的是,由于在奖励设置过程中采用了win(获胜),但是Combat环境中step函数并没有返还该值

 buffer_agent1['rewards'].append(r[0] + 100 if info['win'] else r[0] - 0.1)...
win_array = np.array(win_list)
...

因此需要在step()函数中加入对win(获胜)的判断,与return返回值

        # 判断是否获胜win = Falseif all(self._agent_dones):if sum([v for k, v in self.opp_health.items()]) == 0:win = Trueelif sum([v for k, v in self.agent_health.items()]) == 0:win = Falseelse:win = None  # 平局# 将获胜信息添加到 info 中info = {'health': self.agent_health, 'win': win, 'opp_health': self.opp_health, 'step_count': self._step_count}return self.get_agent_obs(), rewards, self._agent_dones, info
(3)代码结果

左图为separated policy下运行结果,右图为shared policy下运行结果。
明显可以看出,separated policy有着更好的效果,但是代价就是在训练过程中会占用更多的资源。
在这里插入图片描述在这里插入图片描述

2.2.3 代码框架理解

2.2.2节源码不难看出,IPPO算法其实就是在PPO算法上改为了多agent的环境,其中policy net,value net的更新原理并没有改变,因此代码框架的理解查看PPO算法原理即可。
参考链接:【动手学强化学习】part8-PPO(Proximal Policy Optimization)近端策略优化算法

三、疑问

  1. 暂无

四、总结

  • IPPO算法相对于MAPPO算法会占用更多的资源,如果环境较为简单,可以采用该算法。如果环境比较复杂,建议先采用MAPPO算法进行训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/78397.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

驱动开发硬核特训 · Day 17:深入掌握中断机制与驱动开发中的应用实战

🎥 视频教程请关注 B 站:“嵌入式 Jerry” 一、前言 在嵌入式驱动开发中,“中断”几乎无处不在。无论是 GPIO 按键、串口通信、网络设备,还是 SoC 上的各种控制器,中断都扮演着核心触发机制的角色。对中断机制掌握程度…

通过门店销售明细表用PySpark得到每月每个门店的销冠和按月的同比环比数据

假设我在Amazon S3上有销售表的Parquet数据文件的路径,包含ID主键、门店ID、日期、销售员姓名和销售额,需要分别用PySpark的SparkSQL和Dataframe API统计出每个月所有门店和各门店销售额最高的人,不一定是一个人,以及他所在的门店…

PostgreSQL 常用日志

PostgreSQL 常用日志详解 PostgreSQL 提供了多种日志类型&#xff0c;用于监控数据库活动、排查问题和优化性能。以下是 PostgreSQL 中最常用的日志类型及其配置和使用方法。 一、主要日志类型 日志类型文件位置主要内容用途服务器日志postgresql-<日期>.log服务器运行…

MySQL 存储过程:解锁数据库编程的高效密码

目录 一、什么是存储过程?二、创建存储过程示例 1:创建一个简单的存储过程示例 2:创建带输入参数的存储过程示例 3:创建带输出参数的存储过程三、调用存储过程调用无参数存储过程调用带输入参数的存储过程调用带输出参数的存储过程四、存储过程中的流控制语句示例 1:使用 …

基于STM32的物流搬运机器人

功能&#xff1a;智能循迹、定距夹取、颜色切换、自动跟随、自动避障、声音夹取、蓝牙遥控、手柄遥控、颜色识别夹取、循迹避障、循迹定距…… 包含内容&#xff1a;完整源码、使用手册、原理图、视频演示、PPT、论文参考、其余资料 资料只私聊

pg_jieba 中文分词

os: centos 7.9.2009 pg: 14.7 pg_jieba 依赖 cppjieba、limonp pg_jieba 下载 su - postgreswget https://github.com/jaiminpan/pg_jieba/archive/refs/tags/vmaster.tar.gzunzip ./pg_jieba-master cd ~/pg_jieba-mastercppjieba、limonp 下载 su - postgrescd ~/pg_jie…

基于Python+Flask的MCP SDK响应式文档展示系统设计与实现

以下是使用Python Flask HTML实现的MCP文档展示系统&#xff1a; # app.py from flask import Flask, render_templateapp Flask(__name__)app.route(/) def index():return render_template(index.html)app.route(/installation) def installation():return render_templa…

【“星睿O6”AI PC开发套件评测】GPU矩阵指令算力,GPU带宽和NPU算力测试

【“星睿O6”AI PC开发套件评测】GPU矩阵指令算力&#xff0c;GPU带宽和NPU算力测试 安谋科技、此芯科技与瑞莎计算机联合打造了面向AI PC、边缘、机器人等不同场景的“星睿O6”开发套件 该套件异构集成了Armv9 CPU核心、Arm Immortalis™ GPU以及安谋科技“周易”NPU 开箱和…

【Go语言】RPC 使用指南(初学者版)

RPC&#xff08;Remote Procedure Call&#xff0c;远程过程调用&#xff09;是一种计算机通信协议&#xff0c;允许程序调用另一台计算机上的子程序&#xff0c;就像调用本地程序一样。Go 语言内置了 RPC 支持&#xff0c;下面我会详细介绍如何使用。 一、基本概念 在 Go 中&…

11、Refs:直接操控元素——React 19 DOM操作秘籍

一、元素操控的魔法本质 "Refs是巫师与麻瓜世界的连接通道&#xff0c;让开发者能像操控魔杖般精准控制DOM元素&#xff01;"魔杖工坊的奥利凡德先生轻抚着魔杖&#xff0c;React/Vue的refs能量在杖尖跃动。 ——以神秘事务司的量子纠缠理论为基&#xff0c;揭示DOM…

MinIO 教程:从入门到Spring Boot集成

文章目录 一. MinIO 简介1. 什么是MinIO&#xff1f;2. 应用场景 二. 文件系统存储发展史1. 服务器磁盘&#xff08;本地存储&#xff09;2. 分布式文件系统(如 HDFS、Ceph、GlusterFS)3. 对象存储&#xff08;如 MinIO、AWS S3&#xff09;4.对比总结5.选型建议6.示例方案 三.…

电竞俱乐部护航点单小程序,和平地铁俱乐部点单系统,三角洲护航小程序,暗区突围俱乐部小程序

电竞俱乐部护航点单小程序开发&#xff0c;和平地铁俱乐部点单系统&#xff0c;三角洲护航小程序&#xff0c;暗区突围俱乐部小程序开发 端口包含&#xff1a; 超管后台&#xff0c; 老板端&#xff0c;打手端&#xff0c;商家端&#xff0c;客服端&#xff0c;管事端&#x…

基于 IPMI + Kickstart + Jenkins 的 OS 自动化安装

Author&#xff1a;Arsen Date&#xff1a;2025/04/26 目录 环境要求实现步骤自定义 ISO安装 ipmitool安装 NFS定义 ks.cfg安装 HTTP编写 Pipeline 功能验证 环境要求 目标服务器支持 IPMI / Redfish 远程管理&#xff08;如 DELL iDRAC、HPE iLO、华为 iBMC&#xff09;&…

如何在SpringBoot中通过@Value注入Map和List并使用YAML配置?

在SpringBoot开发中&#xff0c;我们经常需要从配置文件中读取各种参数。对于简单的字符串或数值&#xff0c;直接使用Value注解就可以了。但当我们需要注入更复杂的数据结构&#xff0c;比如Map或者List时&#xff0c;该怎么操作呢&#xff1f;特别是使用YAML这种更人性化的配…

短信验证码安全实战:三网API+多语言适配开发指南

在短信服务中&#xff0c;创建自定义签名是发送通知、验证信息和其他类型消息的重要步骤。万维易源提供的“三网短信验证码”API为开发者和企业提供了高效、便捷的自定义签名创建服务&#xff0c;可以通过简单的接口调用提交签名给运营商审核。本文将详细介绍如何使用该API&…

RabbitMQ和Seata冲突吗?Seata与Spring中的事务管理冲突吗

1. GlobalTransactional 和 Transactional 是否冲突&#xff1f; 答&#xff1a;不冲突&#xff0c;它们可以协同工作&#xff0c;但作用域不同。 Transactional: 这是 Spring 提供的注解&#xff0c;用于管理单个数据源内的本地事务。在你当前的 register 方法中&#xff0c…

一台服务器已经有个python3.11版本了,如何手动安装 Python 3.10,两个版本共存

环境&#xff1a; debian12.8 python3.11 python3.10 问题描述&#xff1a; 一台服务器已经有个python3.11版本了&#xff0c;如何手动安装 Python 3.10&#xff0c;两个版本共存 解决方案&#xff1a; 1.下载 Python 3.10 源码&#xff1a; wget https://www.python.or…

c++中的enum变量 和 constexpr说明符

author: hjjdebug date: 2025年 04月 23日 星期三 13:40:21 CST description: c中的enum变量 和 constexpr说明符 文章目录 1.Q:enum 类型变量可以有,--操作吗&#xff1f;1.1补充: c/c中enum的另一个细微差别. 2.Q: constexpr 修饰的函数,要求传入的参数必需是常量吗&#xff…

postman工具

postman工具 进入postman官网 www.postman.com/downloads/ https://www.postman.com/downloads/ https://www.postman.com/postman/published-postman-templates/documentation/ae2ja6x/postman-echo?ctxdocumentation Postman Echo is a service you can use to test your …

Spring和Spring Boot集成MyBatis的完整对比示例,包含从项目创建到测试的全流程代码

以下是Spring和Spring Boot集成MyBatis的完整对比示例&#xff0c;包含从项目创建到测试的全流程代码&#xff1a; 一、Spring集成MyBatis示例 1. 项目结构 spring-mybatis-demo/ ├── src/ │ ├── main/ │ │ ├── java/ │ │ │ └── com.example/…