优先经验回放(prioritized experience replay)

prioritized experience replay 思路

优先经验回放出自ICLR 2016的论文《prioritized experience replay》。

prioritized experience replay的作者们认为,按照一定的优先级来对经验回放池中的样本采样,相比于随机均匀的从经验回放池中采样的效率更高,可以让模型更快的收敛。其基本思想是RL agent在一些转移样本上可以更有效的学习,也可以解释成“更多地训练会让你意外的数据”。

那优先级如何定义呢?作者们使用的是样本的TD error δ \delta δ 的幅值。对于新生成的样本,TD error未知时,将样本赋值为最大优先级,以保证样本至少将会被采样一次。每个采样样本的概率被定义为
P ( i ) = p i α ∑ k p k α P(i) = \frac {p_i^{\alpha}} {\sum_k p_k^{\alpha}} P(i)=kpkαpiα
上式中的 p i > 0 p_i >0 pi>0是回放池中的第i个样本的优先级, α \alpha α则强调有多重视该优先级,如果 α = 0 \alpha=0 α=0,采样就退化成和基础DQN一样的均匀采样了。

p i p_i pi如何取值,论文中提供了如下两种方法,两种方法都是关于TD error δ \delta δ 单调的:

  • 基于比例的优先级: p i = ∣ δ i ∣ + ϵ p_i = |\delta_i| + \epsilon pi=δi+ϵ ϵ \epsilon ϵ是一个很小的正数常量,防止当TD error为0时样本就不会被访问到的情形。(目前大部分实现都是使用的这个形式的优先级)
  • 基于排序的优先级: p i = 1 r a n k ( i ) p_i = \frac {1}{rank(i)} pi=rank(i)1, 式中的 r a n k ( i ) rank(i) rank(i)是样本根据 ∣ δ i ∣ |\delta_i| δi 在经验回放池中的排序号,此时P就变成了带有指数 α \alpha α的幂率分布了。

作者们定义的概率调整了样本的优先级,因此也就在数据分布中引入了偏差,为了弥补偏差,使用了重要性采样权重(importance-sampling (IS) weights):
w i = ( 1 N ⋅ 1 P ( i ) ) β w_i = \left( \frac{1}{N} \cdot \frac{1}{P(i)} \right)^{\beta} wi=(N1P(i)1)β
上式权重中,当 β = 1 \beta=1 β=1时就完全补偿了非均匀概率采样引入的偏差,作者们提到为了收敛性考虑,最后让 β \beta β从0到1中的某个值开始,并逐渐增加到1。在Q-learning更新时使用这些权重乘以TD error,也就是使用 w i δ i w_i \delta_i wiδi而不是原来的 δ i \delta_i δi。此外,为了使训练更稳定,总是对权重乘以 1 / m a x i w i 1/\mathcal{max}_i{w_i} 1/maxiwi进行归一化。

以Double DQN为例,使用优先经验回放的算法(论文算法1)如下图:

在这里插入图片描述

prioritized experience replay 实现

直接实现优先经验回放池如下代码(修改自代码 )

class PrioReplayBufferNaive:def __init__(self, buf_size, prob_alpha=0.6, epsilon=1e-5, beta=0.4, beta_increment_per_sampling=0.001):self.prob_alpha = prob_alphaself.capacity = buf_sizeself.pos = 0self.buffer = []self.priorities = np.zeros((buf_size, ), dtype=np.float32)self.beta = betaself.beta_increment_per_sampling = beta_increment_per_samplingself.epsilon = epsilondef __len__(self):return len(self.buffer)def size(self):  # 目前buffer中数据的数量return len(self.buffer)def add(self, sample):# 新加入的数据使用最大的优先级,保证数据尽可能的被采样到max_prio = self.priorities.max() if self.buffer else 1.0if len(self.buffer) < self.capacity:self.buffer.append(sample)else:self.buffer[self.pos] = sampleself.priorities[self.pos] = max_prioself.pos = (self.pos + 1) % self.capacitydef sample(self, batch_size):if len(self.buffer) == self.capacity:prios = self.prioritieselse:prios = self.priorities[:self.pos]probs = np.array(prios, dtype=np.float32) ** self.prob_alphaprobs /= probs.sum()indices = np.random.choice(len(self.buffer), batch_size, p=probs, replace=True)samples = [self.buffer[idx] for idx in indices]total = len(self.buffer)self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])weights = (total * probs[indices]) ** (-self.beta)weights /= weights.max()return samples, indices, np.array(weights, dtype=np.float32)def update_priorities(self, batch_indices, batch_priorities):'''更新样本的优先级'''for idx, prio in zip(batch_indices, batch_priorities):self.priorities[idx] = prio + self.epsilon

直接实现的优先经验回放,在样本数很大时的采样效率不够高,作者们通过定义了sumtree的数据结构来存储样本优先级,该数据结构的每一个节点的值为其子节点之和,而样本优先级被放在树的叶子节点上,树的根节点的值为所有优先级之和 p t o t a l p_{total} ptotal,更新和采样时的效率为 O ( l o g N ) O(logN) O(logN)。在采样时,设采样批次大小为k,将 [ 0 , p t o t a l ] [0, p_{total}] [0,ptotal]均分为k等份,然后在每一个区间均匀的采样一个值,再通过该值从树中提取到对应的样本。python 实现如下(代码来源)

class SumTree:"""父节点的值是其子节点值之和的二叉树数据结构"""write = 0def __init__(self, capacity):self.capacity = capacityself.tree = np.zeros(2 * capacity - 1)self.data = np.zeros(capacity, dtype=object)self.n_entries = 0# update to the root nodedef _propagate(self, idx, change):parent = (idx - 1) // 2self.tree[parent] += changeif parent != 0:self._propagate(parent, change)# find sample on leaf nodedef _retrieve(self, idx, s):left = 2 * idx + 1right = left + 1if left >= len(self.tree):return idxif s <= self.tree[left]:return self._retrieve(left, s)else:return self._retrieve(right, s - self.tree[left])def total(self):return self.tree[0]# store priority and sampledef add(self, p, data):idx = self.write + self.capacity - 1self.data[self.write] = dataself.update(idx, p)self.write += 1if self.write >= self.capacity:self.write = 0if self.n_entries < self.capacity:self.n_entries += 1# update prioritydef update(self, idx, p):change = p - self.tree[idx]self.tree[idx] = pself._propagate(idx, change)# get priority and sampledef get(self, s):idx = self._retrieve(0, s)dataIdx = idx - self.capacity + 1return (idx, self.tree[idx], self.data[dataIdx])class PrioReplayBuffer:  # stored as ( s, a, r, s_ ) in SumTreeepsilon = 0.01alpha = 0.6beta = 0.4beta_increment_per_sampling = 0.001def __init__(self, capacity):self.tree = SumTree(capacity)self.capacity = capacitydef _get_priority(self, error):return (np.abs(error) + self.epsilon) ** self.alphadef add(self, error, sample):p = self._get_priority(error)self.tree.add(p, sample)def sample(self, n):batch = []idxs = []segment = self.tree.total() / npriorities = []self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])for i in range(n):a = segment * ib = segment * (i + 1)s = random.uniform(a, b)(idx, p, data) = self.tree.get(s)priorities.append(p)batch.append(data)idxs.append(idx)sampling_probabilities = priorities / self.tree.total()is_weight = np.power(self.tree.n_entries * sampling_probabilities, -self.beta)is_weight /= is_weight.max()return batch, idxs, is_weightdef update(self, idx, error):'''这里是一次更新一个样本,所以在调用时,写for循环依次更次样本的优先级'''p = self._get_priority(error)self.tree.update(idx, p)

参考资料

  1. Schaul, Tom, John Quan, Ioannis Antonoglou, and David Silver. 2015. “Prioritized Experience Replay.” arXiv: Learning,arXiv: Learning, November.

  2. sum_tree的实现代码

  3. 相关blog: 1 (对应的代码), 2, 3

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/160617.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

UML建模图文详解教程——类图

版权声明 本文原创作者&#xff1a;谷哥的小弟作者博客地址&#xff1a;http://blog.csdn.net/lfdfhl本文参考资料&#xff1a;《UML面向对象分析、建模与设计&#xff08;第2版&#xff09;》吕云翔&#xff0c;赵天宇 著 类图概述 类图用来描述系统内各种实体的类型以及不同…

Unsupervised MVS论文笔记

Unsupervised MVS论文笔记 摘要1 引言2 相关工作3 实现方法 Tejas Khot and Shubham Agrawal and Shubham Tulsiani and Christoph Mertz and Simon Lucey and Martial Hebert. Tejas Khot and Shubham Agrawal and Shubham Tulsiani and Christoph Mertz and Simon Lucey and …

JAVA小游戏拼图

第一步是创建项目 项目名自拟 第二部创建个包名 来规范class 然后是创建类 创建一个代码类 和一个运行类 代码如下&#xff1a; package heima; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.awt.event.KeyEvent; import …

10、信息打点——APP小程序篇抓包封包XP框架反编译资产提取

APP信息搜集思路 外在——抓包封包——资产安全测试 抓包&#xff08;Fiddle&茶杯&burp&#xff09;封包&#xff08;封包监听工具&#xff09;&#xff0c;提取资源信息 资产收集——资源提取——ICO、MAD、hash——FOFA等网络测绘进行资产搜集 外在——功能逻辑 内在…

国际版Amazon Lightsail的功能解析

Amazon Lightsail是一项易于使用的云服务,可为您提供部署应用程序或网站所需的一切,从而实现经济高效且易于理解的月度计划。它是部署简单的工作负载、网站或开始使用亚马逊云科技的理想选择。 作为 AWS 免费套餐的一部分&#xff0c;可以免费开始使用 Amazon Lightsail。注册…

【Python进阶】近200页md文档14大体系第4篇:Python进程使用详解(图文演示)

本文从14大模块展示了python高级用的应用。分别有Linux命令&#xff0c;多任务编程、网络编程、Http协议和静态Web编程、htmlcss、JavaScript、jQuery、MySql数据库的各种用法、python的闭包和装饰器、mini-web框架、正则表达式等相关文章的详细讲述。 Python全套笔记直接地址…

028 - STM32学习笔记 - ADC结构体学习(二)

028 - STM32学习笔记 - 结构体学习&#xff08;二&#xff09; 上节对ADC基础知识进行了学习&#xff0c;这节在了解一下ADC相关的结构体。 一、ADC初始化结构体 在标准库函数中基本上对于外设都有一个初始化结构体xx_InitTypeDef&#xff08;其中xx为外设名&#xff0c;例如…

YOLO目标检测——卫星遥感多类别检测数据集下载分享【含对应voc、coco和yolo三种格式标签】

实际项目应用&#xff1a;卫星遥感目标检测数据集说明&#xff1a;卫星遥感多类别检测数据集&#xff0c;真实场景的高质量图片数据&#xff0c;数据场景丰富&#xff0c;含网球场、棒球场、篮球场、田径场、储罐、车辆、桥、飞机、船等类别标签说明&#xff1a;使用lableimg标…

2023年【上海市安全员C证】考试及上海市安全员C证找解析

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 2023年上海市安全员C证考试为正在备考上海市安全员C证操作证的学员准备的理论考试专题&#xff0c;每个月更新的上海市安全员C证找解析祝您顺利通过上海市安全员C证考试。 1、【多选题】2017年9月颁发的《中共上海市委…

基于STM32的烟雾浓度检测报警仿真设计(仿真+程序+讲解视频)

这里写目录标题 &#x1f4d1;1.主要功能&#x1f4d1;2.仿真&#x1f4d1;3. 程序&#x1f4d1;4. 资料清单&下载链接&#x1f4d1;[资料下载链接](https://docs.qq.com/doc/DS0VHTmxmUHBtVGVP) 基于STM32的烟雾浓度检测报警仿真设计(仿真程序讲解&#xff09; 仿真图prot…

SkyWalking配置报警推送到企业微信

1、先在企业微信群里创建一个机器人&#xff0c;复制webhook的地址&#xff1a; 2、找到SkyWalking部署位置的alarm-settings.yml文件 编辑&#xff0c;在最后面加上此段配置 &#xff01;&#xff01;&#xff01;一定格式要对&#xff0c;不然一直报警报不出来按照网上指导…

JVM 堆外内存详解

Java 进程内存占用除了JVM 运行时数据区&#xff0c;还有直接内存&#xff08;Direct Memory&#xff09;区域及 JVM 程序自身也会占用内存 直接内存&#xff08;Direct Memory&#xff09;区域&#xff1a;直接内存通过使用Native堆外内存来存储数据&#xff0c;这意味着数据…

大数据平台实践之CDH6.2.1+spark3.3.0+kyuubi-1.6.0

前言&#xff1a;关于kyuubi的原理和功能这里不做详细的介绍&#xff0c;感兴趣的同学可以直通官网&#xff1a;https://kyuubi.readthedocs.io/en/v1.7.1-rc0/index.html 下载软件版本 wget http://distfiles.macports.org/scala2.12/scala-2.12.16.tgz wget https://archi…

pikachu_php反序列化

pikachu_php反序列化 源代码 class S{var $test "pikachu";function __construct(){echo $this->test;} }//O:1:"S":1:{s:4:"test";s:29:"<script>alert(xss)</script>";} $html; if(isset($_POST[o])){$s $_POST[…

基于python人脸性别年龄检测系统-深度学习项目

欢迎大家点赞、收藏、关注、评论啦 &#xff0c;由于篇幅有限&#xff0c;只展示了部分核心代码。 文章目录 一项目简介简介技术组成1. OpenCV2. Dlib3. TensorFlow 和 Keras 功能流程 二、功能三、系统四. 总结 一项目简介 # Python 人脸性别年龄检测系统介绍 简介 该系统基…

Android studio 迁移之后打开没反应

把Android studio由d盘迁移到c盘&#xff0c;点击没反应&#xff1b; 需要把C:\Users\xxxx\AppData\Roaming\Google\AndroidStudio2022.3 目录下的studio64.exe.vmoptions 修改为C:&#xff0c;删除该文件会导致无法安装app。 里面配置了一个

SpringMVC问题

文章目录 SpringMVC运行流程MVC的概念与请求在MVC中的执行路径&#xff0c;ResponsBody注解的用途SpringMVC启动流程 SpringMVC运行流程 • 客户端&#xff08;浏览器&#xff09;发送请求&#xff0c;直接请求到 DispatcherServlet 。 • DispatcherServlet 根据请求信息调用 …

【React-Router】路由导航

1. 概念 路由系统中的多个路由之间需要进行路由跳转&#xff0c;并且在跳转的同时有可能需要传递参数进行通信。 2. 声明式导航 // /page/Login/index.jsimport { Link } from react-router-dom const Login () > {return <div>登录页{/* 解析成 a 链接 */}<Li…

Windows平台如何实现RTSP流二次编码并添加动态水印后推送RTMP或轻量级RTSP服务

技术背景 我们在对接RTSP播放器相关的技术诉求的时候&#xff0c;遇到这样的需求&#xff0c;客户做特种设备巡检的&#xff0c;需要把摄像头拍到的RTSP流拉下来&#xff0c;然后添加动态水印后&#xff0c;再生成新的RTSP URL&#xff0c;供平台调用。真个流程需要延迟尽可能…

6.基于蜻蜓优化算法 (DA)优化的VMD参数(DA-VMD)

代码原理 基于蜻蜓优化算法 (Dragonfly Algorithm, DA) 优化的 VMD 参数&#xff08;DA-VMD&#xff09;是指使用蜻蜓优化算法对 VMD 方法中的参数进行自动调优和优化。 VMD&#xff08;Variational Mode Decomposition&#xff09;是一种信号分解方法&#xff0c;用于将复杂…