一、说明
在本文中,我们将阅读有关Wasserstein GANs的信息。具体来说,我们将关注以下内容:i)什么是瓦瑟斯坦距离?,ii)为什么要使用它?iii) 我们如何使用它来训练 GAN?
二、Wasserstein距离概念
Wasserstein距离,又称为Earth Mover's Distance (EMD),是衡量两个概率分布之间的差异程度的一种数学方式。它考虑了分布之间的距离和它们之间的“传输成本”。
简单来说,Wasserstein距离将两个分布看作“堆积在地图上的土堆”,并计算将一个堆移到另一个的最小成本。这个距离度量的优点是它能够处理非均匀分布,并且能够考虑分布的形状和结构。
Wasserstein距离在机器学习领域中应用非常广泛,特别是在生成模型中用来评估生成器生成的图像与真实图像之间的差异。
2.1 瓦瑟施泰因距离
Wasserstein 距离(地球移动器的距离)是给定度量空间上两个概率分布之间的距离度量。直观地说,它可以被视为将一个分布转换为另一个分布所需的最小功,其中功被定义为必须移动的分布的质量和要移动的距离的乘积。在数学上,它被定义为:
在方程1中,Π(P_r,P_g)是x和y上所有联合分布的集合,使得边际分布等于P_r和P_g。 γ(x, y)可以看作是必须从x移动到y才能将P_r转换为P_g的质量量[1]。因此,瓦瑟斯坦距离是最佳运输计划的成本。
2.2 瓦瑟斯坦距离 vs. 詹森-香农分歧
最初的GAN目标被证明是Jensen-Shannon分歧的最小化[2]。JS背离定义为:
与JS相比,Wasserstein距离具有以下优点:
- Wasserstein 距离是连续的,几乎可以在任何地方微分,这使我们能够训练模型达到最佳状态。
- 随着鉴别器的变好,JS散度局部饱和,因此梯度变为零并消失。
- Wasserstein 距离是一个有意义的度量,即当分布彼此靠近时,它收敛到 0,当它们越来越远时发散。
- 作为目标函数的 Wasserstein 距离比使用 JS 散度更稳定。当使用Wasserstein距离作为目标函数时,模式崩溃问题也得到了缓解。
从图 1 我们清楚地看到,最佳GAN鉴别器饱和并导致梯度消失,而优化Wasserstein距离的WGAN评论家在整个过程中具有稳定的梯度。
有关数学证明和更详细的研究,请查看此处的论文!
三、瓦瑟斯坦·GAN
现在可以清楚地看到,优化 Wasserstein 距离比优化 JS 散度更有意义,还需要注意的是,方程 1 中定义的 Wasserstein 距离非常棘手[3],因为我们不可能计算所有 γ ∈Π(Pr ,Pg) 的下界(最大下界)。然而,从坎托罗维奇-鲁宾斯坦二元性中,我们有,
这里我们有 W(P_r, P_g) 作为所有 1-Lipschitz 函数 f: X → R 的上确界(最低上限)。
K-利普希茨连续性:给定 2 个度量空间 (X, d_X) 和 (Y, d_Y),变换函数 f: X → Y 是 K-利普希茨连续的,如果
其中d_X和d_Y是各自度量空间中的距离函数。当一个函数是 K-Lipschitz 时,从方程 2 开始,我们最终得到 K ∙ W(P_r, P_g)。
现在,如果我们有一系列参数化函数 {f_w},其中 w∈W 是 K-Lipschitz 连续的,我们可以有
即,w∈W 最大化方程 4 给出瓦瑟斯坦距离乘以一个常数。
四、WGAN评论家
为此,WGAN引入了一个批评者,而不是我们在GAN中了解到的鉴别器。批评者网络在设计上类似于判别器网络,但通过优化找到将最大化方程 4 的 w* 来预测 Wasserstein 距离。为此,批评家的客观功能如下:
在这里,为了在函数f上强制执行Lipschitz连续性,作者诉诸于将权重w限制在一个紧凑的空间内。这是通过将砝码夹紧到一个小范围(论文中的[-1e-2,1e-2][1])来完成的。
鉴别器和批评者之间的区别在于,鉴别器经过训练以正确识别P_r样本和P_g样本,批评家估计P_r和P_g之间的Wasserstein距离。
这是训练批评家的python代码。
for ix in n_critic_steps:opt_critic.zero_grad()real_images = data[0].float().to(device)# * Generate imagesnoise = sample_noise()fake_images = netG(noise)# * though they are name so, they are not logits!real_logits = netCritic(real_images)fake_logits = netCritic(fake_images)# * max E_{x~P_X}[C(x)] - E_{Z~P_Z}[C(g(z))]loss = -(real_logits.mean() - fake_logits.mean())loss.backward(retain_graph=True)opt_critic.step()# * Gradient clipplingfor p in netCritic.parameters():p.data.clamp_(-self.c, self.c)
五、WGAN生成器目标
当然,发电机的目标是最小化P_r和P_g之间的瓦瑟斯坦距离。生成器试图找到最小化P_g和P_r之间的 Wasserstein 距离的 θ*。为此,生成器的目标函数如下:
公式 6:生成器目标函数。
在这里,WGAN生成器和标准生成器之间的主要区别再次在于,WGAN生成器试图最小化P_r和P_g之间的Wasserstein距离,而标准生成器试图用生成的图像欺骗鉴别器。
以下是训练生成器的 python 代码:
opt_gen.zero_grad()noise = sample_noise()fake_images = netG(noise)# again, these are not logits.
fake_logits = netCritic(fake_images)# * - E_{Z~P_Z}[C(g(z))]
loss = -fake_logits.mean().view(-1)loss.backward()
opt_gen.step()
六、培训结果
图例.2显示了训练WGAN的一些早期结果。请注意,图 2 中的图像是早期结果,一旦确认模型按预期训练,训练就会停止。
七、代码
Wasserstein GAN的完整实现可以在这里找到[3]。
八、结论
WGAN提供非常稳定的培训和有意义的培训目标。本文介绍并直观地解释了什么是 Wasserstein 距离,Wasserstein 距离相对于标准 GAN 使用的 Jensen-Shannon 散度的优势,以及如何使用 Wasserstein 距离来训练 WGAN。我们还看到了用于训练 Critic 和生成器的代码片段,以及早期训练模型的大量输出。尽管WGAN比标准GAN具有许多优势,但WGAN论文的作者明确承认,权重裁剪不是执行Lipschitz连续性的最佳方法[1]。为了解决这个问题,他们提出了带有梯度惩罚的Wasserstein GAN[4],我们将在后面的文章中讨论。
如果您喜欢这个,请查看本系列的下一篇文章,其中讨论了 WGAN-GP!