【论文精读】Self-Attentive Assocative Memory,2020

目录

1 引言

这篇论文介绍了基于对象和对象关系的记忆模型,这对于设计类脑记忆模型有很大的启发作用。

2 Outer product attention (OPA)

标准transformer模型中定义的是内积注意力,即dot product attention:
A ° ( q , K , V ) = ∑ i = 1 n k v S ( q ⋅ k i ) v i A^°(q, K, V ) = \sum ^{n_{kv}}_{i=1}S(q \cdot k_i) v_i A°(q,K,V)=i=1nkvS(qki)vi
其中, A ° ∈ R d v , q , k i ∈ R d q k , v i ∈ R d v A^° ∈ R^{d_v} , q, k_i ∈ R^{d_{qk}} , v_i ∈ R^{d_v} A°Rdv,q,kiRdqk,viRdv ⋅ \cdot 表示内积,计算结果是个标量, S S S是一个对向量元素的softmax计算函数。

作者定义了外积注意力命名为Outer product attention:
A ⊗ ( q , K , V ) = ∑ i = 1 n k v F ( q ⊙ k i ) ⊗ v i A^⊗ (q, K, V ) = \sum ^{n_{kv}}_{i=1} \text{F}(q ⊙ k_i) ⊗ v_i A(q,K,V)=i=1nkvF(qki)vi
其中, A ⊗ ∈ R d q k × d v , q , k i ∈ R d q k , v ∈ R d v A^⊗ ∈ R^{d_{qk}×d_v} , q, k_i ∈ R^{d_{qk}} , v ∈ R^{d_v} ARdqk×dv,q,kiRdqk,vRdv ⊙ ⊙ 表示对应位置元素的相乘,计算结果是个同维数向量, ⊗ ⊗ 表示外积, F F F是一个对向量元素的tanh计算函数。

最好对照着标准注意力去理解。
差异: A ° A^° A°是token序列中受注意力关注的token, A ⊗ A^⊗ A是token序列中token之间的关系表征。

3 Self-attentive Associative Memory (SAM)

作者设计了一个关联记忆网络模块,命名为SAM,用来表征item及item之间的关系。

SAM θ ( M ) [ s ] = A ⊗ ( M q [ s ] , M k , M v ) = ∑ j = 1 n k v F ( M q [ s ] ⊙ M k [ j ] ) ⊗ M v [ j ] \begin{align} \text{SAM}_θ (M) [s] &= A^⊗ (M_q [s] , M_k, M_v) \\ &=\sum ^{n_{kv}}_{j=1} \text{F} (M_q [s] ⊙ M_k [j]) ⊗ M_v [j] \end{align} SAMθ(M)[s]=A(Mq[s],Mk,Mv)=j=1nkvF(Mq[s]Mk[j])Mv[j]
其中,
与注意力相关的q,k,v三个向量 M q , M k , M v M_q,M_k,M_v Mq,Mk,Mv
M q = L N ( W q M ) M k = L N ( W k M ) M v = L N ( W v M ) \begin{align} M_q &= \mathcal{LN} (W_qM) \\ M_k &= \mathcal{LN} (W_kM) \\ M_v &= \mathcal{LN} (W_vM) \end{align} MqMkMv=LN(WqM)=LN(WkM)=LN(WvM)
M是输入token序列组成的向量矩阵, M ∈ R n × d M ∈ R^{n×d} MRn×d,n为token序列长度,d为token的维度;
s s s为M中第s行;
W q , W k , W v W_q,W_k,W_v Wq,Wk,Wv是q,k,v对应线性变换层的参数矩阵;
L N \mathcal{LN} LN是 layer normalization操作,而不是激活函数;
θ θ θ代表SAM模块的内部参数是 { W q ∈ R n k v × n , W k ∈ R n k v × n , W v ∈ R n k v × n } \{W_q ∈ R^{n_{kv}×n},W_k ∈ R^{n_{kv}×n},W_v ∈ R^{n_{kv}×n}\} {WqRnkv×n,WkRnkv×n,WvRnkv×n} n q n_q nq是query的个数, n k v n_{kv} nkv是key-value对的个数;

4 SAM-based Two-Memory Model (STM)

作者设计了2个记忆模块分别为 M t i ∈ R d × d , M t r ∈ R n q × d × d M^i_t ∈ R^{d×d}, M^r_t ∈ R^{n_q×d×d} MtiRd×d,MtrRnq×d×d,都是基于SAM实现的,前者是用来记忆item,后者用来记忆item之间的关联关系。
在这里插入图片描述

4.1 M i M^i Mi写操作

X t = f 1 ( x t ) ⊗ f 2 ( x t ) M t i = F t ( M t − 1 i , x t ) ⊙ M t − 1 i + I t ( M t − 1 i , x t ) ⊙ X t \begin{align} X_t &= f_1 (x_t) ⊗ f_2 (x_t) \\ M^i_t &= F_t(M^i_{t−1} , x_t) ⊙ M^i_{t−1} + I_t(M^i_{t−1} , x_t) ⊙X_t \end{align} XtMti=f1(xt)f2(xt)=Ft(Mt1i,xt)Mt1i+It(Mt1i,xt)Xt
其中,
x t x_t xt是输入数据;
f 1 , f 2 f_1, f_2 f1,f2是前馈神经网络,输出维度为d;
F t F_t Ft为遗忘门,计算公式为 F t ( M t − 1 i , x t ) = W F x t + U F t a n h ( M t − 1 i ) + b F F_t(M^i_{t−1} , x_t)= W_F x_t + U_F\mathcal tanh(M^i_{t−1}) + b_F Ft(Mt1i,xt)=WFxt+UFtanh(Mt1i)+bF,其中 W F , U F ∈ R d × d W_F , U_F ∈ R^{d×d} WF,UFRd×d为网络参数;
I t I_t It为输入的门控,计算公式为 I t ( M t − 1 i , x t ) = W I x t + U I t a n h ( M t − 1 i ) + b I I_t(M^i_{t−1} , x_t)= W_I x_t + U_I\mathcal tanh(M^i_{t−1}) + b_I It(Mt1i,xt)=WIxt+UItanh(Mt1i)+bI,其中 W I , U I ∈ R d × d W_I , U_I ∈ R^{d×d} WI,UIRd×d为网络参数;

4.2 M r M^r Mr读操作

v t r = s o f t m a x ( f 3 ( x t ) ⊤ ) M t − 1 r f 2 ( x t ) \begin{align} v^r_t = \mathcal{softmax}(f_3 (x_t)^⊤) M^r_{t−1} f_2 (x_t) \end{align} vtr=softmax(f3(xt))Mt1rf2(xt)
其中,
v t r v^r_t vtr为从关系记忆模块 M r M^r Mr中读出的值,将在下式(9)中使用;
f 3 f_3 f3是前馈神经网络,输出维度为 n q n_q nq;
M t − 1 r M^r_{t−1} Mt1r M r M^r Mr的前一个状态,其状态值由下式(9)计算得到;

4.3 M i M^i Mi读操作和 M r M^r Mr写操作过程

M t r = M t − 1 r + α 1 SAM θ ( M t i + α 2 v t r ⊗ f 2 ( x t ) ) \begin{align} M^r_t = M^r_{t−1} + α_1 \text{SAM}_ \theta (M^i_t + α_2 v^r_t ⊗ f_2 (x_t)) \end{align} Mtr=Mt1r+α1SAMθ(Mti+α2vtrf2(xt))
其中,
α 1 , α 2 α_1,α_2 α1,α2是调和超参数,用于平衡量纲的,又类似于学习率;

4.4 用 M r M^r Mr实现item转移

M i M^i Mi利用 M r M^r Mr实现更新,可以认为是hebbian更新,更新公式如下:
M t i = M t i + α 3 G 1 ◦ V f ◦ M t r \begin{align} M^i_t = M^i_t + α_3 \mathcal{G_1} ◦ \mathcal{V_f} ◦ M^r_t \end{align} Mti=Mti+α3G1VfMtr
其中,
V f \mathcal{V_f} Vf是输入X(其shape为(batch_size, sequeue_length, dimension))的前两维展开的向量;
G 1 \mathcal{G_1} G1是前馈神经网络,负责维度变换 R ( n q d ) × d → R d × d R^{(n_qd)×d} → R^{d×d} R(nqd)×dRd×d,其计算公式为 G 1 ( X ) = W g V f ( X ) \mathcal{G_1}(X) = W^g\mathcal{V_f}(X) G1(X)=WgVf(X)
α 3 α_3 α3是调和超参数;

4.5 模型输出 o t o_t ot

o t = G 3 ◦ V l ◦ G 2 ◦ V l ◦ M t r \begin{align} o_t = \mathcal{G_3} ◦ \mathcal{V_l} ◦ \mathcal{G_2} ◦ \mathcal{V_l} ◦ M^r_t \end{align} ot=G3VlG2VlMtr
其中,
V l \mathcal{V_l} Vl是输入X(其shape为(batch_size, sequeue_length, dimension))的后两维展开的向量;
G 2 , G 3 \mathcal{G_2},\mathcal{G_3} G2G3是前馈神经网络,分别负责维度变换 R n q × d d → R d × d R^{n_q×dd} → R^{d×d} Rnq×ddRd×d R n q n r → R n o R^{n_qn_r} → R^{n_o} RnqnrRno,其计算公式为 G 2 ( X ) = W g V l ( X ) \mathcal{G_2}(X) = W^g\mathcal{V_l}(X) G2(X)=WgVl(X) n q n_q nq是query的个数, n r n_r nr是超参数;

5 实验结果

源代码:https://github.com/thaihungle/SAM
作者做了消融实验,并在几何与图任务、强化学习任务、问答任务上做了测试。具体可以看论文附录和源码。
在这里插入图片描述
在这里插入图片描述

6 总结

该论文一个有趣的idea就是用两个前馈神经网络 M i , M r M^i,M^r Mi,Mr分别表示对象与对象间关系,但是参数更新方法不是梯度下降而是赫布更新,后续可能是一个改进点。

7 参考资料

[1]. Self-Attentive Assocative Memory , 2020.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/14255.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

c语言内存函数的深度解析

本章对 memcpy,memmove,memcmp 三个函数进行详解和模拟实现; 本章重点:3个常见内存函数的使用方法及注意事项并学会模拟实现; 如果您觉得文章不错,期待你的一键三连哦,你的鼓励是我创作的动力…

el-table数据处理

在写表格时遇到,后端返回的数据是对象,并且缺少字段 1.每一条数据加上 一个字段 2.将对象转成数组 以下是数据 {"groupA": {"groupName": null,"orgName": null,"orgId": null,"allPeoper": &quo…

C# 泛型(Generic)

方法重载:方法名称相同,参数个数和参数类型不同; 优势:可以节约方法名称 劣势:方法过多 语法:public void writeContent(T t) 原理:普通的C#代码他是运行在前端进行编译,所有的类型需…

IntersectionObserver实现小程序长列表优化

IntersectionObserver实现小程序长列表优化 关于 IntersectionObserver 思路 这里以一屏数据为单位【一个分页的10条数据,最好大于视口高度】, 监听每一屏数据和视口的相交比例,即用户能不能看到它 只将可视范围的数据渲染到页面上&#x…

[Spring]Spring声明式事务总结

文章目录 1、介绍2、Spring事务的隔离级别3、事务的传播行为4、Transactional注解包含的属性5、使用6、Transactional失效场景 1、介绍 声明式事务管理是建立在 AOP 之上的。其本质是通过 AOP 功能,对方法前后进行拦截,将事务处理的功能编织到拦截的方法…

Oracle 19c 报ORA-704 ORA-01555故障处理---惜分飞

异常断电导致数据库无法启动,尝试对数据文件进行recover操作,报ORA-00283 ORA-00742 ORA-00312错误,由于redo写丢失无法正常应用 D:\check_db>sqlplus / as sysdba SQL*Plus: Release 19.0.0.0.0 - Production on 星期日 7月 30 07:49:19 2023 Version 19.3.0.0.0 Copyrig…

利用读时建模等数据分析能力,实现网络安全态势感知的落地

摘要:本文提出一种基于鸿鹄数据平台的网络安全态势感知系统,系统借助鸿鹄数据平台读时建模、时序处理、数据搜索等高效灵活的超大数据存储和分析处理能力,支持海量大数据存储、分类、统计到数据分析、关联、预测、判断的网络安全态势感知能力…

FastAPI 5 - 依赖、安全

文章目录 一、Dependencies 依赖注入1、函数作为依赖2、类作为依赖3、多次依赖4、同时依赖多个二、安全、授权2、获取当前用户3、密码验证、令牌使用4、JWT 令牌、哈希加密学习自:FastAPI教程第二季(三):依赖+安全(最快python异步并发web框架之一) https://www.bilibili.…

PID模块化__以stm32直流电机速度为例

文章目录 前言一、相关PID源码.c.h 二、如何使用1.创建变量2.初始化3.运算4.修改pid参数 总结 前言 本篇使用到的基于这个STM32CubeMX 直流电机PID速度控制、HAL库、cubemx、PID、速度控制、增量式 由于上次使用的pid没有模块化,当多出使用pid的时候就会很麻烦 所以…

CentOS7系统Nvidia Docker容器基于TensorFlow2.12测试GPU

CentOS7系统Nvidia Docker容器基于TensorFlow1.15测试GPU 参考我的另一篇博客 1. 安装NVIDIA-Docker的Tensorflow2.12.0版本 1. 版本依赖对应关系:从源代码构建 | TensorFlow GPU 版本Python 版本编译器构建工具cuDNNCUDAtensorflow-2.6.03.6-3.9GCC 7.3.1Ba…

beego通过gorm访问mysql数据库

一、下载golang 二、解压下载包到C盘 三、配置golang系统环境变量 四、进入新建的工作目录C:\project下载并安装beego 五、将新生成的bee.exe所在的路径c:\project\bin加入到系统变量path里面 六、下载并安装mysql 例如在上图中, 选“No thanks,just start my down…

如何在3ds max中创建可用于真人场景的巨型机器人:第 3 部分

推荐: NSDT场景编辑器助你快速搭建可二次开发的3D应用场景 1. 创建腿部装备 步骤 1 打开 3ds Max。 打开在本教程最后一部分中保存的文件。 打开 3ds Max 步骤 2 转到创建> 系统并单击骨骼。 创建>系统 步骤 3 为的 侧视口中的腿,如下图所示…

Java 程序员:本是为了跳槽刷完 1000 道真题,想不到被老板知道直接给我升职

同事:前阵子听说你要跳槽,现在准备得怎么样啊? 程序员 T:不跳了 同事:啊?为什么? 程序员 T:涨薪了呗? 同事:真的吗?涨了多少?你自…

R语言无法调用stats.dll的问题解决方案[补充]

写在前面 在去年10月份,出过一起关于R语言无法调用stats.dll的问题解决方案,今天(你看到后是昨天)不知道为什么,安装包,一直安装不了,真的是炸裂了。后面再次把R与Rstuido升级。说实话,我是真不…

flutter 图片相关

官方链接:https://api.flutter.dev/flutter/widgets/Image-class.html 图片基本使用 显示本地图片时,要在pubspec.yaml文件里面添加如:(注意空格) assets: - assets/images/logo.png Fit属性: BoxFit.cover最常用 显示可能拉伸,可能裁…

etcd入门和常用操作

概述 etcd 是一个高可用的分布式键值(key-value)数据库,采用了更为简洁的Raft共识算法来实现数据强一致。基于Go语言实现,主要用于共享配置和服务发现。 名称说明 名称说明etcd一种基于 raft 协议的分布式 kv 数据库&#xff0…

秋招算法备战第31天 | 贪心算法理论基础、455.分发饼干、376. 摆动序列、53. 最大子序和

贪心算法理论基础 贪心算法并没有固定的套路,唯一的难点就是如何通过局部最优,推出整体最优。如何验证可不可以用贪心算法呢?最好用的策略就是举反例,如果想不到反例,那么就试一试贪心吧。刷题或者面试的时候&#xf…

C语言指针详解

C语言指针详解 字符指针1.如何定义2.类型和指向的内容3.代码例子 指针数组1.如何定义2.类型和内容 数组指针1.如何定义2.类型和指向类型3.数组名vs&数组名数组指针运用 数组参数&指针参数一维数组传参二维数组传参一级指针传参二级指针传参 函数指针1.如何定义2.类型和…

Java ~ Collection/Executor ~ DelayQueue【总结】

前言 文章 相关系列:《Java ~ Collection【目录】》(持续更新)相关系列:《Java ~ Executor【目录】》(持续更新)相关系列:《Java ~ Collection/Executor ~ DelayQueue【源码】》(学…

transformer从开始到结束

首先输入是64 * 10的矩阵,代表64个句子,每个句子10个词。 X = self.positionalEncoding(self.embedding(X)*math.sqrt(self.num_hiddens))在经过embeddeding之后,变为64 * 10 *32 矩阵,每个词使用32维向量表示。然后将数据放入 X = encoder_block(X,valid_lens),这里我们将…