最近新出了一种很火的架构mamba,听说吊打transformer,特此学习一下,总结一下学习的内容。
state-spaces/mamba (github.com)3个月8Kstar,确实有点受欢迎。
目录
1.先验
RNN
LSTM
2.mamba
State Space Models
Selective State Space Models
1.先验
RNN
RNN(循环神经网络)是一种在序列数据处理中广泛应用的神经网络模型。相较于传统的前馈神经网络(Feedforward Neural Network),RNN具有一种循环结构,使其能够对序列数据进行建模和处理。
RNN的关键思想是引入了时间维度上的循环连接,使得网络在处理序列数据时可以保持信息的传递和记忆。在RNN中,每个时间步都有一个隐藏状态(hidden state),它对应于该时间步的输入以及前面时间步的隐藏状态。这种隐藏状态的传递和更新机制使得RNN能够捕捉序列中的时序依赖关系。
在每个时间步,RNN的计算可以分为两个关键步骤:
-
当前隐藏状态 = 上一时刻隐藏状态 *W + input*W
-
输出预测 = 当前隐藏状态*W
-
隐藏状态:可以看作是一种记忆能力,记录了上一时刻的信息,但是上上时刻,上上上时刻的信息都没有,即短期记忆,只记得昨天,不记得前天和之前的信息(举例)
RNN的灵活性使其适用于多种任务,例如自然语言处理(NLP)、语音识别、机器翻译、时间序列预测等。然而,传统的RNN存在梯度消失(gradient vanishing)和梯度爆炸(gradient explosion)等问题,导致对长期时序依赖的建模能力有限。
梯度爆炸:偏导数 ∂ℎ�∂ℎ0 将会变得非常大,实际在训练时将会遇到NaN错误,会影响训练的收敛,甚至导致网络不收敛。可以用梯度裁剪(gradient clipping)来解决。
梯度消失:此时偏导数 ∂ℎ�∂ℎ0 将会变得十分接近0。LSTM和GRU通过门(gate)机制控制RNN中的信息流动,用来缓解梯度消失问题。其核心思想是有选择性的处理输入
推荐视频:【循环神经网络】5分钟搞懂RNN,3D动画深入浅出_哔哩哔哩_bilibili
这是CNN的神经元连接图。从输入层-->隐藏层-->输出层。
RNN更加关注时间维度上的信息。此时的隐藏状态是黄色部分
有了记忆力的RNN可以根据隐藏层状态推出出苹果不是水果,而是手机。
关于2D图怎么看:
蓝色的球就是隐藏层,红色的是输入,最后右边那个不知道什么颜色的球就是输出
黄色的线就是Ws*St-1的那个Ws
把上面旋转90°就变成这样的。横轴为时间。
为了解决这些问题,出现了一些RNN的变体,如长短期记忆网络(LSTM)和门控循环单元(GRU)。这些变体通过引入门控机制,能够更好地处理长期依赖关系,并在许多任务中取得了显著的性能提升。
LSTM
LSTM(长短期记忆网络)是一种循环神经网络(RNN)的变体,专门设计用于解决传统RNN中的梯度消失和长期依赖问题。它通过引入门控机制,能够更好地处理和捕捉序列中的长期依赖关系。
LSTM的关键思想是引入了称为“门”的结构,它能够控制信息的流动和存储。一个标准的LSTM单元包含以下组件:
-
输入门(Input Gate):决定是否将新的输入信息纳入到记忆中的门控单元。
-
遗忘门(Forget Gate):决定是否从记忆中删除某些信息的门控单元。
-
输出门(Output Gate):根据当前输入和记忆状态,决定输出的门控单元。
-
记忆单元(Cell State):负责存储和传递信息的长期记忆。
-
sigmoid实现,把值map到[0,1], =1增加到长期记忆C里面,=0删除(举例)
S:短期记忆链条
C: 长期记忆链条
注意sigmoid是在当前输入Xt和上一时刻隐藏状态St-1(上一时刻的短期记忆)里面来决定删除(遗忘门)不重要的信息,重要的信息添加(输入门)到Ct长期记忆里面,并且把Ct-1的信息归并到Ct
- sigmoid 用在了各种gate上,产生0~1之间的值,这个一般只有sigmoid最直接了。
- tanh 用在了状态和输出上,是对数据的处理,使用tanh函数,是因为其输出在-1-1之间,这与大多数场景下特征分布是0中心的吻合。此外,tanh函数在输入为0近相比 Sigmoid函数有更大的梯度,通常使模型收敛更快。
2.mamba
2312.00752.pdf (arxiv.org)
本文是基于前人的状态空间模型做出优化,提出选择性状态空间模型的算法。
State Space Models
找到个写的不错的blog:https://blog.csdn.net/weixin_4528312/article/details/134829021
RNN和SSM的本质一样。
((输入x * wB + wA * 上一时刻状态) * wC) + wD*输入x = OUTPUT
可以看出这个一个递归函数,进行公式推导:
H4 = A*H3 + B*X4
H3 = A*H2 + B *X3
H2 = A*H1 + B * x2
那么:
H4 = A*H3 + B*X4=A*(A*H2 + B *X3) + B*X4 = A^2*H2 + AB*X3 +B*X4
= A^2(A*H1 + B * x2)+AB*X3 +B*X4 = A^3*H1 +A^2B*X2 +AB*X3 +B*X4
上面的式子表明:
H4 = F(H3,X4) = F(H2,X3,X4) =F(H1,X2,X3,X4)
那么
H3 = F(H2,X3) = F(H1,X2,X3)
也就是说:H3的计算是可以和H2无关的,只需要知道H1,X2,X3即可,同时H4的计算也是可以独立于H3,只需要知道H1,X2,X3,X4即可,这也是后面并行计算的关键
同理:
Y4 = CH4
Y3 = CH3
Y2 = CH2
Y1 =CH1
Y4 = C( A^3*H1 +A^2B*X2 +AB*X3 +B*X4)
使用卷积的原因,RNN没法并行训练,卷积可以,所以使用conv1D卷积进行。
为什么可以使用卷积等价计算?
但是为什么推理又使用RNN?
因为RNN推理速度更快。
训练CNN(可并行),推理RNN(速度快)
RNN的通病是只有短期记忆,如何解决--HiPPO!!
对信息的压缩方法。
Selective State Space Models
SB决定输入词的权重,SC决定Ht的权重(Ht保护过去信息Ht-1和输入Xt)
传统B,输入的所有词的权重是一样的:
SB决定输入信息中每个词的权重:
简单点说,RNN推理快但训练慢,且会遗忘以前的信息,SSM压缩以前的信息可以记住所有,但是矩阵参数固定,无法针对输入做针对性推理。mamba选择性的关注以前的信息,参数化ssm的输入,(不停总结以前的信息)。这样既可以记住以前的重要信息,又内存占用不大,推理快
总结:因为SSM和RNN几乎等价:
都是:
Ht = A * Ht-1 + B *Xt
y = C * Ht
mamba核心就是在此基础上把B,C改成SB,SC可以选择性给信息加权,然后输入信息关注重点,让过去信息保留重要的部分,从而解决遗忘和短期记忆问题。并且可以并行计算。
参考:
https://www.bilibili.com/video/BV1z5411f7Bm/?spm_id_from=333.337.search-card.all.click&vd_source=3aec03706e264c240796359c1c4d7ddc
https://blog.csdn.net/weixin_4528312/article/details/134829021