一、分层强化学习原理
1. 分层学习核心思想
分层强化学习(Hierarchical Reinforcement Learning, HRL)通过时间抽象和任务分解解决复杂长程任务。核心思想是:
对比维度 | 传统强化学习 | 分层强化学习 |
---|---|---|
策略结构 | 单一策略直接输出动作 | 高层策略选择选项(Option) |
时间尺度 | 单一步长决策 | 高层策略决策跨度长,底层策略执行 |
适用场景 | 简单短程任务 | 复杂长程任务(如迷宫导航、机器人操控) |
2. Option-Critic 算法框架
Option-Critic 是 HRL 的代表性算法,其核心组件包括:
二、Option-Critic 实现步骤(基于 Gymnasium)
我们将以 Meta-World 机械臂多阶段任务 为例,实现 Option-Critic 算法:
-
定义选项集合:包含
reach
(接近目标)、grasp
(抓取)、move
(移动) 三个选项 -
构建策略网络:高层策略 + 选项内部策略 + 终止条件网络
-
分层交互训练:高层选择选项,底层执行多步动作
-
联合梯度更新:优化高层和底层策略
三、代码实现
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical, Normal
import gymnasium as gym
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
import time
# ================== 配置参数优化 ==================
class OptionCriticConfig:num_options = 3 # 选项数量(reach, grasp, move)option_length = 20 # 选项最大执行步长hidden_dim = 128 # 网络隐藏层维度lr_high = 1e-4 # 高层策略学习率lr_option = 3e-4 # 选项策略学习率gamma = 0.99 # 折扣因子entropy_weight = 0.01 # 熵正则化权重max_episodes = 5000 # 最大训练回合数device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ================== 高层策略网络 ==================
class HighLevelPolicy(nn.Module):def __init__(self, state_dim, num_options):super().__init__()self.net = nn.Sequential(nn.Linear(state_dim, OptionCriticConfig.hidden_dim),nn.ReLU(),nn.Linear(OptionCriticConfig.hidden_dim, num_options))def forward(self, state):return self.net(state)
# ================== 选项内部策略网络 ==================
class OptionPolicy(nn.Module):def __init__(self, state_dim, action_dim):super().__init__()self.net = nn.Sequential(nn.Linear(state_dim, OptionCriticConfig.hidden_dim),nn.ReLU(),nn.Linear(OptionCriticConfig.hidden_dim, action_dim))def forward(self, state):return self.net(state)
# ================== 终止条件网络 ==================
class TerminationNetwork(nn.Module):def __init__(self, state_dim):super().__init__()self.net = nn.Sequential(nn.Linear(state_dim, OptionCriticConfig.hidden_dim),nn.ReLU(),nn.Linear(OptionCriticConfig.hidden_dim, 1),nn.Sigmoid() # 输出终止概率)def forward(self, state):return self.net(state)
# ================== 训练系统 ==================
class OptionCriticTrainer:def __init__(self):# 初始化环境self.env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE['pick-place-v2-goal-observable']()# 处理观测空间if isinstance(self.env.observation_space, gym.spaces.Dict):self.state_dim = sum([self.env.observation_space.spaces[key].shape[0] for key in ['observation', 'desired_goal']])self.process_state = self._process_dict_stateelse:self.state_dim = self.env.observation_space.shape[0]self.process_state = lambda x: xself.action_dim = self.env.action_space.shape[0]# 初始化网络self.high_policy = HighLevelPolicy(self.state_dim, OptionCriticConfig.num_options).to(OptionCriticConfig.device)self.option_policies = nn.ModuleList([OptionPolicy(self.state_dim, self.action_dim).to(OptionCriticConfig.device)for _ in range(OptionCriticConfig.num_options)])self.termination_networks = nn.ModuleList([TerminationNetwork(self.state_dim).to(OptionCriticConfig.device)for _ in range(OptionCriticConfig.num_options)])# 优化器self.optimizer_high = optim.Adam(self.high_policy.parameters(), lr=OptionCriticConfig.lr_high)self.optimizer_option = optim.Adam(list(self.option_policies.parameters()) + list(self.termination_networks.parameters()),lr=OptionCriticConfig.lr_option)def _process_dict_state(self, state_dict):return np.concatenate([state_dict['observation'], state_dict['desired_goal']])def select_option(self, state):state = torch.FloatTensor(state).to(OptionCriticConfig.device)logits = self.high_policy(state)dist = Categorical(logits=logits)option = dist.sample()return option.item(), dist.log_prob(option)def select_action(self, state, option):state = torch.FloatTensor(state).to(OptionCriticConfig.device)action_mean = self.option_policies[option](state)dist = Normal(action_mean, torch.ones_like(action_mean)) # 假设动作空间连续action = dist.sample()log_prob = dist.log_prob(action).sum(dim=-1) # 沿最后一个维度求和得到标量return action.cpu().numpy(), log_prob # 返回标量log概率def should_terminate(self, state, current_option):state = torch.FloatTensor(state).to(OptionCriticConfig.device)terminate_prob = self.termination_networks[current_option](state)return torch.bernoulli(terminate_prob).item() == 1def train(self):for episode in range(OptionCriticConfig.max_episodes):state_dict, _ = self.env.reset()state = self.process_state(state_dict)episode_reward = 0current_option, log_prob_high = self.select_option(state)option_step = 0while True:# 执行选项内部策略action, log_prob_option = self.select_action(state, current_option)next_state_dict, reward, terminated, truncated, _ = self.env.step(action)done = terminated or truncatednext_state = self.process_state(next_state_dict)episode_reward += reward# 判断是否终止选项terminate = self.should_terminate(next_state, current_option) or (option_step >= OptionCriticConfig.option_length)# 计算梯度if terminate or done:# 计算选项价值(添加detach防止梯度传递)with torch.no_grad():next_value = self.high_policy(torch.FloatTensor(next_state).to(OptionCriticConfig.device)).max().item()termination_output = self.termination_networks[current_option](torch.FloatTensor(state).to(OptionCriticConfig.device))# 计算delta时分离终止网络的梯度delta = reward + OptionCriticConfig.gamma * next_value - termination_output.detach()
# 高层策略梯度计算loss_high = -log_prob_high * deltaself.optimizer_high.zero_grad()loss_high.backward(retain_graph=True) # 保留计算图self.optimizer_high.step()
# 选项策略梯度计算loss_option = -log_prob_option * deltaentropy = -log_prob_option * torch.exp(log_prob_option.detach())loss_option_total = loss_option + OptionCriticConfig.entropy_weight * entropyself.optimizer_option.zero_grad()loss_option_total.backward() # 此时仍可访问保留的计算图self.optimizer_option.step()# 重置选项if not done:current_option, log_prob_high = self.select_option(next_state)option_step = 0else:breakelse:option_step += 1state = next_stateif (episode + 1) % 100 == 0:print(f"Episode {episode+1} | Reward: {episode_reward:.1f}")
if __name__ == "__main__":start = time.time()start_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start))print(f"开始时间: {start_str}")print("初始化环境...")trainer = OptionCriticTrainer()trainer.train()end = time.time()end_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(end))print(f"训练完成时间: {end_str}")print(f"训练完成,耗时: {end - start:.2f}秒")
四、关键代码解析
-
高层策略选择选项
select_option
:基于当前状态选择选项,返回选项 ID 和选择概率的对数值。 -
选项内部策略执行
select_action
:根据当前选项生成动作,支持连续动作空间(使用高斯分布)。 -
终止条件判断
should_terminate
:根据终止网络输出概率判断是否终止当前选项。 -
梯度更新逻辑
高层策略:基于选项的价值差(TD Error)更新。
选项策略:结合 TD Error 和熵正则化更新。
五、训练输出示例
开始时间: 2025-03-24 08:29:46
初始化环境...
Episode 100 | Reward: 2.7
Episode 200 | Reward: 4.9
Episode 300 | Reward: 2.2
Episode 400 | Reward: 2.8
Episode 500 | Reward: 3.0
Episode 600 | Reward: 3.3
Episode 700 | Reward: 3.2
Episode 800 | Reward: 4.7
Episode 900 | Reward: 5.3
Episode 1000 | Reward: 7.5
Episode 1100 | Reward: 6.3
Episode 1200 | Reward: 3.7
Episode 1300 | Reward: 7.8
Episode 1400 | Reward: 3.8
Episode 1500 | Reward: 2.4
Episode 1600 | Reward: 2.3
Episode 1700 | Reward: 2.5
Episode 1800 | Reward: 2.7
Episode 1900 | Reward: 2.7
Episode 2000 | Reward: 3.9
Episode 2100 | Reward: 4.5
Episode 2200 | Reward: 4.1
Episode 2300 | Reward: 4.7
Episode 2400 | Reward: 4.0
Episode 2500 | Reward: 4.3
Episode 2600 | Reward: 3.8
Episode 2700 | Reward: 3.3
Episode 2800 | Reward: 4.6
Episode 2900 | Reward: 5.2
Episode 3000 | Reward: 7.7
Episode 3100 | Reward: 7.8
Episode 3200 | Reward: 3.3
Episode 3300 | Reward: 5.3
Episode 3400 | Reward: 4.5
Episode 3500 | Reward: 3.9
Episode 3600 | Reward: 4.1
Episode 3700 | Reward: 4.0
Episode 3800 | Reward: 5.2
Episode 3900 | Reward: 8.2
Episode 4000 | Reward: 2.2
Episode 4100 | Reward: 2.2
Episode 4200 | Reward: 2.2
Episode 4300 | Reward: 2.2
Episode 4400 | Reward: 6.9
Episode 4500 | Reward: 5.6
Episode 4600 | Reward: 2.0
Episode 4700 | Reward: 1.6
Episode 4800 | Reward: 1.7
Episode 4900 | Reward: 1.9
Episode 5000 | Reward: 3.1
训练完成时间: 2025-03-24 12:41:48
训练完成,耗时: 15122.31秒
在下一篇文章中,我们将探索 逆向强化学习(Inverse RL),并实现 GAIL 算法!
注意事项
-
安装依赖:
pip install metaworld gymnasium torch
-
Meta-World 需要 MuJoCo 许可证:
export MUJOCO_PY_MUJOCO_PATH=/path/to/mujoco
-
训练时间较长(推荐 GPU 加速):
CUDA_VISIBLE_DEVICES=0 python option_critic.py