前言:
ray.rllib的算法配置方式有多种,网上的不同教程各不相同,有的互不兼容,本文汇总罗列了多种算法配置方式,给出推荐,并在最后给出可运行代码。
四种配置方式
方法1
import os
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print## 配置算法
config = PPOConfig()\.rollouts(num_rollout_workers = 2)\.resources(num_gpus=0)\.environment(env="CartPole-v1")
algo = config.build()
缺点:不能在每行配置后面添加注释, 否则报错。
方法2
import os
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print## 配置算法
algo = (PPOConfig().rollouts(num_rollout_workers=1) ## 注释.resources(num_gpus=0).environment(env="CartPole-v1").build()
)
用"()"把配置过程括起来,每行后面可以添加注释,不报错。官方教程使用的该种方式。
方式3:推荐
import os
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print## 配置算法2
storage_path = "F:/codes/RLlib_study/ray_results/build_method_3"
config = PPOConfig()
config = config.rollouts(num_rollout_workers=2)
config = config.resources(num_gpus=0,num_cpus_per_worker=1,num_gpus_per_worker=0)
config = config.environment(env="CartPole-v1",env_config={})
config.output = storage_path ## 设置过程文件的存储路径
algo = config.build()
优点:每一行是一个完整的命令, 后面可以添加注释,可以直接给config类的成员变量赋值。比如上面代码示例中的:config.output = storage_path , 直接配置存储路径,而不用去寻找output变量属于哪一个PPOConfig子模块。
方式4:
import os
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_printstorage_path = "F:/codes/RLlib_study/ray_results/build_method_4"
os.makedirs(storage_path, exist_ok=True)
config = {"env":"CartPole-v1","env_config":{}, ## 用于传递给env的信息"frame_work":"torch","num_gpus":0,"num_workers":2,"num_cpus_per_worker":1,"num_envs_per_worker":1,"num_gpus_per_worker":0,"lr":0.001,"model":{"fcnet_hiddens":[256,256,64],"fcnet_activation":"tanh","custom_model_config":{},"custom_model":None},"output":storage_path
}
algo = PPO(config=config) ## 构建算法
这种方式在ray1.4版本之前使用较多,是唯一的配置方式。随着ray的更新迭代,用class封装了configDict, 即上面的方法1,方法2,方法3所用的方式。用 PPOConfig 进行配置后,最终也是转成方法4中的字典传递给算法使用, 但是相比方法4的字典, 方法1、2、3可以在编程时有语法提示,告诉你有哪几个成员变量或成员函数可以用于设计config。
现在仍旧有很多人用方法4配置rllib算法,我认为这是从老版本传递下来的一种习惯,新上手的人建议使用 AlgorithmConfig的方式配置算法。
汇总代码:
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print
import os ## 配置算法1
# config = PPOConfig()\
# .rollouts(num_rollout_workers = 2)\
# .resources(num_gpus=0)\
# .environment(env="CartPole-v1")
# algo = config.build()# ## 配置算法2
# algo = (
# PPOConfig()
# .rollouts(num_rollout_workers=1)
# .resources(num_gpus=0)
# .environment(env="CartPole-v1")
# .build()
# )# ## 配置算法3
# storage_path = "F:/codes/RLlib_study/ray_results/build_method_4"
# os.makedirs(storage_path, exist_ok=True)
# config = PPOConfig()
# config = config.rollouts(num_rollout_workers=1)
# config = config.resources(num_gpus=0)
# config = config.environment(env="CartPole-v1")
# config.output = storage_path
# algo = config.build()## 配置算法 4
storage_path = "F:/codes/RLlib_study/ray_results/build_method_4"
os.makedirs(storage_path, exist_ok=True)
config = {"env":"CartPole-v1","env_config":{}, ## 用于传递给env的信息"frame_work":"torch","num_gpus":0,"num_workers":2,"num_cpus_per_worker":1,"num_envs_per_worker":1,"num_gpus_per_worker":0,"lr":0.001,"model":{"fcnet_hiddens":[256,256,64],"fcnet_activation":"tanh","custom_model_config":{},"custom_model":None},"output":storage_path
}
algo = PPO(config=config) ## 构建算法## 训练模型. 每个 iter 里重复执行多次 episode. 直到满足条件, 比如新增采样量达到一定体量。
for i in range(2):result = algo.train()print(pretty_print(result))## 保存模型
checkpoint_dir = algo.save().checkpoint.path
## algo.save()用于实现存储checkpoint, 后面跟着的.checkpoint.path用于返回存储路径
print(f"Checkpoint saved in directory {checkpoint_dir}")