基于DQN和TensorFlow的LunarLander实现(全代码)

使用深度Q网络(Deep Q-Network, DQN)来训练一个在openai-gym的LunarLander-v2环境中的强化学习agent,让小火箭成功着陆。
下面代码直接扔到jupyter notebook或CoLab上就能跑起来。

在这里插入图片描述

目录

  • 安装和导入所需的库和环境
  • Q网络搭建
  • 经验回放实现
  • DQNAgent实现
  • 训练

安装和导入所需的库和环境

安装和设置所需的库和环境,使其能够在Jupyter Notebook中运行。

!pip install gym
!apt-get install xvfb -y
!pip install pyvirtualdisplay   #用于在没有显示器的环境中创建虚拟显示
!pip install Pillow             #一个图像处理库
!pip install swig
!pip install "gym[box2d]"

创建并启动一个虚拟显示,在没有图形界面的服务器上运行强化学习环境:

from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

引入所需库:

import gym
import time
import tqdm
import numpy as np
from IPython import display as ipydisplay
from PIL import Image

创建一个LunarLander-v2环境的DQN代理:

agent = DQNAgent('LunarLander-v2')total_score, records = agent.simulate(visualize=True)
print(f'Total score {total_score:.2f}')
record_list = []
for i in tqdm.tqdm(range(100)):total_score, _ = agent.simulate(visualize=False)record_list.append(total_score)print(f'Average score in 100 episode {np.mean(record_list):.2f}')

 

Q网络搭建

import tensorflow as tfL = tf.keras.layersdef create_network_model(input_shape: np.ndarray,action_space: np.ndarray,learning_rate=0.001) -> tf.keras.Sequential:model = tf.keras.Sequential([L.Dense(512, input_shape=input_shape, activation="relu"),L.Dense(256, input_shape=input_shape, activation="relu"),L.Dense(action_space)])model.compile(loss="mse",optimizer=tf.optimizers.Adam(lr=learning_rate))return model

 

经验回放实现

经验回放是一种在深度强化学习中常用的技术,主要用于打破数据的相关性和减少过拟合。
在强化学习中,代理通常会在训练过程中与环境进行大量交互,经验回放允许代理存储这些经验,并在后续的训练中反复利用这些数据。这种机制有助于改善学习效率减少数据样本间的时间相关性,提高训练过程的稳定性。

import random
import numpy as np
from collections import namedtuple# 代表每一个样本的 namedtuple,方便存储和读取数据
Experience = namedtuple('Experience', ('state', 'action', 'reward', 'next_state', 'done'))class ReplayMemory:def __init__(self, max_size):self.max_size = max_sizeself.memory = []def append(self, state, action, reward, next_state, done):"""记录一个新的样本"""sample = Experience(state, action, reward, next_state, done)self.memory.append(sample)# 只留下最新记录的 self.max_size 个样本self.memory = self.memory[-self.max_size:]def sample(self, batch_size):"""按照给定批次大小取样"""samples = random.sample(self.memory, batch_size)batch = Experience(*zip(*samples))# 转换数据为 numpy 张量返回states = np.array(batch.state)actions = np.array(batch.action)rewards = np.array(batch.reward)states_next = np.array(batch.next_state)dones = np.array(batch.done)return states, actions, rewards, states_next, donesdef __len__(self):return len(self.memory)

 

DQNAgent实现

DQNAgent类是DQN算法的核心实现。它包含以下关键部分:
1、初始化:初始化环境、神经网络模型和经验回放缓存。
2、行为选择(choose_action):根据当前状态和ε-greedy策略选择行为。
3、经验回放(replay):从记忆中随机抽取小批量经验进行学习。
4、训练(train):进行多个episode的训练。

from IPython import display
from PIL import Image# 定义超参数
LEARNING_RATE = 0.001
GAMMA = 0.99
EPSILON_DECAY = 0.995
EPSILON_MIN = 0.01class DQNAgent:def __init__(self, env_name):self.env = gym.make(env_name)self.observation_shape = self.env.observation_space.shapeself.action_count = self.env.action_space.nself.model = create_network_model(self.observation_shape, self.action_count)self.memory = ReplayMemory(500000)self.epsilon = 1.0self.batch_size = 64def choose_action(self, state, epsilon=None):"""根据给定状态选择行为- epsilon == 0 完全使用模型选择行为- epsilon == 1 完全随机选择行为"""if epsilon is None:epsilon = self.epsilonif np.random.rand() < epsilon:return np.random.randint(self.action_count)else:q_values = self.model.predict(np.expand_dims(state, axis=0))return np.argmax(q_values[0])def replay(self):"""进行经验回放学习"""# 如果当前经验池经验数量少于批次大小,则跳过if len(self.memory) < self.batch_size:returnstates, actions, rewards, states_next, dones = self.memory.sample(self.batch_size)q_pred = self.model.predict(states)q_next = self.model.predict(states_next).max(axis=1)q_next = q_next * (1 - dones)q_update = rewards + GAMMA * q_nextindices = np.arange(self.batch_size)q_pred[[indices], [actions]] = q_updateself.model.train_on_batch(states, q_pred)def simulate(self, epsilon=None, visualize=True):records = []state = self.env.reset()is_done = Falsetotal_score = 0total_step  = 0while not is_done:action = self.choose_action(state, epsilon)state, reward, is_done, info = self.env.step(action)total_score += rewardtotal_step += 1rgb_array = self.env.render(mode='rgb_array')records.append((rgb_array, action, reward, total_score))if visualize:display.clear_output(wait=True)img = Image.fromarray(rgb_array)# 当前 Cell 中展示图片display.display(img)print(f'Action {action} Action reward {reward:.2f} | Total score {total_score:.2f} | Step {total_step}')time.sleep(0.01)self.env.close()return total_score, recordsdef train(self, episode_count: int, log_dir: str):"""训练方法,按照给定 episode 数量进行训练,并记录训练过程关键参数到 TensorBoard"""# 初始化一个 TensorBoard 记录器file_writer = tf.summary.create_file_writer(log_dir)file_writer.set_as_default()score_list = []best_avg_score = -np.inffor episode_index in range(episode_count):state = self.env.reset()score, step = 0, 0is_done = Falsewhile not is_done:# 根据状态选择一个行为action = self.choose_action(state)# 执行行为,记录行为和结果到经验池state_next, reward, is_done, info = self.env.step(action)self.memory.append(state, action, reward, state_next, is_done)score += rewardstate = state_next# 每 6 步进行一次回放训练# 此处也可以选择每一步回放训练,但会降低训练速度,这个是一个经验技巧if step % 1 == 0:self.replay()step += 1# 记录当前 Episode 的得分,计算最后 100 Episode 的平均得分score_list.append(score)avg_score = np.mean(score_list[-100:])# 记录当前 Episode 得分,epsilon 和最后 100 Episode 的平均得分到 TensorBoardtf.summary.scalar('score', data=score, step=episode_index)tf.summary.scalar('average score', data=avg_score, step=episode_index)tf.summary.scalar('epsilon', data=self.epsilon, step=episode_index)# 终端输出训练进度print(f'Episode: {episode_index} Reward: {score:03.2f} 'f'Average Reward: {avg_score:03.2f} Epsilon: {self.epsilon:.3f}')# 调整 epsilon 值,逐渐减少随机探索比例if self.epsilon > EPSILON_MIN:self.epsilon *= EPSILON_DECAY# 如果当前平均得分比之前有改善,保存模型# 确保提前创建目录 outputs/chapter_15if avg_score > best_avg_score:best_avg_score = avg_scoreself.model.save(f'outputs/chapter_15/dqn_best_{episode_index}.h5')

 

训练

# 使用 LunarLander 初始化 Agent
agent = DQNAgent('LunarLander-v2')
import glob
# 读取现在已经记录的日志数量,避免日志重复记录
tf_log_index = len(glob.glob('tf_dir/lunar_lander/run_*'))
log_dir = f'tf_dir/lunar_lander/run_{tf_log_index}'# 训练 2000 个 Episode
agent.train(20, log_dir)agent.model.summary()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/631928.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【算法Hot100系列】字母异位词分组

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学习,不断总结,共同进步,活到老学到老导航 檀越剑指大厂系列:全面总结 jav…

环境变量配置文件

1.配置文件简介 source命令 source命令通常用于保留、更改当前shell中的环境变量。 格式&#xff1a; source 配置文件 . 配置文件 环境变量配置文件简介 环境变量配置文件中主要是定义对系统的操作环境生效的系统默认环境变量&#xff0c;比如PATH、HISTSI…

(二十)Flask之上下文管理第一篇(粗糙缕一遍源码)

每篇前言&#xff1a; &#x1f3c6;&#x1f3c6;作者介绍&#xff1a;【孤寒者】—CSDN全栈领域优质创作者、HDZ核心组成员、华为云享专家Python全栈领域博主、CSDN原力计划作者 &#x1f525;&#x1f525;本文已收录于Flask框架从入门到实战专栏&#xff1a;《Flask框架从入…

Ikuai中如何添加/更换虚拟机(图文)

Ikuai配置 分区/格式化硬盘(如果已经格式化&#xff0c;无需再次格式化&#xff0c;直接传送到上传镜像) 上传镜像 ⚠️&#xff1a;如果是压缩格式&#xff0c;需要解压缩后上传&#xff0c;如这里的IMG格式。 创建虚拟机 配置虚拟机&#xff08;等待虚拟机起来后执行&#…

Vulnhub-w1r3s-editable

一、信息收集 端口扫描&#xff0c;ftp允许匿名登录&#xff0c;但是没有得到什么有用的线索 PORT STATE SERVICE VERSION 21/tcp open ftp vsftpd 2.0.8 or later | ftp-syst: | STAT: | FTP server status: | Connected to ::ffff:192.168.1.6 | …

FeatInsight: 基于 OpenMLDB 的特征平台助力高效的特征管理和编排

OpenMLDB 社区新开源了特征平台产品 - FeatInsight&#xff08;https://github.com/4paradigm/FeatInsight&#xff09;&#xff0c;是一个先进的特征存储(Feature Store)服务&#xff0c;基于 OpenMLDB 数据库实现高效的特征管理和编排功能。FeatInsight 特征平台提供简便易用…

JeecgBoot集成东方通TongRDS

TongRDS介绍 TongRDS&#xff08;简称 RDS&#xff09;是分布式内存数据缓存中间件&#xff0c;用于高性能内存数据共享与应用支持。RDS为各类应用提供高效、稳定、安全的内存数据处理能力&#xff1b;同时它支持共享内存的搭建弹性伸缩管理&#xff1b;使业务应用无需考虑各种…

Active Directory监控工具

Active Directory 是 Microsoft 为 Windows 环境实现的 LDAP 目录服务&#xff0c;它允许管理员对用户访问资源和服务实施公司范围的策略。Active Directory 通常安装在 Windows 2003 或 2000 服务器中&#xff0c;它们统称为域控制器。如果 Active Directory 出现故障&#xf…

跑通 yolov5-7.0 项目之训练自己的数据集

yolov5 一、yolov5 源码下载二、配置环境&#xff0c;跑通项目三、训练自己的数据集1、获取验证码数据2、标注图片&#xff0c;准备数据集3、开始训练自己的数据集1、train.py 训练数据集2、val.py 验证测试你的模型3、detect.py 正式用你的模型 四、遇到的报错、踩坑1、import…

电脑内存满了怎么清理内存?试试这6个方法~

内存越大&#xff0c;运行越快&#xff0c;程序之间的切换和响应也会更加流畅。但是随着时间的增加&#xff0c;还是堆积了越来越多的各种文件&#xff0c;导致内存不够用&#xff0c;下面就像大家介绍三种好用的清理内存的方法。 方法一&#xff1a;通过电脑系统自带的性能清理…

vim 编辑器如何同时注释多行以及将多行进行空格

当然可以&#xff0c;以下是我对您的文字进行润色后的版本&#xff1a; 一、场景 YAML文件对空格的要求非常严格&#xff0c;因此在修改YAML时&#xff0c;我们可能需要批量添加空格。 二、操作步骤 请注意&#xff1a;您的所有操作都将以第一行为基准。也就是说&#xff0…

OpenCV-Python(39):Meanshift和Camshift算法

目标 学习了解Meanshift 和Camshift 算法在视频中找到并跟踪目标 Meanshift 原理 Meanshift算法是一种基于密度的聚类算法&#xff0c;用于将数据点划分为不同的类别。它的原理是通过数据点的密度分布来确定聚类中心&#xff0c;然后将数据点移动到离其最近的聚类中心&#…

【代码随想录07】344.反转字符串 541. 反转字符串II 05.替换空格 151.翻转字符串里的单词 55. 右旋转字符串

目录 344. 反转字符串题目描述做题思路参考代码 541. 反转字符串 II题目描述参考代码 05. 替换数字题目描述参考代码 151. 反转字符串中的单词题目描述参考代码 55. 右旋转字符串题目描述参考代码 344. 反转字符串 题目描述 编写一个函数&#xff0c;其作用是将输入的字符串反…

C语言从入门到实战——动态内存管理

动态内存管理 前言一、 为什么要有动态内存分配二、 malloc和free2.1 malloc2.2 free 三、calloc和realloc3.1 calloc3.2 realloc 四、常见的动态内存的错误4.1 对NULL指针的解引用操作4.2 对动态开辟空间的越界访问4.3 对非动态开辟内存使用free释放4.4 使用free释放一块动态开…

用于自动驾驶最优间距选择和速度规划的多配置二次规划(MPQP) 论文阅读

论文链接&#xff1a;https://arxiv.org/pdf/2401.06305.pdf 论文题目&#xff1a;用于自动驾驶最优间距选择和速度规划的多配置二次规划&#xff08;MPQP&#xff09; 1 摘要 本文介绍了用于自动驾驶最优间距选择和速度规划的多配置二次规划&#xff08;MPQP&#xff09;。…

黑马程序员JavaWeb开发|案例:tlias智能学习辅助系统(6)解散部门

指路&#xff08;1&#xff09;&#xff08;2&#xff09;&#xff08;3&#xff09;&#xff08;4&#xff09;&#xff08;5&#xff09;&#x1f447; 黑马程序员JavaWeb开发|案例&#xff1a;tlias智能学习辅助系统&#xff08;1&#xff09;准备工作、部门管理_tlias智能…

MATLAB对话框与菜单设计实验

本文MATLAB源码&#xff0c;下载后直接打开运行即可[点击跳转下载]-附实验报告https://download.csdn.net/download/Coin_Collecter/88740733 一、实验目的 1.掌握建立控件对象的方法。 2.掌握对话框设计方法。 3.掌握菜单设计方法。 二、实验内容 建立如下图所示的菜单。菜单…

15.云原生之k8s容灾与恢复实战

云原生专栏大纲 文章目录 Velero与etcd介绍Velero与etcd备份应用场景Velero与etcd在k8s备份上的区别 Velero备份恢复流程备份工作流程Velero备份时&#xff0c;若k8s集群发送变化&#xff0c;会发发生情况&#xff1f;Velero 备份pv&#xff0c;pv中数据变化&#xff0c;会发发…

uniapp实现微信小程序富文本之mp-html插件详解

uniapp实现微信小程序富文本之mp-html插件 1 文章背景1.1 正则表达式1.2 mp-html插件1.3 uniapp 2 过程详解2.1 下载mp-html插件2.2 项目中引入mp-html2.3 引入正则规范图片自适应2.4 效果展示 3 全部代码 1 文章背景 1.1 正则表达式 正则表达式&#xff0c;又称规则表达式,&…

算法刷题——删除排序链表中的重复元素(力扣)

文章目录 题目描述我的解法思路结果分析 官方题解分析 查漏补缺更新日期参考来源 题目描述 传送门 删除排序链表中的重复元素&#xff1a;给定一个已排序的链表的头 head &#xff0c; 删除所有重复的元素&#xff0c;使每个元素只出现一次 。返回 已排序的链表 。 示例 1&…