【强化学习】Deep Q Learning

Deep Q Learning

在前两篇文章中,我们发现RL模型的目标是基于观察空间 (observations) 和最大化奖励和 (maximumize sum rewards) 的。

如果我们能够拟合出一个函数 (function) 来解决上述问题,那就可以避免存储一个 (在Double Q-Learning中甚至是两个) 巨大的Q_table。

Tabular -> Function

  • Continous Observation: 函数能够让我们处理连续的观察空间,而表只能处理离散的。
  • Saving the space: 不用存储 len(state) * len(action) 大小的Q_table

在早期人们试过使用核函数或者线性函数等各种方法去拟合这个function,但后来深度神经网络出现后人们纷纷开始研究如何用DNN来拟合。

然而以上的拟合方式不免存在一个问题,我们期望得到一个DNN,使得DNN(state)->Q-value

可是强化学习中,最好的Q-value在开始时是不知道的 (这也是强化学习和机器学习不一样的地方:我们不知道能否训练到一个Q值,直到有人把它训练出来),这就导致我们在训练过程中没有目标函数。

Natural Deep Q Learning

所有的第一步必须从高维的感官输入中获得对环境的有效表示

深度Q网络(DQN)是一种将深度学习和Q学习相结合的强化学习方法。DQN由DeepMind于2015年提出,并在玩Atari视频游戏方面取得了显著的成功。DQN的核心原理是使用深度神经网络来近似Q函数,即在给定状态下采取某一动作的预期累积奖励。

DQN的关键创新

  1. 使用神经网络近似Q函数

    • 传统的Q学习使用表格(Q表)来存储每个状态-动作对的Q值。当状态空间很大或连续时,这变得不切实际。
    • DQN通过使用深度神经网络来近似Q函数,克服了这一限制。网络输入是状态,输出是该状态下所有可能动作的Q值。
  2. 经验回放

    • DQN引入了经验回放机制,即将代理的经验(状态、动作、奖励、新状态)存储在回放缓冲区中。

      image-20231114211049019
    • 训练时,从这个缓冲区中随机抽取小批量经验进行学习。这增加了数据的多样性,减少了样本之间的相关性,从而稳定了训练。

  3. 目标网络

    • DQN使用两个结构相同但参数不同的网络:一个是在线网络 (dqn_model),用于当前Q值的估计;另一个是目标网络 (target_model),用于计算目标Q值。
    • 目标网络的参数定期从在线网络复制过来,但不是每个训练步骤都更新。这减少了学习过程中的震荡,提高了稳定性。
    image-20231114211236348

训练过程

  • 在每个时间步,代理根据当前的Q值(通常结合探索策略,如ε-贪婪)选择一个动作,接收环境的反馈(新状态和奖励),并将这个转换存储在经验回放缓冲区中。
  • 训练神经网络时,从缓冲区中随机抽取一批经验,然后使用贝尔曼方程计算目标Q值和预测Q值,通过最小化这两者之间的差异来更新网络参数。

DQN解决月球着陆问题

导入环境

import time
from collections import defaultdictimport gymnasium as gym
import numpy as np
import randomfrom matplotlib import pyplot as plt, animation
from IPython.display import display, clear_output
env = gym.make("LunarLander-v2", continuous=False, render_mode='rgb_array')

定义经验池

class ExperienceBuffer:def __init__(self, size=0):self.states = []self.actions = []self.rewards = []self.states_next = []self.actions_next = []self.size = 0def clear(self):self.__init__()def append(self, s, a, r, s_n, a_n):self.states.append(s)self.actions.append(a)self.rewards.append(r)self.states_next.append(s_n)self.actions_next.append(a_n)self.size += 1def batch(self, batch_size=128):indices = np.random.choice(self.size, size=batch_size, replace=True)return  (np.array(self.states)[indices],np.array(self.actions)[indices],np.array(self.rewards)[indices],np.array(self.states_next)[indices],np.array(self.actions_next)[indices],)
import torchfrom torch import nn
from torch.nn.functional import relu
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

定义DQN

class DQN(nn.Module):def __init__(self, state_size, action_size):super().__init__()self.state_size = state_sizeself.action_size = action_sizeself.hidden_size = 32self.linear_1 = nn.Linear(self.state_size, self.hidden_size)self.linear_2 = nn.Linear(self.hidden_size, self.action_size)nn.init.uniform_(self.linear_1.weight, a=-0.1, b=0.1)nn.init.uniform_(self.linear_2.weight, a=-0.1, b=0.1)def forward(self, state):if not isinstance(state, torch.Tensor):state = torch.tensor([state], dtype=torch.float)state = state.to(device)return self.linear_2(relu(self.linear_1(state)))

定义policy

def policy(model, state, eval=False):eps = 0.1if not eval and random.random() < eps:return random.randint(0, model.action_size - 1)else:q_values = model(torch.tensor([state], dtype=torch.float))action = torch.multinomial(F.softmax(q_values), num_samples=1)return int(action[0])

collect

dqn_model = DQN(state_size=8, action_size=4).to(device)
target_model = DQN(state_size=8, action_size=4).to(device)
from tqdm.notebook import tqdm
# 学习率
alpha = 0.9
# 折扣因子
gamma = 0.95
# 训练次数
episode = 1000
experience_buffer = ExperienceBuffer()eval_iter = 100
eval_num = 100# collect
def collect():for e in tqdm(range(episode)):state, info = env.reset()action = policy(dqn_model, state)sum_reward = 0while True:state_next, reward, terminated, truncated, info_next = env.step(action)action_next= policy(dqn_model, state_next)sum_reward += rewardexperience_buffer.append(state, action, reward, state_next, action_next)if terminated or truncated:breakstate = state_nextinfo = info_nextaction = action_next

learning

## learning
from torch.optim import Adamloss_fn = nn.MSELoss()
optimizer = Adam(lr=1e-5, params=dqn_model.parameters())losses = []
target_fix_period = 5
epoch = 3def train():for e in range(epoch):batch_size = 128for i in range(experience_buffer.size // batch_size):s, a, r, s_n, a_n = experience_buffer.batch(batch_size)s = torch.tensor(s, dtype=torch.float).to(device)s_n = torch.tensor(s_n, dtype=torch.float).to(device)r = torch.tensor(r, dtype=torch.float).to(device)a = torch.tensor(a, dtype=torch.long).to(device)a_n = torch.tensor(a_n, dtype=torch.long).to(device)y = r + target_model(s_n).gather(1, a_n.unsqueeze(1)).squeeze(1)y_hat = dqn_model(s).gather(1, a.unsqueeze(1)).squeeze(1)loss = loss_fn(y, y_hat)optimizer.zero_grad()loss.backward()optimizer.step()if i % 500 == 0:print(f'i == {i}, loss = {loss} ')if i % target_fix_period == 0:target_model.load_state_dict(dqn_model.state_dict())

a_n:动作
s_n:状态

image-20231205221613164

image-20231205221643890

将状态 s_n 作为输入,target_model的输出是针对每个可能动作的 Q 值;如果 s_n 包含多个状态(比如一个批量),输出将是一个批量的 Q 值

image-20231205221710717

image-20231205221746045

image-20231205221827050

训练

for i in range(10):print(f'collect/train: {i}')experience_buffer.clear()collect()train()

结果

task_num = 10
frames = []for _ in range(10):state, _ = env.reset()while True:action = policy(dqn_model, state, eval=True)state_next, reward, terminated, truncated, info_next = env.step(action)frames.append(env.render())if terminated or truncated:break

output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/235134.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

git的使用思维导图

源文件在github主页&#xff1a;study_collection/cpp学习/git at main stu-yzZ/study_collection (github.com)

Unity的UI界面——Text/Image

编辑UI界面时&#xff0c;要先切换到2d界面 &#xff08;3d项目的话&#xff09; 1.Text控件 Text控件的相关属性&#xff1a; Character:&#xff08;字符&#xff09; Font&#xff1a;字体 Font Style&#xff1a;字体样式 Font Size&#xff1a;字体大小 Line Spac…

华清远见嵌入式学习——ARM——作业1

要求&#xff1a; 代码&#xff1a; mov r0,#0 用于加mov r1,#1 初始值mov r2,#101 终止值loop: cmp r1,r2addne r0,r0,r1addne r1,r1,#1bne loop 效果&#xff1a;

用户管理第2节课--idea 2023.2 后端规整项目目录

目的&#xff1a;当项目文件多了之后&#xff0c;咱们也能够非常清晰的去找到代码的一个目录 一、项目规整了两大处 1.1 com.yupi.usercenter & resources 二、具体操作 com.daisy.usercenter 2.1 原版 & 鱼皮有出入&#xff0c;demos.web就不删除了 原因&#…

Aurora8B10B(二) 从手册和仿真学习Aurora8B10B

一. 简介 在上篇文章中&#xff0c;主要结合IP配置界面介绍了一下Aurora8B10B&#xff0c;这篇文章将结合文档来学习一下Aurora8B10B内部的一些细节 和 相关的时序吧。文档主要是参考的是这个 pg046-aurora-8b10b-en-us-11.1 二. Aurora8B10B内部细节 在手册上&#xff0c;对…

VR全景是什么?普通人该如何看待VR全景创业?

如果你还没有开始了解VR&#xff0c;那么不妨驻足几分钟细致的了解一下&#xff0c;你就会对VR全景行业有不一样的看法。VR全景与普通的平面图片和视频相比&#xff0c;具有更加丰富的视觉体验和交互性&#xff0c;基于真实场景的全景图像的虚拟现实技术&#xff0c;制作流程简…

Maven仓库上传jar和mvn命令汇总

目录 导入远程仓库 命令结构 命令解释 项目pom 输入执行 本地仓库导入 命令格式 命令解释 Maven命令汇总 mvn 参数 mvn常用命令 web项目相关命令 导入远程仓库 命令结构 mvn deploy:deploy-file -Dfilejar包完整名称 -DgroupIdpom文件中引用的groupId名 -Dartifa…

Ubuntu 常用命令之 apt-get 命令用法介绍

apt-get是Ubuntu系统下的一个命令行工具&#xff0c;用于处理包。这个命令可以自动下载和安装软件包及其依赖项。它是Advanced Packaging Tool (APT)的一部分&#xff0c;APT是处理包的高级工具&#xff0c;可以处理复杂的包关系&#xff0c;如依赖关系等。 apt-get命令的常见…

一个真正的软件测试从业人员必备技能有哪些?

协同开发能力&#xff1a; 1. 项目管理&#xff08;SVN、Git&#xff09; 2. 数据分析能力&#xff08;Fiddler、Charles、浏览器F12&#xff09;。 接口测试&#xff1a; 1. 概念及接口测试原理概念&#xff08;概念、接口测试原理&#xff09; 2. 接口测试工具&#xff…

数据工作者最爱的AI功能,你知道吗~

在工作中难以避免的一项任务就是各种数据总结和汇报&#xff0c;怎么分析总结&#xff1f;以何种形式汇报&#xff1f;都是具有一定的难点&#xff0c;所以我要推荐的就是具有AI图表解析功能的可视化工具——Easyv数字孪生低代码可视化平台。可实现对数据的可视化展示&#xff…

软件测试项目测试报告总结

测试计划概念&#xff1a;就在软件测试工作实施之前明确测试对象&#xff0c;并且通过资源、时间、风险、测试范围和预算等方面的综合分析和规划&#xff0c;保证有效的实施软件测试。 需求挖掘的6个方面&#xff1a; 1、输入方面 2、处理方面 3、结果输出方面 4、性能需求…

linux 驱动——杂项设备驱动

杂项设备驱动 在 linux 中&#xff0c;将无法归类的设备定义为杂项设备。 相对于字符设备来说&#xff0c;杂项设备的主设备号固定为 10&#xff0c;而字符设备不管是动态分配还是静态分配设备号&#xff0c;都会消耗一个主设备号&#xff0c;比较浪费主设备号。 杂项设备会自…

uml用例图是什么?有哪些要素?

UML用例图是什么&#xff1f; UML用例图&#xff08;Unified Modeling Language Use Case Diagram&#xff09;是一种用于描述系统功能和用户之间交互的图形化建模工具。它是UML的一部分&#xff0c;主要用于识别和表示系统中的各个用例&#xff08;用户需求或功能点&#…

鸿蒙开发之压缩/解压缩

本次学习遗留一个问题&#xff1a;压缩/解压缩的路径怎么获取&#xff1f;&#xff1f;希望知道的小伙伴能给说一下&#xff0c;私聊评论皆可。 一、API使用 代码相对来说比较简单 //需要导入的头文件 import zlib from ohos.zlib//压缩函数 function zipFile() {let rawfil…

高通平台开发系列讲解(USB篇)adb应用adbd分析

沉淀、分享、成长,让自己和他人都能有所收获!😄 在apps_proc/system/core/adb/adb_main.cpp文件中main()函数会调用adb_main()函数,然后调用uab_init函数 在uab_init()函数中,会创建一个线程,在线程中会调用init_functionfs()函数,利用ep0控制节点,创建ep1、ep2输…

在区块链中看CHAT的独特见解

问CHAT&#xff1a;谈谈对区块链以及区块链金融的理解 CHAT回复&#xff1a;区块链是一种去中心化的分布式数据库技术&#xff0c;这种技术通过加密算法&#xff0c;使数据在网络中传输和存储的过程变得更加安全可靠。区块链的出现引领了存储、交易等形式的革命&#xff0c;改变…

通过https协议访问Tomcat部署并使用Shiro认证的应用跳转登到录页时协议变为http的问题

问题描述&#xff1a; 在最近的一个项目中&#xff0c;有一个存在较久&#xff0c;并且只在内部城域网可访问的一个使用Shiro框架进行安全管理的Java应用&#xff0c;该应用部署在Tomcat服务器上。起初&#xff0c;应用程序可以通过HTTP协议访问&#xff0c;一切运行都没…

FreeCodeCamp--数千免费编程入门教程,非盈利性网站,质量高且支持中文

在浏览话题“Github上获得Star最多的项目”时&#xff0c;看到了FreeCodeCamp&#xff0c;顾名思义--免费编程营地&#xff0c;于是就做了些调研&#xff0c;了解了下这是个什么项目 这是一个致力于推动编程教育的非营利性组织&#xff0c;团队由来自世界各地的杰出的技术开发…

java中常用的加密算法总结

目前在工作中常用到加密的一些场景&#xff0c;比如密码加密&#xff0c;数据加密&#xff0c;接口参数加密等&#xff0c;故通过本文总结以下常见的加密算法。 1. 对称加密算法 对称加密算法使用相同的密钥进行加密和解密。在Java中&#xff0c;常见的对称加密算法包括&…

机器人也能干的更好:RPA技术的优势和应用场景

RPA是什么&#xff1f; 机器人流程自动化RPA&#xff08;Robotic Process Automation&#xff09;是一种自动化技术&#xff0c;它使用软件机器人来高效完成重复且有逻辑性的工作。近年来&#xff0c;随着人工智能和自动化技术的不断发展和普及&#xff0c;RPA已经成为企业提高…