优先经验回放(prioritized experience replay)

prioritized experience replay 思路

优先经验回放出自ICLR 2016的论文《prioritized experience replay》。

prioritized experience replay的作者们认为,按照一定的优先级来对经验回放池中的样本采样,相比于随机均匀的从经验回放池中采样的效率更高,可以让模型更快的收敛。其基本思想是RL agent在一些转移样本上可以更有效的学习,也可以解释成“更多地训练会让你意外的数据”。

那优先级如何定义呢?作者们使用的是样本的TD error δ \delta δ 的幅值。对于新生成的样本,TD error未知时,将样本赋值为最大优先级,以保证样本至少将会被采样一次。每个采样样本的概率被定义为
P ( i ) = p i α ∑ k p k α P(i) = \frac {p_i^{\alpha}} {\sum_k p_k^{\alpha}} P(i)=kpkαpiα
上式中的 p i > 0 p_i >0 pi>0是回放池中的第i个样本的优先级, α \alpha α则强调有多重视该优先级,如果 α = 0 \alpha=0 α=0,采样就退化成和基础DQN一样的均匀采样了。

p i p_i pi如何取值,论文中提供了如下两种方法,两种方法都是关于TD error δ \delta δ 单调的:

  • 基于比例的优先级: p i = ∣ δ i ∣ + ϵ p_i = |\delta_i| + \epsilon pi=δi+ϵ ϵ \epsilon ϵ是一个很小的正数常量,防止当TD error为0时样本就不会被访问到的情形。(目前大部分实现都是使用的这个形式的优先级)
  • 基于排序的优先级: p i = 1 r a n k ( i ) p_i = \frac {1}{rank(i)} pi=rank(i)1, 式中的 r a n k ( i ) rank(i) rank(i)是样本根据 ∣ δ i ∣ |\delta_i| δi 在经验回放池中的排序号,此时P就变成了带有指数 α \alpha α的幂率分布了。

作者们定义的概率调整了样本的优先级,因此也就在数据分布中引入了偏差,为了弥补偏差,使用了重要性采样权重(importance-sampling (IS) weights):
w i = ( 1 N ⋅ 1 P ( i ) ) β w_i = \left( \frac{1}{N} \cdot \frac{1}{P(i)} \right)^{\beta} wi=(N1P(i)1)β
上式权重中,当 β = 1 \beta=1 β=1时就完全补偿了非均匀概率采样引入的偏差,作者们提到为了收敛性考虑,最后让 β \beta β从0到1中的某个值开始,并逐渐增加到1。在Q-learning更新时使用这些权重乘以TD error,也就是使用 w i δ i w_i \delta_i wiδi而不是原来的 δ i \delta_i δi。此外,为了使训练更稳定,总是对权重乘以 1 / m a x i w i 1/\mathcal{max}_i{w_i} 1/maxiwi进行归一化。

以Double DQN为例,使用优先经验回放的算法(论文算法1)如下图:

在这里插入图片描述

prioritized experience replay 实现

直接实现优先经验回放池如下代码(修改自代码 )

class PrioReplayBufferNaive:def __init__(self, buf_size, prob_alpha=0.6, epsilon=1e-5, beta=0.4, beta_increment_per_sampling=0.001):self.prob_alpha = prob_alphaself.capacity = buf_sizeself.pos = 0self.buffer = []self.priorities = np.zeros((buf_size, ), dtype=np.float32)self.beta = betaself.beta_increment_per_sampling = beta_increment_per_samplingself.epsilon = epsilondef __len__(self):return len(self.buffer)def size(self):  # 目前buffer中数据的数量return len(self.buffer)def add(self, sample):# 新加入的数据使用最大的优先级,保证数据尽可能的被采样到max_prio = self.priorities.max() if self.buffer else 1.0if len(self.buffer) < self.capacity:self.buffer.append(sample)else:self.buffer[self.pos] = sampleself.priorities[self.pos] = max_prioself.pos = (self.pos + 1) % self.capacitydef sample(self, batch_size):if len(self.buffer) == self.capacity:prios = self.prioritieselse:prios = self.priorities[:self.pos]probs = np.array(prios, dtype=np.float32) ** self.prob_alphaprobs /= probs.sum()indices = np.random.choice(len(self.buffer), batch_size, p=probs, replace=True)samples = [self.buffer[idx] for idx in indices]total = len(self.buffer)self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])weights = (total * probs[indices]) ** (-self.beta)weights /= weights.max()return samples, indices, np.array(weights, dtype=np.float32)def update_priorities(self, batch_indices, batch_priorities):'''更新样本的优先级'''for idx, prio in zip(batch_indices, batch_priorities):self.priorities[idx] = prio + self.epsilon

直接实现的优先经验回放,在样本数很大时的采样效率不够高,作者们通过定义了sumtree的数据结构来存储样本优先级,该数据结构的每一个节点的值为其子节点之和,而样本优先级被放在树的叶子节点上,树的根节点的值为所有优先级之和 p t o t a l p_{total} ptotal,更新和采样时的效率为 O ( l o g N ) O(logN) O(logN)。在采样时,设采样批次大小为k,将 [ 0 , p t o t a l ] [0, p_{total}] [0,ptotal]均分为k等份,然后在每一个区间均匀的采样一个值,再通过该值从树中提取到对应的样本。python 实现如下(代码来源)

class SumTree:"""父节点的值是其子节点值之和的二叉树数据结构"""write = 0def __init__(self, capacity):self.capacity = capacityself.tree = np.zeros(2 * capacity - 1)self.data = np.zeros(capacity, dtype=object)self.n_entries = 0# update to the root nodedef _propagate(self, idx, change):parent = (idx - 1) // 2self.tree[parent] += changeif parent != 0:self._propagate(parent, change)# find sample on leaf nodedef _retrieve(self, idx, s):left = 2 * idx + 1right = left + 1if left >= len(self.tree):return idxif s <= self.tree[left]:return self._retrieve(left, s)else:return self._retrieve(right, s - self.tree[left])def total(self):return self.tree[0]# store priority and sampledef add(self, p, data):idx = self.write + self.capacity - 1self.data[self.write] = dataself.update(idx, p)self.write += 1if self.write >= self.capacity:self.write = 0if self.n_entries < self.capacity:self.n_entries += 1# update prioritydef update(self, idx, p):change = p - self.tree[idx]self.tree[idx] = pself._propagate(idx, change)# get priority and sampledef get(self, s):idx = self._retrieve(0, s)dataIdx = idx - self.capacity + 1return (idx, self.tree[idx], self.data[dataIdx])class PrioReplayBuffer:  # stored as ( s, a, r, s_ ) in SumTreeepsilon = 0.01alpha = 0.6beta = 0.4beta_increment_per_sampling = 0.001def __init__(self, capacity):self.tree = SumTree(capacity)self.capacity = capacitydef _get_priority(self, error):return (np.abs(error) + self.epsilon) ** self.alphadef add(self, error, sample):p = self._get_priority(error)self.tree.add(p, sample)def sample(self, n):batch = []idxs = []segment = self.tree.total() / npriorities = []self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])for i in range(n):a = segment * ib = segment * (i + 1)s = random.uniform(a, b)(idx, p, data) = self.tree.get(s)priorities.append(p)batch.append(data)idxs.append(idx)sampling_probabilities = priorities / self.tree.total()is_weight = np.power(self.tree.n_entries * sampling_probabilities, -self.beta)is_weight /= is_weight.max()return batch, idxs, is_weightdef update(self, idx, error):'''这里是一次更新一个样本,所以在调用时,写for循环依次更次样本的优先级'''p = self._get_priority(error)self.tree.update(idx, p)

参考资料

  1. Schaul, Tom, John Quan, Ioannis Antonoglou, and David Silver. 2015. “Prioritized Experience Replay.” arXiv: Learning,arXiv: Learning, November.

  2. sum_tree的实现代码

  3. 相关blog: 1 (对应的代码), 2, 3

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/160617.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

UML建模图文详解教程——类图

版权声明 本文原创作者&#xff1a;谷哥的小弟作者博客地址&#xff1a;http://blog.csdn.net/lfdfhl本文参考资料&#xff1a;《UML面向对象分析、建模与设计&#xff08;第2版&#xff09;》吕云翔&#xff0c;赵天宇 著 类图概述 类图用来描述系统内各种实体的类型以及不同…

Unsupervised MVS论文笔记

Unsupervised MVS论文笔记 摘要1 引言2 相关工作3 实现方法 Tejas Khot and Shubham Agrawal and Shubham Tulsiani and Christoph Mertz and Simon Lucey and Martial Hebert. Tejas Khot and Shubham Agrawal and Shubham Tulsiani and Christoph Mertz and Simon Lucey and …

JAVA小游戏拼图

第一步是创建项目 项目名自拟 第二部创建个包名 来规范class 然后是创建类 创建一个代码类 和一个运行类 代码如下&#xff1a; package heima; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.awt.event.KeyEvent; import …

10、信息打点——APP小程序篇抓包封包XP框架反编译资产提取

APP信息搜集思路 外在——抓包封包——资产安全测试 抓包&#xff08;Fiddle&茶杯&burp&#xff09;封包&#xff08;封包监听工具&#xff09;&#xff0c;提取资源信息 资产收集——资源提取——ICO、MAD、hash——FOFA等网络测绘进行资产搜集 外在——功能逻辑 内在…

国际版Amazon Lightsail的功能解析

Amazon Lightsail是一项易于使用的云服务,可为您提供部署应用程序或网站所需的一切,从而实现经济高效且易于理解的月度计划。它是部署简单的工作负载、网站或开始使用亚马逊云科技的理想选择。 作为 AWS 免费套餐的一部分&#xff0c;可以免费开始使用 Amazon Lightsail。注册…

【Python进阶】近200页md文档14大体系第4篇:Python进程使用详解(图文演示)

本文从14大模块展示了python高级用的应用。分别有Linux命令&#xff0c;多任务编程、网络编程、Http协议和静态Web编程、htmlcss、JavaScript、jQuery、MySql数据库的各种用法、python的闭包和装饰器、mini-web框架、正则表达式等相关文章的详细讲述。 Python全套笔记直接地址…

PostgreSQL10安装postgis插件

1.安装pgsql10 2.下载插件&#xff0c;以Windows为例&#xff0c;地址&#xff1a;Index of /postgis/windows/pg10/ 3.安装插件&#xff0c;直接安装&#xff0c;和pgsql的目录相同即可&#xff0c;一直下一步 4.安装之后&#xff0c;需要执行sql打开 CREATE EXTENSION po…

028 - STM32学习笔记 - ADC结构体学习(二)

028 - STM32学习笔记 - 结构体学习&#xff08;二&#xff09; 上节对ADC基础知识进行了学习&#xff0c;这节在了解一下ADC相关的结构体。 一、ADC初始化结构体 在标准库函数中基本上对于外设都有一个初始化结构体xx_InitTypeDef&#xff08;其中xx为外设名&#xff0c;例如…

Redis设计与实现-数据结构(建设进度17%)

Redis数据结构 引言数据结构stringSDS数据结构原生string的不足 hash 本博客基于《Redis设计与实现》进行整理和补充&#xff0c;该书依赖于Redis 3.0版本&#xff0c;但是Redis6.0版本在一些底层实现上仍然没有明显的变动&#xff0c;因此本文将在该书的基础上&#xff0c;对于…

PostgreSQL基本操作

1.查询某个表的所在磁盘大小 select pg_size_pretty(pg_relation_size(grb_grid)); 2.插入point类型的记录 insert into tb_person ("name", "address", "location", "create_time", "area", "girls") values …

Java 两个线程交替打印1-100

线程题&#xff1a;交替打印1-100 这里演示两个线程&#xff0c;一个打印奇数&#xff0c;一个打印偶数 方式一&#xff1a;synchronized FixedThreadPool public class example {private static int count 1;private static final Object lock new Object();public stat…

WPF基础DataGrid控件

WPF DataGrid 是一个用于显示和编辑表格数据的强大控件。它提供了丰富的功能&#xff0c;包括排序、筛选、分组、编辑、选择等&#xff0c;使你能够以类似电子表格的方式呈现和操作数据。 DataGrid 的布局主要由以下部分组成&#xff1a; 列定义 (Columns): DataGrid 列定义了…

YOLO目标检测——卫星遥感多类别检测数据集下载分享【含对应voc、coco和yolo三种格式标签】

实际项目应用&#xff1a;卫星遥感目标检测数据集说明&#xff1a;卫星遥感多类别检测数据集&#xff0c;真实场景的高质量图片数据&#xff0c;数据场景丰富&#xff0c;含网球场、棒球场、篮球场、田径场、储罐、车辆、桥、飞机、船等类别标签说明&#xff1a;使用lableimg标…

2023年【上海市安全员C证】考试及上海市安全员C证找解析

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 2023年上海市安全员C证考试为正在备考上海市安全员C证操作证的学员准备的理论考试专题&#xff0c;每个月更新的上海市安全员C证找解析祝您顺利通过上海市安全员C证考试。 1、【多选题】2017年9月颁发的《中共上海市委…

基于STM32的烟雾浓度检测报警仿真设计(仿真+程序+讲解视频)

这里写目录标题 &#x1f4d1;1.主要功能&#x1f4d1;2.仿真&#x1f4d1;3. 程序&#x1f4d1;4. 资料清单&下载链接&#x1f4d1;[资料下载链接](https://docs.qq.com/doc/DS0VHTmxmUHBtVGVP) 基于STM32的烟雾浓度检测报警仿真设计(仿真程序讲解&#xff09; 仿真图prot…

【数据结构】B : DS图应用--最短路径

B : DS图应用–最短路径 文章目录 B : DS图应用--最短路径DescriptionInputOutputSampleInput Output 解题思路&#xff1a;初始化主循环心得&#xff1a; AC代码 Description 给出一个图的邻接矩阵&#xff0c;再给出指定顶点v0&#xff0c;求顶点v0到其他顶点的最短路径 In…

SkyWalking配置报警推送到企业微信

1、先在企业微信群里创建一个机器人&#xff0c;复制webhook的地址&#xff1a; 2、找到SkyWalking部署位置的alarm-settings.yml文件 编辑&#xff0c;在最后面加上此段配置 &#xff01;&#xff01;&#xff01;一定格式要对&#xff0c;不然一直报警报不出来按照网上指导…

JVM 堆外内存详解

Java 进程内存占用除了JVM 运行时数据区&#xff0c;还有直接内存&#xff08;Direct Memory&#xff09;区域及 JVM 程序自身也会占用内存 直接内存&#xff08;Direct Memory&#xff09;区域&#xff1a;直接内存通过使用Native堆外内存来存储数据&#xff0c;这意味着数据…

大数据平台实践之CDH6.2.1+spark3.3.0+kyuubi-1.6.0

前言&#xff1a;关于kyuubi的原理和功能这里不做详细的介绍&#xff0c;感兴趣的同学可以直通官网&#xff1a;https://kyuubi.readthedocs.io/en/v1.7.1-rc0/index.html 下载软件版本 wget http://distfiles.macports.org/scala2.12/scala-2.12.16.tgz wget https://archi…

pikachu_php反序列化

pikachu_php反序列化 源代码 class S{var $test "pikachu";function __construct(){echo $this->test;} }//O:1:"S":1:{s:4:"test";s:29:"<script>alert(xss)</script>";} $html; if(isset($_POST[o])){$s $_POST[…