基于PPO的强化学习超级马里奥自动通关

目录

一、环境准备

二、训练思路

1.训练初期:

2.思路整理及改进:

思路一:

思路二:

思路三:

思路四:

3.训练效果:

三、结果分析

四、完整代码

训练代码:

测试代码:


本文将基于强化学习中的PPO算法训练一个自动玩超级马里奥的智能体,用于强化学习的项目实践

一、环境准备

所需环境如下:

pip install nes-py
pip install gym-super-mario-bros
pip install setuptools==65.5.0 "wheel<0.40.0"
pip install gym==0.21.0
pip install stable-baselines3【extra】==1.6.0
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116

注意: 在环境配置方面,nes-py库安装的先决条件是 安装Microsoft Visual C++,其下载地址为:Microsoft C++ Build Tools - Visual Studio

在安装Microsoft Visual C++时需选择桌面开发:

二、训练思路

1.训练初期:

使用了最简单的训练框架,并选择PPO算法中较简单的的CnnPolicy网络(可以尝试MlpPolicy和MultiInputPolicy网络)以及马里奥操控中的SIMPLE_MOVEMENT操作模块:

自然,效果是不尽人意的,马里奥在所选关卡的第三根水管处(即最高的那个水管)不断尝试跳跃,直至时间耗尽也未能通过。

2.思路整理及改进:

思路一:

        既然训练效果不佳,是否跟训练轮数有关?固将总训练轮数增加至3000000,并尝试训练。跑出来的模型有所改进,马里奥在成功越过所有水管后,遇到了新的难题——越过两个断崖。至此,无论如何增加轮数,马里奥似乎到了一个瓶颈,固继续进行修改。

思路二:

        在增加训练轮数的基础上,选择对关卡的环境图像进行预处理——使用GrayScaleObservation转换为灰度观察,并保留通道维度。同时,我们对训练参数进行调整:

        尝试训练后,能够得到一个不稳定越过断崖的新模型,但对断崖之后的环境似乎有些陌生,陷入了前半段关卡的“局部最优解”。

思路三:

        由于之前的训练过程中使用了较小的学习率(1e-9),进而使得马里奥在关卡中陷入了局部最优,所以选择对学习率进行微调,使其在最开始的训练阶段使用较大的学习率,在后期减小学习率,从而达到先快速探索参数空间并加速收敛,再提高模型的稳定性和收敛精度。

至此,训练出来的测试模型,奖励反馈有所增长,但实际测试效果与调整前相差不多。

思路四:

        在上述尝试无明显效果后,猜测效果的好坏是否与马里奥的奖励机制有关,固在查阅奖励部分代码后,对“抵达终点”的奖励予以提高,希望对效果有所改善。

然结果并没有明显改观,更换调整方向。分别尝试马里奥的三套运动方式

经过对比,complex_movement的效果远超另外两套,且在前面思路的改动下模型质量有显著提升,固整理上述调整方案,进行底模训练。

3.训练效果:

        以奖励折扣率gamma = 0.9、gae_lambda = 0.9、clip_range = 0.2、步长n_steps = 7168,并用1e-3作为开始训练的学习率,并在训练过程中使其动态地在1e-5,1e-7中调整,修改抵达终点的奖励反馈,同时设置训练轮数为4000000,训练动作组为complex_movement进行训练。得到基础奖励回报为1520的底模,并将其继续用于迁移学习,得到2300的新模型。在实际测试后发现,模型确有改观,固继续将新模型用于训练,最终得到3200的最终模型,其能顺利到达终点并进入关卡的下一阶段。

三、结果分析

        与之前的训练经验相比,使用复杂的动作组未必比简单的动作组训练出的效果差,学习率的调整也是必要的,先用较大学习率打好基础,再有小学习率继续细化模型。同时,要给足够的训练轮数(足够的训练时间)。若是能够把奖励机制更进一步细化增加奖励细节,对其的训练是会更有帮助的。

四、完整代码

训练代码:

from nes_py.wrappers import JoypadSpace
import time
import os
import numpy as np
from datetime import datetime
from matplotlib import pyplot as plt
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT, RIGHT_ONLY
from gym.wrappers import GrayScaleObservation
from gym import Wrapper
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3 import PPO
from stable_baselines3.common.results_plotter import load_results, ts2xy
from stable_baselines3.common.callbacks import BaseCallback# 定义自定义奖励包装器
class CustomRewardWrapper(Wrapper):def __init__(self, env):super(CustomRewardWrapper, self).__init__(env)self.curr_score = 0def step(self, action):state, reward, done, info = self.env.step(action)# 自定义的奖励reward += (info["score"] - self.curr_score) / 40.self.curr_score = info["score"]if done:if info["flag_get"]:reward += 50else:reward -= 50return state, reward / 10., done, infoclass SaveOnBestTrainingRewardCallback(BaseCallback):"""Callback for saving a model (the check is done every ``check_freq`` steps)based on the training reward (in practice, we recommend using ``EvalCallback``).:param check_freq: (int):param log_dir: (str) Path to the folder where the model will be saved.It must contains the file created by the ``Monitor`` wrapper.:param verbose: (int)"""def __init__(self, check_freq, save_model_dir, verbose=1):super(SaveOnBestTrainingRewardCallback, self).__init__(verbose)self.check_freq = check_freqself.save_path = os.path.join(save_model_dir, './')self.best_model_subdir = os.path.join(self.save_path, 'best_model')self.best_mean_reward = -np.infself.best_model_path = Noneself.best_score_model_path = os.path.join(self.save_path, 'pass_customs_model.zip')  # 增加通关模型路径# def _init_callback(self) -> None:def _init_callback(self):# Create folder if neededif self.save_path is not None:os.makedirs(self.save_path, exist_ok=True)# def _on_step(self) -> bool:def _on_step(self):if self.n_calls % self.check_freq == 0:print('self.n_calls: ', self.n_calls)model_path1 = os.path.join(self.save_path, 'model_{}'.format(self.n_calls))self.model.save(model_path1)# Save the best modelx, y = ts2xy(load_results(monitor_dir), 'timesteps')if len(x) > 0:mean_reward = np.mean(y[-self.check_freq:])if self.verbose > 0:print("Num timesteps: {}, Best mean reward: {:.2f}, Last mean reward: {:.2f}".format(self.n_calls, self.best_mean_reward, mean_reward))if mean_reward > self.best_mean_reward:if self.best_model_path is not None:try:os.remove(self.best_model_path)  # Delete the old best modelexcept OSError:passself.best_mean_reward = mean_reward# Update path for the new best modelself.best_model_path = os.path.join(self.save_path, 'best_model.zip')# Save the new best modelself.model.save(self.best_model_path)if self.verbose > 0:print("New best mean reward: {:.2f} - saving best model".format(mean_reward))# Save the best mean reward to a filereward_record_file = './Mario_model_save/model/mario_model/best_mean_reward.txt'with open(reward_record_file, 'a') as file:# 将最佳平均奖励值和时间戳一同写入文件file.write("New best mean reward: {:.2f} - Recorded at {}\n".format(mean_reward, datetime.now()))return True# 总的训练timesteps
my_total_timesteps = 4000000
# 需要改变学习率的timestep
change_lr_timestep = 2000000# 学习率调度函数
def learning_rate_schedule(progress_remaining):"""参数 progress_remaining 表示剩下的训练进度(从1开始降低到0)。通过训练进度来动态调整学习率。"""current_timestep = my_total_timesteps * (1 - progress_remaining)if current_timestep < change_lr_timestep:return 1e-3  # 1e-3elif change_lr_timestep <= current_timestep <= int(change_lr_timestep * 1.5):return 1e-5else:return 1e-7env = gym_super_mario_bros.make('SuperMarioBros-1-2-v0')
env = JoypadSpace(env, COMPLEX_MOVEMENT)  # 使用复杂的按键映射env = CustomRewardWrapper(env)  # 应用自定义奖励包装器monitor_dir = r'./Mario_model_save/monitor_log/'
os.makedirs(monitor_dir, exist_ok=True)
env = Monitor(env, monitor_dir)  # 将环境包装为监视器env = GrayScaleObservation(env, keep_dim=True)  # 转换为灰度观察,并保留通道维度
env = DummyVecEnv([lambda: env])  # 创建虚拟环境
env = VecFrameStack(env, 4, channels_order='last')  # 将最近4帧堆叠在一起best_params = {'n_steps': 7168,  # 7168'gamma': 0.9,# 'learning_rate': 1e-3,   # 1e-3, 1e-4, 1e-5'clip_range': 0.2,'gae_lambda': 0.9,
}# 更新best_params中的learning_rate参数
best_params.update({'learning_rate': learning_rate_schedule})tensorboard_log = r'./Mario_model_save/tensorboard_log/'
# 正常训练
model = PPO("CnnPolicy", env, verbose=1,tensorboard_log=tensorboard_log,**best_params)
'''
# 加载预训练模型
pretrained_model_path = r'D:\python_project\Mario\model\mario_model\pretraining_model_4.zip'
model = PPO.load(pretrained_model_path, env=env, tensorboard_log=tensorboard_log, **best_params)'''# 保存模型位置
save_model_dir = r'./Mario_model_save/model/mario_model/'
callback1 = SaveOnBestTrainingRewardCallback(10000, save_model_dir)model.learn(total_timesteps=my_total_timesteps, callback=callback1)
# model.save("mario_model")

测试代码:

from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, RIGHT_ONLY, COMPLEX_MOVEMENT
import time
from matplotlib import pyplot as plt
from gym.wrappers import GrayScaleObservation
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.vec_env import VecFrameStack
import os
from stable_baselines3 import PPOfrom stable_baselines3.common.results_plotter import load_results, ts2xy
import numpy as np
from stable_baselines3.common.callbacks import BaseCallbackenv = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, COMPLEX_MOVEMENT)monitor_dir = r'./Mario/monitor_log/'
os.makedirs(monitor_dir, exist_ok=True)
env = Monitor(env, monitor_dir)env = GrayScaleObservation(env, keep_dim=True)
env = DummyVecEnv([lambda: env])
env = VecFrameStack(env, 4, channels_order='last')save_model_dir = r'model/mario_model/pretraining_model_5.zip'
# save_model_dir = r'./Mario/model/mario_model/pretraining_model.zip'model = PPO.load(save_model_dir)obs = env.reset()
obs = obs.copy()
done = True
while True:if done:state = env.reset()action, _states = model.predict(obs)obs, rewards, done, info = env.step(action)obs = obs.copy()# time.sleep(0.01)env.render()env.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/27957.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024.ZCPC.M题 计算三角形个数

题目描述&#xff1a; 小蔡有一张三角形的格子纸&#xff0c;上面有一个大三角形。这个边长为 的大三角形&#xff0c; 被分成 个边长为 1 的小三角形(如图一所示)。现在&#xff0c;小蔡选择了一条水平边 删除&#xff08;如图二所示&#xff09;&#xff0c;请你找出图上剩余…

RestTemplate远程请求的艺术

1 简说 编程是一门艺术,追求优雅的代码就像追求优美的音乐。 很多有多年工作经验的开发者,在使用RestTemplate之前常常使用HttpClient,然而接触了RestTemplate之后,却愿意放弃多年相处的“老朋友”,转向RestTemplate。那么一定是RestTemplate有它的魅力,有它的艺术风范。…

【ARM-Linux篇】阿里云人脸识别方案

一、接入阿里云 https://vision.aliyun.com/ 点击“人脸搜索1:N” 点击"立即开通"&#xff1a; 使用阿里云APP/支付宝/钉钉扫码登录&#xff1a; 购买“人脸搜索1:N”能力&#xff0c;第一次购买&#xff0c;可以有5000次的免费使用&#xff1a; 开通完后&#xff…

【踩坑日记】I.MX6ULL裸机启动时由于编译的程序链接地址不对造成的程序没正确运行

1 现象 程序完全正确&#xff0c;但是由于程序链接的位置不对&#xff0c;导致程序没有正常运行。 2 寻找原因 对生成的bin文件进行反汇编&#xff1a; arm-linux-gnueabihf-objdump -D -m arm ledc.elf > ledc.dis查看生成的反汇编文件 发现在在链接的开始地址处&…

Ubuntu基础-VirtualBox安装增强功能

目录 零. 前言 一. 安装 1.点击安装增强功能 2.点击光盘图标 3.复制到新文件夹 4.运行命令 5.重启系统 6.成果展示 二. 打开共享 1.共享粘贴 ​编辑2.共享文件夹 三.总结 安装步骤 打开共享粘贴功能&#xff1a; 打开共享文件夹功能&#xff1a; 零. 前言 在使用…

redis未授权访问

redis数据库基本知识 redis非关系型数据库 redis未授权访问蓝队的成因和危害 漏洞的定义&#xff1a;redis未授权访问漏洞是一个由于redis服务器版本较低&#xff0c;其未设置登录密码导致的登录。 攻击者可以直接利用redis服务器的ip地址和端口完成redis服务器的远程登陆&…

为什么笔记本电脑触控板不工作?这里有你想要的答案和解决办法

序言 你的笔记本电脑触控板停止工作了吗?值得庆幸的是,这个令人沮丧的问题通常很容易解决。以下是笔记本电脑触控板问题的最常见原因和修复方法。 触控板被功能键禁用 大多数(如果不是全部的话)Windows笔记本电脑都将其中一个功能键用于禁用和启用笔记本电脑触控板。按键…

民生银行信用卡中心金融科技24届春招面经

本文介绍2024届春招中&#xff0c;中国民生银行下属信用卡中心的金融科技&#xff08;系统研发方向&#xff09; 岗位2场面试的基本情况、提问问题等。 2024年04月投递了中国民生银行下属信用卡中心的金融科技&#xff08;系统研发方向&#xff09; 岗位&#xff0c;暂时不清楚…

关于反弹shell的学习

今天学习反弹shell&#xff0c;在最近做的ctf题里面越来越多的反弹shell的操作&#xff0c;所以觉得要好好研究一下&#xff0c;毕竟是一种比较常用的操作 什么是反弹shell以及原理 反弹Shell&#xff08;也称为反向Shell&#xff09;是一种技术&#xff0c;通常用于远程访问和…

C++设计模式——Decorator装饰器模式

一&#xff0c;装饰器模式简介 装饰器模式是一种结构型设计模式&#xff0c; 它允许在不改变现有对象的情况下&#xff0c;动态地将功能添加到对象中。 装饰器模式是通过创建具有新行为的对象来实现的&#xff0c;这些对象将原始对象进行了包装。 装饰器模式遵循开放/关闭原…

element-plus 的el-scrollbar滚动条组件

el-scrollbar组件可以替换原生的滚动条&#xff0c;可以设置出现滚动条的高度&#xff0c;若无设置则根据容器自适应。 通过使用 setScrollTop 与 setScrollLeft 方法&#xff0c;可以手动控制滚动条滚动。 scroll 滚动条的滚动事件&#xff0c;会返回滚动条当前的位置。 &l…

snap nextcloud 通过不被信任的域名访问

安装向导 — Nextcloud latest 管理手册 latest 文档 find / -name config.php trusted_domains >array (0 > localhost,1 > server1.example.com,2 > 192.168.1.50,3 > [fe80::1:50], ), vim /var/snap/nextcloud/42567/nextcloud/config/config.php vim /va…

pytorch--Pooling layers

文章目录 1.torch.nn.MaxPool1d()2.torch.nn.MaxPool2d3.torch.nn.AvgPool2d()4.torch.nn.FractionalMaxPool2d()5.torch.nn.AdaptiveMaxPool2d()6.torch.nn.AdaptiveAvgPool2d() 1.torch.nn.MaxPool1d() torch.nn.MaxPool1d() 是 PyTorch 库中的一个类&#xff0c;用于在神经网…

ISP图像算法面试准备(1)

ISP图像算法面试准备 ISP图像算法面试准备(1) 文章目录 ISP图像算法面试准备前言一、ISP流程二、重点关注1. AWB必须在Demosaic之后进行。2. Gamma矫正通常在CCM之前进行 三、如何实现ISP参数自动化调试四、AE&#xff0c;即自动曝光&#xff08;Auto Exposure&#xff09;总结…

【太原理工大学】软件系统安全—分析题

OK了&#xff0c;又是毫无准备的一场仗&#xff0c;我真是ありがとうございます 凸^o^凸 根据前几年传下来的信息&#xff0c;所谓“分析”&#xff0c;就是让你根据情节自行设计&#xff0c;例如如何设计表单等&#xff0c;这类多从实验中出&#xff0c;王老师强调好好做实验一…

Mybatis框架中结果映射resultMap标签方法属性收录

Mybatis框架中结果映射resultMap标签收录 在MyBatis框架中&#xff0c;resultMap 是一种强大的机制&#xff0c;用于将数据库结果集映射到Java对象上。它允许你定义如何将查询结果中的列映射到Java对象的属性上&#xff0c;尤其是当数据库表的字段名与Java对象的属性名不一致时…

HTML静态网页成品作业(HTML+CSS)—— 明星吴磊介绍网页(5个页面)

&#x1f389;不定期分享源码&#xff0c;关注不丢失哦 文章目录 一、作品介绍二、作品演示三、代码目录四、网站代码HTML部分代码 五、源码获取 一、作品介绍 &#x1f3f7;️本套采用HTMLCSS&#xff0c;未使用Javacsript代码&#xff0c;共有5个页面。 二、作品演示 三、代…

TCP与UDP案例

udp不会做拆分整合什么的 多大就是多大

【Spine学习08】之短飘,人物头发动效制作思路

上一节说完了跑步的&#xff0c; 这节说头发发型。 基础过程总结&#xff1a; 1.创建骨骼&#xff08;头发需要在上方加一个总骨骼&#xff09; 2.创建网格&#xff08;并绑定黄线&#xff09; 3.绑定权重&#xff08;发根位置的顶点赋予更多总骨骼的权重&#xff09; 4.切换到…

Orange_Pi_AIpro运行蜂鸟RISC-V仿真

Orange_Pi_AIpro运行蜂鸟RISC-V仿真 突发奇想&#xff0c;试一试Orange Pi AIpro上运行蜂鸟RISC-V的仿真。 准备 默认已经有一个Orange Pi AIpro&#xff0c;并且对设备进行一定的初始化配置&#xff0c;可以参考上一篇博文开源硬件初识——Orange Pi AIpro&#xff08;8T&a…