论文精读(保姆级解析)—— Flash Diffusion

0 前言

  今天分析的论文是《Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation》。该论文发表在2024年,目前已开源在arxiv上,主要提出了一种高效、快速且多功能的蒸馏方法,用于加速预训练扩散模型的生成:Flash Diffusion。下面给出论文的地址和代码仓库链接:

  • paper
  • code

1 摘要

  本文提出了一种高效、快速且多功能的蒸馏方法,以加速预训练扩散模型的生成:Flash Diffusion。在COCO2014和COCO2017数据集上,该方法在FID和CLIP-Score方面达到了最先进的性能,与现有的方法相比,只需几个GPU小时的训练和更少的可训练参数。除了其高效性外,该方法的多功能性也在多个任务中得到了体现,例如文本生成图像、修复、换脸、超分辨率,以及使用不同的骨干网络(如基于UNet的去噪器SD1.5、SDXL或DiT (Pixart-α))和适配器。在所有情况下,该方法都显著减少了采样步骤的数量,同时保持了非常高质量的图像生成。
在这里插入图片描述

2 引言

  扩散模型很有效,应用非常广泛。然而,由于采样机制的内在迭代行,导致计算成本比较高。。。
  最近,出现了一些更高效的求解器或扩散蒸馏方法,旨在从训练的扩散模型中减少所需的采样步骤以生成令人满意的样本。然而,求解器通常需要至少10次神经函数评估(NFEs)才能生成令人满意的样本,而蒸馏方法可能需要大量的训练资源或需要一种迭代训练程序来在整个训练过程中更新teacher model,这限制了它们的应用范围。此外,大多数现有的蒸馏方法是为特定任务(如文本生成图像)量身定制的,目前尚不清楚它们在使用不同条件和扩散模型架构的其他任务中的表现。此外,最有效的方法依赖对抗训练程序,这可能导致训练不稳定并需要大量的超参数调整。

  在本文中,我们提出了Flash Diffusion,这是一种快速、稳健且多功能的扩散蒸馏方法,能够大幅减少采样步骤的数量,同时保持非常高的图像生成质量。提出的方法旨在训练一个student model,使其在单步预测中对损坏的输入样本进行去噪的多步教师预测。该方法还通过对抗目标将学生分布引导至真实输入样本流形,并通过分布匹配确保其不会过度偏离已学习的教师分布。

  该方法与LoRA兼容,与现有的方法相比,该方法能够在仅需几个步骤的情况下生成高质量的样本,同时只需要几个GPU小时的训练时间更少的可训练参数。该方法能够在COCO2014和COCO2017数据集上以少步图像生成的FID和CLIP分数达到最新的(SOTA)性能。除了其高效性外,该方法在多个任务(如文本生成图像、修复、换脸、图像放大)以及使用不同的扩散模型骨干(SD1.5 [57]、SDXL [50] 和Pixart-α [5])和适配器 [46] 中也展示了其多功能性。在所有情况下,该方法都能显著减少采样步骤的数量,同时保持非常高质量的图像生成(论文都得吹一下)。

  本文的主要贡献如下:

  • 提出了一种高效、快速、多功能且与LoRA兼容的蒸馏方法,旨在减少从训练的扩散模型中生成高质量样本所需的采样步骤。
  • 验证了该方法在文本生成图像任务中的效果,显示其能够在标准基准数据集上仅用两个神经函数评估(NFE)以及少量的图像生成步骤达到SOTA效果,相当于使用无分类器指导的一步,同时所需的训练参数远少于竞争对手,仅需几个GPU小时的训练。
  • 进行了广泛的消融研究,展示了该方法各个组成部分的影响,并证明了其稳健性和可靠性。
  • 通过广泛的实验研究,强调了该方法的多功能性,涵盖了各种任务(文本生成图像、图像修复、超分辨率、换脸)和扩散模型架构(SD1.5、SDXL和Pixart-α),并展示了其与适配器的兼容性。

3 相关工作

3.1 扩散模型

  扩散模型包括根据给定的噪声调度人为地破坏输入数据[64,17,67],使数据分布最终类似于标准高斯分布。然后,它们被训练来估计添加的噪声量,以学习反向扩散过程,从而在训练完成后能够从高斯噪声生成新样本。这些模型可以根据各种输入进行条件化,如图像、深度图、边缘、姿势或文本,在这些条件下它们展示了非常令人印象深刻的结果。然而,为了生成高质量的样本,在推理时需要大量的采样步骤(通常为50步),这限制了它们在实时应用中的使用和推广。

3.2 扩散蒸馏法

  为了解决这一限制,最近出现了几种方法来减少推理时所需的函数计算次数。一方面,几篇论文尝试构建更高效的求解器来加速生成过程,但这些方法仍然需要使用多个步骤(通常为10步)来生成令人满意的样本。另一方面,一些依赖模型蒸馏的方法提出训练一个学生网络,使其学会在更少的步骤中匹配教师模型生成的样本。一种简单的方法是建立噪声/教师样本对,并训练一个学生模型,使其在具有回归损失的单步中匹配来自相同噪声的教师预测。尽管如此,这种方法仍然非常有限,并且很难与教师模型的质量相匹配,因为在充满噪音的环境中,学生没有潜在的有用信息可以学习。在此基础上,一些方法提出先对输入样本应用正向扩散过程,然后将其传递给学生网络。学生的预测然后使用回归损失、对抗目标或分布匹配与教师模型的学习分布进行比较。

3.3 渐进式蒸馏

  渐进蒸馏(Progressive Distillation)也是一种被证明相当有前景的方法。它包括训练一个学生模型在一个步骤中预测一个噪声样本的两步教师去噪,理论上减少了一半所需的采样步骤。然后教师模型被新的学生模型替换,这个过程重复多次。这种方法也被丰富为基于GAN的目标,使得所需的采样步骤从4-8步进一步减少到一个步骤。InstaFlow提出依靠修正流来简化单步蒸馏过程。然而,这种方法可能需要大量的训练参数和长时间的训练过程,使其计算密集。

3.4 一致性模型

  一致性模型(Consistency Models)也是一种在文献中提出的有前景、有效且多功能的蒸馏方法。主要思想是训练一个模型,将位于概率流常微分方程上的任何点映射到其原点,理论上解锁单步生成。Luo等人结合潜在一致性模型和LoRAs,展示了在非常有限的训练参数和几个GPU数小时的训练下,训练出一个强大的学生模型的可能性。然而,这些模型仍然难以实现单步生成并达到同类方法的采样质量。

  在最近进行的一项平行研究中,Yin等人还引入了联合使用分布匹配损失和对抗损失的方法,作者也在论文中使用了这种方法。然而,它们不依赖于在我们的实验中证明非常有效的蒸馏损失的使用,也不计算相对于相同输入的对抗损失。此外,他们的方法仍然需要训练另一个去噪器来评估假样本的分数,显著增加了可训练参数的数量和方法的计算负担。此外,他们方法在不同任务和扩散模型架构中进行泛化和有效表现的能力仍不明确。

4.1 扩散模型

  设 x 0 ∈ X x_0 \in X x0X 是一组输入数据,使得 x 0 ∼ p ( x 0 ) x_0 \sim p(x_0) x0p(x0),其中 p ( x 0 ) p(x_0) p(x0) 是一个未知分布。扩散模型是一类生成模型,它们定义了一个马尔可夫过程 ( x t ) t ∈ [ 0 , T ] (x_t)_{t \in [0, T]} (xt)t[0,T],通过向数据 x 0 x_0 x0中逐步注入高斯噪声来创建 x 0 x_0 x0的噪声版本 x t x_t xt。随着 t t t 增加,噪声样本 x t x_t xt 的分布最终变得等同于各向同性高斯分布。噪声调度由两个可微函数 α ( t ) \alpha(t) α(t) σ ( t ) \sigma(t) σ(t) 控制,对于任意 t ∈ [ 0 , T ] t \in [0, T] t[0,T],使得信噪比的对数 log ⁡ [ α ( t ) 2 / σ ( t ) 2 ] \log[\alpha(t)^2 / \sigma(t)^2] log[α(t)2/σ(t)2] 随时间递减。给定任意 t ∈ [ 0 , T ] t \in [0, T] t[0,T],噪声样本相对于输入 x 0 x_0 x0 的分布 q ( x t ∣ x 0 ) q(x_t | x_0) q(xtx0) 称为前向过程,定义为 q ( x t ∣ x 0 ) = N ( x t ; α ( t ) ⋅ x 0 , σ ( t ) 2 ⋅ I ) q(x_t | x_0) = \mathcal{N}(x_t; \alpha(t) \cdot x_0, \sigma(t)^2 \cdot I) q(xtx0)=N(xt;α(t)x0,σ(t)2I),可以如下进行采样:
x t = α ( t ) ⋅ x 0 + σ ( t ) ⋅ ϵ 其中 ϵ ∼ N ( 0 , I ) (1) x_t = \alpha(t) \cdot x_0 + \sigma(t) \cdot \epsilon \quad \text{其中} \quad \epsilon \sim \mathcal{N}(0, I) \tag{1} xt=α(t)x0+σ(t)ϵ其中ϵN(0,I)(1)

  扩散模型的主要思想是学习对噪声样本 x t ∼ q ( x t ∣ x 0 ) x_t \sim q(x_t | x_0) xtq(xtx0) 进行去噪,以学习反向过程,最终从纯噪声中生成样本 x ~ 0 \tilde{x}_0 x~0。在实践中,在训练过程中,扩散模型包括学习一个以时间步长 t t t为条件的参数化函数 x θ x_\theta xθ,并将噪声样本 x t x_t xt作为输入,使其预测原始样本 x 0 x_0 x0的去噪版本。参数 θ \theta θ 通过去噪得分匹配学习:
L = E x 0 ∼ p ( x 0 ) , t ∼ π ( t ) , ϵ ∼ N ( 0 , I ) [ λ ( t ) ∥ x θ ( x t , t ) − x 0 ∥ 2 ] (2) L = \mathbb{E}_{x_0 \sim p(x_0), t \sim \pi(t), \epsilon \sim \mathcal{N}(0, I)} \left[ \lambda(t) \left\| x_\theta(x_t, t) - x_0 \right\|^2 \right] \tag{2} L=Ex0p(x0),tπ(t),ϵN(0,I)[λ(t)xθ(xt,t)x02](2)
  其中 λ ( t ) \lambda(t) λ(t) 是一个取决于时间步 t ∈ [ 0 , 1 ] t \in [0, 1] t[0,1] 的缩放因子, π ( t ) \pi(t) π(t) 是时间步的分布。注意,公式 (2) 实际上等同于学习一个函数 ϵ θ \epsilon_\theta ϵθ,它估计添加到原始样本上的噪声 ϵ \epsilon ϵ,通过重参数化 ϵ θ ( x t , t ) = x t − α ( t ) ⋅ x θ ( x t , t ) σ ( t ) \epsilon_\theta(x_t, t) = \frac{x_t - \alpha(t) \cdot x_\theta(x_t, t)}{\sigma(t)} ϵθ(xt,t)=σ(t)xtα(t)xθ(xt,t)得到。Song 等人表明, ϵ θ \epsilon_\theta ϵθ 可以通过求解以下 PF-ODE 来从高斯噪声生成新的数据点:
d x t = [ f ( x t , t ) − 1 2 g 2 ( t ) ∇ log ⁡ p θ ( x t ) ] d t (3) d x_t = \left[ f(x_t, t) - \frac{1}{2} g^2(t) \nabla \log p_\theta(x_t) \right] dt \tag{3} dxt=[f(xt,t)21g2(t)logpθ(xt)]dt(3)
其中 f ( x t , t ) f(x_t, t) f(xt,t) g ( t ) g(t) g(t) 分别是 PF-ODE 的漂移函数和扩散函数,定义如下:
f ( x t , t ) = d log ⁡ α ( t ) d t ⋅ x t g 2 ( t ) = d σ ( t ) 2 d t − 2 ⋅ d log ⁡ α ( t ) d t ⋅ σ 2 ( t ) f(x_t, t) = \frac{d \log \alpha(t)}{dt} \cdot x_t \\ g^2(t) = \frac{d \sigma(t)^2}{dt} - 2 \cdot \frac{d \log \alpha(t)}{dt} \cdot \sigma^2(t) f(xt,t)=dtdlogα(t)xtg2(t)=dtdσ(t)22dtdlogα(t)σ2(t)
∇ log ⁡ p θ ( x t ) = − ϵ θ ( x t , t ) σ ( t ) \nabla \log p_\theta(x_t) = -\frac{\epsilon_\theta(x_t, t)}{\sigma(t)} logpθ(xt)=σ(t)ϵθ(xt,t)称为 p θ ( x t ) p_θ(x_t) pθ(xt)的分数函数。PF-ODE 可以使用神经 ODE 积分器 [7] 求解,该积分器通过给定的更新规则,如欧拉方法 [67] 或 Heun 解算器 [23],迭代地应用学习到的函数 ϵ θ \epsilon_\theta ϵθ
  通过学习条件去噪函数 ϵ θ ( x t , t , c ) \epsilon_\theta(x_t, t, c) ϵθ(xt,t,c) x θ ( x t , t , c ) x_\theta(x_t, t, c) xθ(xt,t,c) ,可以训练条件扩散模型从条件分布 p ( x 0 ∣ c ) p(x_0 | c) p(x0c)生成样本。在这种特定设置下,Classifier-Free Guidance (CFG) 已证明是一种非常有效的方法,可以更好地强制模型遵守条件,从而提高采样质量。CFG 是一种技术,它在训练期间以一定概率丢弃条件 c c c,并在推理时用以下线性组合替换条件噪声估计 ϵ θ ( x t , t , c ) \epsilon_\theta(x_t, t, c) ϵθ(xt,t,c)
ϵ θ ( x t , t , c ) = ω ⋅ ϵ θ ( x t , t , c ) + ( 1 − ω ) ⋅ ϵ θ ( x t , t , ∅ ) (4) \epsilon_\theta(x_t, t, c) = \omega \cdot \epsilon_\theta(x_t, t, c) + (1 - \omega) \cdot \epsilon_\theta(x_t, t, \emptyset) \tag{4} ϵθ(xt,t,c)=ωϵθ(xt,t,c)+(1ω)ϵθ(xt,t,)(4)
其中 ω > 0 \omega > 0 ω>0 被称为引导尺度。

4.2 一致性模型

  由于本文的方法受到一致性模型(Consistency Models,CM)的启发,作者回顾了一些这些模型的要素。CM 是一种新型的生成模型,主要用于学习一致性函数 f θ f_\theta fθ,该函数将位于公式(3)给出的PF-ODE轨迹上的任意样本 x t x_t xt 直接映射到原始样本 x 0 x_0 x0,同时确保任意 t ∈ [ ϵ , T ] t \in [\epsilon, T] t[ϵ,T] ϵ > 0 \epsilon > 0 ϵ>0时的自一致性属性:
f θ ( x t , t ) = f θ ( x t ′ , t ′ ) , ∀ ( t , t ′ ) ∈ [ ϵ , T ] 2 (5) f_\theta(x_t, t) = f_\theta(x_{t'}, t'), \forall(t, t') \in [\epsilon, T]^2 \tag{5} fθ(xt,t)=fθ(xt,t),(t,t)[ϵ,T]2(5)
为了确保一致性属性,Song等人提出对 f θ f_\theta fθ 进行如下参数化:
f θ ( x t , t ) = c skip ( t ) ⋅ x t + c out ( t ) ⋅ F θ ( x t , t ) , f_\theta(x_t, t) = c_{\text{skip}}(t) \cdot x_t + c_{\text{out}}(t) \cdot F_\theta(x_t, t) , fθ(xt,t)=cskip(t)xt+cout(t)Fθ(xt,t),
其中 F θ F_\theta Fθ 是使用神经网络进行参数化的, c skip c_{\text{skip}} cskip c out c_{\text{out}} cout 是可微函数【68, 42】。一致性模型可以从头开始训练(Consistency Training)或可用于蒸馏现有的DM(Consistency Distillation)。在这两种情况下,模型的目标是学习 f θ f_\theta fθ 以匹配目标函数 f θ − f_{\theta^-} fθ 的输出,其权重使用指数移动平均(EMA)进行更新,针对任意位于 PF-ODE 轨迹上的点 ( x t , x t ′ ) (x_t, x_{t'}) (xt,xt)
L = E x 0 ∼ p ( x 0 ) , t ∼ π ( t ) , ϵ ∼ N ( 0 , I ) [ ∥ f θ ( x t , t ) − f θ − ( x t ′ , t ′ ) ∥ 2 ] L = \mathbb{E}_{x_0 \sim p(x_0), t \sim \pi(t), \epsilon \sim \mathcal{N}(0, I)} \left[ \| f_\theta(x_t, t) - f_{\theta^-}(x_{t'}, t') \|^2 \right] L=Ex0p(x0),tπ(t),ϵN(0,I)[fθ(xt,t)fθ(xt,t)2]
  换句话说,给定使用公式(1)得到的噪声样本 x t x_t xt,其思想是强制 f θ ( x t , t ) = f θ − ( x t ′ , t ′ ) f_\theta(x_t, t) = f_{\theta^-}(x_{t'}, t') fθ(xt,t)=fθ(xt,t),其中 x t ′ x_{t'} xt 是使用相同噪声 ϵ \epsilon ϵ 和输入 x 0 x_0 x0 通过公式(1)进行一致性训练得到的或使用训练好的扩散模型 ϵ ∅ teacher \epsilon_{\emptyset}^{\text{teacher}} ϵteacher 和 ODE 求解器 Ψ \Psi Ψ 进行一致性蒸馏。一旦模型训练完毕,可以通过首先绘制噪声样本 x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0, I) xTN(0,I),然后应用学习到的函数 f θ f_\theta fθ 来理论上在一步内生成样本 x ~ 0 \tilde{x}_0 x~0。在实际操作中,需要进行多次迭代以生成令人满意的样本,因此估计的样本 x ~ 0 \tilde{x}_0 x~0 被使用学习到的函数 f θ f_\theta fθ 多次反复地添加噪声和去噪。

5 方法

  在这一部分中,作者介绍了基于文献中提出的若干理念构建的方法。接下来,作者将自己置于潜在扩散模型[57]的背景下进行图像生成,并将教师模型称为 ϵ ϕ teacher \epsilon_\phi^{\text{teacher}} ϵϕteacher,学生模型称为 ϵ θ student \epsilon_\theta^{\text{student}} ϵθstudent,训练图像称为 x 0 x_0 x0 及其未知分布为 p ( x 0 ) p(x_0) p(x0)。将 z 0 = ϵ ( x 0 ) z_0 = \epsilon(x_0) z0=ϵ(x0) 称为通过编码器 ϵ \epsilon ϵ 得到的相关潜变量。记时间步长的概率密度函数为 π \pi π,并设定 T = 1 T = 1 T=1。请注意,所提出的方法也可以直接应用于像素空间扩散模型。

5.1 蒸馏一个预训练的扩散模型

  该方法主要是为了实现一种快速、鲁棒且可靠的方案,能够轻松适应不同的使用场景。给定一组数据 x 0 ∈ X x_0 \in X x0X,使得 x 0 ∼ p ( x 0 ) x_0 \sim p(x_0) x0p(x0),以及通过编码器 E E E 得到的相关潜变量 z 0 = ϵ ( x 0 ) z_0 = \epsilon(x_0) z0=ϵ(x0),该方法的主要思想与扩散模型相似。给定由 α ( t ) \alpha(t) α(t) σ ( t ) \sigma(t) σ(t) 定义的噪声时间表,作者建议创建一个带噪声的潜变量样本 z t z_t zt,其中 t ∼ π ( t ) t \sim \pi(t) tπ(t),如公式(1)所示,并训练一个函数 f θ student f_\theta^{\text{student}} fθstudent 来预测原始样本 z 0 z_0 z0 的去噪版本 z ~ 0 \tilde{z}_0 z~0。与扩散模型的主要区别在于它不是使用 z 0 z_0 z0作为目标,作者建议利用教师模型的知识,使用属于教师模型学习的数据分布 p ϕ t e a c h e r ( z 0 ) p_{\phi}^{teacher} (z_0) pϕteacher(z0) 的样本。换句话说,就是使用教师模型和一个 ODE 求解器生成一个去噪的潜变量样本 z ~ 0 teacher ( z t ) \tilde{z}_{0}^{\text{teacher}}(z_t) z~0teacher(zt),它属于学习数据分布,并将其作为学生模型的目标。主要的蒸馏损失函数写作:
L distil = E z 0 , t , ϵ [ ∥ f student θ ( z t , t ) − z ~ 0 teacher ( z t ) ∥ 2 ] , ( 6 ) L_{\text{distil}} = \mathbb{E}_{z_0, t, \epsilon} \left[ \| f_{\text{student}}^\theta(z_t, t) - \tilde{z}_{0}^{\text{teacher}}(z_t) \|^2 \right] ,(6) Ldistil=Ez0,t,ϵ[fstudentθ(zt,t)z~0teacher(zt)2](6)
其中 π ( t ) \pi(t) π(t) 表示时间步长的分布, z 0 ~ teacher ( z t ) \tilde{z_0}^{\text{teacher}}(z_t) z0~teacher(zt) 是通过在教师模型 ϵ ϕ t e a c h e r \epsilon_\phi^{teacher} ϵϕteacher 上从 z t = α ( t ) ⋅ z 0 + σ ( t ) ⋅ ϵ z_t = \alpha(t) \cdot z_0 + \sigma(t) \cdot \epsilon zt=α(t)z0+σ(t)ϵ 开始运行ODE 求解器 Ψ \Psi Ψ 的若干步得到的。类似的思想在Sauer等人的方法中也有应用,但作者生成完全合成的样本,这意味着样本 z t z_t zt 是纯噪声, z t ∼ N ( 0 , I ) z_t \sim \mathcal{N}(0, I) ztN(0,I)。相反,在本文的方法中,作者假设允许 z t z_t zt 保留一些来自真实编码样本 z 0 z_0 z0 的信息可以增强蒸馏过程。如Luo等人所述,在蒸馏条件扩散模型时,我们还与教师模型一起执行无分类器指导(Classifier-Free Guidance,CFG),以更好地确保模型遵守条件。这项技术实际上显著提高了学生生成的样本质量。此外,它消除了在学生推理过程中执行 CFG 的需要,进一步减少了每步计算量的一半。训练期间使用的指导尺度 ω \omega ω 的值作者在消融实验中进行了展示,但在实践中, ω \omega ω [ ω min ⁡ , ω max ⁡ ] [\omega_{\min}, \omega_{\max}] [ωmin,ωmax] 中均匀采样,其中 0 ≤ ω min ⁡ ≤ ω max ⁡ 0 \leq \omega_{\min} \leq \omega_{\max} 0ωminωmax。如第4.2节所述,作者说他的方法与现有的一致性模型相似。但是他们不是依赖学生模型的先前实例来估计 PF-ODE 的起源,而是直接使用教师模型结合 ODE 求解器生成目标。并且观察到这些要素增强了训练过程的稳定性。

5.2 时间步长采样

  方法的核心在于时间步长概率密度函数 π ( t ) \pi(t) π(t) 的选择。根据【67】中介绍的连续建模,扩散模型(DMs)被训练在任意连续时间 t t t 上从潜在样本 z t z_t zt 中去除噪声。然而,由于我们目标是在推理时实现少步数数据生成(通常为1-4步),学习的函数 ϵ θ \epsilon_\theta ϵθ 仅在少数离散的时间步长 { t i } i = 1 K \{t_i\}_{i=1}^K {ti}i=1K 上进行评估。

  为了解决这个问题并确保蒸馏过程集中于最相关的时间步长,我们建议在区间 [ 0 , 1 ] [0, 1] [0,1] 内选择 K(通常为16、32或64)个均匀分布的时间步长,并根据概率质量函数 π ( t ) \pi(t) π(t) 为每个时间步长分配概率。我们选择 π ( t ) \pi(t) π(t) 作为由一系列权重 { β i } i = 1 K \{\beta_i\}_{i=1}^K {βi}i=1K 控制的高斯分布混合:
π ( t ) = 1 2 π σ 2 ∑ i = 1 K β i exp ⁡ ( − ( t − μ i ) 2 2 σ 2 ) , (7) \pi(t) = \frac{1}{\sqrt{2\pi\sigma^2}} \sum_{i=1}^K \beta_i \exp \left( - \frac{(t - \mu_i)^2}{2\sigma^2} \right) , \tag{7} π(t)=2πσ2 1i=1Kβiexp(2σ2(tμi)2),(7)

其中每个高斯分布的均值由 { μ i = i K } i = 1 K \{\mu_i = \frac{i}{K}\}_{i=1}^K {μi=Ki}i=1K 控制,方差固定为 σ = 0.5 K 2 \sigma = \sqrt{\frac{0.5}{K^2}} σ=K20.5 。这种方法使得在蒸馏教师模型时,只有少数 K 个离散时间步长被采样,而不是连续区间 [ 0 , 1 ] 3 [0, 1]^3 [0,1]3。此外,分布 π \pi π 定义为在 K 个选定的时间步长中,用于1、2和4步生成的4个时间步长被过采样(通常我们设定 β i > 0 \beta_i > 0 βi>0 如果 i ∈ [ K 4 , K 2 , 3 K 4 , K ] i \in [\frac{K}{4}, \frac{K}{2}, \frac{3K}{4}, K] i[4K,2K,43K,K] β i = 0 \beta_i = 0 βi=0)。与其Sauer等人的方法相比,本文不仅关注这4个时间步长,因为我们注意到这可能会导致生成样本的多样性减少,对此,作者进行了消融研究验证。实际上,作者注意到热身阶段对训练过程是有益的。因此,决定首先对对应于最少噪声增加的时间步长施加更高的概率,通过设定 β K / 4 = β K / 2 = 0.5 \beta_{K/4} = \beta_{K/2} = 0.5 βK/4=βK/2=0.5 和其他 β i = 0 \beta_i = 0 βi=0。然后我们逐渐将概率质量转移到全噪声,以促进单步生成,同时仍然对目标的4个时间步长进行过采样,设定严格正值的 β i \beta_i βi,其中 i ≡ 0 [ K 4 ] i \equiv 0[\frac{K}{4}] i0[4K],其他 β i = 0 \beta_i = 0 βi=0。图2中展示了 K=32 的 π \pi π 示例。如图所示, [ 0 , 1 ] [0, 1] [0,1] 区间被分为32个时间步长。在热身阶段,概率质量将更高的概率分配给时间步长 [ 0.25 , 0.5 ] [0.25, 0.5] [0.25,0.5] 以简化蒸馏过程。随着训练的进行,概率质量函数逐渐向全噪声转移,以促进单步生成,同时始终为4个时间步长 [ 0.25 , 0.5 , 0.75 , 1 ] [0.25, 0.5, 0.75, 1] [0.25,0.5,0.75,1] 分配更高的概率。时间步长分布的影响在第6.2节中进一步讨论。
在这里插入图片描述

5.3对抗性目标

  为了进一步提高样本的质量,并且由于一些文献中提出的几项工作证明了实现几步图像生成的效率很高,于是作者决定引入对抗性目标。核心思想是训练学生模型生成与真实数据分布 p ( x 0 ) p(x_0) p(x0) 难以区分的样本。为此,我们提出训练一个判别器 D ν D_\nu Dν 来区分生成样本 x ~ 0 \tilde{x}_0 x~0 与真实样本 x 0 ∼ p ( x 0 ) x_0 \sim p(x_0) x0p(x0)。如Sauer和Lin等人所建议的,我们也直接在潜在空间中应用判别器。这种方法避免了使用VAE解码样本的必要性,这一过程在Sauer等人的文章中有所概述,被证明是昂贵的并且阻碍了方法在高分辨率图像上的可扩展性。

  借鉴Sauer和Lin等人提出的文章中的灵感,作者提出了一种方法,在这种方法中,一步的学生预测 z ~ 0 \tilde{z}_0 z~0 和输入的潜变量 z 0 z_0 z0 按照教师的噪声计划重新添加噪声。这个过程使用一个时间步长 t ′ t' t,它从集合 [ 0.01 , 0.25 , 0.5 , 0.75 ] [0.01, 0.25, 0.5, 0.75] [0.01,0.25,0.5,0.75] 中均匀选择。样本首先通过冻结的教师模型,然后通过判别器,得到真假预测。当使用 UNet 架构作为教师模型时,我们的方法专注于仅使用 UNet 的编码器部分,生成更压缩的潜变量表示,并进一步减少判别器的参数数量。我们仔细选择特定的时间步长,以使判别器能够有效地根据高频和低频细节区分样本,如Lin等人所讨论的。需要注意的是,在本文提出的方法中,判别器是唯一需要训练的组件,而教师模型保持冻结状态。对抗损失 L a d v L_{adv} Ladv 和判别器损失 L discriminator L_{\text{discriminator}} Ldiscriminator 写作:
L adv = 1 2 E z 0 , t ′ , ϵ [ ∣ ∣ D ν ( f θ s t u d e n t ( z t ′ , t ′ ) ) − 1 ∣ ∣ 2 ] , L discriminator = 1 2 E z 0 , t ′ , ϵ [ ∥ D ν ( z 0 ) − 1 ∥ 2 + ( D ν ( f θ s t u d e n t ( z t ′ , t ′ ) ) − 0 ) 2 ] (8) L_{\text{adv}} = \frac{1}{2} \mathbb{E}_{z_0,t',\epsilon} \left[ || D_\nu(f_\theta^{student}(z_{t'}, t')) - 1 ||^2 \right], \\ L_{\text{discriminator}} = \frac{1}{2} \mathbb{E}_{z_0,t',\epsilon} \left[ \|D_\nu(z_0) - 1\|^2 + \left( D_\nu(f_\theta^{student}(z_{t'}, t')) - 0 \right)^2 \right] \tag{8} Ladv=21Ez0,t,ϵ[∣∣Dν(fθstudent(zt,t))1∣2]Ldiscriminator=21Ez0,t,ϵ[Dν(z0)12+(Dν(fθstudent(zt,t))0)2](8)
其中 ν \nu ν 表示判别器参数。我们选择这些特定的损失是因为它们在训练过程中表现出可靠性和稳定性,如我们的实验所示。在消融研究中,作者强调了所选择的对抗损失 L adv L_{\text{adv}} Ladv 的影响。在实践中,鉴别器的架构被设计为一个简单的卷积神经网络(CNN),其步幅为2,核大小为4,SiLU激活和组归一化。

5.4 分布匹配

  受Yin等人工作的启发,作者还提出引入分布匹配蒸馏(DMD)损失,以确保生成的样本紧密反映教师模型学习到的数据分布。具体来说,这涉及最小化学生模型的样本分布 p θ student p_\theta^{\text{student}} pθstudent 和教师模型学习到的数据分布 p ∅ teacher p_\emptyset^{\text{teacher}} pteacher 之间的Kullback-Leibler(KL)散度:

L DMD = D KL ( p θ s t u d e n t ∣ ∣ p ∅ t e a c h e r ) = E z 0 , t , ϵ [ − ( log ⁡ p ∅ t e a c h e r ( f θ s t u d e n t ( z t , t ) ) − log ⁡ p θ s t u d e n t ( f θ s t u d e n t ( z t , t ) ) ) ] (9) L_{\text{DMD}} = D_{\text{KL}}(p_\theta^{student}|| p_\emptyset^{teacher}) = \\ \mathbb{E}_{z_0,t,\epsilon} \left[ -\left( \log p_\empty^{teacher} \left( f_\theta^{student}(z_t, t) \right)- \log p_\theta^{student} \left( f_\theta^{student}(z_t, t) \right) \right) \right] \tag{9} LDMD=DKL(pθstudent∣∣pteacher)=Ez0,t,ϵ[(logpteacher(fθstudent(zt,t))logpθstudent(fθstudent(zt,t)))](9)
对KL散度关于学生模型参数 θ \theta θ 求导得到以下更新规则:
∇ θ L DMD = E z 0 , t , ϵ [ − ( s t e a c h e r ( f θ s t u d e n t ( z t , t ) ) − s s t u d e n t ( f θ s t u d e n t ( z t , t ) ) ) ∇ f θ s t u d e n t ( z t , t ) ] , \nabla_\theta L_{\text{DMD}} = \\ \mathbb{E}_{z_0,t,\epsilon} \left[ -\left( s^{teacher}\left( f_\theta^{student}(z_t, t) \right)- s^{student}\left( f_\theta^{student}(z_t, t) \right) \right) \nabla f_\theta^{student}(z_t, t) \right], θLDMD=Ez0,t,ϵ[(steacher(fθstudent(zt,t))sstudent(fθstudent(zt,t)))fθstudent(zt,t)]
其中 s teacher s^{\text{teacher}} steacher s student s^{\text{student}} sstudent 分别是教师和学生分布的得分函数。

  受Yin等人的启发,单步学生预测使用均匀采样的时间步长 t ′ ′ ∼ U ( [ 0 , 1 ] ) t'' \sim U([0, 1]) t′′U([0,1]) 和教师的噪声计划重新加噪。新的有噪声样本通过冻结的教师模型以获取教师分布的得分函数: s t e a c h e r ( f θ s t u d e n t ( z t ′ ′ , t ′ ′ ) ) = − ( ϵ ∅ t e a c h e r ( x t ′ ′ , t ′ ′ ) / σ ( t ′ ′ ) ) s^{teacher}(f_\theta^{student}(z_{t^{\prime\prime}}, t^{\prime\prime}))=-(\epsilon_\empty^{teacher}(x_{t^{\prime\prime}},t^{\prime\prime})/\sigma(t^{\prime\prime})) steacher(fθstudent(zt′′,t′′))=(ϵteacher(xt′′,t′′)/σ(t′′))。在我们的方法中,我们利用学生模型来获取学生分布的得分函数,而不是像Yin等人所提到的专用扩散模型。这一选择显著减少了可训练参数的数量和计算成本。

5.5 模型训练

  在追求鲁棒性和多样性的同时,我们还旨在设计一个可训练参数最少的模型,因为它涉及加载计算密集型函数(教师模型和学生模型)。为此,我们提出依赖参数高效方法LoRA并将其应用于我们的学生模型。通过这种方式,我们大幅减少了参数数量并加快了训练过程。

  简而言之,我们的学生模型被训练以最小化蒸馏损失(Eq. (6))、对抗性损失(Eq. (8))和分布匹配损失(Eq. (9))的加权组合:
L = L distil + λ adv L adv + λ DMD L DMD (10) L = L_{\text{distil}} + \lambda_{\text{adv}} L_{\text{adv}} + \lambda_{\text{DMD}} L_{\text{DMD}} \tag{10} L=Ldistil+λadvLadv+λDMDLDMD(10)
在这里插入图片描述

  训练过程详见算法1,并在下图中进行了说明。具体来说,首先从未知数据分布中随机选取一个样本 x 0 ∼ p ( x 0 ) x_0 \sim p(x_0) x0p(x0)。然后使用编码器 ϵ \epsilon ϵ 对该样本进行编码,得到相应的潜在样本 z 0 z_0 z0。根据第5.2节中详述的时间步长概率质量函数 π \pi π,绘制时间步长 t t t,并使用Eq. (1) 创建一个有噪声样本 z t z_t zt。然后使用教师模型 ϵ φ teacher \epsilon_\varphi^{\text{teacher}} ϵφteacher 和ODE求解器 Ψ \Psi Ψ 来求解PF-ODE,从而生成一个属于教师模型学习到的分布的合成样本 z ~ 0 teacher \tilde{z}_0^{\text{teacher}} z~0teacher。同时,学生模型 f θ s t u d e n t f_\theta^{student} fθstudent 被用来在单步内生成一个去噪样本 z ~ 0 student = f θ student ( z t , t ) \tilde{z}_0^{\text{student}} = f_\theta^{\text{student}}(z_t, t) z~0student=fθstudent(zt,t)。然后,根据Eq. (6) 计算蒸馏损失。接着,重新对单步学生预测 z ~ 0 student \tilde{z}_0^{\text{student}} z~0student 和输入的潜在样本 z 0 z_0 z0 进行加噪,并按第5.3节所述计算对抗性损失。最后,对于分布匹配,再次取单步学生预测 z ~ 0 student \tilde{z}_0^{\text{student}} z~0student,并使用均匀采样的时间步长 t ∼ U ( [ 0 , 1 ] ) t \sim U([0, 1]) tU([0,1]) 对其进行加噪。新的有噪声样本通过教师模型获取教师得分函数 s teacher s^{\text{teacher}} steacher,同时我们使用学生模型(而不是Yin等人的专用扩散模型)获取学生得分函数 s student s^{\text{student}} sstudent。然后按第5节所述计算分布匹配损失。
在这里插入图片描述

  总的来说,我们提出的方法仅依赖少量参数的训练。这是通过将LoRA应用于学生模型,利用冻结的教师模型进行对抗性方法,并直接使用学生去噪器而不是引入一个新的扩散模型来计算分布匹配损失的假分数实现的。这种方法不仅大大减少了参数数量,还加快了训练过程。

实验

  作者将所提出的方法与现有的蒸馏方法在文本到图像生成中的效果进行比较。在本节中,我们将我们的蒸馏方法应用于公开可用的SD1.5模型,并在COCO2014和COCO2017数据集上报告FID和CLIP得分。模型在LAION数据集上进行训练,我们选择美学评分高于6的样本,并使用CogVLM提示词生成合成图像。对于COCO2017,我们依赖于[45]中提出的评估方法,并从验证集中选择5000个提示来生成合成图像。对于COCO2014,采用【22】中提出的评估协议,从验证集中选择30000个提示词。然后,计算与各自验证集中真实图像的FID,COCO2017验证集包含5000张图像,COCO2014验证集包含40504张图像。模型在2个H100-80Gb GPU上进行20k次迭代训练,批量大小为4,学习率为1e-5,使用Adam优化器训练学生模型和判别器。使用第5.2节中详述的时间步长分布 π ( t ) \pi(t) π(t),其中 K = 32 K=32 K=32,每5000次迭代进行一次相移。从 λ adv = 0 \lambda_{\text{adv}}=0 λadv=0 λ DMD = 0 \lambda_{\text{DMD}}=0 λDMD=0开始,并在每次更改时间步长分布时逐步增加,最终值分别设为0.3和0.7。指导尺度 ω \omega ω U ( [ 3 , 13 ] ) U([3, 13]) U([3,13])中采样。学生模型的权重全部用教师模型的权重初始化。

  表1和表2给出了定量比较结果。本文方法在COCO2017和COCO2014上分别达到了22.6和12.27的FID,仅需要2个NFE(网络功能评估)即可达到少步数图像生成的SOTA(最新技术)。在COCO2017上,该方法在2和4个NFE下分别达到了0.306和0.311的CLIP得分。另外,该方法只需要训练2640万参数(相对于900M的教师参数)和仅26个H100 GPU小时的训练时间。这与许多竞争对手需要训练整个学生UNet架构(涉及数亿参数)的情况形成鲜明对比。在图4中提供了1、2和4个NFE下生成样本的视觉可视化。
在这里插入图片描述
在这里插入图片描述

  以上就是对本篇论文的解读,如有任何问题欢迎留言,批评指正!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/49680.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[C++][STL源码剖析] 详解AVL树的实现

目录 1.概念 2.实现 2.1 初始化 2.2 插入 2.2.1 旋转(重点) 左单旋 右单旋 双旋 2.❗ 双旋后,对平衡因子的处理 2.3 判断测试 完整代码: 拓展:删除 1.概念 二叉搜索树虽可以缩短查找的效率,但…

遇到Websocket就不会测了?别慌,学会这个Jmeter插件轻松解决....

websocket 是一种双向通信协议,在建立连接后,websocket服务端和客户端都能主动向对方发送或者接收数据,而在http协议中,一个request只能有一个response,而且这个response也是被动的,不能主动发起。 websoc…

【研路导航】保研英语面试高分攻略,助你一路过关斩将

面试攻略之 千锤百炼英语口语 写在前面 在保研面试中,英语口语往往是让许多同学感到头疼的一部分。如何在面试中展现出自信和流利的英语表达能力,是我们今天要探讨的主题。以下是一些有效的英语口语练习方法和常见题型解析,帮助你在保研面试…

LoRA:低秩自适应

LoRA:低秩自适应 本章节是对轻松上手微调大语言模型——QLORA篇中提到的LoRA的原理解释。 背后动机 现今模型的参数量变得越来越大,对预训练模型进行全微调变得越来越不可行。为了解决这个问题有了LoRA(Low-Rank Adaption)的诞生。将可训练…

Nginx制作下载站点

使用nginx制作一个类似nginx官网的下载站点 如何制作一个下载站点,首先需要ngx_http_autoindex_module模块 该模块处理以斜杠(“/”)结尾的请求,并生成目录列表。 nginx编译的时候会自动加载该模块,但是该模块默认是关闭的,需要使用下来指令…

3 FreeRTOS移植(从FREERTOS官网移植进自己的工程)

3 FreeRTOS移植 1 获取FreeRTOS源码(熟悉)1.1 介绍源码内容1.2 FreeRTOS内核1.2.1 Demo文件夹1.2.2 Source文件夹1.2.2.1 portable文件夹 2 FreeRTOS手把手移植(掌握)(重要)2.1 移植步骤 3 系统配置文件说明…

GraphHopper-map-navi_路径规划、导航(web前端页面版)

文章目录 一、项目地址二、踩坑环境三、问题记录3.1、graphhopper中地图问题3.1.1. getOpacity不存在的问题3.1.2. dispatchEvent不存在的问题3.1.3. vectorLayer.set(background-maplibre-layer, true)不存在set方法3.1.4. maplibre-gl.js.map不存在的问题3.1.5. Uncaught Ref…

学习记录:ESP32控制舵机 FREERTOS BLE

控制舵机 PWM信号 PWM信号是一种周期性变化的方波信号,它有两个关键参数: 周期(Period):一个完整的PWM信号的时间长度,通常用秒(s)或毫秒(ms)表示。占空比…

FFmpeg解复用器如何从封装格式中解析出不同的音视频数据

目录 1、ffmpeg介绍 2、FFMPEG的目录结构 3、FFmpeg的格式封装与分离 3.1、数据结构 3.2、封装和分离(muxer和demuxer) 3.2.1、Demuxer流程 3.2.2、Muxer流程 4、总结 4.1、播放器 4.2、转码器 C++软件异常排查从入门到精通系列教程(专栏文章列表,欢迎订阅,持续…

微服务上(黑马)

文章目录 微服务011 认识微服务1.1 单体架构1.2 微服务1.3 SpringCloud 2 微服务拆分2.1 熟悉黑马商城2.2 服务拆分原则2.2.1.什么时候拆2.2.2.怎么拆 2.3 拆分服务2.3.1 拆分商品管理功能模块2.3.2 拆分购物车功能模块 2.4 远程调用2.4.1 RestTemplate2.4.2.远程调用 2.5 总结…

顺序表算法题

在学习了顺序表专题后,了解的顺序表的结构以及相关概念后就可以来试着完成一些顺序表的算法题了,在本篇中将对三道顺序表相关的算法题进行讲解,希望能对你有所帮助,一起加油吧!!! 1.移除元素 2…

nginx转发netty长链接(nginx负载tcp长链接配置)

首先要清楚一点,netty是长链接是tcp连接不同于http中负载在http中配置server监听。长连接需要开启nginx的stream模块(和http是并列关系) 安装nginx时注意开启stream,编译时加上参数 --with-stream (其他参数根据自己所需来加) …

脊髓损伤的小伙伴锻炼贴士

Hey小伙伴们~👋 今天要跟大家聊一个超燃又超温馨的话题!🌟 对于我们脊髓损伤的小伙伴们来说,保持身体活力,不仅是健康的小秘诀,更是拥抱美好生活的超能量哦!💪 #脊髓损伤# 首先&…

Cache 替换策略--PLRU算法详解

一、引言 LRU(Least Recently Used)是 cache 的经典替换策略之一,但当 Cache 的路数比较大时(多路组相连结构),实现 LRU 的硬件开销就会变得很大。现代处理器一般会考虑使用 PLRU(pseudo-LRU&a…

一文带你搞懂C++运算符重载

7. C运算符重载 C运算符重载 什么是运算符重载 运算符重载赋予运算能够操作自定义类型。 运算符重载前提条件: 必定存在一个自定义类型 运算符重载实质: 就是函数调用 友元重载 类重载 在同一自定义类型中,一个运算符只能被重载一次 C重载只能重载…

vue element-ui日期控件传参

前端&#xff1a;Vue element-ui <el-form-item label"过期时间" :rules"[ { required: true, message: 请选择过期时间, trigger: blur }]"><el-date-picker v-model"form.expireTime" type"date" format"yyyy-MM-dd&…

【C++】透析类和对象(下)

有不懂的可以翻阅我之前文章&#xff01; 个人主页&#xff1a;CSDN_小八哥向前冲 所属专栏&#xff1a;CSDN_C入门 目录 拷贝构造函数 运算符重载 赋值运算符重载 取地址运算符重载 const成员函数 取地址重载 再探构造函数 初始化列表 类型转换 static成员 友元 内…

MySQL查询执行(二):order by工作原理

假设你要查询城市是“杭州”的所有人名字&#xff0c; 并且按照姓名排序返回前1000个人的姓名、 年龄。 假设这个表的部分定义是这样的&#xff1a; -- 创建表t CREATE TABLE t (id int(11) NOT NULL,city varchar(16) NOT NULL,name varchar(16) NOT NULL,age int(11) NOT N…

Docker 搭建Elasticsearch详细步骤

本章教程使用Docker搭建Elasticsearch环境。 一、拉取镜像 docker pull docker.elastic.co/elasticsearch/elasticsearch:8.8.2二、运行容器 docker run -d --name elasticsearch -p 9200:9200 -p 9300:9300 -e "discovery.type=single-n