这里写目录标题
- 1. RNN存在哪些问题呢?
- 1.1 梯度弥散和梯度爆炸
- 1.2 RNN为什么会出现梯度弥散和梯度爆炸呢?
- 2. 解决梯度爆炸方法
- 3. Gradient Clipping的实现
- 4. 解决梯度弥散的方法
1. RNN存在哪些问题呢?
1.1 梯度弥散和梯度爆炸
梯度弥散是梯度趋近于0
梯度爆炸是梯度趋近无穷大
1.2 RNN为什么会出现梯度弥散和梯度爆炸呢?
先看RNN的梯度推导公式,如下图:
从hk的梯度求导公式和hk的计算过程可以看出,hk的计算和Whh相关,也就是梯度也与Whh有关,因此从h1 时刻到hk时刻,Whh被乘了k-1次,即Whhk-1,那么当W>1时,就使得Wrk随着k(句子长度)的增大,梯度趋近无穷大,会出现梯度爆炸,而W<1时,Wrk随着k(句子长度)的增大,梯度会趋近于0,会出现梯度弥散。
综上:RNN并不是可以处理无限长的句子,其随着句子的增长可能出现梯度弥散和梯度爆炸的问题
2. 解决梯度爆炸方法
上图为一篇解决梯度爆炸的paper,其中左边的图描述的是梯度爆炸产生的原因,当W出现巨变的时候会导致loss的方向发生变化,从而偏移原来正确的方向,出现梯度爆炸。
解决梯度爆炸的方法是给w.grad设置一个阈值,比如是15,当大于阈值时,将w.grad’=w.grad/||w.grad||15=115=15,从而保证了loss的方向不变,loss虽然可能有一些跳变,比如:从0.23~0.32,,但慢慢的还会下降。
这种方法叫gradient clipping
3. Gradient Clipping的实现
只需获取到模型参数后调用torch.nn.utils.clip_grad_norm_(p,10)即可,10为阈值。
见下图,注意torch.nn.utils.clip_grad_norm_(p,10)和print是平齐的。
4. 解决梯度弥散的方法
下文LSTM会讲。