文章目录
- 3.2 对角结构化状态空间模型
- 3.2.1 S4D:对角SSM算法
- 3.2.2 完整应用实例
- 3.3 对角化加低秩(DPLR)参数化
- 3.3.1 DPLR 状态空间核算法
- 3.3.2 S4-DPLR 算法和计算复杂度
- 3.3.3赫尔维兹(稳定)DPLR形式
这篇文章是Mamba作者博士论文 MODELING SEQUENCES WITH STRUCTURED STATE SPACES
的第三章的部分翻译,为了解决计算上存在的代价问题,引入了结构化状态空间模型,介绍了对角结构化状态空间模型和低秩对角结构化状态空间模型。
3.2 对角结构化状态空间模型
为了解决SSM的计算瓶颈,我们使用一个允许我们变换和简化SSM的结构化结果。
Lemma 3.3 共轭是SSM的等价关系:
( A , B , C ) ∼ ( V − 1 AV , V − 1 B , CV ) (\textbf A, \textbf B, \textbf C) \sim(\textbf V^{-1}\textbf A \textbf V, \textbf V^{-1}\textbf B, \textbf C \textbf V) (A,B,C)∼(V−1AV,V−1B,CV)
证明:写出两个SSM, x x x和 x ~ \tilde{x} x~为对应的状态:
x ′ = A x + B u x ~ = V − 1 AV x ~ + V − 1 B u y = C x y = CV x ~ x^{'} = \textbf Ax +\textbf Bu \ \ \ \ \ \ \ \ \ \tilde x = \textbf V^{-1}\textbf A \textbf V\tilde x +\textbf V^{-1}\textbf Bu \\ y = \textbf C x \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ y = \textbf C \textbf V \tilde x x′=Ax+Bu x~=V−1AVx~+V−1Buy=Cx y=CVx~
当用 V \textbf V V乘以SSM右侧,两个SSM相同,其中 x = V x ~ x = \textbf V \tilde x x=Vx~。因此它们在计算相同的算子 u ↦ y u\mapsto y u↦y,但是状态基被 V \textbf V V改变了。
Lemma 3.3说明了状态空间 ( A , B , C ) (\textbf A, \textbf B, \textbf C) (A,B,C)和 ( V − 1 AV , V − 1 B , CV ) (\textbf V^{-1}\textbf A \textbf V, \textbf V^{-1}\textbf B, \textbf C \textbf V) (V−1AV,V−1B,CV)
实际是等价的。换句话说,它们表示的是同一个映射 u ↦ y u\mapsto y u↦y,在SSM文献中也被叫做状态空间变换。
对此一个非常自然的选择是对角矩阵形式,可能是最典型的形式。众所周知,几乎所有的矩阵在复平面上对角化。
Proposition 3.4 集合 D ⊂ C N × N \mathcal D \subset \mathcal C ^{N\times N} D⊂CN×N 可对角矩阵在 C N × N \mathcal C ^{N\times N} CN×N 上稠密且满测度
换句话说,Proposition 3.4说明(几乎)所有的SSM可以等价成一个对角SSM。除此之外,对角SSM结构化可以解决问题一和问题二,特别是计算 K ‾ \overline{\textbf K} K成为一个成熟的结构化矩阵乘法有高效的时间和空间复杂度。
3.2.1 S4D:对角SSM算法
Remark 3.6. 对于对角SSM的例子, A \textbf A A是对交的,因此我们重载定义 A n \textbf A_n An代表其对角的迹。回想我们定义SISO情况为 B ∈ R N × 1 \textbf B \in \mathcal R ^ {N\times 1} B∈RN×1和 C ∈ R 1 × N \textbf C \in \mathcal R ^ {1\times N} C∈R1×N,因此我们令 B n , C n \textbf B_n, \textbf C_n Bn,Cn直接索引它们的元素。
现在我们提出了S4D在对角SSM上解决了问题一和问题二
S4D 递归
在对角SSM上计算任何对角化都很简单,因为对角矩阵上的解析函数简化为其对角线上按元素进行。实现一个对角矩阵的矩阵乘法也很简单,因为它减少了元素级的乘法。因此对角SSM轻松地适合Definition3.1
S4D 卷积核:范德蒙矩阵乘法
当 A \textbf A A是对角的,计算卷积核变得十分简单:
K ‾ ℓ = ∑ n = 0 N − 1 C n A ‾ n ℓ B ‾ n ⟹ K ‾ = ( B ‾ ⊤ ∘ C ) ⋅ V L ( A ‾ ) where V L ( A ‾ ) n , ℓ = A ‾ n ℓ ( 3.2 ) \begin{aligned}\overline{K}_\ell=\sum_{n=0}^{N-1}C_n\overline{A}_n^\ell\overline{B}_n\implies\overline{K}=(\overline{B}^\top\circ C)\cdot\mathcal{V}_L(\overline{A})\quad\text{where}\quad\mathcal{V}_L(\overline{A})_{n,\ell}=\overline{A}_n^\ell\quad(3.2)\end{aligned} Kℓ=n=0∑N−1CnAnℓBn⟹K=(B⊤∘C)⋅VL(A)whereVL(A)n,ℓ=Anℓ(3.2)
∘ \circ ∘是哈达玛积, ⋅ \cdot ⋅是矩阵乘法, V \mathcal V V被称为范德蒙矩阵
再展开一下,我们可以把 K ‾ \overline {\textbf K} K写成下面的范德蒙矩阵-向量乘法
K ‾ = [ B ‾ 0 C 0 … B ‾ N − 1 C N − 1 ] [ 1 A ‾ 0 A ‾ 0 2 … A ‾ 0 L − 1 1 A ‾ 1 A ‾ 1 2 … A ‾ 1 L − 1 ⋮ ⋮ ⋮ ⋱ ⋮ 1 A ‾ N − 1 A ‾ N − 1 2 … A ‾ N − 1 L − 1 ] \overline{K}=\begin{bmatrix}\overline{B}_0C_0&\ldots&\overline{B}_{N-1}C_{N-1}\end{bmatrix}\begin{bmatrix}1&\overline{A}_0&\overline{A}_0^2&\ldots&\overline{A}_0^{L-1}\\1&\overline{A}_1&\overline{A}_1^2&\ldots&\overline{A}_1^{L-1}\\\vdots&\vdots&\vdots&\ddots&\vdots\\1&\overline{A}_{N-1}&\overline{A}_{N-1}^2&\ldots&\overline{A}_{N-1}^{L-1}\end{bmatrix} K=[B0C0…BN−1CN−1] 11⋮1A0A1⋮AN−1A02A12⋮AN−12……⋱…A0L−1A1L−1⋮AN−1L−1
对角化结构SSM(S4D)有一个非常简单的解释。(左)对角化结构允许它被看作1维SSM的集合,或者scalar递归(右)。作为一个卷积模型,S4D有一个简单的可解释的卷积核,可以用两行代码实现。颜色代表独立的1-D SSM;紫色代表可训练参数。
时间和空间复杂度
原始方法计算3.2是通过范德蒙矩阵 V L ( A ‾ ) \mathcal V_L(\overline{\textbf A}) VL(A)和实现一个矩阵乘法,需要 O ( N L ) O(NL) O(NL)的时间和空间。
然而,范德蒙矩阵已经经过大量研究在理论上乘法可以以 O ~ ( N + L ) \tilde O(N+L) O~(N+L)操作和 O ( N + L ) O(N+L) O(N+L)空间实现。
3.2.2 完整应用实例
整个S4D方法可以直接应用,仅仅需要几行代码来参数化和初始化,核计算和完整的前向传播。
最后,注意结合不同的参数化选择可能导致在kernel实现上的少许不同。图3.1说明了用ZOH离散化的S4D核甚至可以进一步简化到两行代码。
def parameters(N, dt_min = 1e-3, dt_max = 1e-1):#初始化#几何均匀时间尺度 [第五章]log_dt = np.rnadom.rand() * (np.log(dt_max) - np.log(dt_min)) + np.log(dt_min)# S4D-Lin 初始化 (A, B) [第六章]A = -0.5 + 1j * np.pi * np.arange(N // 2)B = np.one(n // 2) + 0j#方差保持初始化 [第五章]C = np.random.randn(N // 2) + 1j * np.random.randn(N)return log_dt, np,log(-A.real, A_imag, B, C)def kernel(L, log_dt, log_A_real, A_imag, B, C):#离散化(例如双线性变换)dt, A = np.exp(log_dt), -np.exp(log_A_real) + 1j * A_imagdA, dB = (1 + dt * A /2) / (1 - dt * A / 2), dt * B / (1 - dt * A / 2)#计算(范德蒙矩阵乘法-可以被优化)#返回实部两倍-核添加共轭对相同return 2 * ((B * C) @ (dA[:, None] ** np.arrange(L))).realdef forward(u, parameters):L = u.shape[-1]K = kernel(L, *parameters)#用FFT卷积 y = u * KK_f,u_f = np.fft.fft(K, n = 2 * L), np.fft.fft(u, n = 2 * L)return np.fft.ifft(K_f*u_f, n = 2 * L)[...,:L]
参数化和计算一通道S4D模型的完整Numpy示例
3.3 对角化加低秩(DPLR)参数化
当可能的时候,对角SSM在实际中使用是理想的因为它们的简单和灵活。然而,它们的强结构有时太过限制。特别是,Chapter 6将会说明基于HIPPO矩阵的重要SSM类(Chapter 4 和 5)不能在数值上表达为对角SSM,而使用一个对角结构的拓展替代。虽然我们推迟这一动机到部分二,这一部分从计算角度,独立地表示这个结构。除了和部分二中的特殊SSM的关系,这个重参数化背后的想法和算法理论上是独立的,在之后的序列模型中会用到3.6.2
这个部分定义了对角SSM的拓展依然可以高效计算的**对角低秩(DPLR)**SSM。我们主要的技术结果关注于发展这个参数化和展示如何高效计算所有的SSM表达(Section 2.3),特别是找到一个问题一和问题二的算法。
3.3.1给出了我们方法关键组成部分的总览并形式上定义了S4—DPLR参数化。3.3.2给出了主要的结构,说明S4是渐进有效的对于序列模型。证明在附录3.1。
3.3.1 DPLR 状态空间核算法
尽管从对角到DPLE矩阵的扩展看起来很小,额外的低秩项时矩阵计算更困难。特别是不像对角矩阵,计算等式2.8DPLR矩阵的幂次方依然很慢(和非结构化矩阵相同)并且难以被优化,我们通过同时应用三种新技术解决这个瓶颈。
- 我们通过评估它的单元 ζ \zeta ζ的根截断生成函数 ∑ j = 0 L − 1 K ‾ j ζ j \sum _{j = 0}^{L - 1}\overline{\textbf K}_j\zeta^j ∑j=0L−1Kjζj来计算它的谱而不是直接计算 K ‾ \overline {\textbf K} K。 K ‾ \overline {\textbf K} K之后可以通过一个反FFT实现。
- 这个生成函数和矩阵分解相近,现在包括一个矩阵求逆而不是幂。低秩项现在可以通过Woodbury恒等式(Proposition A.2)将$(A + PQ*){-1} 按 按 按A^{-1}$真正减少到对角情形。
- 最后,我们表明对角矩阵形式是Cauchy kernel 1 w j − ζ k \frac{1}{w_j - \zeta _k} wj−ζk1的等价形式,一个使用stable near-linear算法的充分研究问题。
3.3.2 S4-DPLR 算法和计算复杂度
我们的算法在循环和卷积表达下都是经过优化的,满足Definitions 3.1和3.2
Theorem 3.5 (S4递归) 给定任意步长 Δ \Delta Δ,计算提柜的以部可以在 O ( N ) O(N) O(N)操作下完成, N N N是状态大小。
Theorem 3.6 (S4卷积)给定任意步长 Δ \Delta Δ,计算SSM卷积核 K ‾ \overline{\textbf K} K可以被减少到4次Cauchy 乘法,需要仅仅 O ~ ( N + L ) \tilde O(N+L) O~(N+L)次操作和 O ( N + L ) O(N+L) O(N+L)空间
附录C.1,定义C.5形式上定义了Cauchy 矩阵,和有理插值问题相关。在数值分析上计算Cauchy 矩阵同样得到充分研究,有基于著名的快速多极子算法(FMM)的快速算术和数值算法。不同情况下这些算法的计算复杂度在附录C.1 Proposition C.6中展示。
3.3.3赫尔维兹(稳定)DPLR形式
独立于计算S4-DPLR的算法细节,我们使用一个基础DPLR参数化的修正来确保状态空间模型的稳定性。特别是,赫尔维茨矩阵(又称为稳定矩阵)是一类可以确保SSM渐进稳定的。
Definition 3.7. 一个赫尔维茨矩阵 A \textbf A A是一个所有本征值都有负实数部分的矩阵
从离散时间SSM角度,我们很容易明白为什么 A \textbf A A需要是一个赫尔维茨矩阵从基本原则和下面简单的观察。受限,展开RNN模式包含重复升幂 A ‾ \overline {\textbf A} A,只有在 A ‾ \overline {\textbf A} A的所有本征值在(复数)单位圆内或上才是稳定的。第二,变换(2.4)(不论是对于双线性还是ZOH离散化)映射复数左半平面到单位圆,因此计算一个SSM的RNN模式(例如自回归推断)需要 A \textbf A A是一个赫尔维兹矩阵。
从连续角度看,另一种方式看到至一点是线性ODE解是指数形式。我们也可以看到等价卷积形式有脉冲响应 K ( t ) = C e t A B K(t) = \textbf C e^{t\textbf A}\textbf B K(t)=CetAB当 t → ∞ t\rightarrow\infin t→∞时, K ( t ) = C e t A B K(t) = \textbf C e^{t\textbf A}\textbf B K(t)=CetAB也会爆炸到 ∞ \infin ∞
然而,控制一个常见DPLR矩阵的谱是困难的。在S4的先前版本,我们发现无限制DPLR矩阵在训练胡变成非赫尔维茨(因此不能再无限循环模式中运用)。
为了解决这一点,我们使用DPLR矩阵的小改建,我们称之为赫尔维茨 DPLR形式,我们可以使用参数 Λ − P P ∗ \Lambda - PP^* Λ−PP∗代替 Λ + P Q ∗ \Lambda + PQ^* Λ+PQ∗。这相当于基本上绑定了参数 Q = − P Q = -P Q=−P。注意在技术上这依然是一个DPLR,因此我们使用S4-DPLR算法作为黑盒。
接着,我们讨论这种参数化是如何让S4稳定。高阶想法是SSM的稳定性包含状态矩阵 A \textbf A A的谱,更容易被控制因为 − P P ∗ -\textbf P \textbf P^* −PP∗是半负定矩阵(我们知道它的谱的符号)
Lemma 3. 8 一个矩阵 A = Λ − P P ∗ \textbf A = \Lambda - \textbf P \textbf P^{*} A=Λ−PP∗是赫尔维茨的如果 Λ \Lambda Λ的所有迹有负的实数部分。
证明:我们首先观察到如果 A + A ∗ \textbf A +\textbf A^* A+A∗是半负定(NSD)的,那么 A \textbf A A是赫尔维茨的。这是因为 0 > v ∗ ( A + A ∗ ) v = ( v ∗ A v ) + ( v ∗ A v ) ∗ = 2 R e ( v ∗ A v ) = 2 λ 0>v^*(A+A^*)v = (v^*Av)+(v^*Av)^* = 2\mathcal Re(v^*Av) = 2\lambda 0>v∗(A+A∗)v=(v∗Av)+(v∗Av)∗=2Re(v∗Av)=2λ对于任何 A A A的(单位长度)本征对来说。之后,注意到条件暗示 A + A ∗ \textbf A +\textbf A^* A+A∗是半负定(NSD)的(非正数迹的实数对角矩阵)。因为矩阵 − P P ∗ -PP^* −PP∗也是NSD的, A + A ∗ A+A^* A+A∗也是这样。
Lemma 3.8表明,对于赫尔维兹DPLR表示,控制学习的A矩阵的频谱变成简单地控制对角线部分 Λ \Lambda Λ。这是一个比控制一般DPLR矩阵容易得多的问题,可以通过正则化或重新参数化来强制执行(第3.4.2节)。
Remark 3.7. 赫尔维茨DPLR形式 Λ − P P ∗ \Lambda - PP* Λ−PP∗有更少的参数而且在技术上表现能力差于不受限DPLR形式 Λ + P Q ∗ \Lambda + PQ^* Λ+PQ∗但在经验上并没有影响模型表现。
Remark 3.8. 潜在的稳定性问题只在使用S4在特定内容如自回归生成时上升,因为S4的卷积模式在训练时并没有升幂 A ‾ \overline {\textbf A} A因此对赫尔维茨矩阵并不是严格要求。在实践中,出于原则,我们仍然总是使用赫尔维茨DPLR。