论文:https://arxiv.org/abs/2202.00512v2
这篇论文提出了一种"渐进式蒸馏"(Progressive Distillation)的方法,来大幅降低用于生成对抗网络(GAN)的扩散模型(Diffusion Model)在推理阶段的采样步数,从而提高其推理效率。主要贡献有:
-
提出了一种新的扩散模型参数化方式,使得模型在采样步数较少时也能保持稳定。
-
提出了一种渐进式的蒸馏算法,将原先需要几千步采样的扩散模型蒸馏成只需4步采样就能生成高质量样本的新模型。
-
证明了整个渐进式蒸馏过程的计算开销不会超过训练原始模型的开销,从而实现了在训练和推理阶段都高效的生成建模。
论文在标准的图像生成基准测试上,如CIFAR-10、ImageNet和LSUN,展示了其所提方法的有效性。通过渐进式蒸馏,论文将原始需要8192步采样的模型成功蒸馏为只需4步采样就能取得接近的生成质量的新模型,大幅提高了扩散模型的推理效率。
(1)新的扩散模型参数化方式及其稳定性:
在论文中,为了提高扩散模型在使用较少采样步骤时的稳定性,提出了几种新的参数化方式。这些参数化方式包括:
- 直接预测去噪后的样本 ( \hat{x} )。
- 通过神经网络的两个独立输出通道分别预测样本 ( \hat{x} ) 和噪声 ( \hat{\epsilon} ),然后根据它们合并预测,平滑地在直接预测 ( x ) 和通过 ( \epsilon ) 预测之间插值。
- 预测一个新变量 ( v ),其定义为 ( v \equiv \alpha_t \epsilon - \sigma_t x ),从而得到 ( \hat{x} = \alpha_t z_t - \sigma_t \hat{v} \theta(z_t) )。
这些参数化方式之所以能够保持稳定,是因为它们能够在信号噪声比(SNR)变化时保持对 ( \hat{x} ) 的稳定预测。特别是预测 ( v ) 的方式,由于它使得 DDIM 步骤的大小与 SNR 无关,因此在低 SNR 条件下仍然能够保持稳定。
(2)关于采样步骤的问题:
在传统的扩散模型(如 DDPM)中,生成高质量样本通常需要大量的采样步骤,这些步骤可能达到数百或数千步。然而,通过渐进式蒸馏方法,可以将这些需要大量步骤的模型转化为只需少量步骤(例如4步)的新模型,同时保持样本质量。这种方法通过迭代地将一个需要较多步骤的模型(称为教师模型)蒸馏成一个需要较少步骤的学生模型,每次迭代都将采样步骤减半,直到达到所需的采样步骤数量。
(3)渐进式蒸馏过程的计算开销:
论文中提出,渐进式蒸馏过程的计算开销不会超过训练原始模型的开销。这是通过以下方式实现的:
- 在每次蒸馏迭代中,学生模型的参数是从教师模型复制而来,并且使用的训练数据是从训练集中采样得到的带噪声数据。
- 蒸馏过程中的目标设置与标准训练不同。在蒸馏中,学生模型的目标是通过教师模型进行两次 DDIM 采样步骤得到的 ( \tilde{x} ),而不是原始数据 ( x )。这意味着学生模型只需要在一个步骤中预测出与教师模型两次步骤相同的结果。
- 蒸馏过程中,每一步的学生模型训练所需的参数更新次数是固定的,并且随着采样步骤的减少,学生模型的复杂度也在减少,这使得每次迭代的训练时间得以控制。
通过这种方式,整个蒸馏过程可以在与原始模型训练相当的时间内完成,从而实现了在训练和推理阶段都高效的生成建模。这种方法不仅减少了生成高质量样本所需的时间,而且还降低了在实际应用中运行扩散模型的计算成本。