文章目录
- 1 什么是RNN
- 2 LSTM
- 3 Training
- 3.1 Learning Target
- 3.2 为什么难train
- 4 应用举例
- 4.1 Many To One
- 4.2 Many To Many
- 4.3 其他
本文为李弘毅老师【Recurrent Neural Network(Part I)】和【Recurrent Neural Network(Part II)】的课程笔记,课程视频来源于youtube(需翻墙)。
下文中用到的图片均来自于李宏毅老师的PPT,若有侵权,必定删除。
1 什么是RNN
顾名思义,RNN就是一个不断循环的神经网络,只不过它在循环的过程当中是有记忆的,这也是发明RNN的初衷,就是希望神经网络在看一个序列的输入的时候可以考虑一下前面看过的内容。
我们以Slot Filling来举例。Slot Filiing就是填空的意思,比如今天有一个旅客说了一句"I would like to arrive Taipei on November 2nd",我们的订票系统就需要从这句话中找到"目的地"和"期望的到达时间"这两个slot。而帮助订票系统来解析这个句子的,也就是我们的RNN模型。
RNN是怎么运作的呢?简单粗暴地来说,就比如我们先往RNN里塞进一个"arrive"和一个随机初始化的状态向量a0a^0a0,然后RNN会输出一个output(yiy^iyi)和一个hidden state(aia^iai)。output用来表示"arrive"这个词在每个slot中的概率,hidden state是一个向量,包含着看过"arrive"之后,模型自己记下来的信息,然后再把这个hidden state和下一个输入"Taipei"塞进RNN里,如此反复,直到把整个句子看完。这就是一个最最最简单的RNN的运作流程,aia^iai就表示着RNN的记忆。不过也正是因为RNN需要记忆,所以RNN没法并行计算。
市面上两种RNN的记忆方法,一种是利用hidden state来当作记忆,叫做"Elman Network",另一种是利用"output"当作记忆,叫做"Jordan Network"。据说,output是有label在监督的,不像hidden state那么自由,所以"Jordan Network"的效果会好一些。不过,市面上也是有把两者结合起来的记忆方法的,名字就不知道了。
当然,这样的记忆只是单向的,有些时候句子的理解是需要句子后面的一些词汇的辅助的。为了解决这个问题,也就有了双向的RNN。双向的RNN的两个方向是可以并行计算的。下面这幅图应该是比较清楚的了,每个output是结合了从头到尾和从尾到头两个方向的output。
2 LSTM
市面上在使用的RNN并不是上述那么简单的,其中LSTM是比较常用的一种方法。LSTM的每一个cell的结构如下图所示,它吃四个input,吐出一个output。四个input分别是该time step的输入、控制input gate的信号、控制forget gate的信号以及控制output gate的信号。最中间的那个memory cell是用来存储之前序列留下的记忆信息的。input gate用来决定要不要接收这个输入,forget gate用来决定要不要使用之前的记忆,output gate用来决定要不要输出这个输出。
这么说可能还是有点糊涂,看下面这张图吧。比如我们某个time step的输入为xtx^txt,首先这个xtx^txt会分别乘以一个矩阵得到LSTM需要的四个输入。
z=Wxt+bzf=Wfxt+bfzi=Wixt+bizo=Woxt+boz = Wx^t+b\\ z^f = W^fx^t+b^f\\ z^i = W^ix^t+b^i\\ z^o = W^ox^t+b^o z=Wxt+bzf=Wfxt+bfzi=Wixt+bizo=Woxt+bo
然后,ziz^izi会经过一个激活层,来控制输入zzz,我们把这个中间量记作inputinputinput吧。
input=σ(zi)∗zinput = \sigma(z^i)*z input=σ(zi)∗z
同时,zfz^fzf也会经过一个激活层,来决定是否要使用之前的记忆ct−1c^{t-1}ct−1,我们记这个中间变量叫做memorymemorymemory吧。
memory=σ(zf)∗ct−1memory = \sigma(z^f)*c^{t-1} memory=σ(zf)∗ct−1
这个inputinputinput和memorymemorymemory会相加在一起,作为输出的结果,这个结果由经过一层激活层的zoz^ozo来控制。这个输出我们叫做hidden state(hhh)。
h=σ(zo)∗σ(input+memory)h = \sigma(z^o)*\sigma(input+memory) h=σ(zo)∗σ(input+memory)
最后的结果yty^tyt一般还要来一层全连接。
yt=Wyhy^t = W^yh yt=Wyh
不过这只是一个time step,LSTM在多个循环的时候,长下面这样。
可见,实际情况下,我们的输入并不是xtx^txt,而是xtx^txt,hth^tht和ct−1c^{t-1}ct−1的结合。其中,利用ct−1c^{t-1}ct−1的这个操作被称为peehole。
而现在主流的框架之间的实现,也略有差别,比如pytorch的实现就没有利用peehole,激活层也有一些区别,不过整体的思路是完全一致的。
3 Training
3.1 Learning Target
在进行训练的时候,我们需要一个目标。这个目标其实是需要根据实际的应用场景来定的。比如,我们还是用上面订票系统的例子,我们的每一个time step的输出是一个概率向量,分别表示着[“other”, “dest”, “time”]的概率大小。我们的label就是一个one-hot encoding的向量,比如"arrive"的label中"other"的标签为1,其他为0;"Taipei"的label中"dest"的标签为1,其他为0;"on"的label中"other"的标签为1,其他为0;以此类推。然后用prediction算下cross entrophy loss就行了。
而RNN的反向传播也是和其他的神经网络一样,是可以用梯度下降来做的。不过,因为它是有时间顺序的,所以计算时略有不同,得要用一个叫做BPTT(Backpropagation through time)的方法来做,这里不详述了。总之就是可以和其他网络一样train下去。想了解的可以看下吴恩达的RNN W1L04 : Backpropagation through time。
3.2 为什么难train
虽然RNN也是和正常的神经网络那样可以用gradient descent不断地更新参数来train下去,但在RNN刚出来的时候,几乎没有人可以把它train出来,往往会得到一条如下图绿色曲线这样的结果。只有一个叫做Razvan Pascanu的人,可以train出那条蓝色的曲线。其实原因时因为RNN的loss space非常陡峭,参数微小的变动,可能引起loss极大的改变。Razvan Pascanu在他写博士论文的时候,把他一直可以train出好结果的秘诀公布了出来,那就是gradient clipping,即人为地把gradient的大小限制住了。没错就是这么简单的一个技巧。
但究竟为什么RNN的梯度会发生这么大的变化呢?我们来举个例子说明一下。假如我们有一个全世界最简单的RNN,它输入的weight是1,输出的weight也是1,用来memory的weight为www,那么当我们输入一个长度为1000,且只有第一个元素为1,其余都为0的序列时,最后一个time step的输出就为y1000=w999y^{1000}=w^{999}y1000=w999。
这是一个什么概念呢?比如我们的w=1w=1w=1,那么y1000=1y^{1000}=1y1000=1。而此时,www只要稍稍变大一点,那么y1000y^{1000}y1000就会产生很大的变化,比如w=1.01w=1.01w=1.01时,y1000=20000y^{1000}=20000y1000=20000。这时候也就会发生所谓的梯度爆炸。而当www降到1一下时,y1000y^{1000}y1000又一直变为0了,这也就是所谓的梯度消失。
而这一切的一切都是因为RNN的参数在循环的过程中被不断的重复使用。
LSTM在一定程度上是可以解决梯度消失的问题的。为什么LSTM可以解决梯度消失?我感觉李老师这里还是有点没讲清楚,推荐看下这篇blog(需翻墙)。一句话就是,传统的RNN的反向传播操作(BPTT)有个连乘的东西在,在LSTM的BPTT里是连加的,然后LSTM又有forget gate这个门在,可以使得相隔较远的序列不相互影响。
简单来说就是用clip可以缓解梯度爆炸,用LSTM可以缓解梯度消失。
我有一点很想知道的是,发明LSTM的大佬,是为了解决梯度消失的问题而发明了LSTM呢?还是发明了LSTM之后,发现可以缓解梯度消失?这个也真是太6了~
4 应用举例
RNN的应用范围非常广泛,这里简单列举一下李老师视频中提到的一些例子。
4.1 Many To One
- 情感分析。输入一段评论,输出该段评论是好评还是差评。
- 关键信息提取。输入一篇文章,输出该文章中的关键信息。
4.2 Many To Many
- 语音识别。输入一段语音,输出对应的文字。
- 语言翻译。输入一段某国的文字或语音,输出一段另一个国家的对应意思的文字或语音。
- 聊天机器人。输入一句话,输出回答。
4.3 其他
- 句子文法结果分析。输入一个句子,输出该句子的文法结构。
- 句子自编码。