Target Network缓解DQN的动作价值的高估问题

1、高估问题产生的原因

原因1:由于噪声的存在,影响 m a x ( Q ) max(Q) max(Q)的估计最大值比真实的最大值更大,最小值比真实最小值更小;

原因2:Bootstrapping,DQN近似动作价值 Q Q Q,使用TD算法更新DQN,因为TD算法存在高估,更新DQN时造成高估,下一次TD更新时也会不断高估;


2、Target Network解决动作价值高估问题思路

使用Target Network计算: max ⁡ a Q ( s t + 1 , a ; w − ) \max_aQ(s_{t+1},a;\mathbf{w}^-) maxaQ(st+1,a;w)

TD learning with naïve update:
TD Target:  y t = r t + γ ⋅ max ⁡ a Q ( s t + 1 , a ; w ) . \begin{gathered} \text{TD Target: }\\ y_t=r_t+\gamma\cdot\max_aQ(s_{t+1},a;\mathbf{w}). \\ \end{gathered} TD Target: yt=rt+γamaxQ(st+1,a;w).
TD learning with target network:
TD Target: y t = r t + γ ⋅ max ⁡ a Q ( s t + 1 , a ; w − ) \text{TD Target:}\\ y_t=r_t+\gamma\cdot\max_aQ(s_{t+1},a;\mathbf{w}^-) TD Target:yt=rt+γamaxQ(st+1,a;w)


3、代码实现

实现带有target network的DQN

class DQNWithTargetNetwork:def __init__(self, dim_state=None, num_action=None, discount=0.9):self.discount = discountself.Q = QNet(dim_state, num_action)# 添加target networkself.target_Q = QNet(dim_state, num_action)self.target_Q.load_state_dict(self.Q.state_dict())def get_action(self, state):# 使用最大价值的动作qvals = self.Q(state)return qvals.argmax()def compute_loss(self, s_batch, a_batch, r_batch, d_batch, next_s_batch):# 计算s_batch,a_batch对应的值。qvals = self.target_Q(s_batch).gather(1, a_batch.unsqueeze(1)).squeeze()# 使用target Q网络计算next_s_batch对应的值。next_qvals, _ = self.target_Q(next_s_batch).detach().max(dim=1)# 使用MSE计算loss。loss = F.mse_loss(r_batch + self.discount * next_qvals * (1 - d_batch), qvals)return loss

隔一段时间在再更新target network

# 加权更新target network
def soft_update(target, source, tau=0.01):"""update target by target = tau * source + (1 - tau) * target."""for target_param, param in zip(target.parameters(), source.parameters()):target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

4、对gather的理解
例如三维的input,从广播机制很容易理解。当dim==0,意味着

out[i][j][k]中的[i]指的是用[index[i][j][k]]取数据放到i的,out[j][k]指的是这两个维度与out同时变化

广播机制是计算循环的一种更快的机制,因此用循环来理解是一样的:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0

等价于:

out = torch.zeros(index.shape)#定义zero空tensor# 循环赋值
for j in range(input.shape[1]):for k in range(input.shape[2]):out[:, j, k] = input[index[i][j][k], j, k]

如果是其他维度可参考:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

一个例子:

t = torch.tensor([[1, 2], [3, 4]])
torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))>>tensor([[ 1,  1],[ 4,  3]])

torch.gather — PyTorch 2.0 documentation


5、对detech的理解:
将tensor从计算图中分离,不进行梯度更新

torch.Tensor.detach — PyTorch 2.0 documentation

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/29173.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

备战大型攻防演练,“3+1”一套搞定云上安全

在重大活动保障期间,企业不仅要面对愈发灵活隐蔽的新型攻击挑战,还要在人员、精力有限的情况下应对不分昼夜的高强度安全运维任务。如何在这种多重压力下,从“疲于应付”迈向“胸有成竹”呢? 知己知彼,百战不殆&#…

用户体验旅程图:改进用户体验的好工具

用户体验旅程图:改进用户体验的好工具 怎么改进体验,是有方法的 用户情绪曲线来衡量用户感觉 趣讲大白话:没有流程刨析,就没法改进 【趣讲信息科技245期】 **************************** 企业管理需要基本的流程的 企业流程简称BP…

docker容器监控:Cadvisor+InfluxDB+Grafana的安装部署

目录 CadvisorInfluxDBGrafan安装部署 1、安装docker-ce 2、阿里云镜像加速器 3、下载组件镜像 4、创建自定义网络 5、创建influxdb容器 6、创建Cadvisor 容器 7、查看Cadvisor 容器: (1)准备测试镜像 (2)通…

java判断字符串是否包含英文,以及英文个数

在Java中,可以使用正则表达式或字符遍历的方式来判断字符串是否包含英文字符,并统计英文字符的个数。 使用正则表达式判断字符串是否包含英文字符: String str "Hello, 你好!"; boolean containsEnglish str.matche…

pycharm、idea、golang等JetBrains其他IDE修改行分隔符(换行符)、在Git CRLF、LF 换行符转换

文章目录 pycharm、idea、golang系列修改行分隔符我应该选择什么换行符JetBrains IDE,默认行分隔符 是跟随系统修改JetBrains IDE,默认行分隔符 在Git CRLF、LF 换行符转换需求Git 配置选项 pycharm、idea、golang系列修改行分隔符 一般来说,不同的开发…

JSX语法基础总结

题记:首先我们要了解一下jsx是什么,跟js有什么区别,其实就是js的语法糖,加上了xml的语法,使得产生虚拟dom更加的方便,简单说一下,xml就是存储数据的格式,想了解xml的话,可…

【陈老板赠书活动 - 10期】- 【Python之光:Python编程入门与实战】

陈老老老板🦸 👨‍💻本文专栏:赠书活动专栏(为大家争取的福利,免费送书) 👨‍💻本文简述:生活就像海洋,只有意志坚强的人,才能到达彼岸。讲一些我刚进公司的学…

一篇文章告诉你帮助中心系统如何去落实用户的需求

我们在搭建帮助中心系统的时候最最最主要的就是要知道我们的客户到底需要什么!我们又可以怎么帮到他们呢!那我们就要开始思考,怎么才能洞察到客户的需求并且落实。今天looklook就从这个角度展开,教你们怎么去了解到客户的需求&…

solidworks(2)

记得选择双向

【YOLO】替换骨干网络为轻量级网络MobileNet3

替换骨干网络为轻量级网络MobileNet_v3 上一章 模型网络结构解析&增加小目标检测 文章目录 替换骨干网络为轻量级网络MobileNet_v3前言一、MobileNetV3介绍二、MobileNetV2&MobileNetV3三、MobileNetV3网络结构1. 结构查看2. 查看每层featuremap大小三、YOLOV5替换骨干…

【解放ipad生产力】如何在平板上使用免费IDE工具完成项目开发

我的博客即将同步至腾讯云开发者社区,邀请大家一同入驻:https://cloud.tencent.com/developer/support-plan?invite_code3o19zyy2pneoo 前言 很多人应该会像我一样吧,有时候身边没电脑突然要写项目,发现自己的平板没有一点作用&…

手机开启应急预警通知 / 地震预警

前言 安卓手机在检测到地震时,将发送地震预警通知,但此设置是默认关闭的,原因是以防引发用户恐慌从而引发安全问题,且开启此设置需要完成指引教程,因此默认关闭此设置。下文介绍如何开启此设置。 开启方法 华为手机开…

未来C#上位机软件发展趋势

C#上位机软件迎来新的发展机遇。随着工业自动化的快速发展,C#作为一种流行的编程语言在上位机软件领域发挥着重要作用。未来,C#上位机软件可能会朝着以下几个方向发展: 1.智能化:随着人工智能技术的不断演进,C#上位机…

【CHI】架构介绍

Learn the architecture - Introducing AMBA CHI AMBA CHI协议导论--言身寸 1. AMBA CHI简介 一致性集线器接口(CHI)是AXI一致性扩展(ACE)协议的演进。它是Arm提供的高级微控制器总线架构(AMBA)的一部分。…

【VisualGLM】大模型之 VisualGLM 部署

目录 1. VisualGLM 效果展示 2. VisualGLM 介绍 3. VisualGLM 部署 1. VisualGLM 效果展示 VisualGLM 问答 原始图片 2. VisualGLM 介绍 VisualGLM 主要做的是通过图像生成文字,而 Stable Diffusion 是通过文字生成图像。 一种方法是将图像当作一种特殊的语言进…

Flink 两阶段提交(Two-Phase Commit)协议

Flink 两阶段提交(Two-Phase Commit)是指在 Apache Flink 流处理框架中,为了保证分布式事务的一致性而采用的一种协议。它通常用于在流处理应用中处理跨多个分布式数据源的事务性操作,确保所有参与者(数据源或计算节点…

springboot传给前端日期少了八小时

在Spring Boot中,如果从MySQL数据库中获取日期,并在前端显示时少了8小时,这通常是由于时区的问题导致的。MySQL默认使用系统的时区,而Spring Boot默认使用UTC时区。 spring-boot默认使用Jackson对返回到前端的值进行序列化。Jack…

使用node搭建服务器,前端自己写接口,将vue或react打包后生成的dist目录在本地运行

使用node.jsexpress或者使用node.jspm2搭建服务器,将vue或react打包后生成的dist目录在本地运行 vue项目打包后生成的dist目录如果直接在本地打开index.html,在浏览器中会报错,无法运行起来。 通常我是放到后端搭建的服务上面去运行,当时前端…

高速公路巡检新手段——道路智能巡检系统

高速公路作为我国公路建设的一项重要成果,其建设和运营对于促进我国经济发展、改善交通运输条件和提高人民生活水平具有重要的意义。 高速公路巡检是确保公路安全的重要措施之一。每年数以万计的车辆在高速公路上穿行,因此高速公路的安全性显得尤为重要。…

objectMapper.configure 方法的作用和使用

objectMapper.configure 方法是 Jackson 提供的一个用于配置 ObjectMapper 对象的方法。ObjectMapper 是 Jackson 库的核心类,用于将 Java 对象与 JSON 数据相互转换。 configure 方法的作用是设置 ObjectMapper 的配置选项,例如设置日期格式、设置序列…