【强化学习】值函数算法DQNs详解【Vanilla DQN Double DQN Dueling DQN】

DQNs【Vanilla DQN & Double DQN & Dueling DQN】

文章目录

  • DQNs【Vanilla DQN & Double DQN & Dueling DQN】
    • 1. DQN及其变种介绍
      • 1.1 Vanilla DQN
      • 1.2 Double DQN
      • 1.3 Dueling DQN
    • 2. Gym环境介绍
      • 2.1 Obseravtion Space
      • 2.2 Reward Function
      • 2.3 Action Space
    • 3. DQNs Code
      • 3.1 Vanilla DQN效果
      • 3.2 Double DQN效果
      • 3.3 Dueling DQN效果
    • Reference

​ 在 Reinforcement Learning with Code 【Code 1. Tabular Q-learning】中讲解的 Q-learning 算法中,我们以矩阵的方式建立了一张存储每个状态下所有动作 Q Q Q值的表格。表格中的每一个动作价值 Q ( s , a ) Q(s,a) Q(s,a)表示在状态 s s s下选择动作然后继续遵循某一策略预期能够得到的期望回报。然而,这种用表格存储动作价值的做法只在环境的状态和动作都是离散的,并且空间都比较小的情况下适用。当状态或者动作数量非常大的时候,这种做法就不适用了。例如,当状态是一张 RGB 图像时,假设图像大小是 210 × 160 × 3 210\times160\times3 210×160×3,此时一共有 25 6 ( 210 × 160 × 3 ) 256^{(210\times160\times3)} 256(210×160×3)种状态,在计算机中存储这个数量级的值表格是不现实的。更甚者,当状态或者动作连续的时候,就有无限个状态动作对,我们更加无法使用这种表格形式来记录各个状态动作对的 Q Q Q值。

​ 对于这种情况,我们需要用函数拟合的方法来估计值 Q Q Q,即将这个复杂的 Q Q Q值表格视作数据,使用一个参数化的函数 Q θ Q_\theta Qθ来拟合这些数据。很显然,这种函数拟合的方法存在一定的精度损失,因此被称为近似方法。我们今天要介绍的 DQN 算法便可以用来解决连续状态下离散动作的问题。

1. DQN及其变种介绍

1.1 Vanilla DQN

Vanilla DQN便是最基本的DQN算法,在Q-learning中需要优化的目标函数为
min ⁡ θ J ( θ ) = E [ ( R + γ max ⁡ a Q ( S ′ , a ; θ ) − Q ( S , A ; θ ) ) ] \min_\theta J(\theta) = \mathbb{E} \Big[\Big( R + \gamma \max_a Q(S^\prime,a;\theta)- Q(S,A;\theta) \Big) \Big] θminJ(θ)=E[(R+γamaxQ(S,a;θ)Q(S,A;θ))]
在Q-learning中的TD target是 Y t Q = R t + 1 + γ max ⁡ a Q ( S t + 1 , a ; θ t ) Y_t^Q = R_{t+1} + \gamma \max_a Q(S_{t+1},a;\theta_t) YtQ=Rt+1+γmaxaQ(St+1,a;θt),若直接更新上述网络需要考虑很复杂的时许问题,在DQN使用了两个网络来简化这一过程。

Vallina DQN使用两个技巧:

  • Experience Replay,维护了一个经验池,将智能体与环境交互的experience ( s , a , r , s ′ , done ) (s,a,r,s^\prime,\text{done}) (s,a,r,s,done)存储进经验池中,然后再从维护的经验池中取出experience来进行训练,这样做有两个好处,第一,因为我们所解决的问题是被建模成Markov Decision Process(MDP)。在 MDP 中交互采样得到的数据本身不满足独立假设,因为这一时刻的状态和上一时刻的状态有关。非独立同分布的数据对训练神经网络有很大的影响,会使神经网络拟合到最近训练的数据上。采用经验回放可以打破样本之间的相关性,让其满足独立假设。第二,提高样本效率。每一个样本可以被使用多次,十分适合深度神经网络的梯度学习。

  • Two Networks, Vanilla DQN为了解决训练中的更新时序问题,在未引入两套网络之前,是希望网络的参数 θ \theta θ能够跟踪 Y t Q = R t + 1 + γ max ⁡ a Q ( S t + 1 , a ; θ t ) Y^Q_t=R_{t+1}+\gamma\max_{a}Q(S_{t+1},a;\theta_t) YtQ=Rt+1+γmaxaQ(St+1,a;θt),其中 Y t Q Y_t^Q YtQ被称为TD target。引入了两套网络后,分记为训练网络和目标网路,网络参数分别用 θ \theta θ θ − \theta^- θ来表示,Vanilla DQN的最终目的是让目标网络 θ − \theta^- θ的输出能够逼近
    Y t D Q N = R t + 1 + γ max ⁡ a Q ( S t + 1 , a ; θ − ) \textcolor{red}{Y^{DQN}_t = R_{t+1} + \gamma \max_a Q(S_{t+1},a;\theta^-)} YtDQN=Rt+1+γamaxQ(St+1,a;θ)
    所以这样之后,待优化的目标函数变成了
    min ⁡ θ J ( θ ) = E [ ( R + γ max ⁡ a Q ( S ′ , a ; θ − ) − Q ( S , A ; θ ) ) ] \min_\theta J(\theta) = \mathbb{E} \Big[\Big( R + \gamma \max_a Q(S^\prime,a;\theta^-)- Q(S,A;\theta) \Big) \Big] θminJ(θ)=E[(R+γamaxQ(S,a;θ)Q(S,A;θ))]
    我们使用这个目标函数来作为Vallina DQN需要优化的损失函数来更新训练网络参数 θ \theta θ,然后每隔 τ \tau τ代,将训练网络参数 θ \theta θ拷贝给目标网络 θ − \theta^- θ

1.2 Double DQN

Double DQN的提出是用于解决Vallina DQN在实际应用中对Q值估计过高的问题(overestimation)。Vallina DQN的优化TD target是
Y t D Q N = R t + 1 + γ max ⁡ a Q ( S t + 1 , a ; θ − ) Y^{DQN}_t = R_{t+1} + \gamma \max_a Q(S_{t+1},a;\theta^-) YtDQN=Rt+1+γamaxQ(St+1,a;θ)
其中 max ⁡ a Q ( S t + 1 , a ; θ − ) \max_a Q(S_{t+1},a;\theta^-) maxaQ(St+1,a;θ)是由目标网络的参数 θ − \theta^- θ计算而来的,那么我们可以获得最优的动作 a ∗ = arg ⁡ max ⁡ a Q ( S t + 1 , a ; θ − ) a^*=\arg\max_a Q(S_{t+1},a;\theta^-) a=argmaxaQ(St+1,a;θ),将最优的动作带回,则我们可以将上式进行改写成
Y t D Q N = R t + 1 + γ Q ( S t + 1 , arg ⁡ max ⁡ a Q ( S t + 1 , a ; θ − ) ; θ − ) \textcolor{red}{Y^{DQN}_t = R_{t+1} + \gamma Q(S_{t+1}, \arg\max_a Q(S_{t+1},a;\theta^-);\theta^-)} YtDQN=Rt+1+γQ(St+1,argamaxQ(St+1,a;θ);θ)
换句话说 max ⁡ \max max操作其实可以分为两部分,首先选取状态 S t + 1 S_{t+1} St+1下的最优的动作 a ∗ = arg ⁡ max ⁡ a Q ( S t + 1 , a ; θ − ) a^*=\arg\max_a Q(S_{t+1},a;\theta^-) a=argmaxaQ(St+1,a;θ)接着计算该动作对于的价值 Q ( S t + 1 , a ∗ ; θ − ) Q(S_{t+1},a^*;\theta^-) Q(St+1,a;θ)。但是当这两个部分都采用同一套Q网络来进行训练时,每次计算得到的都是神经网络中当前估计的所有动作价值中最大值。考虑到通过神经网络估计的Q值本身在某些时刻也会产生正向或负向的误差,在DQN的更新方式下神经网络会将正向误差进行累积。

为了解决这一问题,Double DQN提出了利用两个独立的网络估算 max ⁡ a Q ∗ ( S t + 1 , A t + 1 ) \max_a Q^*(S_{t+1},A_{t+1}) maxaQ(St+1,At+1)。具体的做法是将原有的 max ⁡ a Q ( S t + 1 , a ; θ − ) \max_a Q(S_{t+1},a;\theta^-) maxaQ(St+1,a;θ)更改为 Q ( S t + 1 , arg ⁡ max ⁡ a Q ( S t + 1 , a ; θ ) , θ − ) Q(S_{t+1},\arg\max_a Q(S_{t+1},a;\theta),\theta^-) Q(St+1,argmaxaQ(St+1,a;θ),θ)。即利用一套训练网络 θ \theta θ来选取价值最大的动作,用目标神经网络 θ − \theta^- θ来计算该动作的价值。这样,即使其中一套神经网络的某个动作存在比较严重的过高估计问题,由于另一套神经网络的存在,这个动作最终使得Q值不会被过高估计。则我们可以将Double DQN的TD target写作
Y t D D Q N = R t + 1 + γ Q ( S t + 1 , arg ⁡ max ⁡ a Q ( S t + 1 , a ; θ ) ; θ − ) \textcolor{red}{Y^{DDQN}_t = R_{t+1} + \gamma Q(S_{t+1}, \arg\max_a Q(S_{t+1},a;\theta);\theta^-)} YtDDQN=Rt+1+γQ(St+1,argamaxQ(St+1,a;θ);θ)
Pesudocode

Image

1.3 Dueling DQN

Dueling DQN是Vanilla DQN的一种变种算法,它在传统Vanilla DQN的基础上只进行了微小的改动,却大幅提升了DQN的表现能力。具体来说就是Dueling DQN并未直接来估计Q值函数,而是通过估计V状态价值函数和A优势函数来间接获得Q值函数。Dueling DQN非常创新地引入了优势函数的概念(Advantage function),在介绍优势函数之前,我们先回顾一下动作价值函数Q和状态价值函数V的定义:
Q π ( s , a ) = E [ R t ∣ s t = s , a t = a , π ] V π ( s , a ) = E a ∼ π [ Q π ( s , a ) ] Q^\pi(s,a) = \mathbb{E}[R_t|s_t=s,a_t=a,\pi] \\ V^\pi(s,a) = \mathbb{E}_{a\sim \pi}[Q^\pi(s,a)] Qπ(s,a)=E[Rtst=s,at=a,π]Vπ(s,a)=Eaπ[Qπ(s,a)]
优势函数(Advantage function)被定义为
A π ( s , a ) = Q π ( s , a ) − V π ( s ) A^\pi(s,a) = Q^\pi(s,a) - V^\pi(s) Aπ(s,a)=Qπ(s,a)Vπ(s)
我们对优势函数取服从 a ∼ π a\sim\pi aπ的期望,则有
A π ( s , a ) = E a ∼ π [ Q π ( s , a ) ] − E a ∼ π [ Q π ( s , a ) ] = 0 A^\pi(s,a) = \mathbb{E}_{a\sim \pi}[Q^\pi(s,a)] - \mathbb{E}_{a\sim \pi}[Q^\pi(s,a)] = 0 Aπ(s,a)=Eaπ[Qπ(s,a)]Eaπ[Qπ(s,a)]=0
直观上来理解,状态价值函数V衡量的是处于状态 s s s的好坏程度,然而,Q函数衡量的是在此状态 s s s下选择特定操作的价值,优势函数A是从Q函数中减去状态值V,来获得每个动作重要性的相对程度。对于确定性的策略 π \pi π,最优的动作可以表示为
a ∗ = arg ⁡ max ⁡ a ′ ∈ A Q ( s , a ′ ) a^* = \arg\max_{a^\prime\in\mathcal{A}}Q(s,a^\prime) a=argaAmaxQ(s,a)
又因为策略是确定性的,那么则有
Q ( s , a ∗ ) = V ( s ) Q(s,a^*) = V(s) Q(s,a)=V(s)
那么对于最优动作 a ∗ a^* a的优势函数则有
A ( s , a ∗ ) = 0 A(s,a^*)=0 A(s,a)=0
接下来介绍下在网络结构上的创新,在Dueling DQN中为了实现上述的势函数网络结构也发生了变化,

Image

在图中,位于上方的网络结构是DQN的结构,位于下方的网络结构是Dueling DQN的结构。Dueling网络有两个流来分别估计状态值V(scalar)和每个动作的优势A(vector);绿色输出模块实现等式 A π ( s , a ) = Q π ( s , a ) − V π ( s ) A^\pi(s,a) = Q^\pi(s,a) - V^\pi(s) Aπ(s,a)=Qπ(s,a)Vπ(s)以将它们结合起来。结合上述的网络结构,我们将网络结构的参数加上,来重写优势函数A的表达式
A ( s , a ; θ , α ) = Q ( s , a ; θ , α , β ) − V ( s ; θ , β ) Q ( s , a ; θ , α , β ) = A ( s , a ; θ , α ) + V ( s ; θ , β ) \begin{aligned} A(s,a;\theta,\alpha) & = Q(s,a;\theta,\alpha,\beta) - V(s;\theta,\beta) \\ Q(s,a;\theta,\alpha,\beta) & = A(s,a;\theta,\alpha) + V(s;\theta,\beta) \end{aligned} A(s,a;θ,α)Q(s,a;θ,α,β)=Q(s,a;θ,α,β)V(s;θ,β)=A(s,a;θ,α)+V(s;θ,β)
其中, θ \theta θ代表了前面共享网络结构的参数, α , β \alpha,\beta α,β分别代表两个流各自的网络结构参数。但是上述式子中存在着对于V值,A值的唯一性不确定的问题,为了解决这一问题,我们对于同样的Q值加上任意大小的常数C,再将所有A值减去C,这样得到的Q值仍然不变,这就导致了训练的不稳定性,为了解决这一问题,Dueling DQN强制将最优动作的优质函数输出置为0,即
Q ( s , a ; θ , α , β ) = V ( s ; θ , β ) + ( A ( s , a ; θ , α ) − max ⁡ a ′ ∈ A A ( s , a ′ ; θ , α ) ) \textcolor{red}{Q(s,a;\theta,\alpha,\beta) = V(s;\theta,\beta) + \Big(A(s,a;\theta,\alpha) -\max_{a^\prime\in\mathcal{A}}A(s,a^\prime;\theta,\alpha) \Big)} Q(s,a;θ,α,β)=V(s;θ,β)+(A(s,a;θ,α)aAmaxA(s,a;θ,α))
根据之前的分析我们知道,对于最优的动作 a ∗ a^* a,优势函数的值为0,即 A ( s , a ∗ ) = 0 A(s,a^*)=0 A(s,a)=0,对于最优的动作 a ∗ a^* a
a ∗ = arg ⁡ max ⁡ a ′ ∈ A Q ( s , a ′ ; θ , α , β ) = arg ⁡ max ⁡ a ′ ∈ A [ A ( s , a ; θ , α ) + V ( s ; θ , β ) ] = arg ⁡ max ⁡ a ′ ∈ A A ( s , a ; θ , α ) \begin{aligned} a^* & = \arg\max_{a^\prime\in\mathcal{A}}Q(s,a^\prime;\theta,\alpha,\beta) \\ & = \arg\max_{a^\prime\in\mathcal{A}}[A(s,a;\theta,\alpha) + V(s;\theta,\beta)] \\ & = \arg\max_{a^\prime\in\mathcal{A}} A(s,a;\theta,\alpha) \end{aligned} a=argaAmaxQ(s,a;θ,α,β)=argaAmax[A(s,a;θ,α)+V(s;θ,β)]=argaAmaxA(s,a;θ,α)
再将最优动作 a ∗ a^* a代回上式中,则有
Q ( s , a ∗ ; θ , α , β ) = V ( s ; θ , β ) Q(s,a^*;\theta,\alpha,\beta) = V(s;\theta,\beta) Q(s,a;θ,α,β)=V(s;θ,β)
因此这样就成功实现了解耦,一个流 V ( s ; θ , β ) V(s;\theta,\beta) V(s;θ,β)提供状态价值的估计,另一个流 A ( s , a ; θ , α ) A(s,a;\theta,\alpha) A(s,a;θ,α)实现优势值的估计。

在实际中,我们常用平均(average)操作来替换最大值(max)操作,这样更具稳定性,如下
Q ( s , a ; θ , α , β ) = V ( s ; θ , β ) + ( A ( s , a ; θ , α ) − 1 ∣ A ∣ A ( s , a ′ ; θ , α ) ) \textcolor{red}{Q(s,a;\theta,\alpha,\beta) = V(s;\theta,\beta) + \Big(A(s,a;\theta,\alpha) -\frac{1}{|\mathcal{A}|}A(s,a^\prime;\theta,\alpha) \Big)} Q(s,a;θ,α,β)=V(s;θ,β)+(A(s,a;θ,α)A1A(s,a;θ,α))
将上述Q值函数来计算Vanilla DQN中的TD target就能得到Dueling DQN的TD target,再按照Vanilla DQN剩下的方法来进行更新,我们就得到了Dueling DQN的完整算法。
Y t DuelingDQN = R t + 1 + γ max ⁡ a Q ( S t + 1 , a ; θ , α , β ) \textcolor{red}{Y^{\text{DuelingDQN}}_t = R_{t+1} + \gamma \max_a Q(S_{t+1},a;\theta,\alpha,\beta)} YtDuelingDQN=Rt+1+γamaxQ(St+1,a;θ,α,β)
其中 Q ( S t + 1 , a ; θ , α , β ) Q(S_{t+1},a;\theta,\alpha,\beta) Q(St+1,a;θ,α,β)可以替换成average操作产生的结果,那么完整的TD target就成了
Y t DuelingDQN = R t + 1 + γ V ( s ; θ , β ) + γ max ⁡ a ( A ( s , a ; θ , α ) − 1 ∣ A ∣ A ( s , a ′ ; θ , α ) ) Y^{\text{DuelingDQN}}_t = R_{t+1} + \gamma V(s;\theta,\beta) + \gamma\max_a \Big(A(s,a;\theta,\alpha) - \frac{1}{|\mathcal{A}|}A(s,a^\prime;\theta,\alpha) \Big) YtDuelingDQN=Rt+1+γV(s;θ,β)+γamax(A(s,a;θ,α)A1A(s,a;θ,α))
其中 Q ( S t + 1 , a ; θ , α , β ) Q(S_{t+1},a;\theta,\alpha,\beta) Q(St+1,a;θ,α,β)也可以替换成max操作产生的记过,那么完整的TD target就成了
Y t DuelingDQN = R t + 1 + γ V ( s ; θ , β ) + γ max ⁡ a ( A ( s , a ; θ , α ) − max ⁡ a ′ ∈ A A ( s , a ′ ; θ , α ) ) Y^{\text{DuelingDQN}}_t = R_{t+1} + \gamma V(s;\theta,\beta) + \gamma\max_a \Big(A(s,a;\theta,\alpha) - \max_{a^\prime\in\mathcal{A}}A(s,a^\prime;\theta,\alpha) \Big) YtDuelingDQN=Rt+1+γV(s;θ,β)+γamax(A(s,a;θ,α)aAmaxA(s,a;θ,α))

2. Gym环境介绍

为了更好得观察到DQN存在的Q值估计overestimation的问题,我们使用gym中的Pendulum-v1的环境(详见官网),注意这里gym的版本为v26

倒立摆的数据标注如下所示

Image

2.1 Obseravtion Space

Image

2.2 Reward Function

Image

该环境的奖励函数为
− ( θ 2 + 0.1 θ ˙ 2 + 0.001 a 2 ) -(\theta^2 + 0.1\dot{\theta}^2 + 0.001a^2) (θ2+0.1θ˙2+0.001a2)
倒立摆向上保持不动的时候奖励为0,其余位置倒立摆的奖励为负数,所以该环境下动作值Q不会超过0。

2.3 Action Space

Image

动作空间是一个连续值,是作用于倒立摆末端的扭矩。但是DQNs只能用于处理离散动作空间环境,因此我们无法直接使用DQNs来处理倒立摆环境,但是倒立摆环境会比较方便地验证Vanilla DQN存在对Q值过估计的问题。为了使得DQNs能够处理这种连续动作空间的环境,我们可以使用将连续动作空间离散化的方法,来达到伪连续的效果。

import gym
env_name = 'Pendulum-v1'
env = gym.make(id=env_name)
print("The minimum of action space is ", env.action_space.low[0])
print("The maximum of action space is ", env.action_space.high[0])def dis2con(discrete_action, action_dim, env):action_upbound = env.action_space.high[0]action_lowbound = env.action_space.low[0]return action_lowbound + discrete_action * (action_upbound - action_lowbound) / (action_dim - 1)# 示例将[-2.0,2.0]的连续动作空间转换成30维度离散动作空间
action_dim = 30 
discrete_action =  list(range(action_dim))
continue_action = []
for da in discrete_action:continue_action.append(dis2con(da, action_dim, env))
print(discrete_action)
print(continue_action)
env.close()

结果如下

The minimum of action space is  -2.0
The maximum of action space is  2.0
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
[-2.0, -1.8620689655172413, -1.7241379310344827, -1.5862068965517242, -1.4482758620689655, -1.3103448275862069, -1.1724137931034484, -1.0344827586206895, -0.896551724137931, -0.7586206896551724, -0.6206896551724137, -0.48275862068965525, -0.3448275862068966, -0.2068965517241379, -0.06896551724137923, 0.06896551724137945, 0.2068965517241379, 0.34482758620689635, 0.48275862068965525, 0.6206896551724137, 0.7586206896551726, 0.896551724137931, 1.0344827586206895, 1.1724137931034484, 1.3103448275862069, 1.4482758620689653, 1.5862068965517242, 1.7241379310344827, 1.8620689655172415, 2.0]

3. DQNs Code

在本节,我们正式实现Vanilla DQN,Double DQN以及Dueling DQN并且观察对比其效果

import random
import collections
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import gym
import matplotlib.pyplot as plt# Experience Replay
class ReplayBuffer():def __init__(self,  capacity):self.buffer = collections.deque(maxlen=capacity)def size(self):return len(self.buffer)def add(self, state, action, reward, next_state, done):self.buffer.append((state, action, reward, next_state, done))def sample(self, batch_size):transition = random.sample(self.buffer, batch_size)states, actions, rewards, next_states, dones = zip(*transition)return np.array(states), np.array(actions), np.array(rewards),  np.array(next_states), np.array(dones)# Value Approximation Net
class QNet(torch.nn.Module):# DDQN & VanillaDQN网络框架def __init__(self, state_dim, hidden_dim, action_dim):super(QNet,self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self,observation):x = F.relu(self.fc1(observation))x = self.fc2(x)return x# Advantage Net
class AVNet(torch.nn.Module):# DuelingDQN网络框架def __init__(self, state_dim, hidden_dim, action_dim):super(AVNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc_V = torch.nn.Linear(hidden_dim, 1)self.fc_A = torch.nn.Linear(hidden_dim, action_dim)def forward(self, observation):x = F.relu(self.fc1(observation))V = self.fc_V(x)A = self.fc_A(x)Q = A + V - A.mean(dim=1).view(-1,1)return Q# Vanilla DQN algorithm & Double DNQ & Dueling DQN 
class DQNs():def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, epsilon, target_update, device, dqn_type):self.action_dim = action_dimif dqn_type == "VanillaDQN" or dqn_type == "DoubleDQN":self.q_net = QNet(state_dim , hidden_dim, action_dim).to(device) # behavior net将计算转移到cuda上self.target_q_net = QNet(state_dim, hidden_dim, action_dim).to(device) # target netprint(self.q_net)elif dqn_type == "DuelingDQN": # DuelingDQN采取不同的网络框架self.q_net = AVNet(state_dim, hidden_dim, action_dim).to(device)self.target_q_net = AVNet(state_dim, hidden_dim, action_dim).to(device)self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=learning_rate)self.target_update = target_update  # 目标网络更新频率 self.gamma = gamma  # 折扣因子self.epsilon = epsilon # epsilon-greedyself.count = 0 # record update timesself.device = device # deviceself.dqn_type = dqn_type # VanillaDQN or DoubleDQN or DuelingDQNdef choose_action(self, state): # epsilon-greedy# one state is a list [x1, x2, x3, x4] if np.random.random() < self.epsilon:action = np.random.randint(self.action_dim) # 产生[0,action_dim)的随机数作为actionelse:state = torch.tensor([state], dtype=torch.float).to(self.device)action = self.q_net(state).argmax(dim=1).item()return actiondef max_q_values(self, state):state = torch.tensor([state], dtype=torch.float).to(self.device)return self.q_net(state).max(dim=1)[0].item()def learn(self, transition_dict):states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions'], dtype=torch.int64).view(-1,1).to(self.device)rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1,1).to(self.device)dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1,1).to(self.device)q_values = self.q_net(states).gather(dim=1, index=actions)if self.dqn_type == 'DoubleDQN':    # DoubleDQNmax_action = self.q_net(next_states).max(dim=1)[1].view(-1,1)max_next_q_values = self.target_q_net(next_states).gather(dim=1, index=max_action)elif self.dqn_type == 'VanillaDQN' or dqn_type == 'DuelingDQN': # VanillaDQN & DuelingDQNmax_next_q_values = self.target_q_net(next_states).max(dim=1)[0].view(-1,1)q_target = rewards + self.gamma * max_next_q_values * (1 - dones)   # TD targetdqn_loss = torch.mean(F.mse_loss(q_target, q_values)) # 均方误差损失函数self.optimizer.zero_grad()dqn_loss.backward()self.optimizer.step()# 一定周期后更新target network参数if self.count % self.target_update == 0:self.target_q_net.load_state_dict(self.q_net.state_dict())self.count += 1def dis_to_con(discrete_action, env, action_dim):  # 离散动作转回连续的函数action_lowbound = env.action_space.low[0]  # 连续动作的最小值action_upbound = env.action_space.high[0]  # 连续动作的最大值return action_lowbound + (discrete_action /(action_dim - 1)) * (action_upbound -action_lowbound)# train DQN agent
def train_DQN_agent(env, agent, replaybuffer, num_episodes, batch_size, minimal_size, seed):return_list = []max_q_value = 0max_q_value_list = []for i in range(10):with tqdm(total=int(num_episodes/10), desc="Iteration %d"%(i+1)) as pbar:for i_episode in range(int(num_episodes/10)):episode_return = 0observation, _ = env.reset(seed=seed)done = Falsewhile not done:env.render()action = agent.choose_action(observation)max_q_value = agent.max_q_values(observation) * 0.005 + max_q_value * 0.995 # smooth the maximum q-valuemax_q_value_list.append(max_q_value)  # save maximum q-value# convert discrete action to pesudo-continuousaction_continuous = dis_to_con(action, env,agent.action_dim)observation_, reward, terminated, truncated, _ = env.step([action_continuous])done = terminated or truncatedreplaybuffer.add(observation, action, reward, observation_, done)observation = observation_episode_return += rewardif replaybuffer.size() > minimal_size:b_s, b_a, b_r, b_ns, b_d = replaybuffer.sample(batch_size)transition_dict  = {'states': b_s,'actions': b_a,'rewards': b_r,'next_states': b_ns,'dones': b_d}# print('\n--------------------------------\n')# print(transition_dict)# print('\n--------------------------------\n')agent.learn(transition_dict)return_list.append(episode_return)if (i_episode + 1) % 10 == 0:pbar.set_postfix({'episode':'%d' % (num_episodes / 10 * i + i_episode + 1),'return':'%.3f' % np.mean(return_list[-10:])})pbar.update(1)env.close()return return_list, max_q_value_list        def plot_curve(return_list, mv_return, algorithm_name, env_name):episodes_list = list(range(len(return_list)))plt.plot(episodes_list, return_list, c='gray', alpha=0.6)plt.plot(episodes_list, mv_return)plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('{} on {}'.format(algorithm_name, env_name))plt.show()def moving_average(a, window_size):cumulative_sum = np.cumsum(np.insert(a, 0, 0)) middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_sizer = np.arange(1, window_size-1, 2)begin = np.cumsum(a[:window_size-1])[::2] / rend = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]return np.concatenate((begin, middle, end))if __name__ == "__main__":# reproducibleseed_number = 0random.seed(seed_number)np.random.seed(seed_number)torch.manual_seed(seed_number)# render or notrender = Falseenv_name = 'Pendulum-v1'hidden_dim = 128 # number of hidden layerslr = 2e-3   # learning ratenum_episodes = 500  # episode lengthgamma = 0.98 # discounted rateepsilon = 0.01 # epsilon-greedytarget_update = 10 # per step to update target networkbuffer_size = 10000 # maximum size of replay bufferminimal_size = 500 # minimum size of replay buffer to begin learningbatch_size = 64 # batch_size using to train the neural networkdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')if render:  env = gym.make(id=env_name, render_mode='human')else:env = gym.make(id=env_name)state_dim = env.observation_space.shape[0]action_dim = 30 # discrete the action space to 30 dimensiondqn_type = 'DuelingDQN'  # VanillaDQN & DoubleDQN & DuelingDQNagent = DQNs(state_dim, hidden_dim, action_dim, lr, gamma, epsilon, target_update, device, dqn_type)replaybuffer = ReplayBuffer(buffer_size)return_list, max_q_value_list = train_DQN_agent(env, agent, replaybuffer, num_episodes, batch_size, minimal_size, seed_number)# plot moving average return curvemv_return = moving_average(return_list, 9)plot_curve(return_list, mv_return, dqn_type, env_name)# plot maximum q-value curveframe_list = list(range(len(max_q_value_list)))plt.plot(frame_list, max_q_value_list)plt.axhline(0, c='green', ls='--')plt.axhline(10, c='red', ls='--')plt.xlabel('Frames')plt.ylabel('Max Q Values')plt.title("{} on {}".format(dqn_type, env_name))plt.show()

3.1 Vanilla DQN效果

Vanilla DQN的学习回报(return)如下图

Image

Vanilla DQN的最大Q值估计如图

Image

按照agent交互环境的分析,我们对Q值的估计不应该超过0,但是Vanilla DQN已经超过了600,这是很严重的过高估计的问题。

3.2 Double DQN效果

Double DQN的学习回报如下图

Image

Double DQN的最大Q值估计如下图

Image

可以看出来Double DQN确实有效缓解了对最大Q值估计过高的问题。

3.3 Dueling DQN效果

Dueling DQN的学习回报如下图

Image

Dueling DQN的最大Q值估计如下图

Image

可以看出来Dueling DQN也能缓解对最大Q值估计过高的问题。

Reference

Materials
DQN及其多种变式
Hands on RL
Deep Reinforcement Learning with Double Q-learning
Reinforcement Learning with Code 【Code 1. Tabular Q-learning】
Reinforcement Learning with Code 【Chapter 7. Temporal-Difference Learning】
Reinforcement Learning with Code 【Code 4. Vanilla DQN】

Papers

  • Vallina DQN

    Playing Atari with Deep Reinforcement Learning

  • Double DQN

    Deep Reinforcement Learning with Double Q-learning

  • Dueling DQN

    Dueling Network Architectures for Deep Reinforcement Learning

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/33906.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

28 | Boss直聘数据分析

针对boss直聘网的招聘信息,然后分析互联网发展排名前十的城市在互联网方面职位的薪水,学历要求,经验要求,等等信息。 准备从以下几个方面进行分析: (1)各个城市的平均工资 (2)各个学历的平均工资 (3)各个岗位的平均工资 (4)不同工作经验要求的工资 (5)各个经验…

LinuxC编程——进程间通信(一)(管道)

目录 一、Linux平台通信方式发展史二、进程间通信方式⭐⭐⭐三、无名管道3.1 特点⭐⭐⭐3.2 函数pipe3.3 注意事项⭐⭐⭐3.4 练习 四、有名管道4.1 特点⭐⭐⭐4.2 函数 mkfifo4.3 注意事项⭐⭐4.4 练习 五、无名管道与有名管道对比⭐⭐ 复杂的编程环境通常使用多个相关的进程来…

CTF-Flask-Jinja2(持续更新)

放心&#xff0c;我会一直陪着你 一.知识一.在终端的一些指令1.虚拟环境2.docker容器二.SSTI相关知识介绍1.魔术方法2.python如何执行cmd命令3.SSTI常用注入模块(1)文件读取(2)内建函数eval执行命令(3)os模块执行命令(4)importlib类执行命令(5)linecache函数执行命令(6)subproc…

线性代数(二) 矩阵及其运算

前言 行列式det(A) 其实表示的只是一个值 ∣ a b c d ∣ a d − b c \begin{vmatrix} a & b\\ c & d\end{vmatrix} ad -bc ​ac​bd​ ​ad−bc&#xff0c;其基本变化是基于这个值是不变。而矩阵表示的是一个数表。 定义 矩阵与线性变换的关系 即得 ( a 11 a 12…

逆向破解学习-登山赛车

试玩 课程中的内容 Hook代码 import de.robv.android.xposed.XC_MethodHook; import de.robv.android.xposed.XposedHelpers; import de.robv.android.xposed.callbacks.XC_LoadPackage;public class HookComYoDo1SkiSafari2TXYYB_01 extends HookImpl{Overridepublic String p…

科技云报道:一波未平一波又起?AI大模型再出邪恶攻击工具

AI大模型的快速向前奔跑&#xff0c;让我们见识到了AI的无限可能&#xff0c;但也展示了AI在虚假信息、深度伪造和网络攻击方面的潜在威胁。 据安全分析平台Netenrich报道&#xff0c;近日&#xff0c;一款名为FraudGPT的AI工具近期在暗网上流通&#xff0c;并被犯罪分子用于编…

Vue输入框或者选择框无效,或者有延迟

问题剖析 使用Vue这种成熟好用的框架&#xff0c;一般出现奇奇怪怪的问题都是因为操作不当导致的&#xff0c;例如没有合理调用组件、组件位置不正确、没有合理定义组件或者变量、样式使用不当等等... 解决方案 如果你也出现了输入框输入东西&#xff0c;但是没有效果…

实时通信应用的开发:Vue.js、Spring Boot 和 WebSocket 整合实践

目录 1. 什么是webSocket 2. webSocket可以用来做什么? 3. webSocket协议 4. 服务器端 5. 客户端 6. 测试通讯 1. 什么是webSocket WebSocket是一种在单个TCP连接上进行全双工通信的协议。WebSocket使得客户端和服务器之间的数据交换变得更加简单&#xff0c;允许服务…

百度网盘非会员倍速播放(电脑端)

百度网盘非会员倍速播放&#xff08;电脑端&#xff09; 1. 打开edge浏览器&#xff0c;点击右上角的三个点后&#xff0c;选择“扩展” 2. 选择“管理扩展” 3. 选择“获取MicrosoftEdge扩展” 4. 搜索“Global Speed” 5. 选择Global Speed:视频速度控制&#xff0c;然…

PHP 求解两字符串所有公共子序列及最长公共子序列 支持多字节字符串

/*** 获取两字符串所有公共子序列【不连续的】 例&#xff1a;abc ac > ac** param string $str1 字符串1* param string $str2 字符串2** return array*/ function public_sequence(string $str1, string $str2): array {$data [[-1, -1, , 0, ]]; // 子序列容器【横坐标 …

配置Arduino+ESP32走过的巨坑

项目场景&#xff1a; 前几天去淘宝买了块ESP32拿来用&#xff0c;配置Arduino最新版ESP32 2.0.11走过的巨坑。 问题描述 先安装好了ArduinoIDE最新版&#xff08;教程里介绍去官网下&#xff09;&#xff0c;然后配置ESP32开发板&#xff0c;后面发现安装速度惊人。 去找加速…

Pytorch深度学习-----完整神经网络模型训练套路

系列文章目录 PyTorch深度学习——Anaconda和PyTorch安装 Pytorch深度学习-----数据模块Dataset类 Pytorch深度学习------TensorBoard的使用 Pytorch深度学习------Torchvision中Transforms的使用&#xff08;ToTensor&#xff0c;Normalize&#xff0c;Resize &#xff0c;Co…

websocket知识点

http协议 http协议特点&#xff1a; 无状态协议每个请求是独立的单双工通信&#xff0c;且服务器无法主动给客户端发信息http协议受浏览器同源策略影响 http实现双向通信方法: 轮询长轮询iframe流sse EventSource websocket协议 websocket协议: 全双工协议支持跨域支持多…

自动测试框架airtest应用一:将XX读书书籍保存为PDF

一、Airtest的简介 Airtest是网易出品的一款基于图像识别和poco控件识别的一款UI自动化测试工具。Airtest的框架是网易团队自己开发的一个图像识别框架&#xff0c;这个框架的祖宗就是一种新颖的图形脚本语言Sikuli。Sikuli这个框架的原理是这样的&#xff0c;计算机用户不需要…

asp.net core webapi如何执行周期性任务

使用Api执行周期性任务 第一种&#xff0c;无图形化界面1.新建类&#xff0c;继承IJob&#xff0c;在实现的方法种书写需要周期性执行的事件。2.编写方法类&#xff0c;定义事件执行方式3.在启动方法中&#xff0c;进行设置&#xff0c;.net 6中在program.cs的Main方法中&#…

旅卦-火山旅

前言&#xff1a;人生就像一趟旅行&#xff0c;为谋生奔波也是旅&#xff0c;旅是人生的常态&#xff0c;我们看一下易经里的旅卦&#xff0c;分析下卦辞和爻辞以及自己的理解。 目录 卦辞 爻辞 总结 卦辞 旅&#xff1a;小亨&#xff0c;旅贞吉。 卦序&#xff1a;穷大者…

java获取到heapdump文件后,如何快速分析?

简介 在之前的OOM问题复盘之后&#xff0c;本周&#xff0c;又一Java服务出现了内存问题&#xff0c;这次问题不严重&#xff0c;只会触发堆内存占用高报警&#xff0c;没有触发OOM&#xff0c;但好在之前的复盘中总结了dump脚本&#xff0c;会在堆占用高时自动执行jstack与jm…

560. 和为 K 的子数组

思路 本题的主要思路为创建一个哈希表记录每个0~i的和&#xff0c;在遍历这个数组的时候查询有没有sum-k的值在哈希表中&#xff0c;如果有&#xff0c;说明有个位置到当前位置的和为k。   有可能不止一个&#xff0c;哈希表负责记录有几个sum-k&#xff0c;将和记录下来。这…

【ArcGIS Pro二次开发】(60):按图层导出布局

在使用布局导图时&#xff0c;会遇到如下问题&#xff1a; 为了切换图层和导图方便&#xff0c;一般情况下&#xff0c;会把相关图层做成图层组。 在导图的时候&#xff0c;如果想要按照图层组进行分开导图&#xff0c;如上图&#xff0c;想导出【现状图、规划图、管控边界】3…

UNIX网络编程——TCP协议API 基础demo服务器代码

目录 一.TCP客户端API 1.创建套接字 2.connect连接服务器​编辑 3.send发送信息 4.recv接受信息 5.close 二.TCP服务器API 1.socket创建tcp套接字(监听套接字) 2.bind给服务器套接字绑定port,ip地址信息 3.listen监听并创建连接队列 4.accept提取客户端的连接 5.send,r…