07、基于LunarLander登陆器的强化学习案例(含PYTHON工程)

07、基于LunarLander登陆器的强化学习(含PYTHON工程)

开始学习机器学习啦,已经把吴恩达的课全部刷完了,现在开始熟悉一下复现代码。全部工程可从最上方链接下载。

基于TENSORFLOW2.10

0、实践背景

gym的LunarLander是一个用于强化学习的经典环境。在这个环境中,智能体(agent)需要控制一个航天器在月球表面上着陆。航天器的动作包括向上推进、不进行任何操作、向左推进或向右推进。环境的状态包括航天器的位置、速度、方向、是否接触到地面或月球上空等。

智能体的任务是在一定的时间内通过选择正确的动作使航天器安全着陆,并且尽可能地消耗较少的燃料。如果航天器着陆时速度过快或者与地面碰撞,任务就会失败。智能体需要通过不断地尝试和学习来选择最优的动作序列,以完成这个任务。

下面是训练的结果:
在这里插入图片描述

1、实现原理

1.1 强化学习

强化学习实现原理主要包括以下几个方面:

智能体与环境交互:强化学习中的智能体(agent)通过与环境不断地进行交互,学习一个从环境到动作的映射,学习的目标就是使累计回报最大化。
试错学习:强化学习是一种试错学习,智能体需要在各种状态(环境)下尝试所有可以选择的动作,通过环境给出的反馈(即奖励)来判断动作的优劣,最终获得环境和最优动作的映射关系(即策略)。
奖励函数与策略更新:强化学习算法的核心在于定义奖励函数,并通过不断迭代来更新策略,从而实现最优化的决策。
状态获取:智能体需要通过传感器等手段获取当前环境的状态信息,如图像、声音等。

1.2 软更新

软更新(Soft Updates)技术是一种在强化学习中常用的技术,特别是在Q-learning算法中。该技术的主要目的是提高学习过程的稳定性。

在强化学习中,我们通常有一个主要的网络(如Q-network)来学习并更新其权重。然而,如果我们直接使用这个网络来估计Q值并选择动作,同时也在每个步骤中更新其权重,这可能会导致学习过程的不稳定。因为网络权重的连续变化会导致Q值的波动,从而使得学习策略变得不一致。

为了解决这个问题,软更新技术被引入。其基本思想是创建一个额外的网络,通常被称为目标网络(Target Network),该网络的结构与主要网络相同,但其权重的更新是缓慢的,即它不会在每个步骤中都进行更新。相反,目标网络的权重会在主要网络经过一定数量的步骤或达到一定的条件后才进行更新。这通常是通过将主要网络的权重与目标网络的权重进行某种形式的平均来实现的。

由于目标网络的权重更新是缓慢的,因此它提供的Q值估计更为稳定。这有助于使学习过程更加稳定,因为即使主要网络的权重发生显著变化,目标网络的权重也只会有较小的变化,从而减少了Q值的波动:

1.3 贪婪策略

训练时,每一步并不完全采用最优行为,有一定可能尝试新的动作:

def get_action(q_values, epsilon=0):if random.random() > epsilon:return np.argmax(q_values.numpy()[0])else:return random.choice(np.arange(4))

2、强化学习实现步骤

2.1、导入相关机器学习使用的包
# 导入时间处理库  
import time  
# 从collections模块导入双端队列和命名元组  
from collections import deque, namedtuple  
# 导入用于开发和比较强化学习算法的库  
import gym  
# 导入数值计算库,以np作为别名  
import numpy as np  
# 导入Python图像处理库中的Image模块  
import PIL.Image  
# 导入机器学习框架  
import tensorflow as tf  
# 导入自定义的Lunar Lander工具库  
import Lunar_Lander_utils  
# 从Keras库导入顺序模型类  
from keras import Sequential  
# 从Keras层模块导入全连接层和输入层类  
from keras.layers import Dense, Input  
# 从Keras损失模块导入均方误差损失函数  
from keras.losses import MSE  
# 从Keras优化器模块导入Adam优化器  
from keras.optimizers import Adam
2.2、LunarLander登陆器环境加载

在gym库中的使用指导可以参考:LunarLander

我们关注的是可以从这个交互接口中得到什么和控制什么,对于此处的登陆器,我们关注可以得到它的哪些状态和对其进行那些操作
在这里插入图片描述
依据官方手册,存在四种可用的离散动作:不执行任何操作、启动左方向引擎、启动主引擎、启动右方向引擎。能够得到的状态是一个8维向量,包括着陆器在x和y方向上的坐标、x和y方向上的线速度、角度、角速度,以及两个布尔值,表示每个着陆腿是否与地面接触。

# 使用gym库创建一个名为'LunarLander-v2'的环境,并设置渲染模式为'rgb_array'  
# 'rgb_array'模式返回一个numpy数组,表示环境的RGB图像  
env = gym.make('LunarLander-v2', render_mode='rgb_array')  # 重置环境到初始状态,并返回初始状态  
env.reset()  # 使用PIL库(Python Imaging Library)从环境的渲染数组创建一个图像  
PIL.Image.fromarray(env.render())  # 获取观测空间(状态)的尺寸,这是一个8维向量  
state_size = env.observation_space.shape  # 获取动作空间的数量,这表示有多少种可能的离散动作可以选择  
num_actions = env.action_space.n  # 打印状态空间和动作空间的信息  
print('State Shape:', state_size)  
print('Number of actions:', num_actions)  
2.3、创建神经网络结构-使用软更新
# 创建一个名为Q-Network的神经网络  
q_network = Sequential([Input(shape=state_size),  # 输入层,形状由state_size定义  Dense(units=128, activation='relu'),  # 全连接层,128个单元,使用ReLU激活函数  Dense(units=128, activation='relu'),  # 全连接层,128个单元,使用ReLU激活函数  Dense(units=num_actions, activation='linear'),  # 输出层,单元数由num_actions定义,使用线性激活函数  
])# 这里是软更新的网络(Target Q-Network)  
target_q_network = Sequential([Input(shape=state_size),  # 输入层,形状由state_size定义  Dense(units=128, activation='relu'),  # 全连接层,128个单元,使用ReLU激活函数  Dense(units=128, activation='relu'),  # 全连接层,128个单元,使用ReLU激活函数  Dense(units=num_actions, activation='linear'),  # 输出层,单元数由num_actions定义,使用线性激活函数  
])
2.4、强化学习的误差计算与梯度下降

首先是误差计算的函数,这边的Q-learning算法类似于一种迭代算法,
在这里插入图片描述
这就好像我们在高中学习的数组题目中,已经知道了an和an+1的关系式,去求解详细的an的表达式。此处误差计算的代码如下(值得注意的是,下一步的回报Q(s’,a’)是使用Target Q-Network计算的,而当前步的是使用Q-Network网络计算的):

def compute_loss(experiences, gamma, q_network, target_q_network):  """  计算损失函数。  参数:  experiences: 一个包含["state", "action", "reward", "next_state", "done"]的namedtuples的元组  gamma: (浮点数) 折扣因子。  q_network: (tf.keras.Sequential) 用于预测q_values的Keras模型  target_q_network: (tf.keras.Sequential) 用于预测目标的Keras模型  返回:  loss: (TensorFlow Tensor(shape=(0,), dtype=int32)) y目标与Q(s,a)值之间的均方误差。  """  # 解压经验元组的小批量数据  states, actions, rewards, next_states, done_vals = experiences  # 计算最大的Q^(s,a),reduce_max用于求最大值  max_qsa = tf.reduce_max(target_q_network(next_states), axis=-1)  # 如果回合结束,设置y = R,否则设置y = R + γ max Q^(s,a)。  y_targets = rewards + (gamma * max_qsa * (1 - done_vals))  # 获取q_values  q_values = q_network(states)  q_values = tf.gather_nd(q_values, tf.stack([tf.range(q_values.shape[0]),  tf.cast(actions, tf.int32)], axis=1))  # 计算损失  loss = MSE(y_targets, q_values)  return loss

学习算法的定义如下所示,使用了软更新技术:


def agent_learn(experiences, gamma):"""  更新Q网络的权重。  参数:  experiences: 一个包含["state", "action", "reward", "next_state", "done"]的namedtuples的元组  gamma: (浮点数) 折扣因子。  """# 使用tf.GradientTape()来计算损失相对于权重的梯度  with tf.GradientTape() as tape:# 调用compute_loss函数计算损失  loss = compute_loss(experiences, gamma, q_network, target_q_network)# 使用GradientTape计算损失相对于q_network的可训练变量的梯度  gradients = tape.gradient(loss, q_network.trainable_variables)# 使用优化器应用梯度,从而更新q_network的权重  optimizer.apply_gradients(zip(gradients, q_network.trainable_variables))# 使用软更新技术将q_network的权重更新至target_q_network  Lunar_Lander_utils.update_target_network(q_network, target_q_network)

Lunar_Lander_utils.update_target_network(q_network, target_q_network)是软更新的关键所在:

def update_target_network(q_network, target_q_network):for target_weights, q_net_weights in zip(target_q_network.weights, q_network.weights):target_weights.assign(TAU * q_net_weights + (1.0 - TAU) * target_weights)
2.5、强化学习的训练过程

在这里插入图片描述

# 重置环境至初始状态并获得初始状态  
state,_ = env.reset()  
total_points = 0  # 这里进行一次模拟,最多运行max_num_timesteps个时间步  
for t in range(max_num_timesteps):  # 从当前状态S使用ε-贪婪策略选择一个动作A  # 从元组中提取NumPy数组  # (注:这部分代码被注释掉了,所以下面的state_array并不会实际运行)  # if state[0].shape == ():  #     state_array = state  # else:  #     state_array = state[0]  # 将state_array转换为NumPy数组  state_qn = np.expand_dims(state, axis=0)  # 得到每个动作的回报数值,是一个1x4的数组,分别表示4个action的回报  q_values = q_network(state_qn)  # 此处实行贪婪策略,从当前最优action和随机action中选择  action = Lunar_Lander_utils.get_action(q_values, epsilon)  # 执行上述动作后得到的新状态、奖励、是否完成等信息  next_state, reward, done, _, _ = env.step(action)  # 将经验元组(S,A,R,S')存储在记忆缓冲区中  # 使用memory存储历史数据  memory_buffer.append(experience(state, action, reward, next_state, done))  # 只在特定的时间步进行更新  update = Lunar_Lander_utils.check_update_conditions(t, NUM_STEPS_FOR_UPDATE, memory_buffer)  if update:  # 从D中随机抽取小批量的经验元组(S,A,R,S')  # 只随机取MINIBATCH_SIZE个数据进行一次训练  experiences = Lunar_Lander_utils.get_experiences(memory_buffer)  # 设置y目标,执行梯度下降步骤,并更新网络权重  agent_learn(experiences, GAMMA)  state = next_state.copy()  total_points += reward  if done:  break  # 将本次总得分添加到历史得分中  
total_point_history.append(total_points)  
# 计算最近num_p_av次得分的平均值  
av_latest_points = np.mean(total_point_history[-num_p_av:])  # 更新ε值  
epsilon = Lunar_Lander_utils.get_new_eps(epsilon)

3、LunarLander文件解释

Lunar_Lander.py:运行此文件进行训练
lunar_lander_model.h5:Lunar_Lander.py训练得到的模型文件
Lunar_Lander_test.py:此文件调用h5模型并运行模拟器,将数据打包成视频格式,视频位于Lunar_Lander_videos文件夹
Lunar_Lander_utils.py:函数库

注意:运行Lunar_Lander_test.py出现长时间(大于20s)无返回0的情况,需要重新运行。这是因为LunarLander一直悬浮在空中了(相当于直升机了)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/193409.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第十五届蓝桥杯模拟赛(第二期)

大家好,我是晴天学长,本次分享,制作不易,本次题解只用于学习用途,如果有考试需要的小伙伴请考完试再来看题解进行学习,需要的小伙伴可以点赞关注评论一波哦!后续会继续更新第三期的。&#x1f4…

解决uview中uni-popup弹出层不能设置高度问题

开发场景:点击条件筛选按钮,在弹出的popup框中让用户选择条件进行筛选 但是在iphone12/13pro展示是正常,但是切换至其他手机型号就填充满了整个屏幕,需要给这个弹窗设置一个固定的高度 iphone12/13pro与其他型号手机对比 一开始…

Linux环境下 make/makefile、文件时间属性 详解!!!

1.项目自动化构建工具make/makefile 1.为什么要有make/makefile 我们先写一个简单的代码,然后编译生成一个可执行程序,下面的内容我们需要知道gcc识和编译链接的一些知识,不清楚的朋友们可以点这里http://t.csdnimg.cn/0QvL8 我们知道要想生…

Java 数据结构篇-用链表、数组实现队列(数组实现:循环队列)

🔥博客主页: 【小扳_-CSDN博客】 ❤感谢大家点赞👍收藏⭐评论✍ 文章目录 1.0 队列的说明 1.1 队列的几种常用操作 2.0 使用链表实现队列说明 2.1 链表实现队列 2.2 链表实现队列 - 入栈操作 2.3 链表实现队列 - 出栈操作 2.4 链表实现队列 …

9-1定义一个结构体计算该日是本年中的第几天。

#include<stdio.h> struct {int year;int month;int day; }date; int main(){int days;printf("输入年月日&#xff1a;\n");scanf("%d,%d,%d",&date.year,&date.month,&date.day);switch(date.month){case 1:daysdate.day; break;case…

【Element-ui】Checkbox 多选框 与 Input 输入框

文章目录 前言一、Checkbox 多选框1.1 基础用法1.2 禁用状态1.3 多选框组1.4 indeterminate 状态1.5 可选项目数量的限制1.6 按钮样式1.7 带有边框1.8 Checkbox Events1.9 Checkbox Attributes 二、Input 输入框2.1 基础用法2.2 禁用状态2.3 可清空2.4 密码框2.5 带 icon 的输入…

nexus私服开启HTTPS

maven3.8.1以上不允许使用HTTP服务的仓库地址&#xff0c;如果自己搭建的私服需要升级为HTTPS或做一些设置&#xff0c;如果要升级HTTPS服务有两种方式&#xff1a;1、使用Nginx开启HTTPS并反向代理nexus&#xff1b;2、直接在nexus开启HTTPS。这里介绍第二种方式 1、在ssl目录…

计算机网络的分类

目录 一、按照传输介质进行分类 1、有线网络 2、无线网络 二、按照使用者进行分类 1、公用网 (public network) 2、专用网(private network) 三、按照网络规模和作用范围进行分类 1、PAN 个人局域网 2、LAN 局域网 3、MAN 城域网 4、 WAN 广域网 5、Internet 因特…

ChatGPT 的 18 种玩法,你还不会用吗?

你确定&#xff0c;你会使用 ChatGPT 了吗&#xff1f; 今天给大家整理了 18 种 ChatGPT 的用法&#xff0c;看看有哪些方法是你能得上的。 用之前我们可以打开R5Ai平台&#xff0c;可以免费使用目前所有的大模型 地址&#xff1a;R5Ai.com 语法更正 用途&#xff1a;文章…

【vue】尚硅谷vue3学习笔记

Vue3快速上手 1.Vue3简介 2020年9月18日&#xff0c;Vue.js发布3.0版本&#xff0c;代号&#xff1a;One Piece&#xff08;海贼王&#xff09;耗时2年多、2600次提交、30个RFC、600次PR、99位贡献者github上的tags地址&#xff1a;https://github.com/vuejs/vue-next/release…

mysql(八)docker版Mysql8.x设置大小写忽略

Mysql 5.7设置大小写忽略可以登录到Docker内部&#xff0c;修改/etc/my.cnf添加lower_case_table_names1&#xff0c;并重启docker使之忽略大小写。但MySQL8.0后不允许这样&#xff0c;官方文档记录&#xff1a; lower_case_table_names can only be configured when initializ…

机器人与3D视觉 Robotics Toolbox Python 一 安装 Robotics Toolbox Python

一 安装python 库 前置条件需要 Python > 3.6&#xff0c;使用pip 安装 pip install roboticstoolbox-python测试安装是否成功 import roboticstoolbox as rtb print(rtb.__version__)输出结果 二 Robotics Toolbox Python样例程序 加载机器人模型 加载由URDF文件定义…

【算法每日一练]-图论(保姆级教程篇12 tarjan篇)#POJ3352道路建设 #POJ2553图的底部 #POJ1236校园网络 #缩点

目录 POJ3352&#xff1a;道路建设 思路&#xff1a; POJ2553&#xff1a;图的底部 思路&#xff1a; POJ1236校园网络 思路&#xff1a; 缩点&#xff1a; 思路&#xff1a; POJ3352&#xff1a;道路建设 由于道路要维修&#xff0c;维修时候来回都不能走&#xff0c;现要…

MDK提示:在多字节的目标代码中,没有此Unicode 字符可以映射到的字符

MDK警告提示在多字节的目标代码中&#xff0c;没有此Unicode 字符可以映射到的字符 警告提示&#xff1a; 在写MDK的工程代码时&#xff0c;发现代码中引入的头文件前方出现一些红色的叉叉&#xff0c;但是编译工程并不报错&#xff0c;功能也能正常执行的&#xff0c;只是提…

JS利用时间戳倒计时案例

我们在逛某宝&#xff0c;或者逛某东时&#xff0c;我们时常看到一个倒计时&#xff0c;时间一到就开抢&#xff0c;这个倒计时是如何做的呢&#xff1f;让我为大家介绍一下。 理性分析一下&#xff1a; 1.用将来时间减去现在时间就是剩余的时间 2.核心&#xff1a;使用将来的时…

C指针介绍(1)

文章目录 每日一言指针的简单介绍内存和地址指针在内存中的存储指针的定义和声明泛型指针 指针的关系运算算数运算关系运算 结语 每日一言 ⭐「 一声梧叶一声秋&#xff0c;一点芭蕉一点愁&#xff0c;三更归梦三更后。 」–水仙子夜雨-徐再思 指针的简单介绍 C语言指针是C语…

人工智能轨道交通行业周刊-第67期(2023.11.27-12.3)

本期关键词&#xff1a;列车巡检机器人、城轨智慧管控、制动梁、断路器、AICC大会、Qwen-72B 1 整理涉及公众号名单 1.1 行业类 RT轨道交通人民铁道世界轨道交通资讯网铁路信号技术交流北京铁路轨道交通网上榜铁路视点ITS World轨道交通联盟VSTR铁路与城市轨道交通RailMetro…

算法工程师面试八股(搜广推方向)

文章目录 机器学习线性和逻辑回归模型逻辑回归二分类和多分类的损失函数二分类为什么用交叉熵损失而不用MSE损失&#xff1f;偏差与方差Layer Normalization 和 Batch NormalizationSVM数据不均衡特征选择排序模型树模型进行特征工程的原因GBDTLR和GBDTRF和GBDTXGBoost二阶泰勒…

React使报错不再白屏

如果代码中出现问题导致报错&#xff0c;通常会使页面报错&#xff0c;导致白屏 function Head() {// 此时模拟报错导致的白屏return <div>Head --- {content}</div> } export default () > {return (<><div>下面是标题</div><Head />…

若依框架分页

文章目录 一、分页功能解析1.前端代码分析2.后端代码分析3. LIMIT含义 二、自定义MyPage,多态获取total1.定义MyPage类和对应的调用方法 一、分页功能解析 1.前端代码分析 页面代码 封装的api请求 接口请求 2.后端代码分析 controller代码 - startPage() getDataTable(…