DQN在Gym的MountainCar环境的实现

DQN on MountainCar

引言

在本次实验里,我构建了DQN和Dueling DQN,并在Gymnasium库的MountainCar环境中对它们展开测试。我通过调整训练任务的超参数,同时设计不同的奖励函数及其对应参数,致力于获取更优的训练效果。最后,将训练结果进行可视化处理并加以比较。
在这里插入图片描述

DQN实现流程

实现方法

  • 实现了DQN类。
  • 实现了经验回放缓冲区(Buffer)类。
  • 编写了训练流程。
  • 将所有超参数集中配置在一处。
  • 设计了程序运行参数,具体如下:
    • --train:训练模式。
    • --test:测试模式。
    • --resume:断点续训模式。
    • --checkpoint:指定用于断点续训或测试模式加载的模型。
    • --task_name:确定任务名称,该名称与保存日志的路径相关。
    • --visualize_train:开启Gym环境的可视化功能。
    • --dueling:使用Dueling DQN网络。
    • --stop_threshold:早停阈值。

遇到的难题

难题一

在这里插入图片描述
在这里插入图片描述

在进行DQN训练时,进程常常在第20多轮(episode)时卡住。起初,我考虑可能是超参数设计的问题,但后来怀疑并非是训练超参数设置不当所致。我了解到Gym环境设定了停止条件,要么是完成任务,要么是达到最大步数。鉴于进程一直卡在第29轮不动,按常理应该早就达到了200步的上限,触发停止条件。然而,我在本地保存的日志文件,其TensorBoard日志文件大小却持续变化,似乎训练仍在进行。由此推测,可能是DQN算法中的经验回放缓冲区部分存在问题,导致缓冲区占用内存过大。

经过更为细致的分析,我发现编写循环终止条件时,仅设置了到达终点这一条件,而未考虑步数达到上限的情况。这就使得每次奖励值都为负几万。因此,应判定每个回合在步数达到上限时也停止。

修改终止条件后,训练能够以正常速度推进。

难题二

由于Gym中封装的MountainCar环境每走一步奖励值为 -1,小车很难学会先向左再向右冲刺的策略。

我先后尝试了多种奖励设计方法:

  1. 提取状态( s t a t e state state)中的位置信息( p o s i t i o n position position),在该环境中, p o s i t i o n position position 取值范围为 [ − 1.2 , 0.6 ] [-1.2, 0.6] [1.2,0.6]。若 p o s i t i o n ≥ 0.4 position \geq 0.4 position0.4,则 r e w a r d reward reward 加 1,以此奖励小车到达更靠右的位置。
  2. 为奖励小车尽可能到达更靠右的位置,直接设计一个关于 p o s i t i o n position position 的一次多项式,即 r e w a r d + = α ( p o s i t i o n + 0.5 ) reward += \alpha(position + 0.5) reward+=α(position+0.5),加 0.5 是因为小车初始位置约为 -0.5。
  3. 奖励小车向右的速度,公式为 r e w a r d + = α ( p o s i t i o n + 0.5 ) + β v e l o c i t y reward += \alpha(position + 0.5) + \beta velocity reward+=α(position+0.5)+βvelocity
  4. 为激励小车更多地探索先向左再向右的路径,设计了一个势能奖励函数,采用二次多项式形式,即 r e w a r d + = α 2 ( p o s i t i o n + 0.5 ) 2 + α 1 ( p o s i t i o n + 0.5 ) + β v e l o c i t y reward += \alpha_2(position + 0.5)^2 + \alpha_1(position + 0.5) + \beta velocity reward+=α2(position+0.5)2+α1(position+0.5)+βvelocity

经过逐步设计与尝试,最终发现第四种方案最为完善。经过参数调整后,成功训练小车到达山顶。
在这里插入图片描述

Dueling DQN实现流程

使用DQN的改进版本Dueling DQN,只需在DQN网络基础上拆分为两个子网络,并进行优势函数的计算。

遇到的难题

在这里插入图片描述

最初,网络结构设定为 2->128->128->128->3,训练结果难以收敛。我推测可能是网络结构过于复杂,导致权重参数极为稀疏。于是,我去掉了两个隐藏层,将网络结构改为 2->128->3,训练效果得到显著提升。
在这里插入图片描述

两种方法训练结果对比

训练曲线比较

我在代码中添加了保存TensorBoard日志的功能,这样可以通过TensorBoard查看不同训练任务的曲线变化。
在这里插入图片描述

从图中可以看出,在同一套奖励函数和参数设置下,两种方法的训练速度和收敛速度相近。

最终策略比较

通过测试功能,可以观察到两种方法训练出的策略。两种方法都能让小车成功到达山顶,但奖励表现存在些许差异。可以发现,Dueling DQN的训练结果得分更高,策略变化的波动更小。

[检查点] 从 runs/Dueling_exp6/model_final.pt 加载模型:回合数 = 2000,探索率(epsilon) = 0.600
进行 10 个回合的测试...测试 #1:奖励值 = -113.00测试 #2:奖励值 = -113.00测试 #3:奖励值 = -113.00测试 #4:奖励值 = -113.00测试 #5:奖励值 = -113.00测试 #6:奖励值 = -112.00测试 #7:奖励值 = -113.00测试 #8:奖励值 = -113.00测试 #9:奖励值 = -92.00测试 #10:奖励值 = -115.00
10 个回合的平均奖励值:-111.00 ± 6.37[检查点] 从 runs/exp6/model_final.pt 加载模型:回合数 = 2000,探索率(epsilon) = 0.600
进行 10 个回合的测试...测试 #1:奖励值 = -119.00测试 #2:奖励值 = -121.00测试 #3:奖励值 = -186.00测试 #4:奖励值 = -160.00测试 #5:奖励值 = -158.00测试 #6:奖励值 = -115.00测试 #7:奖励值 = -160.00测试 #8:奖励值 = -158.00测试 #9:奖励值 = -89.00测试 #10:奖励值 = -119.00
10 个回合的平均奖励值:-138.50 ± 28.30

最后提供代码:

import os
import time
import argparse
import random
import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from collections import deque
import matplotlib.pyplot as plt# ---------- 参数配置区域 ----------
EPISODES = 2000                 # 最大训练轮数
BATCH_SIZE = 64                 # 批次大小
GAMMA = 0.99                    # 折扣因子
LR = 1e-3                       # 学习率
EPS_START, EPS_END = 1.0, 0.6   # ε-greedy 起始/最小
EPS_DECAY = 0.995               # ε 衰减
TARGET_UPDATE = 10              # (已废弃,按步数更新)
TARGET_UPDATE_STEPS = 500       # 梯度更新步数间隔更新目标网络
REPLAY_BUFFER_SIZE = 10000      # Replay Buffer 容量
SAVE_INTERVAL = 50              # Checkpoint 保存间隔 (episodes)
RENDER_DELAY = 0.01             # 渲染延迟(秒)
ALPHA = 6                       # reward 参数(未用)
ALPHA2 = 2
ALPHA1 = 1
BETA = 1
# ----------------------------device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# ---------- 标准 DQN ----------
class DQN(nn.Module):def __init__(self, obs_dim, action_dim):super().__init__()self.net = nn.Sequential(nn.Linear(obs_dim, 128), nn.ReLU(),#  nn.Linear(128, 128),     nn.ReLU(),nn.Linear(128, action_dim))def forward(self, x):return self.net(x)# ---------- Dueling DQN ----------
class DuelingDQN(nn.Module):def __init__(self, obs_dim, action_dim):super().__init__()# 共享特征层self.feature = nn.Sequential(nn.Linear(obs_dim, 128), nn.ReLU(),# nn.Linear(128, 128),     nn.ReLU())# Advantage 分支self.advantage = nn.Sequential(# nn.Linear(128, 128), nn.ReLU(),nn.Linear(128, action_dim))# Value 分支self.value = nn.Sequential(# nn.Linear(128, 128), nn.ReLU(),nn.Linear(128, 1))def forward(self, x):x = self.feature(x)adv = self.advantage(x)                    # [B, A]val = self.value(x)                        # [B, 1]# Q(s,a) = V(s) + (A(s,a) - mean_a A(s,a))q = val + adv - adv.mean(dim=1, keepdim=True)return qclass ReplayBuffer:def __init__(self, capacity):self.buffer = deque(maxlen=capacity)def push(self, state, action, reward, next_state, done):self.buffer.append((state, action, reward, next_state, done))def sample(self, batch_size):batch = random.sample(self.buffer, batch_size)state, action, reward, next_state, done = map(np.array, zip(*batch))return state, action, reward, next_state, donedef __len__(self):return len(self.buffer)def save_checkpoint(path, policy_net, target_net, optimizer, episode, epsilon, rewards):os.makedirs(os.path.dirname(path), exist_ok=True)torch.save({'episode': episode,'epsilon': epsilon,'policy_state': policy_net.state_dict(),'target_state': target_net.state_dict(),'optim_state': optimizer.state_dict(),'episode_rewards': rewards,}, path)print(f"[Checkpoint] Saved to {path}")def load_checkpoint(path, policy_net, target_net, optimizer):data = torch.load(path, map_location=device)policy_net.load_state_dict(data['policy_state'])target_net.load_state_dict(data['target_state'])optimizer.load_state_dict(data['optim_state'])print(f"[Checkpoint] Loaded from {path}: episode={data['episode']}, epsilon={data['epsilon']:.3f}")return data['episode'], data['epsilon'], data['episode_rewards']def train(args):base_dir = os.path.join('runs', args.task_name)env = gym.make("MountainCar-v0", render_mode="human") if args.visualize_train else gym.make("MountainCar-v0")obs_dim = env.observation_space.shape[0]action_dim = env.action_space.n# 根据命令行参数选择网络net_cls = DuelingDQN if args.dueling else DQNpolicy_net = net_cls(obs_dim, action_dim).to(device)target_net = net_cls(obs_dim, action_dim).to(device)target_net.load_state_dict(policy_net.state_dict())target_net.eval()optimizer = optim.Adam(policy_net.parameters(), lr=LR)buffer = ReplayBuffer(REPLAY_BUFFER_SIZE)writer = SummaryWriter(base_dir)episode_rewards = []epsilon = EPS_STARTstart_ep = 0update_steps = 0if args.resume and args.checkpoint and os.path.isfile(args.checkpoint):start_ep, epsilon, episode_rewards = load_checkpoint(args.checkpoint, policy_net, target_net, optimizer)print("Start training...")for episode in range(start_ep, EPISODES):state, _ = env.reset()state = np.array(state, dtype=np.float32)total_r = 0.0done = Falsewhile not done:if args.visualize_train:env.render()time.sleep(RENDER_DELAY)# ε-greedyif random.random() < epsilon:action = env.action_space.sample()else:with torch.no_grad():qv = policy_net(torch.from_numpy(state).unsqueeze(0).to(device))action = qv.argmax(dim=1).item()next_s, r, terminated, truncated, _ = env.step(action)done = terminated or truncatednext_s = np.array(next_s, dtype=np.float32)position, velocity = next_sshaped_r = rshaped_r += (ALPHA2 * (position + 0.5)**2 + ALPHA1 * (position + 0.5)) + BETA * velocity# if (position >= 0.4):#     shaped_r += 1buffer.push(state, action, shaped_r, next_s, done)state = next_stotal_r += shaped_r# 学习更新if len(buffer) >= BATCH_SIZE:s, a, r_b, s2, d = buffer.sample(BATCH_SIZE)s_t = torch.from_numpy(s).to(device)a_t = torch.from_numpy(a).long().unsqueeze(1).to(device)r_t = torch.from_numpy(r_b).float().unsqueeze(1).to(device)s2_t = torch.from_numpy(s2).to(device)d_t = torch.from_numpy(d.astype(np.float32)).unsqueeze(1).to(device)q_curr = policy_net(s_t).gather(1, a_t)with torch.no_grad():q_next = target_net(s2_t).max(1)[0].unsqueeze(1)q_target = r_t + GAMMA * q_next * (1 - d_t)loss = nn.functional.mse_loss(q_curr, q_target)optimizer.zero_grad()loss.backward()optimizer.step()writer.add_scalar("Loss/Train", loss.item(), episode)# 按步数更新目标网络update_steps += 1if update_steps % TARGET_UPDATE_STEPS == 0:target_net.load_state_dict(policy_net.state_dict())# 每集结束后的记录epsilon = max(EPS_END, epsilon * EPS_DECAY)episode_rewards.append(total_r)writer.add_scalar("Reward/Episode", total_r, episode)writer.add_scalar("Epsilon", epsilon, episode)print(f"Episode {episode+1}/{EPISODES}  Reward={total_r:.2f}  Epsilon={epsilon:.3f}")# 早停检查if args.stop_threshold is not None and len(episode_rewards) >= 10:last10_avg = np.mean(episode_rewards[-10:])if last10_avg > args.stop_threshold:print(f"[Early Stop] Last 10 episodes avg reward = {last10_avg:.2f} > threshold {args.stop_threshold}")break# 定期保存if (episode+1) % SAVE_INTERVAL == 0:ckpt = os.path.join(base_dir, f"model_{episode+1}.pt")save_checkpoint(ckpt, policy_net, target_net, optimizer,episode+1, epsilon, episode_rewards)# 保存最后一版模型final_ckpt = os.path.join(base_dir, "model_final.pt")save_checkpoint(final_ckpt, policy_net, target_net, optimizer,episode+1, epsilon, episode_rewards)env.close()writer.close()# 画训练曲线plt.figure(figsize=(8,4))plt.plot(range(1, len(episode_rewards)+1), episode_rewards)plt.xlabel("Episode")plt.ylabel("Total Reward")plt.title("Reward Trend")plt.grid(True)plt.savefig(os.path.join(base_dir, "training_rewards.png"))plt.show()return policy_netdef test(policy_net, episodes=10):env = gym.make("MountainCar-v0", render_mode="human")rewards = []print(f"Testing over {episodes} episodes...")for i in range(episodes):state, _ = env.reset()state = np.array(state, dtype=np.float32)done = Falsetotal_r = 0.0while not done:env.render()with torch.no_grad():action = policy_net(torch.from_numpy(state).unsqueeze(0).to(device)).argmax(1).item()state, r, terminated, truncated, _ = env.step(action)done = terminated or truncatedstate = np.array(state, dtype=np.float32)total_r += rtime.sleep(0.02)rewards.append(total_r)print(f" Test #{i+1}: Reward = {total_r:.2f}")env.close()mean_r = np.mean(rewards)std_r = np.std(rewards)print(f"Average Reward over {episodes} episodes: {mean_r:.2f} ± {std_r:.2f}")if __name__ == "__main__":parser = argparse.ArgumentParser(description="DQN / DuelingDQN MountainCar with Early Stop & Step-wise Target Update")parser.add_argument("--train",          action="store_true")parser.add_argument("--test",           action="store_true")parser.add_argument("--resume",         action="store_true")parser.add_argument("--checkpoint",     type=str,   default=None)parser.add_argument("--task_name",      type=str,   default="default")parser.add_argument("--visualize_train",action="store_true")parser.add_argument("--dueling",        action="store_true", help="Use Dueling DQN instead of standard DQN")parser.add_argument("--stop_threshold", type=float, default=None,help="Early stop if avg reward over last 10 eps > this")args = parser.parse_args()model = Noneif args.train:model = train(args)if args.test:net_cls = DuelingDQN if args.dueling else DQNif not model:dummy = net_cls(2, 3).to(device)opt = optim.Adam(dummy.parameters(), lr=LR)ckpt = args.checkpoint or os.path.join('runs', args.task_name, "model_final.pt")if os.path.isfile(ckpt):_, _, _ = load_checkpoint(ckpt, dummy, dummy, opt)model = dummyif model:test(model)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/77700.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机网络综合实验指南

计算机网络综合实验指南 本实验将结合《计算机网络自顶向下》前三章的核心概念&#xff0c;通过实际操作加深对应用层、运输层和网络层的理解。实验涵盖 HTTP/TCP抓包分析、DNS解析观察、网页性能评估及简单Socket编程&#xff0c;帮助你将理论转化为实践。 实验准备 工具&…

【AI部署】腾讯云GPU-RUN—SadTalker的AI数字人视频—未来之窗超算中心

磁盘空间 创建未来之窗 查看磁盘命令 df -h 指定路径创建环境 conda create --prefix sadtalker python3.10 指令路径运行环境 conda activate ./sadtalker 安装环境 pip install torch1.12.1cu113 torchvision0.13.1cu113 torchaudio0.12.1 --extra-index-url https://…

爬虫利器SpiderTools谷歌插件教程v1.0.0!!!web端JavaScript环境检测!!!

SpiderTools谷歌插件教程v1.0.0 一、SpiderTools简介二、下载通道三、插件介绍四、插件使用五、工具函数使用 补环境工具推荐&#xff1a;爬虫补环境利器webEnv 一、SpiderTools简介 SpiderTools主要用于检测和监控网页的JavaScript运行环境。该插件可以帮助开发者更好地查看…

Android开发协调布局滑动悬停

Android开发协调布局滑动悬停 直接给个xml,防止下次忘了怎么写。 <?xml version="1.0" encoding="utf-8"?> <androidx.coordinatorlayout.widget.CoordinatorLayout xmlns:android="http://schemas.android.com/apk/res/android"x…

Linux学习——TCP

一.TCP编程API 1.socket函数 1.socket函数 include include int socket(int domain,int type,int protocol); 参数 domain AF_INET AF_INET6 AF_UNIX,AF_LOCAL AF_NETLINK AF_PACKET type SOCK_STREAM: 流式…

Linux驱动开发--异步通知与异步I/O

3、异步通知与异步I/O 3.1 Linux信号 阻塞与非阻塞访问、poll()函数提供了较好的解决设备访问的机制&#xff0c;但是如果有了异步通知&#xff0c;整套机制则更加完整了。 异步通知的意思是&#xff1a;一旦设备就绪&#xff0c;则主动通知应用程序&#xff0c;这样应用程序…

大语言模型推理能力的强化学习现状理解GRPO与近期推理模型研究的新见解

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

【Linux系统】Linux基础指令(详解Linux命令行常用指令,每一个指令都有示例演示)

文章目录 一、与文件路径相关的指令0.补充知识&#xff1a;路径的认识1.pwd 指令2.cd 指令&#xff08;含家目录的介绍&#xff09; 二、创建和删除文件的指令0.补充知识&#xff1a;普通文件和目录文件1.touch 指令&#xff08;可以修改文件的时间戳&#xff09;2.mkdir 指令3…

LangChain 单智能体模式示例【纯代码】

# LangChain 单智能体模式示例import os from typing import Anyfrom langchain.agents import AgentType, initialize_agent, Tool from langchain_openai import ChatOpenAI from langchain.tools import BaseTool from langchain_experimental.tools.python.tool import Pyt…

解决:VSCode C++ conan 安装第三方库后 头文件报错

文章目录 1 头文件include路径查找报错参考 1 头文件include路径查找报错 找到conan_toolchain.cmake中 INCLUDE_PATH list(PREPEND CMAKE_INCLUDE_PATH "/Users/hanliqiang/.conan2/p/b/fmte8c4f7a755477/p/include")生成C编译配置 CtrlShiftP 中选择C Edit Confi…

松灵Cobot Magic双臂具身遥操机器人(基于ROS的定位建图与协同导航技术)

摘要 本文以CobotMagic可移动协作机器人为研究对象&#xff0c;从硬件架构设计、软件系统架构、多传感器融合定位建图系统、智能导航系统协同机制四个维度&#xff0c;深入解析机器人系统工作原理。重点研究多传感器融合定位建图系统实现原理&#xff0c;结合实测数据验证系统…

回归,git 分支开发操作命令

核心分支说明 主分支&#xff08;master/production&#xff09;存放随时可部署到生产环境的稳定代码&#xff0c;仅接受通过测试的合并请求。 开发分支&#xff08;develop&#xff09;集成所有功能开发的稳定版本&#xff0c;日常开发的基础分支&#xff0c;从该分支创建特性…

ASP.NET Core 最小 API:极简开发,高效构建(下)

在上篇文章 ASP.NET Core 最小 API&#xff1a;极简开发&#xff0c;高效构建&#xff08;上&#xff09; 中我们添加了 API 代码并且测试&#xff0c;本篇继续补充相关内容。 一、使用 MapGroup API 示例应用代码每次设置终结点时都会重复 todoitems URL 前缀。 API 通常具有…

Spring之我见 - Spring Boot Starter 自动装配原理

欢迎光临小站&#xff1a;致橡树 Spring Boot Starter 的核心设计理念是 约定优于配置&#xff0c;其核心实现基于 自动配置&#xff08;Auto-Configuration&#xff09; 和 条件化注册&#xff08;Conditional Registration&#xff09;。以下是其生效原理&#xff1a; 约定…

精益数据分析(7/126):打破创业幻想,拥抱数据驱动

精益数据分析&#xff08;7/126&#xff09;&#xff1a;打破创业幻想&#xff0c;拥抱数据驱动 在创业的道路上&#xff0c;我们都怀揣着梦想&#xff0c;但往往容易陷入自我编织的幻想中。我希望通过和大家一起学习《精益数据分析》&#xff0c;能帮助我们更清醒地认识创业过…

牛客java练习题

[toc] 1.依赖注入 依赖注入是一种设计模式和编程思想,不依赖 具体的框架实现,可以通过多种方式和框架来实现可以通过Spring , Google Guice , PicoContainer 等都可以实现依赖注入,也可以通过手动编写实现目的: 为了解耦合,将对象之间的依赖关系从代码中解耦出来, 使系统更加…

大模型应用开发自学笔记

理论学习地址&#xff1a; https://zh.d2l.ai/chapter_linear-networks/index.html autodl学术加速&#xff1a; source /etc/network_turboconda常见操作: 删除&#xff1a; conda remove --name myenv --all -y导出&#xff1a; conda env export > environment.yml…

鸿蒙ArkUI实战之TextArea组件、RichEditor组件、RichText组件、Search组件的使用

本文接上篇继续更新ArkUI中组件的使用&#xff0c;本文介绍的组件有TextArea组件、RichEditor组件、RichText组件、Search组件&#xff0c;这几个组件的使用对应特定场景&#xff0c;使用时更加需要注意根据需求去使用 TextArea组件 官方文档&#xff1a; TextArea-文本与输…

除了`String`、`StringBuffer` 和 `StringBuilder`之外,还有什么处理字符串的方法?

一、标准库中的字符串处理类 1. StringJoiner&#xff08;Java 8&#xff09; 用途&#xff1a;用于在拼接字符串时自动添加分隔符、前缀和后缀。示例&#xff1a;StringJoiner sj new StringJoiner(", ", "[", "]"); sj.add("A").…

Qt中读写结构体字节数据

在Qt中读写结构体字节数据通常涉及将结构体转换为字节数组(QByteArray)或直接从内存中读写。以下是几种常见方法&#xff1a; 方法1&#xff1a;使用QDataStream读写结构体 cpp #include <QFile> #include <QDataStream>// 定义结构体 #pragma pack(push, 1) //…