论文笔记之:Deep Attention Recurrent Q-Network

  

Deep Attention Recurrent Q-Network

5vision groups 

 

   摘要:本文将 DQN 引入了 Attention 机制,使得学习更具有方向性和指导性。(前段时间做一个工作打算就这么干,谁想到,这么快就被这几个孩子给实现了,自愧不如啊( ⊙ o ⊙ ))

    引言:我们知道 DQN 是将连续 4帧的视频信息输入到 CNN 当中,那么,这么做虽然取得了不错的效果,但是,仍然只是能记住这 4 帧的信息,之前的就会遗忘。所以就有研究者提出了 Deep Recurrent Q-Network (DRQN),一个结合 LSTM 和 DQN 的工作:

  1. the fully connected layer in the latter is replaced for a LSTM one , 

  2. only the last visual frame at each time step is used as DQN's input. 

  作者指出虽然只是使用了一帧的信息,但是 DRQN 仍然抓住了帧间的相关信息。尽管如此,仍然没有看到在 Atari game上有系统的提升。

 

   另一个缺点是:长时间的训练时间。据说,在单个 GPU 上训练时间达到 12-14天。于是,有人就提出了并行版本的算法来提升训练速度。作者认为并行计算并不是唯一的,最有效的方法来解决这个问题。 

  

   最近 visual attention models 在各个任务上都取得了惊人的效果。利用这个机制的优势在于:仅仅需要选择然后注意一个较小的图像区域,可以帮助降低参数的个数,从而帮助加速训练和测试。对比 DRQN,本文的 LSTM 机制存储的数据不仅用于下一个 actions 的选择,也用于 选择下一个 Attention 区域。此外,除了计算速度上的改进之外,Attention-based models 也可以增加 Deep Q-Learning 的可读性,提供给研究者一个机会去观察 agent 的集中区域在哪里以及是什么,(where and what)。

 

 


  

  Deep Attention Recurrent Q-Network:

 

 

    如上图所示,DARQN 结构主要由 三种类型的网络构成:convolutional (CNN), attention, and recurrent . 在每一个时间步骤 t,CNN 收到当前游戏状态 $s_t$ 的一个表示,根据这个状态产生一组 D feature maps,每一个的维度是 m * m。Attention network 将这些 maps 转换成一组向量 $v_t = \{ v_t^1, ... , v_t^L \}$,L = m*m,然后输出其线性组合 $z_t$,称为 a context vector. 这个 recurrent network,在我们这里是 LSTM,将 context vector 作为输入,以及 之前的 hidden state $h_{t-1}$,memory state $c_{t-1}$,产生 hidden state $h_t$ 用于:

  1. a linear layer for evaluating Q-value of each action $a_t$ that the agent can take being in state $s_t$ ; 

  2. the attention network for generating a context vector at the next time step t+1. 

 


 

  Soft attention 

  这一小节提到的 "soft" Attention mechanism 假设 the context vector $z_t$ 可以表示为 所有向量 $v_t^i$ 的加权和,每一个对应了从图像不同区域提取出来的 CNN 特征。权重 和 这个 vector 的重要程度成正比例,并且是通过 Attention network g 衡量的。g network 包含两个 fc layer 后面是一个 softmax layer。其输出可以表示为:

  其中,Z是一个normalizing constant。W 是权重矩阵,Linear(x) = Ax + b 是一个放射变换,权重矩阵是A,偏差是 b。我们一旦定义出了每一个位置向量的重要性,我们可以计算出 context vector 为:

  另一个网络在第三小节进行详细的介绍。整个 DARQN model 是通过最小化序列损失函数完成训练:

  其中,$Y_t$ 是一个近似的 target value,为了优化这个损失函数,我们利用标准的 Q-learning 更新规则:

  DARQN 中的 functions 都是可微分的,所以每一个参数都有梯度,整个模型可以 end-to-end 的进行训练。本文的算法也借鉴了 target network 和 experience replay 的技术。

 


 

  Hard Attention

  此处的 hard attention mechanism 采样的时候要求仅仅从图像中采样一个图像 patch。

  假设 $s_t$ 从环境中采样的时候,受到了 attention policy 的影响,attention network g 的softmax layer 给出了带参数的类别分布(categorical distribution)。然后,在策略梯度方法,策略参数的更新可以表示为:

  其中 $R_t$ 是将来的折扣的损失。为了估计这个值,另一个网络 $G_t = Linear(h_t)$ 才引入进来。这个网络通过朝向 期望值 $Y_t$ 进行网络训练。Attention network 参数最终的更新采用如下的方式进行:

    其中 $G_t - Y_t$ 是advantage function estimation。

  

  作者提供了源代码:https://github.com/5vision/DARQN  

  

  实验部分

  

 

 

 

 


 

  总结:   

 

 

  

 

 

 

 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/457847.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Codeforces Round #354 (Div. 2)

贪心 A Nicholas and Permutation #include <bits/stdc.h>typedef long long ll; const int N 1e5 5; int a[105]; int pos[105];int main() {int n;scanf ("%d", &n);for (int i1; i<n; i) {scanf ("%d", ai);pos[a[i]] i;}int ans abs …

linux c程序中内核态与用户态内存存储问题

Unix/Linux的体系架构 如上图所示&#xff0c;从宏观上来看&#xff0c;Linux操作系统的体系架构分为用户态和内核态&#xff08;或者用户空间和内核&#xff09;。内核从本质上看是一种软件——控制计算机的硬件资源&#xff0c;并提供上层应用程序运行的环境。用户态即上层应…

线程自动退出_C++基础 多线程笔记(一)

join & detachjoin和detach为最基本的用法&#xff0c;join可以使主线程&#xff08;main函数&#xff09;等待子线程&#xff08;自定义的function_1函数&#xff09;完成后再退出程序&#xff0c;而detach可以使子线程与主线程毫无关联的独立运行&#xff0c;当主线程执行…

WEB在线预览PDF

这是我在博客园发表的第一篇文章。以后会陆续把在线预览其他格式文档的解决方案发表出来。 解决思路&#xff1a;把pdf转换成html显示。 在线预览pdf我暂时了解3种解决方案&#xff0c;欢迎大家补充。 方案一&#xff1a; 利用pdf2html软件将PDF转换成HTML。 用法: PDF2HTML [选…

[算法]判断一个数是不是2的N次方

如果一个数是2^n&#xff0c;说明这个二进制里面只有一个1。除了1. a (10000)b a-1 (01111)b a&(a-1) 0。 如果一个数不是2^n&#xff0c; 说明它的二进制里含有多一个1。 a (1xxx100)b a-1(1xxx011)b 那么 a&(a-1)就是 (1xxx000)b&#xff0c; 而不会为0。 所以可…

VMware Ubuntu 全屏问题解决

在终端中输入&#xff1a; sudo apt install open-vm* 回车 自动解决

数组拼接时中间怎么加入空格_【题解二维数组】1123:图像相似度

1123&#xff1a;图像相似度时间限制: 1000 ms 内存限制: 65536 KB【题目描述】给出两幅相同大小的黑白图像(用0-1矩阵)表示&#xff0c;求它们的相似度。说明&#xff1a;若两幅图像在相同位置上的像素点颜色相同&#xff0c;则称它们在该位置具有相同的像素点。两幅图像的…

(旧)子数涵数·C语言——条件语句

首先&#xff0c;我们讲一下理论知识&#xff0c;在编程中有三种结构&#xff0c;分别是顺序结构、条件结构、循环结构&#xff0c;如果用流程图来表示的话就是&#xff1a; 那么在C语言中&#xff0c;如何灵活运用这三种结构呢&#xff1f;这就需要用到控制语句了。 而条件语句…

apache.commons.lang.StringUtils 使用心得

apache.commons.lang.StringUtils 使用心得 转载于:https://www.cnblogs.com/qinglizlp/p/5549687.html

python哪个版本支持xp_windows支持哪个版本的python

Windows操作系统支持Python的Python2版本和Python3版本&#xff0c;下载安装时要根据windows的操作系统来选择对应的Python安装包&#xff0c;否则将不能安装成功。 Python是跨平台的&#xff0c;免费开源的一门计算机编程语言。是一种面向对象的动态类型语言&#xff0c;最初被…

Ubuntu 键盘错位解决 更改键盘布局

原因是键盘布局不能适应键盘 解绝方法&#xff1a;更改键盘布局 一般改为标准104键盘就行 在终端输入 sudo dpkg-reconfigure keyboard-configuration 选择 标准104键盘 然后一直回车就行

【No.1 Ionic】基础环境配置

Node 安装git clone https://github.com/nodejs/node cd node ./configure make sudo make install node -v npm -vnpm设置淘宝镜像npm config set registry https://registry.npm.taobao.org npm config set disturl https://npm.taobao.org/distIOS Simulatorsudo npm instal…

识别操作系统

使用p0f进行操作系统探测 p0f是一款被动探测工具&#xff0c;通过分析网络数据包来判断操作系统类型。目前最新版本为3.06b。同时p0f在网络分析方面功能强大&#xff0c;可以用它来分析NAT、负载均衡、应用代理等。 p0f的命令参数很简单&#xff0c;基本说明如下&#xff1a; l…

常用RGB颜色表

转载于:https://www.cnblogs.com/Itwonderful/p/5550800.html

python中seek函数的用法_在Python中操作文件之seek()方法的使用教程

seek()方法在偏移设定该文件的当前位置。参数是可选的&#xff0c;默认为0&#xff0c;这意味着绝对的文件定位&#xff0c;它的值如果是1&#xff0c;这意味着寻求相对于当前位置&#xff0c;2表示相对于文件的末尾。 没有返回值。需要注意的是&#xff0c;如果该文件被打开或…

WPF中Grid实现网格,表格样式通用类(转)

/// <summary> /// 给Grid添加边框线 /// </summary> /// <param name"grid"></param> public static void InsertFrameForGrid(Grid grid) { var rowcon grid.RowDefinitions.Count; var clcon grid.ColumnDefinitions.Count; for (var i…

VS2017 安装 QT5.9

VS2017专业版使用最新版Qt5.9.2教程&#xff08;最新教材&#xff09; 目录 VS2017专业版使用最新版Qt5.9.2教程&#xff08;最新教材&#xff09; 运行环境&#xff1a; 1.安装Qt5.9.2 2.安装Qt5.9与VS2017之间的插件: 3.配置Qt VS Tool的环境. 4.设置创建的Qt的项目的属…

异步与并行~ReaderWriterLockSlim实现的共享锁和互斥锁

返回目录 在System.Threading.Tasks命名空间下&#xff0c;使用ReaderWriterLockSlim对象来实现多线程并发时的锁管理&#xff0c;它比lock来说&#xff0c;性能更好&#xff0c;也并合理&#xff0c;我们都知道lock可以对代码块进行锁定&#xff0c;当多线程共同访问代码时&am…

linux ssh yum升级_Linux 运维必备的 13 款实用工具,拿好了

作者丨Erstickthttp://blog.51cto.com/13740508/2114819本文介绍几款 Linux 运维比较实用的工具&#xff0c;希望对 Linux 运维人员有所帮助。1. 查看进程占用带宽情况 - NethogsNethogs 是一个终端下的网络流量监控工具可以直观的显示每个进程占用的带宽。下载&#xff1a;htt…

iOS应用如何支持IPV6

本文转自 http://www.code4app.com/forum.php?modviewthread&tid8427&highlightipv6 果然是苹果打个哈欠&#xff0c;iOS行业内就得起一次风暴呀。自从5月初Apple明文规定所有开发者在6月1号以后提交新版本需要支持IPV6-Only的网络&#xff0c;大家便开始热火朝天的研…