强化学习11——DQN算法

DQN算法的全称为,Deep Q-Network,即在Q-learning算法的基础上引用深度神经网络来近似动作函数 Q ( s , a ) Q(s,a) Q(s,a) 。对于传统的Q-learning,当状态或动作数量特别大的时候,如处理一张图片,假设为 210 × 160 × 3 210×160×3 210×160×3,共有 25 6 ( 210 × 60 × 3 ) 256^{(210×60×3)} 256(210×60×3)种状态,难以存储,但可以使用参数化的函数 Q θ Q_{\theta} Qθ 来拟合这些数据,即DQN算法。同时DQN还引用了经验回放和目标网络,接下来将以此介绍。

CartPole 环境

image.png

在车杆环境中,通过移动小车,让小车上的杆保持垂直,如果杆的倾斜度数过大或者车子偏离初始位置的距离过大,或者坚持了一定的时间,则结束本轮训练。该智能体的状态是四维向量,每个状态是连续的,但其动作是离散的,动作的工作空间是2。

维度意义最小值最大值
0车的位置-2.42.4
1车的速度-InfInf
2杆的角度~ -41.8°~ 41.8°
3杆尖端的速度-InfInf
标号动作
0向左移动小车
1向右移动小车

深度网络

我们通过神经网络将输入向量 x x x映射到输出向量 y y y,通过下式表示:
y = f θ ( x ) y=f_{\theta}(x) y=fθ(x)
神经网络可以理解为是一个函数,输入输出都是向量,并且拥有可以学习的参数 θ \theta θ ,通过梯度下降等方法,使得神经网络能够逼近任意函数,当然可以用来近似动作价值函数:
y ⃗ = Q θ ( s ⃗ , a ⃗ ) \vec{y}=Q_{\theta}(\vec{s},\vec{a}) y =Qθ(s ,a )
在本环境种,由于状态的每一维度的值都是连续的,无法使用表格记录,因此可以使用一个神经网络表示函数Q。当动作是连续(无限)时,神经网络的输入是状态s和动作a,输出一个标量,表示在状态s下采取动作a能获得的价值。若动作是离散(有限)的,除了采取动作连续情况下的做法,还可以只将状态s输入到神经忘了,输出每一个动作的Q值。

假设使用神经网络拟合w,则每一个状态s下所有可能动作a的Q值为 Q w ( s , a ) Q_w(s,a) Qw(s,a),我们称为Q网络:

image.png

我们在Q-learning种使用下面的方式更新:
Q ( s , a ) ← Q ( s , a ) + α [ r + γ max ⁡ a ′ ∈ A Q ( s ′ , a ′ ) − Q ( s , a ) ] Q(s,a)\leftarrow Q(s,a)+\alpha\left[r+\gamma\max_{a'\in\mathcal{A}}Q(s',a')-Q(s,a)\right] Q(s,a)Q(s,a)+α[r+γaAmaxQ(s,a)Q(s,a)]
即让 Q ( s , a ) Q(s,a) Q(s,a) r + γ max ⁡ a ′ ∈ A Q ( s ′ , a ′ ) r+\gamma\max_{a'\in\mathcal{A}}Q(s',a') r+γmaxaAQ(s,a)靠近,那么Q网络的损失函数为均方误差的形式:
ω ∗ = arg ⁡ min ⁡ ω 1 2 N ∑ i = 1 N [ Q ω ( s i , a i ) − ( r i + γ max ⁡ a ′ Q ω ( s i ′ , a ′ ) ) ] 2 \omega^*=\arg\min_{\omega}\frac{1}{2N}\sum_{i=1}^{N}\left[Q_\omega\left(s_i,a_i\right)-\left(r_i+\gamma\max_{a'}Q_\omega\left(s_i',a'\right)\right)\right]^2 ω=argωmin2N1i=1N[Qω(si,ai)(ri+γamaxQω(si,a))]2

经验回访

将Q-learning过程中,每次从环境中采样得到的四元组数据(状态、动作、奖励、下一状态)存储到回放缓冲区中,之后在训练Q网络时,再从回访缓冲区中,随机采样若干数据进行训练。

image.png

在一般的监督学习中,都是假定训练数据是独立同分布的,而在强化学习中,连续的采样、交互所得到的数据有很强的相关性,这一时刻的状态和上一时刻的状态有关,不满足独立假设。通过在回访缓冲区采样,可以打破样本之间的相关性。另外每一个样本可以使用多次,也适合深度学习。

目标网络

构建两个网络,一个是目标网络,一个是当前网络,二者结构相同,都用于近似Q值。在实践中每隔若干步才把每步更新的当前网络参数复制给目标网络,这样做的好处是保证训练的稳定,当训练的结果不好时,可以不同步当前网络的值,避免Q值的估计发散。

image.png

在计算期望时,使用目标网络来计算:
Q 期望 = [ r t + γ max ⁡ a ′ Q ω ˉ ( s ′ , a ′ ) ] Q_\text{期望}=[r_t+\gamma\max_{a^{\prime}}Q_{\bar{\omega}}(s^{\prime},a^{\prime})] Q期望=[rt+γamaxQωˉ(s,a)]
具体流程如下所示:

  • 使用随机的网络参数 ω \omega ω初始化初始化当前网络 Q ω ( s , a ) Q_{\omega}(s,a) Qω(s,a)
  • 复制相同的参数初始化目标网络 ω ˉ ← ω \bar{\omega}\gets \omega ωˉω
  • 初始化经验回访池R
  • for 序列 e = 1 → E e=1\to E e=1E do
    • 获取环境初始状态 s 1 s_1 s1
    • for 时间步 t = 1 → T 时间步t=1\to T 时间步t=1T do
      • 根据当前网络 Q ω ( s , a ) Q_{\omega}(s,a) Qω(s,a) ϵ − g r e e d y \epsilon -greedy ϵgreedy策略选择动作 a t a_t at
      • 执行动作 a t a_t at,获得回报 r t r_t rt,环境状态变为 s t + 1 s_{t+1} st+1
      • ( s t , a t , r t , s t + 1 ) (s_t,a_t,r_t,s_{t+1}) (st,at,rt,st+1)存储进回池R
      • 若R中数据足够,则从R中采样N个数据 { ( s i , a i , r i , s i + 1 ) } i = 1 , … , N \{(s_i,a_i,r_i,s_{i+1})\}_{i=1,\ldots,N} {(si,ai,ri,si+1)}i=1,,N
      • 对每个数据,用目标网络计算 y = r i + γ max ⁡ a Q ω ˉ ( s i + 1 , a ) y=r_i+\gamma\max_aQ_{\bar{\omega}}(s_{i+1},a) y=ri+γmaxaQωˉ(si+1,a)
      • 最小化目标损失 L = 1 N ∑ i ( y i − Q ω ( s i , a i ) ) 2 L=\frac{1}{N}\sum_{i}(y_{i}-Q_{\omega}(s_{i},a_{i}))^{2} L=N1i(yiQω(si,ai))2,以更新当前网络 Q ω Q_{\omega} Qω
      • 更新目标网络
    • end for
  • end for
import random
from typing import Any
import gymnasium as gym
import numpy as np
import collections
from tqdm import tqdm
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import rl_utils# 首先定义经验回收池的类,包括加入数据、采样数据
class ReplayBuffer:def __init__(self, capacity):# 创建一个队列,先进先出self.buffer=collections.deque(maxlen=capacity)def add(self,state,action,reward,next_state,done):# 加入数据self.buffer.append((state,action,reward,next_state,done))def sample(self,batch_size):# 随机采样数据mini_batch=random.sample(self.buffer,batch_size)# zip(*)取mini_batch中的每个元素(即取列),并返回一个元组state,action,reward,next_state,done=zip(*mini_batch)return np.array(state), action, reward, np.array(next_state), donedef size(self):return len(self.buffer)# 定义一个只有一层隐藏层的Q网络
class Qnet(torch.nn.Module):def __init__(self,state_dim,hidden_dim,action_dim):super(Qnet,self).__init__()# 定义一个全连接层,输入为state_dim维向量,输出为hidden_dim维向量self.fc1=torch.nn.Linear(state_dim,hidden_dim)# 定义一个全连接层,输入为hidden_dim维向量,输出为action_dim维向量self.fc2=torch.nn.Linear(hidden_dim,action_dim)def forward(self,state):x = F.relu(self.fc1(state))return self.fc2(x)class DQN:def __init__(self,state_dim,hidden_dim,action_dim,learning_rate,gamma,epsilon,target_update,device):self.action_dim=action_dimself.q_net=Qnet(state_dim,hidden_dim,action_dim).to(device)# 目标网络self.target_q_net=Qnet(state_dim,hidden_dim,action_dim).to(device)# 使用Adam优化器self.optimizer=torch.optim.Adam(self.q_net.parameters(),lr=learning_rate)# 折扣因子self.gamma=gamma# 贪婪策略self.epsilon=epsilon# 目标网络更新频率self.target_update=target_update# 计数器self.count=0self.device=devicedef take_action(self,state):# 判断是否需要贪婪策略if np.random.random()<self.epsilon:action=np.random.randint(self.action_dim)else:state=torch.tensor([state],dtype=torch.float).to(self.device)action=self.q_net(state).argmax().item()return actiondef update(self,transition_dict):states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)# Q值q_values=self.q_net(states).gather(1,actions)# 下一个状态的最大Q值max_next_q_values=self.target_q_net(next_states).max(1)[0].view(-1, 1)q_targets=rewards+self.gamma*max_next_q_values*(1-dones)# 反向传播更新参数dqn_loss=torch.mean(F.mse_loss(q_values, q_targets)) # 均方误差损失函数self.optimizer.zero_grad()dqn_loss.backward()self.optimizer.step()if self.count % self.target_update == 0:self.target_q_net.load_state_dict(self.q_net.state_dict())  # 更新目标网络self.count += 1lr = 2e-3
num_episodes = 500
hidden_dim = 128
gamma = 0.98
epsilon = 0.01
target_update = 10
buffer_size = 10000
minimal_size = 500
batch_size = 64
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")env_name = 'CartPole-v0'
env = gym.make(env_name)
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
replay_buffer = ReplayBuffer(buffer_size)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon,target_update, device)return_list = []
for i in range(10):with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:for i_episode in range(int(num_episodes / 10)):episode_return = 0state = env.reset()[0]aa=state[0]print(state)done = Falsewhile not done:action = agent.take_action(state)next_state, reward, done,info, _ = env.step(action)replay_buffer.add(state, action, reward, next_state, done)state = next_stateepisode_return += reward# 当buffer数据的数量超过一定值后,才进行Q网络训练if replay_buffer.size() > minimal_size:b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)transition_dict = {'states': b_s,'actions': b_a,'next_states': b_ns,'rewards': b_r,'dones': b_d}agent.update(transition_dict)return_list.append(episode_return)if (i_episode + 1) % 10 == 0:pbar.set_postfix({'episode':'%d' % (num_episodes / 10 * i + i_episode + 1),'return':'%.3f' % np.mean(return_list[-10:])})pbar.update(1)

image.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/625315.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

加密经济学:Web3时代的新经济模型

随着Web3技术的迅猛发展&#xff0c;我们正迈入一个全新的数字经济时代。加密经济学作为这一时代的核心&#xff0c;不仅在数字货币领域崭露头角&#xff0c;更是重新定义了传统经济模型&#xff0c;为我们开启了一个充满创新和机遇的新纪元。 1. 去中心化的经济体系 Web3时代…

7.3 CONSTANT MEMORY AND CACHING

掩模数组M在卷积中的使用方式有三个有趣的属性。首先&#xff0c;M阵列的大小通常很小。大多数卷积掩模在每个维度上都少于10个元素。即使在3D卷积的情况下&#xff0c;掩码通常也只包含少于1000个元素。其次&#xff0c;在内核执行过程中&#xff0c;M的内容不会改变。第三&am…

启动Vue项目,报错:‘vue-cli-service‘ 不是内部或外部命令,也不是可运行的程序

前言&#xff1a; 最近在打开一个Vue项目的时候&#xff0c;打开之后输入命令行&#xff1a;npm run serve之后发现&#xff0c;报错&#xff1a;vue-cli-service 不是内部或外部命令&#xff0c;也不是可运行的程序&#xff0c;以下是解决方案&#xff1a; 报错图片截图&…

HNU-算法设计与分析-实验3

算法设计与分析实验3 计科210X 甘晴void 202108010XXX 目录 文章目录 算法设计与分析<br>实验31 用Dijkstra贪心算法求解单源最短路径问题问题重述证明模板&#xff1a;Dijkstra算法代码验证算法分析 1【扩展】 使用堆优化的Dijkstra原因代码算法分析验证 2 回溯法求解…

运筹说 第98期|无约束极值问题

上一期我们一起学习了关于非线性规划问题的一维搜索方法的相关内容&#xff0c;本期小编将带大家学习非线性规划的无约束极值问题。 下面&#xff0c;让我们从实际问题出发&#xff0c;学习无约束极值问题吧&#xff01; 一、问题描述及求解原理 1 无约束极值问题的定义 无约…

ArkUI-X跨平台已至,何需其它!

运行环境 DevEco Studio&#xff1a;4.0Release OpenHarmony SDK API10 开发板&#xff1a;润和DAYU200 自从写了一篇ArkUI-X跨平台的文章之后&#xff0c;好多人都说对这个项目十分关注。 那么今天我们就来完整的梳理一下这个项目。 1、ArkUI-X 我们之前可能更多接触的…

登录验证

目录 会话技术 Cookie Session JWT JWT生成 JWT校验 会话技术 会话 打开浏览器&#xff0c;访问web服务器的资源&#xff0c;会话建立&#xff0c;直到有一方断开连接&#xff0c;会话结束。在一次会话中可以包含多次请求与响应 会话跟踪 一种维护浏览器的方法 服务器需要…

性能测试jmeter

选的这些怎么添加 在一个列表里面 方法调用${__time(YMD)} 两个下划线&#xff0c;后跟函数名&#xff0c;小括号内是输入参数&#xff0c;整个用大括号包裹。 注意POST一定要在消息体数据里面写,不能再参数里面 否则报错:loginOut,没cookie等

VueCli-自定义创建项目

参考 1.安装脚手架 (已安装可以跳过) npm i vue/cli -g2.创建项目 vue create 项目名 // 如&#xff1a; vue create dn-demo键盘上下键 - 选择自定义选型 Vue CLI v5.0.8 ? Please pick a preset:Default ([Vue 3] babel, eslint)Default ([Vue 2] babel, eslint) > M…

小迪安全第二天

文章目录 一、Web应用&#xff0c;架构搭建二、web应用环境架构类三、web应用安全漏洞分类总结 一、Web应用&#xff0c;架构搭建 #网站搭建前置知识 域名&#xff0c;子域名&#xff0c;dns,http/https,证书等 二、web应用环境架构类 理解不同web应用组成角色功能架构 开发…

显示CPU架构的有关信息 lscpu

文章目录 显示CPU架构的有关信息 lscpu默认实例更多信息 显示CPU架构的有关信息 lscpu Linux的CPU设备查看器。lscpu命令用来显示cpu的相关信息。 lscpu从sysfs和/proc/cpuinfo收集cpu体系结构信息&#xff0c;命令的输出比较易读 。 命令输出的信息包含cpu数量&#xff0c;线…

tensorflow报错: DNN library is no found

错误描述 如上图在执行程序的时候&#xff0c;会出现 DNN library is no found 的报错 解决办法 这个错误基本上说明你安装的 cudnn有问题&#xff0c;或者没有安装这个工具。 首先检测一下你是否安装了 cudnn 进入CUDA_HOME下&#xff0c;也就是进入你的cuda的驱动的安装目…

个人数据备份方案分享(源自一次悲惨经历)

文章目录 1 起源2 备份架构2.1 生活照片2.2 生活录音2.3 微信文件2.4 工作文件2.5 笔记、影视音乐、书籍 3 使用工具介绍3.1 小米云服务3.2 中国移动云盘3.3 小米移动硬盘&#xff08;1T&#xff09;3.4 FreeFileSync 4 总结 1 起源 本文的灵感源于我个人的一次不幸遭遇&#…

领域驱动设计——DDD领域驱动设计进阶

摘要 进阶篇主要讲解领域事件、DDD 分层架构、几种常见的微服务架构模型以及中台设计思想等内容。如何通过领域事件实现微服务解耦&#xff1f;、怎样进行微服务分层设计&#xff1f;、如何实现层与层之间的服务协作&#xff1f;、通过几种微服务架构模型的对比分析&#xff0…

记一个有关 Vuetify 组件遇到的一些问题

Vuetify 官网地址 所有Vuetify 组件 — Vuetify 1、Combobox使用对象数组 Combobox 组合框 — Vuetify items数据使用对象数组时&#xff0c;默认选中的是整个对象&#xff0c;要对数据进行处理 <v-comboboxv-model"defaultInfo.variableKey":rules"rules…

基于springboot体育场馆运营管理系统源码

基于springboot体育场馆运营管理系统源码330 -- MySQL dump 10.13 Distrib 5.7.31, for Linux (x86_64) -- -- Host: localhost Database: springboot3cprm -- ------------------------------------------------------ -- Server version 5.7.31/*!40101 SET OLD_CHARACT…

网络安全全栈培训笔记(53-WEB攻防-通用漏洞CRLF注入URL重定向资源处理拒绝服务)

第53天 WEB攻防-通用漏洞&CRLF注入&URL重定向&资源处理拒绝服务 知识点&#xff1a; 1、CRLF注入-原理&检测&利用 2、URL重定向-原理&检测&利用 3、Web拒绝服务-原理&检测&利用 #下节预告&#xff1a; 1、JSONP&CORS跨域 2、域名安全…

嵌入式软件工程师面试题——2025校招社招通用(十八)

说明&#xff1a; 面试群&#xff0c;群号&#xff1a; 228447240面试题来源于网络书籍&#xff0c;公司题目以及博主原创或修改&#xff08;题目大部分来源于各种公司&#xff09;&#xff1b;文中很多题目&#xff0c;或许大家直接编译器写完&#xff0c;1分钟就出结果了。但…

共识算法介绍

文章目录 共识算法Paxos 算法三种角色一致性提交算法prepare 阶段accept 阶段commit 阶段 CAP 定理BASE 理论Zookeeper 算法实现三类角色三个数据三种模式四种状态消息广播算法Leader选举算法 共识算法 Paxos 算法 Paxos 算法是莱斯利兰伯特(Leslie Lamport)1990 年提出的一种…

基于Java (spring-boot)的社团管理系统

一、项目介绍 系统管理员的功能概述&#xff1a; ①用户管理 a.注册用户账户 当一个新用户注册时&#xff0c;用户填写基本信息并上传。用户基本信息包括账号、 姓名、密码、手机、地址等信息。 b.用户信息管理 管理员可以查看系统所有用户的基本信息&#xff0c;并修改和…