在habitat中训练一个模型需要指定配置文件,(根据目前的学习)一般要指定两个yaml文件:
- 一个是训练的配置文件
- 一个是任务的配置文件
举例如下:
import random
import numpy as np
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.config.default import get_config as get_baselines_config
import torchif __name__ == "__main__":run_type = "train" #指定是训练还是评估#指定训练配置文件config = get_baselines_config("../habitat_baselines/config/pointnav/ppo_pointnav_example.yaml")#下面是在代码中对一些配置参数进行修改config.defrost()config.TASK_CONFIG.DATASET.DATA_PATH="/home/yons/LK/skill_transformer-main/data/datasets/pointnav/habitat-test-scenes/v1/{split}/{split}.json.gz"config.TASK_CONFIG.DATASET.SCENES_DIR="/home/yons/LK/skill_transformer-main/data/scene_datasets"config.freeze()random.seed(config.TASK_CONFIG.SEED)np.random.seed(config.TASK_CONFIG.SEED)torch.manual_seed(config.TASK_CONFIG.SEED)if config.FORCE_TORCH_SINGLE_THREADED and torch.cuda.is_available():torch.set_num_threads(1)trainer_init = baseline_registry.get_trainer(config.TRAINER_NAME)###config.TRAINER_NAME指定模型名字assert trainer_init is not None, f"{config.TRAINER_NAME} is not supported"trainer = trainer_init(config)if run_type == "train":trainer.train()elif run_type == "eval":trainer.eval()
上面所指定的训练文件ppo_pointnav_example.yaml中有一个配置项如下:
BASE_TASK_CONFIG_PATH: "../configs/tasks/pointnav.yaml"
从上面的代码可以看出来在代码中指定训练的配置文件,在训练配置文件中配置任务配置文件。
训练过程肯定要指定数据集(TASK_CONFIG.DATASET.DATA_PATH)(在训练配置文件中配置还是在任务配置文件中配置?目前至少看到在任务配置文件中是可以的)。
如果TASK_CONFIG.DATASET.DATA_PATH没有重新指定,会有默认值(目前知道有些默认值是从…/habitat-lab/habitat_baselines/config/default.py 中定义的)。
如果是点导航任务,需要同时指定正确的DATA_PATH和SCENES_DIR,否则会报错Could not find dataset file
具体原因见下面的代码
文件位置:.../habitat/datasets/pointnav/pointnav_dataset.py
@registry.register_dataset(name="PointNav-v1")
class PointNavDatasetV1(Dataset):r"""Class inherited from Dataset that loads Point Navigation dataset."""episodes: List[NavigationEpisode]content_scenes_path: str = "{data_path}/content/{scene}.json.gz"@staticmethoddef check_config_paths_exist(config: Config) -> bool:return os.path.exists(config.DATA_PATH.format(split=config.SPLIT)) and os.path.exists(config.SCENES_DIR)@classmethoddef get_scenes_to_load(cls, config: Config) -> List[str]:r"""Return list of scene ids for which dataset has separate files withepisodes."""dataset_dir = os.path.dirname(config.DATA_PATH.format(split=config.SPLIT))if not cls.check_config_paths_exist(config):raise FileNotFoundError(f"Could not find dataset file `{dataset_dir}`")