(2020|ICML PMLR,线性 Transformer,核函数,RNN)Transformer 是 RNN

Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention

公众号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群)

目录

0. 摘要

3. 线性 Transformers

3.1. Transformer

3.2. 线性注意力机制

3.2.1. 特征映射与计算成本

3.3. 因果掩码

3.3.1. 梯度计算

3.3.2. 训练和推理

3.4. transformer 是 RNN

4. 实验


0. 摘要

Transformer 在多项任务中表现出色,但由于其对输入长度的二次复杂度,对于非常长的序列来说,速度极慢。为了解决这一限制,我们将自注意力表示为核特征映射(kernel feature maps)的线性点积,并利用矩阵乘积的结合性将复杂度从 O(N^2) 降低到 O(N),其中 N 是序列长度。我们证明了这种表达方式允许一种迭代实现,大大加速了自回归 Transformer,并揭示了它们与递归神经网络的关系。我们的线性 Transformer 在性能上与普通 Transformer 相似,并且在非常长序列的自回归预测中速度快达 4000 倍。 

3. 线性 Transformers

在本节中,我们提出了线性 Transformer。我们展示了将传统的 softmax 注意力机制改为基于特征映射的点积注意力,可以改善时间和内存复杂度,并且可以实现类似于 RNN 的线性时间序列生成模型。

3.1. Transformer

3.2. 线性注意力机制

公式 2 中的注意力定义是通用的,可以用于定义多种其他注意力实现,例如多项式注意力或 RBF 核注意力(Tsai等人,2019)。注意,为了使公式 3 定义的注意力函数有效,我们需要对 sim(·) 施加的唯一约束是非负性。这包括所有核函数 k(x, y): R^(2 × F) → R_+。

给定具有特征表示 ϕ(x) 的核函数,我们可以将公式 2 重写为:

然后利用矩阵乘法的结合性进一步简化为:

当分子以向量形式书写时,上述公式更容易理解,如下所示:

注意,特征映射 ϕ(·) 是逐行应用于矩阵 Q 和 K 的。

从公式 2 可以看出,softmax 注意力的计算成本随 O(N^2) 缩放,其中 N 表示序列长度。内存需求也是如此,因为必须存储完整的注意力矩阵以计算查询、键和值的梯度。相比之下,我们在公式 5 中提出的线性 transformer 具有 O(N) 的时间和内存复杂度,因为我们可以计算

一次,并在每个查询中重复使用它们。

3.2.1. 特征映射与计算成本

对于 softmax 注意力,就乘法和加法的总成本而言,随着 O(N^2·max(D, M)) 缩放,其中 D 是查询和键的维度,M 是值的维度。相反,对于线性注意力,我们首先计算维度为 C 的特征映射。随后,计算新值需要 O(NCM) 次加法和乘法。

上述分析未考虑核函数和特征函数的选择。需要注意的是,对应于指数核的特征函数是无限维的,这使得精确 softmax 注意力的线性化不可行。另一方面,例如多项式核具有精确的有限维特征映射,并且已证明与指数或 RBF 核(Tsai等人,2019)同样有效。线性化多项式 transformer 的计算成本为 O(N·D^2·M)。当 N > D^2 时,这使得计算复杂度更具优势。实际上,由于我们希望能够处理成千上万元素的序列,这一情况是成立的。

对于我们的实验,处理较小的序列,我们采用了一个结果为正相似函数的特征映射,如下定义:

其中 elu(·) 表示指数线性单元(Clevert等人,2015)的激活函数。我们更喜欢 elu(·) 而不是relu(·),以避免在 x 为负时将梯度设置为 0。这种特征映射导致的注意力函数需要 O(NDM) 次乘法和加法。在我们的实验部分,我们展示了公式 7 的特征映射在性能上与完整 transformer 相当,同时显著减少了计算和内存需求。

3.3. 因果掩码

transformer  架构可以通过掩蔽(masking)注意力计算来高效地训练自回归模型,使得第 i 个位置只能被第 j 个位置影响当且仅当 j ≤ i,即一个位置不能被后续位置影响。形式上,这种因果掩码将公式 3 修改如下:

按照3.2节的推理,我们如下所述对掩码注意力进行线性化:

通过引入 Si 和 Zi 如下所示:

我们可以将公式 9 简化为:

注意,Si 和 Zi 可以从 S_(i-1) 和 Z_(i-1) 在固定时间内计算得出,因此使得具有因果掩码的线性 transformer 的计算复杂度相对于序列长度为线性。

3.3.1. 梯度计算

在任何深度学习框架中,公式 12 的朴素实现需要存储所有中间值 Si,以计算梯度。这会增加max(D, M) 倍的内存消耗,从而阻碍因果线性注意力在更长序列或更深模型中的应用。为了解决这个问题,我们将公式 9 中的分子(numerator)的梯度导出为累积和。这使我们能够在线性时间和固定内存中计算因果线性注意力的前向和后向传播。详细推导见附录材料。

给定分子 ¯V_i 和标量损失函数相对于分子的梯度

推导可得:

累计和项在公式 9 和 13-15 中以线性时间计算,并且相对于序列长度需要常量内存。这导致的算法在给定维度为 C 的特征映射下,其计算复杂度为 O(NCM),内存复杂度为 O(N·max (C, M))。算法 1 是分子部分前向和后向传播的伪代码实现。

3.3.2. 训练和推理

在训练自回归 transformer 模型时,可以使用完整的真实序列。这使得公式 1 中的函数 φ(·) 和注意力计算都可以进行分层并行化。因此,transformer 比 RNN 更高效地进行训练。然而,在推理过程中,时间步 i 的输出是时间步 i + 1 的输入。这使得自回归模型无法并行化。此外,transformer 每个时间步的成本不是常量,而是随着当前序列长度的平方增长,因为必须为所有先前的时间步计算注意力。

我们提出的线性 transformer 模型结合了这两者的优点。在训练时,计算可以并行化并充分利用 GPU 或其他加速器。在推理时,我们模型的每次预测在时间和内存上的成本是常量的。这意味着我们可以简单地将

矩阵存储为内部状态,并在每个时间步像递归神经网络一样更新它。这使得推理速度比其他 transformer 模型快数千倍。

3.4. transformer 是 RNN

在文献中,transformer 模型被认为是一种与递归神经网络(RNN)根本不同的方法。然而,从 3.3 节中的因果掩码公式和前一节的讨论可以看出,任何具有因果掩码的 transformer 层都可以被表示为一种模型,该模型在给定输入后修改内部状态,然后预测输出,即 RNN。注意,与通用变压器(Universal Transformers)(Dehghani等人,2018)不同,我们考虑的是时间上的递归,而不是深度上的递归。

在以下公式中,我们将公式 1 的 Transformer 层形式化为 RNN。所得的 RNN 有两个隐藏状态,即注意力记忆 s 和归一化记忆 z。我们用下标表示递归中的时间步。

在上述公式中,x_i 表示特定 Transformer 层的第 i 个输入,y_i 表示第 i 个输出。需要注意的是,我们的公式对特征函数没有任何约束,因此可以用于表示任何 Transformer 模型,理论上甚至包括使用 softmax 注意力的模型。这一公式是更好理解 Transformer 与流行的 RNN(Hochreiter & Schmidhuber, 1997)及其存储和检索信息过程之间关系的第一步。 

4. 实验

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/17911.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2023、2024国赛web复现wp

2023 Unzip 类型&#xff1a;任意文件上传漏洞 主要知识点&#xff1a;软链接 随便上传一个一句话木马文件&#xff0c;得到一串php代码 根据代码上传zip文件发现进入后还是此页面 代码审计&#xff1a; <?php error_reporting(0); highlight_file(__FILE__);$finfo fin…

Stable Diffusion【写实模型】:逼真,逼真,超级逼真的国产超写实摄影大模型万享XL

今天和大家分享的是一个国产万享系列中使用量最高的大模型:万享XL_超写实摄影&#xff0c;顾名思义&#xff0c;该大模型主要是面向写实摄影&#xff0c;一方面生成的图片人物皮肤纹理细节超级逼真&#xff0c;另一方面对于光影效果的处理也非常到位。对于万享XL超写实摄影大模…

揭秘Tensor Core黑科技:如何让AI计算速度飞跃

揭秘 Tensor Core 底层&#xff1a;如何让AI计算速度飞跃 Tensor Core&#xff0c;加速深度学习计算的利器&#xff0c;专用于高效执行深度神经网络中的矩阵乘法和卷积运算&#xff0c;提升计算效率。 Tensor Core凭借混合精度计算与张量核心操作&#xff0c;大幅加速深度学习…

参数高效微调PEFT(二)快速入门P-Tuning、P-Tuning V2

参数高效微调PEFT(二)快速入门P-Tuning、P-Tuning V2 参数高效微调PEFT(一)快速入门BitFit、Prompt Tuning、Prefix Tuning 今天&#xff0c;我们继续了解下来自清华大学发布的两种参数高效微调方法P-Tuning和P-Tuning v2。可以简单的将P-Tuning是认为针对Prompt Tuning的改进…

零基础小白本地部署大疆上云api(个人记录供参考)

文章目录 运行前准备前后端项目运行1.前端项目&#xff1a; 后端项目运行必须先依靠emqx运行必须先依靠redis运行修改后端项目的application.yml文件 运行前准备 1.保证电脑又node.js环境&#xff0c;可以正常使用npm 2.Java的jdk必须是11及以上版本否则无效 3.下载好emqx,red…

《java数据结构》--队列详解

一.认识队列&#x1f431; 初识队列&#x1f638; 队列和栈类似都对数据的存取有着严格的要求&#xff0c;不同的是栈遵循先进后出的原则&#xff0c;而队列遵循先进先出的原则&#xff0c;栈是只有一端可以存取&#xff0c;队列是一端存&#xff0c;一端取。这里我来画一个图…

鸿蒙ArkUI-X跨语言调用说明:【平台桥接开发指南(Android)BridgePlugin】

BridgePlugin (平台桥接) 本模块提供ArkUI端和Android平台端消息通信的功能&#xff0c;包括数据传输、方法调用和事件调用。需配套ArkUI端API使用&#xff0c;ArkUI侧具体用法请参考[Bridge API]。 说明&#xff1a; 开发前请熟悉鸿蒙开发指导文档&#xff1a; gitee.com/li-…

“2024 亚马逊云科技中国峰会,挑战俱乐部 Hands On 动手实验课程正在直播中,点击链接畅享生成式AI建构之旅,赢心动好礼

只看不过瘾&#xff1f;别急&#xff01;我们为您准备了【生成式AI助手 Amazon Q 初体验】动手实验&#xff0c;一款生成式人工智能 (AI) 支持的对话助理&#xff0c;可以帮助您理解、构建、扩展和操作 Amazon 应用程序&#xff0c;您可以询问有关 Amazon 架构、最佳实践、文档…

马斯克开启军备竞赛,xAI筹集60亿美元

大模型技术论文不断&#xff0c;每个月总会新增上千篇。本专栏精选论文重点解读&#xff0c;主题还是围绕着行业实践和工程量产。若在某个环节出现卡点&#xff0c;可以回到大模型必备腔调重新阅读。而最新科技&#xff08;Mamba&#xff0c;xLSTM,KAN&#xff09;则提供了大模…

ai智能写作怎么样,5款ai写作软件创作文章太棒了

ai智能写作究竟怎么样呢&#xff1f;在当今数字化的时代&#xff0c;AI智能写作正逐渐成为一种引人瞩目的趋势。AI智能写作是指利用人工智能技术来辅助或代替人类进行文本创作的过程。随着人工智能技术的不断发展&#xff0c;AI智能写作在各个领域都呈现出越来越广泛的应用。本…

微服务架构下的‘黑带’安全大师:Spring Cloud Security全攻略!

深入探讨了微服务间的安全通信、安全策略设计以及面对经典安全问题的应对策略。无论你是微服务的新手还是资深开发者&#xff0c;都能在本文中找到提升安全功力的秘籍。让我们一起成为微服务架构下的‘黑带’安全大师&#xff01; 文章目录 1. 引言微服务安全挑战与重要性Sprin…

SHELL编程(三)网络基础命令 Makefile

目标 一、网络基础及相关命令&#xff08;一&#xff09;网络相关命令&#xff08;二&#xff09;重启网络服务 二、Makefile&#xff08;一&#xff09;标签式语法&#xff08;二&#xff09;目标:依赖 式语法1. 格式2. 编译流程&#xff1a;预处理 编译 汇编 链接3. 目标和伪…

Java入门基础学习笔记50——ATM系统

1、项目演示&#xff1b; 2、项目技术实现&#xff1b; 1&#xff09;面向对象编程&#xff1a; 每个账户都是一个对象&#xff0c;所以要设计账户类Account&#xff0c;用于创建账户对象封装账户信息。ATM同样是一个对象&#xff0c;需要设计ATM类&#xff0c;代表ATM管理系…

windows tomcat服务注册和卸载

首页解压tomcat压缩包&#xff0c;然后进入tomcat bin目录&#xff0c;在此目录通过cmd进入窗口&#xff0c; 1&#xff1a;tomcat服务注册 执行命令&#xff1a;service.bat install tomcat8.5.100 命令执行成功后&#xff0c;会在注册服务列表出现这个服务&#xff0c;如果…

基于ssm+vue图书管理系统

基于ssmvue图书管理系统 ssm477图书管理系统 相关技术 javassmmysqlvueelementui

索引下推详情-简单入手

一.概念 索引下推&#xff08;Index Pushdown&#xff09;MySQL5.6添加的&#xff0c;是一种优化技术&#xff0c;用于在查询执行时将部分计算移动到存储引擎层&#xff0c;从而减少数据传输和计算的开销&#xff08;减少回表查询次数&#xff09;&#xff0c;提高查询性能。 …

14、类与对象(采用图解方式分析内存结构)①

在idea中创建一个新文件&#xff0c;名称为Hello.java 其中&#xff0c;Hello就是一个类&#xff0c;main是这个类里面的方法&#xff0c;这意味着我们在学习的时候已经在使用类了。 对象和类 一、概念二、⭐内存分配机制分析Ⅰ、基本内存结构⭐⭐Ⅱ、调用类方法的内存分析&am…

使用 Django 显示表中的数据

1、问题背景 当我们使用 Django 进行 Web 开发时&#xff0c;经常需要在 Web 页面上显示数据库中的数据。例如&#xff0c;我们可能需要在一个页面上显示所有用户的信息&#xff0c;或者在一个页面上显示所有文章的标题和作者。那么&#xff0c;如何使用 Django 来显示表中的数…

打包软件注意

1.建个文件夹D:333 /Dalsa_Cameras /cam1 cam2 2. 3.缺的包 4.自动启动.exe exe快捷方式放一起

编程零基础,如何学习Python?

初学者选择Python入手着实是一个不错的方向&#xff0c;入手简单且广泛的运用是它最显著的特色了。 那有几个问题&#xff0c;我想是开始学习Python之前应该了解的&#xff0c; python能做什么&#xff1f; 发展前景与工作机会有哪些&#xff1f; 需要学习哪些内容&#xf…