Mamba详细介绍和RNN、Transformer的架构可视化对比

Transformer体系结构已经成为大型语言模型(llm)成功的主要组成部分。为了进一步改进llm,人们正在研发可能优于Transformer体系结构的新体系结构。其中一种方法是Mamba(一种状态空间模型)。

Mamba: Linear-Time Sequence Modeling with Selective State Spaces一文中提出了Mamba,我们在之前的文章中也有详细的介绍。

在本篇文章中,通过将绘制RNN,transformer,和Mamba的架构图,并进行详细的对比,这样我们可以更详细的了解它们之间的区别。

为了说明为什么Mamba是这样一个有趣的架构,让我们先介绍Transformer。

Transformer

Transformer将任何文本输入视为由令牌组成的序列。

transformer的一个主要优点是,无论它接收到多长的输入,它都使用序列中的任何令牌信息(无论序列有多长)来对输入数据进行处理。

这就是我们在论文中看到的注意力机制的作用,但是为了获得全局信息,注意力机制在长序列上非常耗费显存,这个我们后面说。

Transformer由两个结构组成,一组用于表示文本的编码器块和一组用于生成文本的解码器块。这些结构可以用于多种任务,包括翻译。

我们可以采用这种结构来创建仅使用解码器的生成模型。比如基于Transformer的GPT,使用解码器块来完成一些输入文本。

单个解码器块由两个主要部分组成,一个是自注意力模块,另一个是前馈神经网络。

注意力创建一个矩阵,将每个令牌与之前的每个令牌进行比较。矩阵中的权重由令牌对之间的相关性决定。

它支持并行化,所以可以极大地加快训练速度!

但是当生成下一个令牌时,我们需要重新计算整个序列的注意力,即使我们已经生成了一些新的令牌。

为长度为L的序列生成令牌大约需要L²的计算量,如果序列长度增加,计算量可能会很大。并且在这里需要计算所有令牌的注意力,所以如果序列很长,那么内存占用也会很大。所以需要重新计算整个序列是Transformer体系结构的主要瓶颈。当然也有很多技巧来提升注意力机制的效率,这里我们暂时不提,只看最经典的原始论文。

RNN

下面我们介绍更早的序列模型RNN。循环神经网络(RNN)是一种基于序列的网络。它在序列的每个时间步长取两个输入,即时间步长t的输入和前一个时间步长t-1的隐藏状态,以生成下一个隐藏状态并预测输出。

RNN有一个循环机制,允许它们将信息从上一步传递到下一步。我们可以“展开”这个可视化,使它更明确。

在生成输出时,RNN只需要考虑之前的隐藏状态和当前的输入。这样不会重新计算以前的隐藏状态,这正Transformer没有的。

这种流程可以让RNN进行快速推理,因为的时间与序列长度线性扩展!并且可以有无限的上下文长度(理论上),因为每次推理他只取一个隐藏状态和当前输入,内存的占用是非常稳定的。

我们将RNN应用于之前使用过的输入文本。

每个隐藏状态都是以前所有隐藏状态的聚合。但是这里就出现了问题,在生成名称“Maarten”时,最后一个隐藏状态不再包含关于单词“Hello”的信息(或者说最早的信息会被坐进的信息覆盖)。这会导致随着时间的推移,rnn会忘记信息,因为它们只考虑前一个状态。

并且rnn的这种顺序性产生了另一个问题。训练不能并行进行,因为它需要按顺序完成每一步。

与Transformer相比,rnn的问题完全相反!它的推理速度非常快,但不能并行化导致训练很慢。

人们一直在寻找一种既能像Transformer那样并行化训练,能够记住先前的信息,并且在推理时间还是随序列长度线性增长的模型,Mamba就是这样宣传的。

在介绍Mamba之前,让我们还需要介绍以下状态空间模型

The State Space Model (SSM)

状态空间模型(SSM),像Transformer和RNN一样,可以处理序列信息,比如文本,也包括信号。

状态空间是包含能够完全描述一个系统的最少数量变量的概念。它是一种通过定义系统可能的状态来数学表示问题的方式。

比如说我们正在通过一个迷宫。“状态空间” 就是所有可能位置(状态)的地图。每个点代表迷宫中的一个独特位置,具有特定的细节,比如你离出口有多远。

“状态空间表示” 是对这个地图的简化描述。它展示了你当前所处的位置(当前状态),以及下一步可以去哪里(可能的未来)。

虽然状态空间模型使用方程和矩阵来跟踪这种行为,描述状态的变量,在我们的例子中是X和Y坐标以及到出口的距离,可以表示为“状态向量”。

听起来熟悉吗?这不就是强化学习中的状态吗,我个人认为是可以这么理解的,那么怎么和序列有关呢?

因为语言模型中的嵌入或向量也经常用于描述输入序列的“状态”。例如,你当前位置的向量(状态向量)可能看起来像这样:

在神经网络中,“状态”通常是指其隐藏状态,在大型语言模型的背景下,这是生成新标记的一个最重要的方面之一。

状态空间模型(SSMs)是用于描述这些状态表示并根据某些输入进行下一个状态预测的模型。

在时间t,状态空间模型(SSMs):

  • 将输入序列x(t)(例如,在迷宫中向左和向下移动)映射到潜在状态表示h(t)(例如,到出口的距离和x/y坐标),
  • 并推导出预测的输出序列y(t)(例如,再次向左移动以更快地到达出口)。

这里就与强化学习中使用离散序列(如仅向左移动一次)不同,它将连续序列作为输入并预测输出序列。

ssm假设动态系统,例如在三维空间中移动的物体,可以通过两个方程从时间t的状态预测。

通过求解这些方程,假设可以揭示基于观测数据(输入序列和先前状态)预测系统状态的统计原理。

它的目标是找到这个状态表示h(t)这样我们就可以从一个输入序列到一个输出序列。

这两个方程就是是状态空间模型的核心。状态方程描述了基于输入如何影响状态(通过矩阵B)的状态变化(通过矩阵A)。

h(t)表示任意时刻t的潜在状态表示,而x(t)表示某个输入。

输出方程描述了状态如何转化为输出(通过矩阵C),以及输入如何影响输出(通过矩阵D)。

矩阵A、B、C和D通常被称为参数,因为它们是可学习的。将这两个方程可视化,我们可以得到如下架构:

下面我们看看这些矩阵如何影响学习过程。

假设我们有一个输入信号x(t)这个信号首先乘以矩阵B它描述了输入如何影响系统。

更新状态(h)是包含环境核心“知识”的潜在空间。我们将状态与矩阵A相乘,矩阵A描述了所有内部状态是如何连接的,因为它们代表了系统的潜在表示。

这里可以看到,在创建状态表示之前应用矩阵A,并在状态表示更新之后更新矩阵A。

然后使用矩阵C来描述如何将状态转换为输出。

最后利用矩阵D提供从输入到输出的直接信号。这通常也被称为跳过(残差)连接。

由于矩阵D类似于跳过连接,所以SSM通常被视为为不进行跳过连接的部分

回到我们的简化视图,现在可以将重点放在矩阵A、B和C上,它们是SSM的核心。

更新原始方程并添加一些颜色来表示每个矩阵的目的

这两个方程根据观测数据预测系统的状态。由于期望输入是连续的,SSM是连续时间表示。

但是因为文字都是离散的输入,我们还需要将模型离散化。这里就要使用* Zero-order hold * 技术

每次我们接收到一个离散信号,都会保证他的值不变,直到接收到一个新的离散信号再改变。这个过程创建了一个SSM可以使用的连续信号:

我们保持该值的时间由一个新的可学习参数表示,称为步长∆。这样就得到了一个连续的信号并且可以只根据输入的时间步长对值进行采样。

这些采样值就是我们的离散输出!在数学上,我们可以应用Zero-order hold如下:

因为我们SSM处理的是离散信号,所以这里不是一个函数到函数,x(t)→y(t),而是一个序列到序列,xₖ→yₖ,我们用公式表示如下:

矩阵A和B现在表示模型的离散参数,用k代替t来表示离散的时间步长。

离散化的SSM允许在特定的时间步中处理信息。就像我们之前在循环神经网络(RNNs)中看到的那样,循环方法在这里也非常有用,可以将问题重新表述为时间步骤:

在每个时间步长,我们计算当前输入(Bxₖ)如何影响前一个状态(Ahₖ₁),然后计算预测输出(Chₖ)。

这种表示看起来是不是有点熟悉?其实他的处理方法和RNN一样

也可以这样展开:

这种技术与RNN类似,快速推理和慢速训练。

另一种ssm的表示是卷积的表示。我们应用过滤器(核)来获得聚合特征:

因为我们处理的是文本而不是图像,所以我只要一维的视角:

我们用来表示这个“过滤器”的核是由SSM公式推导出来的:

可以使用SSM核遍历每一组令牌并计算输出:

上图也说明了padding 可能对输出产生的影响,所以我们一般都会在末尾padding而不是在前面。第二步核被移动一次来执行下一步计算:

在最后一步,我们可以看到核的完整效果:

卷积的一个主要好处是它可以并行训练。但是由于核大小是固定,它们的推理不如rnn快速并且对序列长度有限制。

上面的三种SMM都有各自的优缺点

这里可以使用一个简单的技巧,即根据任务选择表示。在训练过程中使用可以并行化的卷积表示,在推理过程中,我们使用高效的循环表示:

听起来有点奇幻,但是有人就是实现出来了,这个模型叫做Linear State-Space Layer (LSSL)

https://proceedings.neurips.cc/paper_files/paper/2021/hash/05546b0e38ab9175cd905eebcc6ebb76-Abstract.html

它结合了线性动态系统理论和神经网络的概念,可以有效地捕获数据中的时序信息和动态特征。LSSL 基于线性动态系统理论,这种系统可以用状态空间模型表示。在这个模型中,系统的行为由状态变量的演化和外部控制信号的影响决定。状态变量是系统的内部表示,可以捕获系统的动态特性。

这些表示都有一个重要的特性,即线性时不变性(LTI)。LTI表示ssm参数A、B和C对于所有时间步长都是固定的。这意味着对于SSM生成的每个令牌,矩阵A、B和C都是相同的。

也就是说无论给SSM的序列是什么,A、B和C的值都保持不变。这样就得到了一个不感知内容的静态表示。但是静态表示没有任何意义对吧,所以Mamba解决的就是这个问题,但是在介绍Mamba之前,我们还有一个知识点需要强调,那就是矩阵A

因为SSM公式中最重要的就是矩阵a。正如我们之前在循环表示中看到的那样,它捕获了关于前一个状态的信息来构建新状态,如果矩阵a如果跟RNN一样会遗忘掉非常靠前的信息那么SMM将没有任何的意义,对吧。

矩阵A产生隐藏状态:

如何保留大上下文大小的方式创建矩阵A呢?

HiPPO 的模型结合了递归记忆(Recurrent Memory)和最优多项式投影(Optimal Polynomial Projections)的概念,这种投影技术可以显著改善递归记忆的性能,特别是在处理长序列和长期依赖关系时。

https://proceedings.neurips.cc/paper/2020/hash/102f0bb6efb3a6128a3c750dd16729be-Abstract.html

使用矩阵A来构建一个状态表示,该状态表示可以很好地捕获最近的令牌并衰减较旧的令牌。其公式可表示为:

具体的详细内容我们就不介绍了,有兴趣的查看原论文。

这样我们就基本上解决了所有的问题:1、状态空间模型;2、处理远程依赖关系;3、离散化和并行计算

如果想深入了解有关如何计算HiPPO矩阵和自己构建S4模型建议您阅读注释的S4。

https://srush.github.io/annotated-s4/

Mamba

上面介绍完所有必要的基础知识,最后就是我们的重点了

Mamba 有两个主要贡献:

1、选择性扫描算法,模型可以过滤有关和无关的信息

2、硬件感知算法,通过并行扫描、核融合和重计算有效地存储(中间)结果。

在探讨这两个主要贡献之前,我们先看看一下为什么它们是必要的。

状态空间模型,S4(Structured State Space Model),在语言建模和生成中的某些任务上表现不佳

比如在选择性复制任务中,SSM的目标是按顺序复制输入和输出的部分:

(循环/卷积)SSM在这个任务中表现不佳,因为它是线性时不变的。对于SSM生成的每个令牌,矩阵A、B和C都是相同的。

因为它将每个令牌平等地视为固定的a、B和C矩阵的结果,所以SSM不能执行内容感知推理

SSM表现不佳的第二个任务是重现输入中发现的模式:

我们的提示在“教”模型在每个“Q:”之后提供“A:”响应。但是由于ssm是时间不变的,它不能选择从其历史中获取先前的令牌。

以矩阵B为例不管输入x是什么,矩阵B保持完全相同,并且与x无关:

同理无论输入是什么,A和C也不变,这就是我们上面说的静态。

而Transformers 可以根据输入序列动态地改变注意力。可以选择性地“看”或“注意”序列的不同部分,再加上位置编码,这使得Transformers对于这种任务非常的简单。

ssm在这些任务上的糟糕性能说明了定常ssm的潜在问题,矩阵A、B和C的静态特性导致了内容感知问题。

选择性地保留信息

SSM的循环表示创建了一个非常有效的小状态,因为它压缩了整个历史信息,所以与不压缩历史(注意力矩阵)的Transformer模型相比,它的功能要弱得多。

Mamba 的目标是获得Transformer一样强大的“小”状态

通过有选择地将数据压缩到状态,当输入一个句子时,通常会有一些信息,比如停顿词,这些信息没有太多的意义。

我们先看看SSM在训练期时的输入和输出维度:

在结构化状态空间模型(S4)中,矩阵a、B和C独立于输入,因为它们的维度N和D是静态的,不会改变。

而Mamba通过结合输入的序列长度和批量大小,使矩阵B和C,甚至步长∆依赖于输入:

这意味着对于每个输入标记,有不同的B和C矩阵,这解决了内容感知问题!这里矩阵A保持不变,因为希望状态本身保持静态,但影响它的方式(通过B和C)是动态的。

也就是说它们一起选择性地选择将什么保留在隐藏状态中,什么需要忽略,这都是由输入确定的。

较小的步长∆导致忽略特定的单词,而是更多地使用之前的上下文,而较大的步长∆则更多地关注输入单词而不是上下文:

扫描操作

这些矩阵现在是动态的了所以它们不能使用卷积表示来计算,只能使用循环进行处理,这就使得无法进行并行化。

为了实现并行化,我们先看看循环的输出:

每个状态都是前一个状态(乘以A)加上当前输入(乘以B)的和。这被称为扫描操作,可以很容易地通过for循环计算出来。但是并行化似乎是不可能的,因为每个状态只有在我们有前一个状态时才能计算出来。

但是Mamba使用并行扫描算,通过关联属性假定执行操作的顺序无关紧要。这样就可以计算部分序列并迭代组合它们:

这样还有一个好处是因为顺序不重要,也可以省略掉Transformer的位置编码。

硬件感知的算法

最近gpu的一个缺点是它们在小但高效的SRAM和大但效率稍低的DRAM之间的传输(IO)速度有限。在SRAM和DRAM之间频繁地复制信息成为瓶颈。

Mamba的DRAM和SRAM分配的具体实例如下:

中间状态不被保存,但对于反向传播计算梯度是必要的。作者重新计算了反向传递过程中的中间状态。尽管这看起来效率很低,但它比从相对较慢的DRAM读取所有这些中间状态的成本要低得多。

这里我们就不详细说明了,因为这部分我也没太研究过

Mamba 块

选择性SSM可以作为一个块,就像在Transformer中的的注意力模块一样。我们可以堆叠多个块,并使用它们的输出作为下一个曼巴块的输入:

最后一个端到端(输入到输出)的例子包含了归一化层和选择输出标记softmax。

这样就得到了快速的推理和训练,而且是“无限”长度上下文的模型

总结

看完这篇文章,我希望你能对Mamba 和状态空间模型有一定的了解,最后我们以作者的发现为结尾:

作者发现模型与相同尺寸的Transformer模型的性能相当,有时甚至超过了它们!

https://avoid.overfit.cn/post/94105fed36de4cd981da0b916c0ced47

作者:Maarten Grootendorst

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/698890.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CentOS7 安装SSH

说实话,感觉CentOS有点难用。初始配置不是很友好,连自动获取IP地址都是关着的。当然,我这里本地用的是虚拟机。 开启获取IP 首先是获取IP地址,这是一些的起点。 首先使用ip addr 看看网卡地址和名称。我这边是ens33。然后打开自…

【TCP/IP】组播

一、组播介绍 组播(Multicast)是网络技术中数据传输的一种方法,它允许将数据包同时发送给一组指定的目标,而不是单个的目标(单播 Unicast)或所有可能的目标(广播 Broadcast)。组播传…

2024年 Openai的API相关全部概论汇总(通用版)

2024年 Openai的API相关全部概论汇总(通用版) 文章目录 2024年 Openai的API相关全部概论汇总(通用版)一、前言1、python快速开始 二、Openai 平台以及相关项目1、Openai的API管理平台2、ChatGPT项目推荐(1)…

开源大语言模型作为 LangChain 智能体

概要 开源大型语言模型 (LLMs) 现已达到一种性能水平,使它们适合作为推动智能体工作流的推理引擎: Mixtral 甚至在我们的基准测试中 超过了 GPT-3.5,并且通过微调,其性能可以轻易的得到进一步增强。 引言 针对 因果语言建模 训练的大型语言模…

2024-02-21 作业

作业要求: 复习课上内容 //已完成结构体字节对齐,64位没做完的做完,32位重新都做一遍,课上指定2字节对齐的做一遍,自己验证 //已完成两种验证大小端对齐的代码写一遍复习指针内容 //已完成完善顺序表已写出的…

Linux运维-DHCP服务器

DHCP服务器的配置与管理 项目场景 学校各部门共有180台电脑,除了计算机学院的教师会配置电脑的网络连接,其他部门的老师和工作人员均不会,为了提高网络的管理效率,技术人员决定配置一台DHCP服务器,来提供动态的IP地址…

gin源码实战 day1

gin框架源码实战day1 Radix树 这个路由信息: r : gin.Default()r.GET("/", func1) r.GET("/search/", func2) r.GET("/support/", func3) r.GET("/blog/", func4) r.GET("/blog/:post/", func5) r.GET("/…

5G端到端案例三:锚点基站侧5G连接与VOLTE专载建立流程冲突导致CSFB回落问题

1. 问题描述: NSA组网场景下,语音业务仍使用4G VoLTE方案,在拉网测试中,发现存在较多流程交叉导致的VOLTE接入失败的问题。 流程冲突时的空口信令表现为,终端添加SCG流程与语音专载流程冲突时,专有承载建…

重点媒体如何投稿?考核稿件投稿指南

传媒如春雨,润物细无声,大家好,我是51媒体网胡老师。 机构组织,国企央企都需要定期将相关新闻投递到央媒,官媒,或者地方重点媒体中,那么如何进行投稿了,今天就与大家分享下。 央媒投…

vue-nextTick(nextTick---入门到离职系列)

官方定义 在下次 DOM 更新循环结束之后执行延迟回调。在修改数据之后立即使用这个方法&#xff0c;获取更新后的 DOM。 个人理解 假设我们更改了某个 dom 元素内部的文本&#xff0c;而这时候我们想直接打印这个更改之后的文本是需要 dom 更新之后才会实现的。 小案例 <tem…

聊一聊EGO-Planner膨胀系数的大小对无人机避障飞行的影响

EGO-Planner简介 EGO-Planner作为业界知名的无人机轨迹规划算法&#xff0c;其优势在于能够在复杂环境中快速规划出安全、平滑且动态可行的飞行轨迹。在这个算法中&#xff0c;膨胀系数发挥着关键作用。它通过扩大障碍物的感知范围&#xff0c;提供额外的安全边距&#xff0c;…

NOIP2018-J-4-对称二叉树的题解

原题描述&#xff1a; 题目描述 时间&#xff1a;1s 空间&#xff1a;256M 一棵有点权的有根树如果满足以下条件&#xff0c;则被轩轩称为对称二叉树&#xff1a; 1. 二叉树&#xff1b; 2. 将这棵树所有节点的左右子树交换&#xff0c;新树和原树对应位置的结构相同且…

【深度学习】LoRA: Low-Rank Adaptation of Large Language Models,论文解读

文章&#xff1a; https://arxiv.org/abs/2106.09685 文章目录 摘要介绍LoRA的特点什么是低秩适应矩阵&#xff1f;什么是适应阶段&#xff1f;低秩适应矩阵被注入到预训练模型的每一层Transformer结构中&#xff0c;这一步是如何做到的&#xff1f; 摘要 自然语言处理的一个重…

计算机网络-网络互联与互联网(一)

1.常用网络互联设备&#xff1a; 1层物理层&#xff1a;中继器、集线器2层链路层&#xff1a;网桥、交换机3层网络层&#xff1a;路由器、三层交换机4层以上高层&#xff1a;网关 2.网络互联设备&#xff1a; 中继器Repeater、集线器Hub&#xff08;又叫多端口中继器&#xf…

图论(算法竞赛、蓝桥杯)--Dijkstra算法最短路

1、B站视频链接&#xff1a;D02 最短路 Dijkstra 算法_哔哩哔哩_bilibili 题目链接&#xff1a;【模板】单源最短路径&#xff08;弱化版&#xff09; - 洛谷 #include <bits/stdc.h> using namespace std; #define INF 2147483647 int n,m,s,a,b,c; const int N100010…

Redis的主从复制和哨兵模式

Redis的主从复制和哨兵模式 Redis集群搭建&#xff08;一主二从&#xff09;replication 主从复制配置文件 redis.confRedis主从复制工作原理全量复制增量复制redis主从复制策略 搭建集群 &#xff08;主从复制引入&#xff09; 哨兵模式概念哨兵配置文件 sentinel.conf哨兵配置…

ArcgisForJS如何使用ArcGIS Server发布的切片地图服务?

文章目录 0.引言1.准备海量地理数据2.ArcGIS Server发布切片地图服务3.ArcgisForJS使用ArcGIS Server发布的切片地图服务 0.引言 ArcGIS Server是一个由Esri开发的地理信息系统&#xff08;GIS&#xff09;服务器软件&#xff0c;它提供了许多功能&#xff0c;包括发布切片地图…

java面试设计模式篇

面试专题-设计模式 前言 在平时的开发中&#xff0c;涉及到设计模式的有两块内容&#xff0c;第一个是我们平时使用的框架&#xff08;比如spring、mybatis等&#xff09;&#xff0c;第二个是我们自己开发业务使用的设计模式。 面试官一般比较关心的是你在开发过程中&#…

挑战杯 基于卷积神经网络的乳腺癌分类 深度学习 医学图像

文章目录 1 前言2 前言3 数据集3.1 良性样本3.2 病变样本 4 开发环境5 代码实现5.1 实现流程5.2 部分代码实现5.2.1 导入库5.2.2 图像加载5.2.3 标记5.2.4 分组5.2.5 构建模型训练 6 分析指标6.1 精度&#xff0c;召回率和F1度量6.2 混淆矩阵 7 结果和结论8 最后 1 前言 &…

Oracle迁移到mysql-导出mysql所有索引和主键

导出建库表索引等&#xff1a; [rootlnpg ~]# mysqldump -ugistar -pxxx -h192.168.207.143 --no-data -d lndb > lndb20230223-1.sql 只导出索引&#xff1a;参考&#xff1a;MYSQL导出现有库中的索引脚本_mysql 导出数据库所有表的主键和索引-CSDN博客 -- MYSQL导出现有…