原文链接:https://arxiv.org/abs/2402.10739
1. 引言
基于Transformer的点云分析方法有二次时空复杂度,一些方法通过限制感受野降低计算。这引出了一个问题:如何设计方法实现线性复杂度并有全局感受野。
状态空间模型(SSM)作为序列建模方法,Mamba在结构状态空间模型(S4)的基础上使用时变SSM参数和硬件感知算法,实现了线性复杂度和全局感受野。但目前的Mamba较少用于视觉任务。
本文探索SSM在点云分析任务中的潜力。直接使用Mamba的性能不佳,这是因为SSM的单向建模能力(相反,自注意力是输入顺序不变的)。本文提出点状态空间模型(PointMamba),首先生成点的token序列,然后使用重排序策略以特定顺序扫描数据,使模型捕捉点云结构。最后将重排序后点的token输入Mamba编码器,进行全局建模。
实验表明,本文方法可以超过基于Transformer方法的性能,且有更少的参数和计算量。
3. 方法
3.1 准备知识
状态空间模型:状态空间模型建模了时不变(LTI)系统,使用一阶微分方程捕捉系统动态:
h ˙ ( t ) = A h ( t ) + B x ( t ) , y ( t ) = C h ( t ) + D x ( t ) . \dot h(t)=Ah(t)+Bx(t),\\y(t)=Ch(t)+Dx(t). h˙(t)=Ah(t)+Bx(t),y(t)=Ch(t)+Dx(t).
为处理离散token序列输入,需要进行离散化:
h k = A ˉ h k − 1 + B ˉ x k , y k = C ˉ h k + D ˉ x k . h_k=\bar Ah_{k-1}+\bar Bx_k,\\y_k=\bar Ch_k+\bar Dx_k. hk=Aˉhk−1+Bˉxk,yk=Cˉhk+Dˉxk.
其中 A ˉ ∈ R N × N , B ˉ ∈ R N × 1 , C ˉ ∈ R 1 × N , D ˉ ∈ R \bar A\in\mathbb R^{N\times N},\bar B\in\mathbb R^{N\times 1},\bar C\in\mathbb R^{1\times N},\bar D\in\mathbb R Aˉ∈RN×N,Bˉ∈RN×1,Cˉ∈R1×N,Dˉ∈R为参数矩阵。 D ˉ \bar D Dˉ为残差连接,通常可简化或忽略。离散化需要使用时间步长 Δ \Delta Δ,在连续信号 x ( t ) x(t) x(t)进行采样,得到 x k = x ( k Δ ) x_k=x(k\Delta) xk=x(kΔ)。这使得:
A ˉ = ( I − Δ / 2 ⋅ A ) − 1 ( I + Δ / 2 ⋅ A ) , B ˉ = ( I − Δ / 2 ⋅ A ) − 1 Δ B , C ˉ = C \bar A=(I-\Delta/2\cdot A)^{-1}(I+\Delta/2\cdot A),\\\bar B=(I-\Delta/2\cdot A)^{-1}\Delta B,\\\bar C=C Aˉ=(I−Δ/2⋅A)−1(I+Δ/2⋅A),Bˉ=(I−Δ/2⋅A)−1ΔB,Cˉ=C
选择性SSM: B ˉ , C ˉ \bar B,\bar C Bˉ,Cˉ和 Δ \Delta Δ为动态、输入相关的参数,从而使得SSM为时变模型。这样能够过滤和捕捉时间相关的特征和关系,从而更精确地表达输入序列。
3.2 PointMamba
3.2.1 概述
如图所示,本文方法包括点tokenizer,重排序策略、Mamba和下游任务头。本文使用轻量化PointNet嵌入点的patch,得到点的token,然后根据几何坐标进行重排序,将序列长度变为3倍,输入Mamba。
3.2.2 点tokenizer
使用最远点采样(FPS)和K近邻(KNN)算法将点云分为不规则的点patch。具体来说,给定含 M M M个点的点云 I ∈ R M × 3 I\in\mathbb R^{M\times3} I∈RM×3,使用FPS采样 n n n个关键点,然后为每个关键点,使用KNN算法选择 k k k个最近点,得到 n n n个patch P ∈ R n × k × 3 P\in\mathbb R^{n\times k\times3} P∈Rn×k×3。然后,求取patch中各点相对关键点的相对坐标,并使用轻量化PointNet映射到特征空间,得到点token E 0 ∈ R n × C E_0\in\mathbb R^{n\times C} E0∈Rn×C。
3.2.3 重排序策略
由于Mamba是单向处理数据,适合1D数据;但难以处理点云这类无序数据。
本文通过特定顺序扫描点云,以捕捉点云结构。如图所示,本文分别基于点token簇中心的几何 x , y , z x,y,z x,y,z坐标进行排序并拼接,得到 E 0 ′ ∈ R 3 n × C E'_0\in\mathbb R^{3n\times C} E0′∈R3n×C。该方法通过提供更有逻辑的几何扫描顺序,提高了Mamba的几何建模能力。
3.2.4 Mamba块
每个Mamba块包含层归一化(LN)、SSM、逐深度卷积和残差连接,如图1右侧所示。公式表示为:
Z l ′ = D W ( M L P ( L N ( Z l − 1 ) ) ) , Z l = M L P ( L N ( S S M ( σ ( Z l ′ ) ) ) × σ ( L N ( Z l − 1 ) ) ) + Z l − 1 Z'_l=DW(MLP(LN(Z_{l-1}))),\\Z_l=MLP(LN(SSM(\sigma(Z_l')))\times\sigma(LN(Z_{l-1})))+Z_{l-1} Zl′=DW(MLP(LN(Zl−1))),Zl=MLP(LN(SSM(σ(Zl′)))×σ(LN(Zl−1)))+Zl−1
其中 Z l ∈ R 3 n × C Z_l\in\mathbb R^{3n\times C} Zl∈R3n×C为第 l l l块的输出, Z 0 = E 0 ′ Z_0=E'_0 Z0=E0′; σ \sigma σ为SiLU激活函数。
3.2.5 预训练
本文使用PointMAE的设置进行预训练,即随机掩蔽60%的点patch,使用自编码器提取点的特征并使用预测头重建点云。
自编码器可公式化为:
T v = F e ( T v + P E ) , H v , H m = F d ( C o n c a t ( T v , T m ) ) , P m = F h ( H m ) . T_v=F_e(T_v+PE),\\H_v,H_m=F_d(Concat(T_v,T_m)),\\P_m=F_h(H_m). Tv=Fe(Tv+PE),Hv,Hm=Fd(Concat(Tv,Tm)),Pm=Fh(Hm).
其中 F e F_e Fe为编码器,以未掩蔽的token T v T_v Tv为输入; F d F_d Fd为Mamba解码器,以 F e F_e Fe的输出和掩蔽的token T m T_m Tm为输入。本文仅在编码器和解码器的第一层加入位置编码 P E PE PE。 F h F_h Fh为线性层,将掩蔽token H m H_m Hm投影为与掩蔽输入点形状相同的向量。使用Chamfer距离作为重建损失,以恢复掩蔽点的坐标。
4. 实验
4.1 实施细节
与ViT不同,本文不使用类别token。分类时,本文将最后一层Mamba的所有输出平均值用于分类。分割任务则将中间多层的输出合并,进行最大和均值池化得到全局特征,然后与逐点特征拼接,输入线性层预测。
4.2 与基于Transformer的方法比较
实验表明,本文方法在无预训练情况下能达到与基于Transformer的方法相当的性能,且有更少的参数和计算量。预训练和使用重排序策略均能提高性能。
此外,随着序列长度的增加,基于Transformer的方法GPU内存占用显著增加,但本文的PointMamba仅线性增长。
4.3 消融研究
重排序策略:比较不进行重排序(1倍序列长度)、进行重排序(3倍序列长度)和双向重排序(即将重排序结果逆序后与重排序结果拼接,6倍序列长度)。实验表明,基于Transformer的方法在序列长度增加时,性能略微下降;重排序策略能提高单向建模Mamba在点云中的适应能力;进一步增加序列长度能进一步提高性能,但为平衡计算量与性能,本文选择3倍序列长度;尽管如此,由于本文方法的线性复杂度,计算量增长也远小于基于Transformer的方法。
分类token的分析:实验表明,不使用类别token能达到最好的分类性能。
4.4 局限性
预训练没有考虑Mamba的单向建模特点;重排序需要将序列长度变为3倍。