深度学习:梯度裁剪的理解
- 梯度裁剪简介
- 设置范围值裁剪
- 通过 L2 范数裁剪
- 附
在深度学习领域,梯度裁剪是一个常用的技巧,用于防止梯度过小或过大。下面简单介绍一下 梯度裁剪的原理与方法。
梯度裁剪简介
在深度学习模型的训练过程中,通过梯度下降算法更新网络参数。一般地,梯度下降算法分为前向传播和反向更新两个阶段。
在前向传播阶段,输入向量通过各层神经元的计算,得到输出向量,假设网络可以用一个抽象函数 f f f表示,则公式为:
KaTeX parse error: No such environment: equation at position 8: \begin{̲e̲q̲u̲a̲t̲i̲o̲n̲}̲ y = f(x) \end{…
在计算出网络的估计值后,使用类似均方误差的方法,计算出真值和估计值之间的差距,即损失函数loss:
KaTeX parse error: No such environment: equation at position 8: \begin{̲e̲q̲u̲a̲t̲i̲o̲n̲}̲ loss = \frac{1…
在反向更新阶段,调整网络参数 θ \theta θ包括权重 W W W和偏差 b b b。为了更新网络参数,首先要计算损失函数对于参数的梯度 ∂ l o s s ∂ θ \frac{\partial loss}{\partial \theta} ∂θ∂loss,然后使用某种梯度更新算法,执行一步梯度下降,以减小损失函数值。如下式:
KaTeX parse error: No such environment: equation at position 8: \begin{̲e̲q̲u̲a̲t̲i̲o̲n̲}̲ \theta_{t+1} =…
注意:从上式可以看出有时候,减小学习率和梯度裁剪是等效的。
在上述训练过程中,可能出现梯度值变得特别小或者特别大甚至溢出的情况,这就是所谓的“梯度消失”和“梯度爆炸”,这时候训练很难收敛 。梯度爆炸一般出现在由初始权重计算的损失特别大的情况,大的梯度值会导致参数更新量过大,最终梯度下降将发散,无法收敛到全局最优。此外, 随着网络层数的增加,"梯度爆炸"的问题可能会越来越明显。考虑具有三层隐藏层网络的链式法则公式,如果每一层的输出相对输入的偏导 > 1,随着网络层数的增加,梯度会越来越大,则有可能发生 “梯度爆炸”。
KaTeX parse error: No such environment: equation at position 8: \begin{̲e̲q̲u̲a̲t̲i̲o̲n̲}̲ \frac{\partial…
当出现下列情形时,可以认为发生了梯度爆炸:两次迭代间的参数变化剧烈,或者模型参数和损失函数值变为 NaN。
如果发生了 “梯度爆炸”,在网络学习过程中会直接跳过最优解,甚至可能会发散(无法收敛),所以有必要进行梯度裁剪,防止网络在学习过程中越过最优解。梯度裁剪方式:设置范围值裁剪和通过 L2 范数裁剪。
设置范围值裁剪
设置范围值裁剪方法简单,将参数的梯度限定在一个范围内,如果超出这个范围,则进行裁剪(大于阈值为上限阈值 max \max max;小于阈值为下限阈值 min \min min),但是阈值通常较难确定一个合适的。
KaTeX parse error: No such environment: equation at position 8: \begin{̲e̲q̲u̲a̲t̲i̲o̲n̲}̲ y=\left\{ \beg…
通过 L2 范数裁剪
通过 L2 范数裁剪是用阈值限制梯度向量的 L2 范数,从而对梯度进行裁剪。
KaTeX parse error: No such environment: equation at position 8: \begin{̲e̲q̲u̲a̲t̲i̲o̲n̲}̲ y=\left\{ \beg…
附
在模型训练时出现报错:
ValueError: matrix contains invalid numeric entries
通过print对应报错位置的变量可以发现出现nan值:
tensor([[nan, nan, nan, ..., nan, nan, nan],[nan, nan, nan, ..., nan, nan, nan],[nan, nan, nan, ..., nan, nan, nan],...,[nan, nan, nan, ..., nan, nan, nan],[nan, nan, nan, ..., nan, nan, nan],[nan, nan, nan, ..., nan, nan, nan]], device='cuda:0',
出现此问题的原因:
(1)梯度爆炸
(2)不当的输入
(3)不当的模型设计
解决方法:
(1)降低学习率;
(2)加入归一化Norm;
(3)加入梯度裁剪gradient clipping;
(4)数据存在脏数据,需要清洗;
(5)检查网络设计是否存在错误。