Mamba: Linear-Time Sequence Modeling with Selective State Spaces
论文:[2312.00752] Mamba: Linear-Time Sequence Modeling with Selective State Spaces
作者:Albert Gu 和 Tri Dao,分别来自卡内基梅隆大学机器学习系和普林斯顿大学计算机科学系。
Code:GitHub - state-spaces/mamba: Mamba SSM architecture
文章目录
- Mamba: Linear-Time Sequence Modeling with Selective State Spaces
- 摘要
- 引言
- 主要贡献
- 模型结构
- 实验结果
- 讨论与未来工作
- 结论
摘要
- Mamba模型基于Transformer架构,旨在解决传统Transformer在长序列上的计算效率问题。
- 通过引入选择性状态空间SSMs(Selective State Spaces),Mamba能够在保持线性时间复杂度的同时,实现与Transformer相媲美的性能。
- 首先,简单地让SSM参数成为输入的函数,解决了它们在离散模态方面的弱点,允许模型根据当前令牌选择性地沿序列长度维度传播或忘记信息。其次,尽管这一变化阻碍了高效卷积的使用,但在循环模式下设计了一种硬件感知的并行算法。将这些选择性SSM集成到一个简化的端到端神经网络架构中,而不需要注意甚至MLP块(Mamba)。
引言
- 基础模型(Foundation Models, FMs)通常基于Transformer架构,用于预训练大型模型并在下游任务中进行适配。
- Transformer的自注意力机制虽然有效,但在长序列上存在计算瓶颈。
主要贡献
- 提出了一种新的选择性状态空间模型(Selective State Space Models, SSMs),在序列长度上实现线性扩展。
- 设计了一种硬件感知的并行算法,提高了模型在现代硬件上的运行效率。
- 将选择性SSMs集成到一个简化的端到端神经网络架构中,无需注意力或MLP块,该架构被称为Mamba。
模型结构
Mamba模型的结构设计旨在实现高效的序列建模,特别是在处理长序列时。论文认为,序列建模的一个基本问题是将上下文压缩到更小的状态。以下是Mamba模型结构的关键组成部分:
-
选择性状态空间模型(Selective State Space Models, SSMs):
-
Mamba模型的核心是选择性状态空间模型,引入了选择机制,允许模型根据输入动态地选择性地传播或遗忘信息,这些模型可以被视为循环神经网络(RNNs)和卷积神经网络(CNNs)的结合,同时受到经典状态空间模型的启发。
-
-
参数化选择:将SSM的参数(如状态转移矩阵和输入矩阵)作为输入的函数,使得模型能够根据当前的令牌动态地调整其行为。
-
内容感知:选择机制使模型能够基于内容进行推理,选择性地关注或忽略特定的输入。
-
硬件感知算法:设计了一种硬件感知的并行算法,利用现代硬件(如GPU)的内存层次结构,避免了在不同内存级别之间的IO访问,从而提高了计算速度。
-
递归计算:通过递归模式而非卷积模式计算模型,减少了计算和内存需求。
-
-
SSMs能够在序列长度上实现线性或近似线性的扩展,这使得它们在处理长序列时非常高效。
-
- state space sequence models (S4)模型由四个参数( Δ \Delta Δ、 A A A、 B B B和 C C C)定义,这些参数定义了两个阶段的序列到序列转换。在参数通过 ( Δ , A , B , C ) ↦ ( A ‾ , B ‾ , C ) (\Delta,A,B,C)\mapsto(\overline{A},\overline{B},C) (Δ,A,B,C)↦(A,B,C)的转换后,该模型可以通过两种方式计算,即线性递归或全局卷积。在这种情况下, A ∈ R N × N , B ∈ R N × 1 , C ∈ R 1 × N \boldsymbol{A}\in\mathbb{R}^{N\times N},\boldsymbol{B}\in\mathbb{R}^{N\times1},\boldsymbol{C}\in\mathbb{R}^{1\times N} A∈RN×N,B∈RN×1,C∈R1×N矩阵都可以用 N \boldsymbol{N} N表示。为了在批次大小为 B B B、长度为 L L L的输入序列 x x x上操作 D D D个通道,SSM被独立应用于每个通道。
-
参数化选择机制(Selection Mechanism):
- Mamba模型通过将SSM参数作为输入的函数来实现选择机制,这使得模型能够根据当前的令牌动态地选择性地传播或遗忘信息。
- 这种选择机制允许模型过滤掉不相关的信息,并无限期地记住相关信息。
- 将选择机制纳入模型的一种方法是让影响序列相互作用的参数(例如RNN的循环动力学或CNN的卷积核)依赖于输入。算法1和2说明了我们使用的主要选择机制。主要区别只是使输入的几个参数Δ、B、C成为函数,以及整个张量形状的相关变化。特别是,这些参数现在具有长度维度𝐿,这意味着模型已经从时不变变为时变。其中, S B ( x ) = L i n e a r N ( x ) S_B(x)=Linear_N(x) SB(x)=LinearN(x), S C ( x ) = L i n e a r N ( x ) S_C(x)=Linear_N(x) SC(x)=LinearN(x)、 S Δ ( x ) = B r o a d c a s t D ( L i n e a r 1 ( x ) ) S_\Delta(x)=Broadcast_D(Linear_1(x)) SΔ(x)=BroadcastD(Linear1(x))和 τ Δ = softplus \tau_\Delta\:=\:\text{softplus} τΔ=softplus,其中 L i n e a r d Linear_d Lineard是对维度d的参数化投影。
- 选择性允许过滤掉可能在感兴趣的输入之间出现的不相关的噪声标记。这在选择性复制任务中得到了体现,但在常见的数据模式中普遍发生,特别是对于离散数据,例如,存在诸如“um”之类的语言填充词。之所以出现这种特性,是因为模型可以机械地过滤掉任何特定的输入 x t x_t xt,例如,在门控 RNN 情况中,当 g t → 0 g_t → 0 gt→0 时。
- 从经验上观察到,尽管原则上更多的上下文应该导致严格的更好性能,但许多序列模型并没有随着更长的上下文而改善。一种解释是,许多序列模型在必要时无法有效地忽略不相关的上下文;一个直观的例子是全局卷积(和一般的 LTI 模型)。另一方面,选择性模型可以随时简单地重置其状态以删除无关的历史记录,因此它们的性能原则上会随着上下文长度的增加而单调地提高。
- 通常, Δ \Delta Δ控制着聚焦多少或忽略当前输入 x t x_t xt 之间的平衡。它推广了 RNN 门:机械地,一个大的 Δ \Delta Δ 重置状态 h h h 并专注于当前输入 x x x,而一个小的 Δ Δ Δ 保持状态并忽略当前输入。SSM (1)-(2) 可以解释为由时间步长 Δ \Delta Δ 离散化的连续系统,在这种情况下,直觉是,大 Δ → ∞ \Delta → ∞ Δ→∞ 表示系统更长时间地关注当前输入(从而“选择”它并忘记其当前状态),而小 Δ → 0 \Delta → 0 Δ→0 表示被忽略的瞬态输入。
- 在 SSM 中,将 B B B 和 C C C 修改为选择性,可以更精细地控制是让输入 x t x_t xt 进入状态 h t h_t ht,还是让状态进入输出 y t y_t yt。这些可以解释为允许模型分别根据内容(输入)和上下文(隐藏状态)调节循环动态。
-
硬件感知并行算法(Hardware-aware Parallel Algorithm):
- 为了克服模型计算中的技术挑战,Mamba设计了一种硬件感知的并行算法,该算法以递归模式而非卷积模式计算模型。
- 这种算法避免了在GPU内存层次结构的不同级别之间进行IO访问,从而提高了计算效率。
-
简化的端到端神经网络架构(Simplified End-to-End Neural Network Architecture):
- Mamba模型去除了传统的注意力机制和多层感知器(MLP)块,采用了一个简化的、同质的架构设计。
- 这种设计将SSM架构的设计和Transformer的MLP块结合在一起,形成了一个单一的块,这个块被重复堆叠。
-
空间-时间卷积模块(Spatial-Temporal Convolution Module):
- Mamba模型包含空间-时间卷积模块,该模块由图卷积和时间卷积组成,用于捕捉交通数据的空间和时间特征。
- 图卷积用于捕捉基于图结构的数据的空间相关性,而时间卷积则用于描述时间片之间的依赖性。
-
多组件融合(Multi-Component Fusion):
- Mamba模型由三个独立组件组成,每个组件分别针对不同的时间属性(近期、日周期性和周周期性依赖性)建模。
- 这三个组件的输出通过参数矩阵加权融合,以生成最终的预测结果。
-
空间-时间注意力机制(Spatial-Temporal Attention Mechanism):
- Mamba模型设计了一种新颖的空间-时间注意力机制,包括空间注意力和时间注意力。
- 空间注意力用于模拟不同位置之间的复杂空间相关性,而时间注意力则用于捕捉不同时间之间的动态时间相关性。
-
模型维度和参数化(Model Dimensionality and Parameterization):
- Mamba模型通过控制模型维度的扩展因子来优化参数化,使得模型在保持高效计算的同时,能够捕捉更丰富的信息。
-
模型初始化(Model Initialization):
- Mamba模型采用了特定的初始化方法,如S4D-Lin和S4D-Real,这些方法基于HIPPO理论,有助于模型在低数据环境下的表现。
实验结果
- Mamba在多种模态的数据上实现了最先进的性能,包括语言、音频和基因组学。
- 在语言建模任务中,Mamba-3B模型在预训练和下游评估中,超越了同等规模的Transformer模型,并且与规模是其两倍的Transformer模型相匹配。
讨论与未来工作
- 论文讨论了Mamba模型的局限性和未来的发展方向,包括考虑外部因素如天气和社会事件对交通流量预测的影响,以及将Mamba应用于其他实际应用,如估计到达时间。
结论
- Mamba模型通过引入选择性状态空间,实现了在长序列上的高效序列建模,具有快速推理和线性时间复杂度扩展的特点,并且在多个领域展现出卓越的性能。
这篇论文在深度学习序列建模领域提出了一种新的模型,通过创新的选择性状态空间机制,有效地提高了长序列数据处理的性能和效率。