本文对habitat环境中的baseline点导航PPO的模型的构建和训练进行总结
0 训练代码
这个代码在上一篇文章出现过,再粘贴过来,如下:
import random
import numpy as np
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.config.default import get_config as get_baselines_config
import torchif __name__ == "__main__":run_type = "train"config = get_baselines_config("../habitat_baselines/config/pointnav/ppo_pointnav_example.yaml")config.defrost()config.TASK_CONFIG.DATASET.DATA_PATH="/home/yons/LK/skill_transformer-main/data/datasets/pointnav/habitat-test-scenes/v1/{split}/{split}.json.gz"config.TASK_CONFIG.DATASET.SCENES_DIR="/home/yons/LK/skill_transformer-main/data/scene_datasets"config.freeze()random.seed(config.TASK_CONFIG.SEED)np.random.seed(config.TASK_CONFIG.SEED)torch.manual_seed(config.TASK_CONFIG.SEED)if config.FORCE_TORCH_SINGLE_THREADED and torch.cuda.is_available():torch.set_num_threads(1)trainer_init = baseline_registry.get_trainer(config.TRAINER_NAME)print('trainer_init:',trainer_init)assert trainer_init is not None, f"{config.TRAINER_NAME} is not supported"trainer = trainer_init(config)if run_type == "train":trainer.train()elif run_type == "eval":trainer.eval()
1 trainer
在配置文件中要配置TRAINER_NAME,上面代码中的config.TRAINER_NAME是ppo,然后通过
trainer = trainer_init(config)
这句话,trainer就成了一个rl.ppo.ppo_trainer.PPOTrainer
对象。在......./habitat_baselines/rl/ppo/ppo_trainer.py
中定义
2 训练过程
我们看到训练过程调用了rl.ppo.ppo_trainer.PPOTrainer
中的train()
方法
trainer.train()
TODO:对train()
方法进行分析
3 模型结构定义
PPO是actor_critic结构,需要两个网络一个actor网络,一个critic网络。这两个网络可以共享参数也可以不共享参数。habitat中的ppo在特征提取阶段采用了参数共享,然后分出了两个头。
@baseline_registry.register_trainer(name="ddppo")
@baseline_registry.register_trainer(name="ppo")
class PPOTrainer(BaseRLTrainer):r"""Trainer class for PPO algorithmPaper: https://arxiv.org/abs/1707.06347."""supported_tasks = ["Nav-v0"]SHORT_ROLLOUT_THRESHOLD: float = 0.25_is_distributed: boolenvs: VectorEnvagent: PPOactor_critic: NetPolicydef __init__(self, config=None):super().__init__(config)self.actor_critic = Noneself.agent = Noneself.envs = Noneself.obs_transforms = []self._static_encoder = Falseself._encoder = Noneself._obs_space = None# Distributed if the world size would be# greater than 1self._is