1.1 简介
LSTM,全称为长短期记忆网络(Long Short-Term Memory),是一种特殊的循环神经网络(RNN)结构,由Sepp Hochreiter和Jürgen Schmidhuber在1997年提出。它的设计初衷是为了克服传统RNN在处理长序列数据时面临的梯度消失和梯度爆炸问题,从而能够更好地学习和利用长期依赖关系。
LSTM的核心特点
-
记忆单元(Cell State):这是LSTM的核心组成部分,用来存储长期状态信息。与标准RNN中的隐藏状态不同,记忆单元的设计使得它可以有选择性地存储或遗忘信息,从而避免了长期依赖问题。
-
门控机制(Gates):LSTM通过三个主要的门控机制来控制信息的流动,分别是:
- 忘记门(Forget Gate):决定记忆单元中哪些信息需要被遗忘,通常基于前一时刻的隐藏状态和当前输入计算得出。
- 输入门(Input Gate):控制新信息的流入,决定哪些信息应该被更新到记忆单元中。
- 输出门(Output Gate):控制记忆单元中哪些信息应该被用于当前时刻的输出,以及如何影响下一个时间步的隐藏状态。
-
细胞状态更新:在每个时间步骤,LSTM根据输入门、忘记门和输出门的计算结果来更新记忆单元的状态,这一过程允许网络有选择性地保留或丢弃信息。
-
隐藏状态(Hidden State):LSTM的隐藏状态结合了记忆单元的信息和当前需要输出的信息,它不仅受到当前输入的影响,还受到历史信息的影响,但通过输出门进行了过滤,以便于更好地处理当前任务。
-
LSTM的创新在于其独特的门控机制,这使得网络能够灵活地处理和保留长期依赖信息,从而在复杂序列数据的学习任务中表现出色。随着深度学习技术的发展,LSTM及其变体(如GRU、Transformer等)继续推动着人工智能领域的进步。
1.2 RNN与LSTM的比较
RNN(循环神经网络)
想象一下,RNN像一个有记忆力的工人,他在处理一条生产线上的产品,每个产品代表序列中的一个元素(比如一个单词或者一个音符)。这个工人会根据当前的产品和他之前所有产品的累积经验(记忆)来决定如何处理当前的产品。但是,他的记忆力有限,容易忘记很久以前的事情,这就是所谓的“长期依赖问题”。这是因为RNN在反向传播过程中,随着时间步的增加,梯度可能会变得非常小(消失)或非常大(爆炸),导致学习长期依赖很困难。
LSTM(长短期记忆网络)
LSTM可以看作是一个升级版的记忆力大师,他不仅有普通记忆,还有特别设计的“记忆盒子”(记忆单元)和几个聪明的助手(门控机制)。记忆盒子能够长时间保存重要信息,不会轻易忘记。而那些助手——忘记门、输入门、输出门——负责管理这个盒子,决定什么信息该丢弃、什么新信息该加入、以及什么时候从盒子里取出信息用于当前任务。
- 忘记门决定清理记忆盒子中哪些旧信息不再有用。
- 输入门控制哪些新信息值得添加进记忆盒子。
- 输出门决定记忆盒子中的哪些信息对当前任务是重要的,应该用于生成输出。
这些机制让LSTM能够在序列数据中捕捉到非常长的依赖关系,解决了很多RNN面临的难题。
总结区别
- 记忆能力:RNN的记忆能力有限,容易忘记远期信息;而LSTM通过记忆单元和门控机制,能更有效地保留和利用长期信息。
- 结构复杂度:LSTM结构相比RNN更复杂,因为它增加了门控机制和记忆单元。
- 梯度问题:RNN容易遭受梯度消失或爆炸的问题,而LSTM的设计通过门控机制缓解了这个问题。
- 应用场景:虽然RNN对于处理较短序列或不需要长期记忆的任务也能表现良好,但LSTM在处理复杂的、需要长期依赖的序列问题上更为出色,比如自然语言处理、语音识别等领域。
1.3 前向传播
我们先来看一下RNN和LSTM的区别:
RNN:
LSTM:
LSTM的前向传播如下图所示:
(更正:下图中tanh旁的小写ct对应下图的gt)
输入和初始化
- 输入:在时间步t,LSTM接收输入向量 xt∈Rn,以及前一时间步的隐藏状态 ht−1∈Rm 和细胞状态 ct−1∈Rl。
- 权重矩阵:LSTM使用不同的权重矩阵(Wf, Wi, Wg, Wo)和偏置项(bf, bi, bg, bo)来计算各门和候选值。
门控机制计算
-
遗忘门:决定哪些信息从细胞状态中丢弃。
- 计算遗忘门的激活值:ft=σ(Wf⋅[ht−1,xt]+bf),其中 σ 是sigmoid函数,·表示矩阵乘法,[ht−1,xt] 是前一隐藏状态和当前输入的拼接。
-
输入门:决定哪些新信息被存储进细胞状态。
- 计算输入门的激活值:it=σ(Wi⋅[ht−1,xt]+bi)。
-
细胞状态更新门(或称候选值门):生成新细胞状态的候选值。
- 计算候选值:gt=tanh(Wg⋅[ht−1,xt]+bg)。
-
细胞状态更新:结合遗忘和输入门的决策更新细胞状态。
- 新细胞状态:ct=ft⊙ct−1+it⊙gt,其中 ⊙ 表示逐元素乘法。
输出门计算
-
输出门:决定从细胞状态中提取多少信息用于生成当前时间步的隐藏状态。
- 计算输出门的激活值:ot=σ(Wo⋅[ht−1,xt]+bo)。
-
当前隐藏状态:使用输出门调节细胞状态来生成。
- 当前隐藏状态:ht=ot⊙tanh(ct)。
经过以上步骤,LSTM在时间步t产生了新的隐藏状态 ht 和更新后的细胞状态 ct,这两个状态都会传递给下一个时间步继续这一过程。
通俗理解
想象一下LSTM是一个记忆力超群的秘书,负责记录一系列事件并根据这些信息做出决策。在前向传播的过程中,这个秘书在每个时间点接收一个新的事件(比如一封邮件),然后决定如何基于过去的信息和当前事件更新自己的记事本和对外表达。
-
遗忘阶段:秘书首先翻看记事本(细胞状态),用一支有魔力的笔(遗忘门,由sigmoid函数控制)标出哪些旧信息不再重要,笔迹越深(值接近1),这部分信息就越会被保留;反之,浅笔迹(值接近0)意味着这部分信息将被擦除。
-
记录新信息阶段:接着,秘书收到今天的邮件(输入数据),开始判断哪些内容值得记录。他先用一个筛选器(输入门,同样是sigmoid函数控制)决定哪些新信息重要,然后把这些信息写成草稿(通过tanh函数转换成合适的格式,得到候选值)。
-
更新记事本:秘书把刚才决定保留的旧信息和新草稿合并。用遗忘门控制保留旧信息的程度,用输入门控制加入新信息的程度,这样就更新了记事本(细胞状态)。
-
决定表达:最后,秘书准备对外发言(生成隐藏状态),他会基于更新后的记事本内容选择如何表达。先用一个策略(输出门,sigmoid函数控制)决定哪些记事本里的内容适合说出来,然后调整这些内容的表达方式(通过tanh函数),最终形成今天的对外声明。
这个过程在每一天(每个时间步)重复,秘书的记事本和发言会根据不断接收的新事件而演变,但总是试图保持最重要的信息,同时适时更新和表达。这就是LSTM前向传播的一个简单而形象的理解。
1.4 RNN为什么会导致梯度消失
RNN(循环神经网络)的反向传播涉及到时间上的展开,因此通常称为“随时间反向传播”(Backpropagation Through Time, BPTT)。简化的RNN单元可以表示为:
在反向传播过程中,我们需要计算损失关于模型参数的梯度。以权重矩阵𝑊ℎ为例,其梯度可以通过链式法则计算,考虑时间序列的影响,这涉及到将梯度从未来时间步向过去时间步传播。若考虑一个简化的梯度计算(忽略偏置项和其它权重矩阵的更新),对于隐藏状态到隐藏状态的权重𝑊ℎ,其梯度可近似表示为:
RNN中的梯度消失主要由以下几个原因造成:
- 深度的累积效应:在时间序列足够长时,即使每一步的梯度衰减很小,累积起来也可能导致严重的梯度消失。
- 激活函数的性质:像tanh和sigmoid这类饱和激活函数的导数在远离原点时接近于0,连乘效果下会使梯度迅速减小。
- 权重初始化和学习率:不恰当的初始化和过大的学习率也可能加剧梯度消失问题。
1.5 LSTM为什么能防止梯度消失
在反向传播过程中,LSTM通过几个关键点来防止梯度消失:
-
非线性激活函数的导数:sigmoid和tanh函数虽然在饱和区(接近0或1)时导数较小,但LSTM中的门控机制使用这些函数来决定信息的流动,而非直接用于信息的累积。因此,即使在某些部分饱和,重要的是门控信号如何影响整个网络,而不是单一神经元的输出。
-
单元状态(Cell State):最重要的特性是单元状态𝑐𝑡ct,它可以通过时间直接传递而不受激活函数约束(除了在更新时通过输入门和遗忘门进行调节)。这意味着,即使在多个时间步之后,单元状态中的信息仍然可以保持其“原始”梯度信息,不会因多次应用非线性函数而导致梯度消失。
-
门控机制:遗忘门和输入门的使用使得LSTM能够有选择地保留或丢弃信息。特别是遗忘门,当其接近1时,可以有效地“记住”过去的梯度,从而避免了在反向传播时因连乘而导致的梯度消失。输入门则控制新信息的加入,保证了新梯度的流入。
反向传播公式简化说明
在反向传播时,对于每个门和单元状态的更新,对应的梯度计算都会涉及到门的导数(如sigmoid函数的导数接近0时仍然能维持非零值,因为它们通常不会完全饱和到0或1),以及tanh函数在单元状态更新中的使用,这些都能帮助梯度顺利回传。
- 遗忘门和输入门的导数(sigmoid函数的导数)为f′(x)=f(x)(1−f(x)),即使f(x)接近边界值,导数也不会接近0,有助于梯度回传。
- 单元状态更新中的tanh函数导数最大为1,保证了梯度可以无损传递。
- 输出门控制了隐藏状态到输出的梯度流,但即使这里存在一定程度的梯度消失风险,由于前面的机制已经保护了大部分重要信息,LSTM整体上仍能有效避免梯度消失。
LSTM通过其独特的门控机制和对单元状态的处理,在理论上能够有效缓解梯度消失问题,允许梯度在长序列中有效地传播回去,从而优化深层网络的训练。
2.pytorch代码实现
pytorch实现LSTM用于股票预测+时间序列预测-CSDN博客