1、长短期记忆网络LSTM简介
在RNN 计算中,讲到对于传统RNN水平方向进行长时刻序列依赖时可能会出现梯度消失或者梯度爆炸的问题。LSTM 特别适合解决这种需要长时间依赖的问题。
LSTM(Long Short Term Memory,长短期记忆网络)是RNN的一种,大体结构一直,区别在于:
- LSTM 的‘记忆cell’ 是被改造过的,水平方向减少梯度消失与梯度爆炸
- 该记录的信息会一直传递,不该记录的信息会被截断掉,部分输出和输入被从网络中删除
RNN 在语音识别,语言建模,翻译,图片描述等问题的应用的成功,都是通过 LSTM 达到的。
2、LSTM工作原理
2.1、传统的RNN“细胞”结构
所有 RNN 都具有一种重复神经网络模块的链式的形式。在标准的 RNN 中,这个重复的模块只有一个非常简单的结构,例如一个 tanh 层。
2.2、LSTM结构
如下图展示了LSTM的一个神经元内部的结构。单一神经网络层,这里是有四个,以一种非常特殊的方式进行交互。
图中使用的各种元素的图标:
- 每一条黑线传输着一整个向量,从一个节点的输出到其他节点的输入。合在一起的线表示向量的连接,比如一个十维向量和一个二十维向量合并后形成一个三十维向量;分开的线表示内容被复制,然后分发到不同的位置。
- 粉色的圈代表 pointwise 的操作,诸如向量的加法,减法,乘法,除法,都是矩阵的。
- 黄色的矩阵就是神经网络层。
2.3、细胞状态
LSTM关键:“细胞状态” 。细胞状态类似于传送带。直接在整个链上运行,只有一些少量的线性交互。信息在上面流传保持不变很容易。
LSTM怎么控制“细胞状态”?
- LSTM可以通过gates(“门”)结构来去除或者增加“细胞状态”的信息
- 包含一个sigmoid神经网络层次和一个pointwist乘法操作
- Sigmoid层输出一个0到1之间的概率值,描述每个部分有多少量可以通过,0表示“不允 许任务变量通过”,1表示“运行所有变量通过”
- LSTM中主要有三个“门”结构来控制“细胞状态”
- 和 矩阵的shape 是一样的
忘记门
决定从“细胞状态”中丢弃什么信息;比如在语言模型中,细胞状态可能包含了性别信息(“他”或者“她”),当我们看到新的代名词的时候,可以考虑忘记旧的数据
信息增加门
- 决定放什么新信息到“细胞状态”中;
- Sigmoid层决定什么值需要更新;
- Tanh层创建一个新的候选向量;
- 主要是为了状态更新做准备
经过第一个和第二个“门”后,可以确定传递信息的删除和增加,即可以进行 “细胞状态”的更新
- 更新为;
- 将旧状态与相乘,丢失掉确定不要的信息;
- 加上新的候选值得到最终更新后的“细胞状态”
输出门
- 首先运行一个sigmoid网络层来确定细胞状态的那个部分将输出
- 使用tanh处理细胞状态得到一个-1到1之间的值,再将它和sigmoid门的输出相乘,输出程序确定输出的部分。
- 前向传播和反向传播可以参看前面的传播过程写下来,更新LSTM中的参数。具体的公式可以参看:RNN以及LSTM的介绍和公式梳理_Dark_Scope的博客-CSDN博客_lstm公式
- 作者的论文:https://arxiv.org/pdf/1402.1128v1.pdf
2.4、 前向传播和后向传播
前向传播
现在我们来总结下LSTM前向传播算法。LSTM模型有两个隐藏状态,模型参数几乎是RNN的4倍,因为现在多了这些参数。
前向传播过程在每个序列索引位置的过程为:
1)更新遗忘门输出:
2)更新输入门两部分输出:
3)更新细胞状态:
4)更新输出门输出:
5)更新当前序列索引预测输出:
整体的过程如下图所示
可以看到,在t tt时刻,用于计算和。
反向传播
有了LSTM前向传播算法,推导反向传播算法就很容易了, 思路和RNN的反向传播算法思路一致,也是通过梯度下降法迭代更新我们所有的参数,关键点在于计算所有参数基于损失函数的偏导数。
在RNN中,为了反向传播误差,我们通过隐藏状态的梯度一步步向前传播。在LSTM这里也类似。只不过我们这里有两个隐藏状态和。这里我们定义两个,即:
反向传播时只使用了,变量仅为帮助我们在某一层计算用,并没有参与反向传播,这里要注意。如下图所示:
因为,我们在输出层定义的损失函数为对数损失,激活函数为softmax激活函数。因为,与RNN的推导类似,在最后的序列索引位置 的和为:
接着我们由反向推导。
的梯度由本层的输出梯度误差决定,与公式(10)类似,即:
而的反向梯度误差由前一层的梯度误差和本层的从传回来的梯度误差两部分组成,即:
公式(13)的前半部分由公式(4)和公式(9)得到,公式(13)的后半部分由公式(6)和公式(8)得到。
有了和, 计算这一大堆参数的梯度就很容易了,这里只给出的梯度计算过程,其他的的梯度大家只要照搬就可以了。
公式(13)的由公式(1)、公式(4)和公式(9)得到。
由上面可以得到,只要我们清晰地搞清楚前向传播过程,且只使用了进行反向传播的话,反向传播的整个过程是比较清晰的。
在这里有必要解释下为什么反向传播不使用,如果与循环神经网络(RNN)模型的前向反向传播算法里一样的话,那么的计算方式就不应该是(12)式了
因为,参与了和的计算,所以在RNN文章里的求梯度方法,应该是
但是,这里是一个比较复杂的时序模型,如果使用RNN的思路,将的部分也一起反向传播回来的话,这里的反向梯度根本无法得到闭式解。而只考虑一个的话,也可以做反向梯度优化,进度下降,但是优化起来容易的多,可以理解为这里做了一个近似。
3、另一种理解方式
图中方框我们称为记忆单元,其中实线箭头代表当前时刻的信息传递,虚线箭头表示上一时刻的信息传递。从结构图中我们看出,LSTM模型共增加了三个门: 输入门、遗忘门和输出门。进入block的箭头代表输入,而出去的箭头代表输出。
前向传播公式
上图中所有带h的权重矩阵均代表一种泛指,为LSTM的各种变种做准备,表示任意一条从上一时刻指向当前时刻的边,本文暂不考虑。与上篇公式类似,a代表汇集计算结果,b代表激活计算结果, Wil代表输入数据与输入门之间的权重矩阵, Wcl代表上一时刻Cell状态与输入门之间的权重矩阵, WiΦ代表输入数据与遗忘门之间的权重矩阵, WcΦ代表上一时刻Cell状态与遗忘门之间的权重矩阵, Wiω代表输入数据与输出门之间的权重矩阵, Wcω代表Cell状态与输出门之间的权重矩阵, Wic代表输入层原有的权重矩阵。 需要注意的是,图中Cell一栏描述的是从下方输入到中间Cell输出的整个传播过程。
反向传播
和朴素RNN的推导一样,有了前向传播公式,我们就能逐个写出LSTM网络中各个参数矩阵的梯度计算公式。首先,由于输出门不牵扯时间维度,我们可以直接写出输出门Wiω和Wcω的迭代公式,如下图:
遗忘门的权重矩阵 WiΦ也可以直接给出,如下图:
而对于遗忘门的权重矩阵 WcΦ,由于是和上一时刻Cell状态做汇集计算,残差除了来自当前Cell,还来自下一时刻的Cell,因此需要写出下一时刻Cell传播至本时刻遗忘门的时间维度前向传播公式,如下图:
有了上面的公式,我们就能完整写出 WcΦ的梯度公式了。如下图所示(如果对这个时间维度前向公式不理解,可以参考上一篇我对朴素RNN的公式推导过程):
请注意,上图中L”和前面的L’不一样,这里只是为了式子简洁。
推完遗忘门公式,就可以此类推输入门与Cell的公式。其中输入门基本与遗忘门的推法一样,残差都是来自本时刻和下一时刻Cell。而Cell的残差则来自三个地方:输出层、输出门和下一时刻Cell。其中输出层和输出门残差可直接写出;而下一时刻Cell的残差,我们只要写出对应的时间维度前向传播公式便可写出。由于时间关系,这里就不详细推导遗忘门和Cell的梯度公式了,各位若有兴趣可自行继续推导。
相比于朴素RNN模型,LSTM模型更为复杂,且可调整和变化的地方也更多。比如:增加peephole将Cell状态连接到每个门,变体模型Gated Recurrent Unit (GRU),以及后面出现的Attention模型等。LSTM模型在语音识别、图像识别、手写识别、以及预测疾病、点击率和股票等众多领域中都发挥着惊人的效果,是目前最火的神经网络模型之一。敬请期待下节。
4、LSTM 的变体
我们到目前为止都还在介绍正常的 LSTM。但是不是所有的 LSTM 都长成一个样子的。实际上,几乎所有包含 LSTM 的论文都采用了微小的变体。差异非常小,目前为止有上百种,常用的也就几种。 其中一个流形的 LSTM 变体,就是由 Gers & Schmidhuber (2000) 提出的,
4.1、变种1
- 增加了 “peephole connection”层。
- 让门层也会接受细胞状态的输入。
4.2、变种2
通过耦合忘记门和更新输入门(第一个和第二个门);也就是不再单独的考虑忘记什么、增 加什么信息,而是一起进行考虑。
4.3、Gated Recurrent Unit(GRU),2014年提出
- 将忘记门和输入门合并成为一个单一的更新门
- 同时合并了数据单元状态和隐藏状态
- 结构比LSTM的结构更加简单
4.4、https://arxiv.org/pdf/1402.1128v1.pdf 论文
论文中定义的 LTSM cell 如下图所示:
图示
- 代表两个数据源乘上参数后相加。代表两个数据源相加。
- 外面再加花边的,代表两个数据源相乘后再取
sigmoid
。 - 圆圈里是gg的,代表取
tanh
。 - state下标-1代表这是上一次迭代时的结果。
所以像论文里指出的,这里实现的 LSTM Cell 含有更多参数,效果更好?
一般的 LSTM 就够用了,GRU 用的也比较多。
参考
- LSTM Forward and Backward Pass
- Understanding LSTM Networks
- https://arxiv.org/pdf/1402.1128v1.pdf