Double DQN算法

Double DQN算法

问题

DQN 算法通过贪婪法直接获得目标 Q 值,贪婪法通过最大化方式使 Q 值快速向可能的优化目标收敛,但易导致过估计Q 值的问题,使模型具有较大的偏差。
即:
对于DQN模型, 损失函数使用的
Q(state) = reward + Q(nextState)max
Q(state)由训练网络生成, Q(nextState)max由目标网络生成

这种损失函数会存在问题,即当Q(nextState)max总是大于0时,那么Q(state)总是在不停的增大,同时Q(nextState)max也在不断的增大, 即Q(state)存在被高估的情况。

作者采用 Double DQN 算法解耦动作的选择和目标 Q 值的计算,以解决过估计 Q 值的问题。

Double DQN 原理

Double DQN 算法结构如下。在 Double DQN 框架中存在两个神经网络模型,分别是训练网络与目标网络。这两个神经网络模型的结构完全相同,但是权重参数不同;每训练一段之间后,训练网络的权重参数才会复制给目标网络。训练时,训练网络用于估计当前的 ,而目标网络用于估计 ,这样就能保证真实值 的估计不会随着训练网络的不断自更新而变化过快。此外,DQN 还是一种支持离线学习的框架,即通过构建经验池的方式离线学习过去的经验。将均方误差 MSE(Q_{train}, Q_{target}) 作为训练模型的损失函数,通过梯度下降法进行反向传播,对训练模型进行更新;若干轮经验池采样后,再将训练模型的权重赋给目标模型,以此进行 Double DQN 框架下的模型自学习。

目标 Q 值的计算公式如下所示:
y j = r j + γ max ⁡ a ′ Q ( s j + 1 , a ′ ; θ ′ ) y_j=r_j+\gamma \max _{a^{\prime}} Q\left(s_{j+1}, a^{\prime} ; \theta^{\prime}\right) yj=rj+γamaxQ(sj+1,a;θ)

Double DQN 算法不直接通过最大化的方式选取目标网络计算的所有可能 Q Q Q 值,而是首先通过估计网络选取最大 Q Q Q 值对应的动作,公式表示如下:
a max ⁡ = argmax ⁡ a Q ( s t + 1 , a ; θ ) a_{\max }=\operatorname{argmax}_a Q\left(s_{t+1}, a ; \theta\right) amax=argmaxaQ(st+1,a;θ)

然后目标网络根据 a max ⁡ a_{\max } amax 计算目标 Q 值,公式表示如下:
y j = r j + γ Q ( s j + 1 , a max ⁡ ; θ ′ ) y_j=r_j+\gamma Q\left(s_{j+1}, a_{\max } ; \theta^{\prime}\right) yj=rj+γQ(sj+1,amax;θ)

最后将上面两个公式结合,目标 Q Q Q 值的最终表示形式如下:
y j = r j + γ Q ( s j + 1 , argmax ⁡ a Q ( s t + 1 , a ; θ ) ; θ ′ ) y_j=r_j+\gamma Q\left(s_{j+1}, \operatorname{argmax}_a Q\left(s_{t+1, a ; \theta}\right) ; \theta^{\prime}\right) yj=rj+γQ(sj+1,argmaxaQ(st+1,a;θ);θ)

目标是最小化目标函数,即最小化估计 Q Q Q 值和目标 Q Q Q 值的差值,公式如下:
δ = ∣ Q ( s t , a t ) − y t ∣ = ∣ Q ( s t , a t ; θ ) − ( r t + γ Q ( S t + 1 , argmax ⁡ a Q ( s t + 1 , a ; θ ) ; θ ′ ) ) ∣ \begin{aligned} & \delta=\left|Q\left(s_t, a_t\right)-y_t\right|=\mid Q\left(s_t, a_t ; \theta\right)-\left(r_t+\right. \\ & \left.\gamma Q\left(S_{t+1}, \operatorname{argmax}_a Q\left(s_{t+1}, a ; \theta\right) ; \theta^{\prime}\right)\right) \mid \end{aligned} δ=Q(st,at)yt=∣Q(st,at;θ)(rt+γQ(St+1,argmaxaQ(st+1,a;θ);θ))

结合目标函数,损失函数定义如下:
loss  = { 1 2 δ 2 for  ∣ δ ∣ ⩽ 1 ∣ δ ∣ − 1 2 otherwize  } \text { loss }=\left\{\begin{array}{cl} \frac{1}{2} \delta^2 & \text { for }|\delta| \leqslant 1 \\ |\delta|-\frac{1}{2} & \text { otherwize } \end{array}\right\}  loss ={21δ2δ21 for δ1 otherwize }

代码

  1. 游戏环境
import gym#定义环境
class MyWrapper(gym.Wrapper):def __init__(self):env = gym.make('CartPole-v1', render_mode='rgb_array')super().__init__(env)self.env = envself.step_n = 0def reset(self):state, _ = self.env.reset()self.step_n = 0return statedef step(self, action):state, reward, terminated, truncated, info = self.env.step(action)over = terminated or truncated#限制最大步数self.step_n += 1if self.step_n >= 200:over = True#没坚持到最后,扣分if over and self.step_n < 200:reward = -1000return state, reward, over#打印游戏图像def show(self):from matplotlib import pyplot as pltplt.figure(figsize=(3, 3))plt.imshow(self.env.render())plt.show()env = MyWrapper()env.reset()env.show()
  1. Q价值函数
import torch#定义模型,评估状态下每个动作的价值
model = torch.nn.Sequential(torch.nn.Linear(4, 64),torch.nn.ReLU(),torch.nn.Linear(64, 64),torch.nn.ReLU(),torch.nn.Linear(64, 2),
)#延迟更新的模型,用于计算target
model_delay = torch.nn.Sequential(torch.nn.Linear(4, 64),torch.nn.ReLU(),torch.nn.Linear(64, 64),torch.nn.ReLU(),torch.nn.Linear(64, 2),
)#复制参数
model_delay.load_state_dict(model.state_dict())model, model_delay
  1. 单条轨迹
from IPython import display
import random#玩一局游戏并记录数据
def play(show=False):data = []reward_sum = 0state = env.reset()over = Falsewhile not over:action = model(torch.FloatTensor(state).reshape(1, 4)).argmax().item()if random.random() < 0.1:action = env.action_space.sample()next_state, reward, over = env.step(action)data.append((state, action, reward, next_state, over))reward_sum += rewardstate = next_stateif show:display.clear_output(wait=True)env.show()return data, reward_sumplay()[-1]
  1. 经验池
#数据池
class Pool:def __init__(self):self.pool = []def __len__(self):return len(self.pool)def __getitem__(self, i):return self.pool[i]#更新动作池def update(self):#每次更新不少于N条新数据old_len = len(self.pool)while len(pool) - old_len < 200:self.pool.extend(play()[0])#只保留最新的N条数据self.pool = self.pool[-2_0000:]#获取一批数据样本def sample(self):data = random.sample(self.pool, 64)state = torch.FloatTensor([i[0] for i in data]).reshape(-1, 4)action = torch.LongTensor([i[1] for i in data]).reshape(-1, 1)reward = torch.FloatTensor([i[2] for i in data]).reshape(-1, 1)next_state = torch.FloatTensor([i[3] for i in data]).reshape(-1, 4)over = torch.LongTensor([i[4] for i in data]).reshape(-1, 1)return state, action, reward, next_state, overpool = Pool()
pool.update()
pool.sample()len(pool), pool[0]
  1. 训练
#训练
def train():model.train()optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)loss_fn = torch.nn.MSELoss()#共更新N轮数据for epoch in range(1000):pool.update()#每次更新数据后,训练N次for i in range(200):#采样N条数据state, action, reward, next_state, over = pool.sample()#计算valuevalue = model(state).gather(dim=1, index=action)#计算targetwith torch.no_grad():target = model_delay(next_state)target = target.max(dim=1)[0].reshape(-1, 1)target = target * 0.99 * (1 - over) + rewardloss = loss_fn(value, target)loss.backward()optimizer.step()optimizer.zero_grad()#复制参数if (epoch + 1) % 5 == 0:model_delay.load_state_dict(model.state_dict())if epoch % 100 == 0:test_result = sum([play()[-1] for _ in range(20)]) / 20print(epoch, len(pool), test_result)train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/142625.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java14新增特性

前言 前面的文章&#xff0c;我们对Java9、Java10、Java11、Java12 、Java13的特性进行了介绍&#xff0c;对应的文章如下 Java9新增特性 Java10新增特性 Java11新增特性 Java12新增特性 Java13新增特性 今天我们来一起看一下Java14这个版本的一些重要信息 版本介绍 Java 14…

线程相关问题

多线程 计算机在同一时间可以执行多个线程 并行 多个事情在同一时间点内发生&#xff0c;并行的发生是不会抢占资源的 并发 多个事情在一段时间内同时发生&#xff0c;并发的产生会抢占资源 多线程的好处 如果为单线程计算机一次只能处理一个线程&#xff0c;那么当处理的线程需…

JNDI注入

1、什么是 JNDI JNDI(Java Naming and Directory Interface, Java命名和目录接口)&#xff0c;JNDI API 映射为特定的命名&#xff08;Name&#xff09;和目录服务&#xff08;Directory&#xff09;系统&#xff0c;使得Java应用程序可以和这些命名&#xff08;Name&#xff…

【Shell脚本11】Shell 函数

Shell 函数 linux shell 可以用户定义函数&#xff0c;然后在shell脚本中可以随便调用。 shell中函数的定义格式如下&#xff1a; [ function ] funname [()]{action;[return int;]}说明&#xff1a; 1、可以带function fun() 定义&#xff0c;也可以直接fun() 定义,不带任何…

SQL基础理论篇(一):什么是SQL

文章目录 什么是SQLSQL的四大部分常用的SQL标准参考文献 什么是SQL SQL的全称是Structured Query Language&#xff0c;即结构化查询语句。 其最早诞生于1974年&#xff0c;IBM研究员发布的一篇论文"SEQUEL&#xff1a;一门结构化的英语查询语言"。这几十年里&…

旺店通·企业版对接打通金蝶云星空查询调拨单接口与分布式调入单新增接口

旺店通企业版对接打通金蝶云星空查询调拨单接口与分布式调入单新增接口 源系统:旺店通企业版 旺店通是北京掌上先机网络科技有限公司旗下品牌&#xff0c;国内的零售云服务提供商&#xff0c;基于云计算SaaS服务模式&#xff0c;以体系化解决方案&#xff0c;助力零售企业数字化…

Android framework添加自定义的Product项目,lunch目标项目

文章目录 Android framework添加自定义的Product项目1.什么是Product&#xff1f;2.定义自己的Product玩一玩 Android framework添加自定义的Product项目 1.什么是Product&#xff1f; 源码目录下输入lunch命令之后&#xff0c;简单理解下面这些列表就是product。用于把系统编…

OpenCV+特征检测

检测 函数cv.cornerHarris()。其参数为&#xff1a; img 输入图像&#xff0c;应为灰度和float32类型blockSize是拐角检测考虑的邻域大小ksize 使用的Sobel导数的光圈参数k 等式中的哈里斯检测器自由参数 import numpy as np import cv2 as cv filename chessboard.png img…

如何显示标注的纯黑mask图

文章目录 前言一、二分类mask显示二、多分类mask显示 前言 通常情况下&#xff0c;使用标注软件标注的标签图看起来都是纯黑的&#xff0c;因为mask图为单通道的灰度图&#xff0c;而灰度图一般要像素值大于128后&#xff0c;才会逐渐显白&#xff0c;255为白色。而标注的时候…

sass 生成辅助色

背景 一个按钮往往有 4 个状态。 默认状态hover鼠标按下禁用状态 为了表示这 4 个状态&#xff0c;需要设置 4 个颜色来提示用户。 按钮类型一般有 5 个&#xff1a; 以 primary 类型按钮为例&#xff0c;设置它不同状态下的颜色&#xff1a; <button class"btn…

IP-guard Webserver view 远程命令执行漏洞【2023最新漏洞】

IP-guard Webserver view 远程命令执行漏洞【2023最新漏洞】 一、漏洞描述二、漏洞影响三、漏洞危害四、FOFA语句五、漏洞复现1、手动复现yaml pocburp发包 2、自动化复现小龙POC检测工具下载地址 免责声明&#xff1a;请勿利用文章内的相关技术从事非法测试&#xff0c;由于传…

R程序 示例4.3.2版本包 在centos进行编译部署

为了在CentOS上下载和编译R语言4.3.2包&#xff0c;可以按照以下步骤进行操作&#xff1a; 1.首先&#xff0c;需要安装一些必要的依赖项。可以使用以下命令安装它们&#xff1a; sudo yum install -y epel-release sudo yum install -y gcc gcc-c gcc-gfortran readline-dev…

Linux 使用随记

Linux 使用随记 shell 命令行模式登录后所取得的程序被成为shell&#xff0c;这是因为这个程序负责最外层的跟用户&#xff08;我们&#xff09;通信工作&#xff0c;所以才被戏称为shell。 命令 1、命令格式 command [-options] parameter1 parameter2 … 1、一行命令中第…

UML建模语言

UML建模语言 类的关系 依赖关系 类的方法中使用形参、局部变量或者静态方法的方式调用其他类&#xff0c;表示当前类依赖其他类。 public class Main {public void eat(Person person) {person.play();// 方法参数Student student new Student();student.study();// 局部变…

4 条件判断和循环

文章目录 一、条件判断和循环1.1 if语句1.2 if-else1.3 if-elif-else1.4 for循环1.5 while循环1.6 break退出循环1.7 continue继续循环1.8 多重循环 二、练习题小结 一、条件判断和循环 1.1 if语句 输入用户年龄&#xff0c;根据年龄打印不同的内容&#xff0c;在Python程序中…

C#几种截取字符串的方法

在C#编程中&#xff0c;经常需要对字符串进行截取操作&#xff0c;即从一个长字符串中获取所需的部分信息。本文将介绍几种常用的C#字符串截取方法&#xff0c;并提供相应的示例代码。 目录 1. 使用Substring方法2. 使用Split方法3. 使用Substring和IndexOf方法4. 使用Regex类…

JVM之垃圾回收

1. 如何判断对象可以回收 1.1 引用计数法 引用计数法是一种内存管理技术&#xff0c;其中每个对象都有一个与之关联的引用计数。引用计数表示当前有多少个指针引用了该对象。当引用计数变为零时&#xff0c;表示没有指针再指向该对象&#xff0c;该对象可以被释放&#xff0c…

HBase学习笔记(3)—— HBase整合Phoenix

目录 Phoenix Shell 操作 Phoenix JDBC 操作 Phoenix 二级索引 HBase整合Phoenix Phoenix 简介 Phoenix 是 HBase 的开源 SQL 皮肤。可以使用标准 JDBC API 代替 HBase 客户端 API来创建表&#xff0c;插入数据和查询 HBase 数据 使用Phoenix的优点 在 Client 和 HBase …

C++虚基类详解

多继承&#xff08;Multiple Inheritance&#xff09; 是指从多个直接基类中产生派生类的能力&#xff0c;多继承的派生类继承了所有父类的成员。尽管概念上非常简单&#xff0c;但是多个基类的相互交织可能会带来错综复杂的设计问题&#xff0c;命名冲突就是不可回避的一个。…

云原生Kubernetes系列 | 通过容器互联搭建wordpress博客系统

云原生Kubernetes系列 | 通过容器互联搭建wordpress博客系统 通过容器互联搭建一个wordpress博客系统。wordpress系统是需要连接到数据库上的,所以wordpress和mysql的镜像都是需要的。wordpress在创建过程中需要指定一些参数。创建mysql容器时需要把mysql的数据保存在宿主机本…