(stop gradient,停止梯度)运算
它简称为sg
也就是说,前向传播时,𝑠𝑔里的值不变(sg=1);反向传播时,𝑠𝑔按值为0求导(sg=0),即此次计算无梯度。
为什么需要解决梯度截断(不可导)?
按理说我们需要求的损失函数为:
但是因为在encoder后的z(z_q),是通过在codebook中arg min得到的,所以会导致梯度消失,梯度无法在反向传播中从decoder传递到encoder。
怎么解决?
所以使用Straight-Through(直通估计)来解决这个问题:
直通估计器(Straight-Through Estimator, STE)的核心思想是:
- 前向传播:正常计算模型的输出。
- 反向传播:用一种自定义的方式计算梯度,以绕过某些不可微的操作
根据这个思想,我们可以设计一个把梯度从𝑧𝑒(𝑥)复制到𝑧𝑞(𝑥)的loss:
用在VQ-VAE的设计上,那就是:
- 前向传播:计算 decoder(zq(x))并基于此计算重建loss值。
- 反向传播:通过 stop gradient 操作,梯度将直接传递回 ze(x),从而更新encoder,而不会受到 zq(x) 和 ze(x) 之间量化过程的影响。
前向传播时,𝑠𝑔里的值不变(sg=1):
就是拿解码器来进行𝑧𝑞(𝑥)的解码并计算损失:
反向传播时,𝑠𝑔按值为0求导(sg=0):
按下面这个公式进行梯度回传(参数更新),等价于把解码器的梯度全部传给𝑧𝑒(𝑥):
代码实现方法:
在PyTorch里,(x).detach()
就是𝑠𝑔(𝑥),它的值在前向传播时取x
,反向传播时取0
。
L = x - decoder(z_e + (z_q - z_e).detach())
通过这一技巧,我们完成了梯度的传递,可以正常地训练编码器和解码器了。
VQ-VAE的简明介绍:量子化自编码器 - 科学空间|Scientific Spaces
轻松理解 VQ-VAE:首个提出 codebook 机制的生成模型 | 周弈帆的博客