(CVPR-2024)通过多阶段框架和定制的多解码器架构提高扩散模型的训练效率

通过多阶段框架和定制的多解码器架构提高扩散模型的训练效率

Paper Title:Improving Training Efficiency of Diffusion Models via Multi-Stage Framework and Tailored Multi-Decoder Architecture

Paper是密歇根大学发表在CVPR 2024的工作

Paper地址

Code地址

Abstract

扩散模型作为强大的深度生成工具,在各种应用中表现出色。它们通过两个步骤进行操作:将噪声引入训练样本,然后使用模型将随机噪声转换为新样本(例如图像)。然而,它们出色的生成性能受到训练和采样速度缓慢的阻碍。这是因为需要跟踪广泛的正向和反向扩散轨迹,并使用具有跨多个时间步(即噪声级别)的大量参数的大型模型。为了应对这些挑战,我们提出了一个受经验发现启发的多阶段框架。这些观察结果表明,使用针对每个时间步量身定制的不同参数,同时保留跨所有时间步共享的通用参数的优势。我们的方法涉及将时间间隔划分为多个阶段,其中我们使用自定义多解码器 U-net 架构,将时间相关模型与通用共享编码器相结合。我们的框架能够高效分配计算资源并减轻阶段间干扰,从而大大提高训练效率。大量数值实验证实了我们框架的有效性,展示了对三种最先进的扩散模型(包括大规模潜在扩散模型)的显著训练和采样效率提升。此外,我们的消融研究说明了我们框架中两个重要组件的影响:(i)用于阶段划分的新型时间步长聚类算法,以及(ii)创新的多解码器 Unet 架构,无缝集成通用和定制超参数。

1. Introduction

图1

图 1. 三种扩散模型架构概览:(a) 统一、(b) 分离和 © 我们提出的多阶段架构。与 (a) 和 (b) 相比,我们的方法提高了采样质量,并显著提高了训练效率,如 FID 分数及其相应的训练迭代 (d) 所示。

最近,扩散模型作为强大的深度生成建模工具取得了显著进展,在无条件图像生成 [1, 2]、条件图像生成 [3, 4]、图像到图像的翻译 [5–7]、文本到图像的生成 [8–10]、逆问题求解 [11–14]、视频生成 [15, 16] 等各种应用中都表现出色。这些模型采用一种训练过程,包括向训练样本中不断注入噪声(“扩散”),然后通过由模型学习到的数据分布的“得分函数”引导的逆扩散过程对随机噪声实例进行变换,从而生成新样本,如图像。此外,最近的研究表明,与其他类型的生成模型相比,这些扩散模型具有优化稳定性和模型可重复性 [17]。然而,尽管扩散模型具有出色的生成能力,但其训练和采样速度较慢,这阻碍了其在需要实时生成的应用中的使用 [1, 2]。这些缺点主要源于需要跟踪大量的正向和反向扩散轨迹,以及管理跨多个时间步的具有众多参数的大型模型(即扩散噪声水平)。

在本文中,我们基于两个关键观察结果来解决这些挑战:(i)当前的扩散模型中存在大量参数冗余,(ii)由于不同噪声水平的梯度不同,它们的训练效率低下。具体而言,我们发现训练扩散模型在高噪声水平下需要较少的参数来准确学习得分函数,而在低噪声水平下则需要较大的参数。此外,我们还观察到,在学习得分函数时,不同噪声水平下分布的不同形状会导致不同的梯度,这似乎会减慢由梯度下降驱动的训练过程。

基于这些见解,我们提出了一个多阶段框架,该框架包含两个关键组件:(i) 多解码器 U-net 架构,以及 (ii) 一种新的分区算法,用于将时间步长(噪声级别)聚类到不同的阶段。就我们的新架构而言,我们设计了一个多解码器 U-Net,它包含一个在所有间隔中共享的通用编码器和针对每个时间阶段定制的单个解码器;参见图 1 © 中的说明。这种方法结合了通用架构和阶段特定架构的优点,比整个训练过程的统一架构效率高得多 [1, 2, 18](图 1 (a))。此外,与以前完全分离每个子间隔架构的方法 [19–22](图 1 (b))相比,我们的方法可以有效缓解过度拟合,从而提高训练效率。另一方面,在划分网络的训练阶段时,我们设计了一种旨在对时间步长进行分组的算法。这是通过最小化训练目标中每个聚类内的功能距离并利用最佳去噪器公式来实现的 [18]。通过整合这两个关键组件,我们的框架能够有效分配计算资源(例如,U-net 参数)和阶段定制参数化。在我们广泛的数值实验(第 5 节)中,我们表明我们的框架有效地提高了训练和采样效率。这些实验是在不同的基准数据集上进行的,与三种最先进的(SOTA)扩散模型架构相比,使用我们的框架可以显著加速。总结一下,这项工作的主要贡献可以概括如下:

  • 确定两个主要的低效率来源。我们确定了导致在不同时间步骤中训练扩散模型效率低下的两个主要来源:(i)模型容量要求的显著变化,以及(ii)梯度不相似性。因此,使用统一网络无法满足不同时间步骤不断变化的要求。
  • 新的多阶段框架。我们引入了一种新的多阶段架构,如图 1 © 所示。我们通过将时间间隔划分为多个阶段来解决这两个低效率来源,其中我们采用定制的多解码器 U-net 架构,将时间相关模型与通用共享编码器相结合。
  • 提高训练和采样效率。使用与无条件图像生成相当的计算资源,我们证明我们的多阶段方法可以提高所有 SOTA 方法的 Frechet Inception Distance (FID) 分数。例如,在 CIFAR-10 数据集 [23] 上,我们的方法将 DPM-Solver [24] 的 FID 从 2.84 提高到 2.37,并将 EDM [18] 的 FID 从 2.05(我们的训练结果)提高到 1.96。此外,在 CelebA 数据集 [25] 上,在保持相似生成质量的同时,我们的方法显著降低了 EDM 所需的训练 FLOPS(82%),降低了潜在扩散模型 (LDM) [8] 所需的训练 FLOPS (30%)。

组织。在第 2 部分中,我们提供了前期工作和相关文献的概述。在第 3 部分中,我们介绍了促使我们提出多阶段框架的观察和分析,并证明了其发展的合理性。在第 4 部分中,我们描述了我们提出的扩散模型多阶段框架,概述了两个核心组件。最后,在第 5 部分中,我们提供了数值实验的结果,验证了所提出的多阶段方法的有效性。

2. Preliminaries & Related Work

在本节中,我们首先回顾扩散模型的基本原理 [1, 2, 18]。随后,我们深入研究旨在通过划分时间步长间隔来提高扩散模型训练和效率的先前方法。最后,我们回顾了先前的研究,这些研究显著减少了所需的采样迭代次数。

扩散模型背景。设 x 0 ∈ R n x_0 \in \mathbb{R}^n x0Rn 表示来自数据分布 p data ( x ) p_{\text{data}}(\boldsymbol{x}) pdata(x) 的样本。扩散模型在前向和逆向过程中操作。前向过程通过高斯核逐渐将数据 x 0 \boldsymbol{x}_0 x0 扰动为一个噪声版本 x t ∈ [ 0 , 1 ] \boldsymbol{x}_{t \in[0,1]} xt[0,1]。该过程可以表示为形式为 d x = x t f ( t ) d t + g ( t ) d w t \mathrm{d} \boldsymbol{x}=\boldsymbol{x}_t f(t) \mathrm{d} t+g(t) \mathrm{d} \boldsymbol{w}_t dx=xtf(t)dt+g(t)dwt 的随机微分方程(SDE),其中 f ( t ) f(t) f(t) g ( t ) g(t) g(t) 分别是漂移和扩散系数,对应于预定义的噪声调度。 w t ∈ R n \boldsymbol{w}_t \in \mathbb{R}^n wtRn 是标准维纳过程。在前向 SDE 下,扰动核由定义为 p t ( x t ∣ x 0 ) = N ( x t ; s t x 0 , s t 2 σ t 2 I ) p_t\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)=\mathcal{N}\left(\boldsymbol{x}_t ; s_t \boldsymbol{x}_0, s_t^2 \sigma_t^2 \mathbf{I}\right) pt(xtx0)=N(xt;stx0,st2σt2I) 的条件分布给出,其中
s t = exp ⁡ ( ∫ 0 t f ( ξ ) d ξ ) , and  σ t = ∫ 0 t g 2 ( ξ ) s ξ 2 d ξ .  s_t=\exp \left(\int_0^t f(\xi) \mathrm{d} \xi\right), \text { and } \sigma_t=\sqrt{\int_0^t \frac{g^2(\xi)}{s_{\xi}^2} \mathrm{~d} \xi} \text {. } st=exp(0tf(ξ)dξ), and σt=0tsξ2g2(ξ) dξ

参数 s t s_t st σ t \sigma_t σt 被设计为使得: (i) 当 t = 0 t=0 t=0 时数据分布被近似估计,并且 (ii) 当 t = 1 t=1 t=1 时获得接近标准高斯分布。扩散模型的目标是学习对应的逆向 SDE,定义为 d x = [ f ( t ) x t − g 2 ( t ) ∇ x t log ⁡ p t ( x t ) ] d t + g ( t ) d w ‾ \mathrm{d} \boldsymbol{x}=\left[f(t) \boldsymbol{x}_t-g^2(t) \nabla_{\boldsymbol{x}_t} \log p_t\left(\boldsymbol{x}_t\right)\right] \mathrm{d} t+g(t) \mathrm{d} \overline{\boldsymbol{w}} dx=[f(t)xtg2(t)xtlogpt(xt)]dt+g(t)dw,其中 w ‾ ∈ R n \overline{\boldsymbol{w}} \in \mathbb{R}^n wRn 是在时间上向后运行的标准维纳过程, ∇ x t log ⁡ p t ( x t ) \nabla_{\boldsymbol{x}_t} \log p_t\left(\boldsymbol{x}_t\right) xtlogpt(xt) 是 (Stein) 评分函数。在实践中,评分函数使用神经网络 ϵ θ : R n × [ 0 , 1 ] → R n \boldsymbol{\epsilon}_{\boldsymbol{\theta}}: \mathbb{R}^n \times[0,1] \rightarrow \mathbb{R}^n ϵθ:Rn×[0,1]Rn 参数化为 θ \boldsymbol{\theta} θ 来近似,它可以通过去噪评分匹配技术 [26] 进行训练,如下所示:
min ⁡ θ E [ ω ( t ) ∥ ϵ θ ( x t , t ) + s t σ t ∇ x t log ⁡ p t ( x t ∣ x 0 ) ∥ 2 2 ] , \min _{\boldsymbol{\theta}} \mathbb{E}\left[\omega(t)\left\|\boldsymbol{\epsilon}_{\boldsymbol{\theta}}\left(\boldsymbol{x}_t, t\right)+s_t \sigma_t \nabla_{\boldsymbol{x}_t} \log p_t\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)\right\|_2^2\right], θminE[ω(t)ϵθ(xt,t)+stσtxtlogpt(xtx0)22],
也可以写成 min ⁡ θ E [ ω ( t ) ∥ ϵ θ ( x t , t ) \min _{\boldsymbol{\theta}} \mathbb{E}\left[\omega(t) \| \boldsymbol{\epsilon}_{\boldsymbol{\theta}}\left(\boldsymbol{x}_t, t\right)\right. minθE[ω(t)ϵθ(xt,t) ϵ ∥ 2 ] + C \left.\boldsymbol{\epsilon} \|^2\right]+C ϵ2]+C,其中期望是关于 t ∼ [ 0 , 1 ] t \sim[0,1] t[0,1] x t ∼ p t ( x t ∣ x 0 ) \boldsymbol{x}_t \sim p_t\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right) xtpt(xtx0) x 0 ∼ p data ( x ) \boldsymbol{x}_0 \sim p_{\text{data}}(\boldsymbol{x}) x0pdata(x),和 ϵ ∼ N ( 0 , I ) \boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) ϵN(0,I) 计算的。这里, C C C 是与 θ \boldsymbol{\theta} θ 无关的常数, ω ( t ) \omega(t) ω(t) 是表示损失权重随时间 t t t 变化的标量。在 DDPM [1] 中,它被简化为 ω ( t ) = 1 \omega(t)=1 ω(t)=1。一旦参数化评分函数 ϵ θ \epsilon_{\boldsymbol{\theta}} ϵθ 被训练,它可以被用于使用数值求解器如欧拉-马鲁扬方法来近似逆时间 SDE。

时间步聚类方法。扩散模型展示了卓越的性能,但在训练和采样中面临效率挑战。为应对这些挑战,若干研究提出将时间步范围 t ∈ [ 0 , 1 ] t \in[0,1] t[0,1] 划分为多个区间(例如, [ 0 , t 1 ) , [ t 1 , t 2 ) , … , [ t n , 1 ] \left[0, t_1\right),\left[t_1, t_2\right), \ldots,\left[t_n, 1\right] [0,t1),[t1,t2),,[tn,1])。值得注意的是,Choi 等 [19] 重新配置了不同区间的损失权重以提升性能。Deja 等 [27] 根据功能将整个过程分为去噪器和生成器。Balaji 等 [28] 引入了“专家去噪器”,该方法提出在文本到图像扩散模型中为不同时间区间使用不同的架构。Go 等 [22] 通过参数高效微调和无数据知识转移进一步提高了这些专家去噪器的效率。Lee 等 [21] 根据频率特征为每个区间设计了独立的架构。此外,Go 等 [20] 将不同区间视为不同任务,并在扩散模型训练中采用多任务学习策略,以及多种时间步聚类方法。

我们的方法在两个关键方面与上述方法有所不同。第一个关键组成部分是我们定制的 U-net 架构,使用统一的编码器和针对不同区间的不同解码器。以前的模型要么采用统一架构,如 [19,20] 所见,要么为每个区间采用独立架构(称为专家去噪器)[21,22,28]。相比之下,我们的多阶段架构优于这些方法,如第 5.3 节所示。其次,我们开发了一种新的时间步聚类方法,利用一般最优去噪器(Prop. 1),展示了卓越的性能(见第 5.4 节)。相比之下,以前的工作依赖于 (i) 简单的基于时间步的聚类成本函数 [20-22, 28],(ii) 基于信噪比 (SNR) 的聚类 [20],或 (iii) 使用任务亲和分数的基于梯度的划分[20]。

减少采样迭代方法。提高扩散模型采样效率的努力催生了 SDE 和常微分方程 (ODE) 采样器的许多最新进展 [2]。例如,去噪扩散隐式模型 (DDIM) [29] 将正向扩散表述为具有确定性生成路径的非马尔可夫过程,显著减少了采样所需的函数评估次数 (NFE)(从数千次减少到数百次)。广义 DDIM (gDDIM) [30] 通过修改评分网络的参数化进一步优化了 DDIM。此外,[24] 和 [31] 中的工作分别称为扩散概率模型求解器 (DPM-solver) 和扩散指数积分采样器 (DEIS),引入了快速高阶求解器,采用指数积分器,仅需 10 次 NFE 即可获得可比的生成质量。此外,一致性模型 [32] 引入了一种新颖的训练损失和参数化,仅用 1-2 次 NFE 即可实现高质量生成。

我们注意到,虽然上述方法与我们的工作间接相关,但我们在第 5.1 节和第 5.2 节的实验表明,我们的方法可以轻松集成到这些技术中,进一步提高扩散模型的整体训练和采样效率。

3. Identification of Key Sources of Inefficiency

传统的扩散模型架构(例如 [1、2、18])将扩散模型的训练视为跨所有时间步的统一过程。最近的研究(例如 [19–22])强调了识别不同时间步之间的区别的好处以及在训练过程中将它们视为单独任务的潜在效率提升。然而,我们的实验结果表明,统一和分离架构在训练扩散模型方面都存在效率低下的问题,其中效率低下来自 (i) 过度参数化、(ii) 梯度不相似和 (iii) 过拟合。

3.1. Empirical Observations on the Key Sources of Inefficiency


为了说明每个区间内的低效性,我们通过使用一个与其他部分不同的架构来隔离该区间。
实验设置。在我们的实验中,我们考虑了三阶段训练,并将时间步分为三个区间: [ 0 , t 1 ) , [ t 1 , t 2 ) , [ t 2 , 1 ] [0, t_1),[t_1, t_2),[t_2, 1] [0,t1),[t1,t2),[t2,1]。设 ( ϵ θ ) i [ a , b ] , 0 ≤ a < b ≤ 1 \left(\boldsymbol{\epsilon}_{\boldsymbol{\theta}}\right)_i^{[a, b]}, 0 \leq a < b \leq 1 (ϵθ)i[a,b],0a<b1表示一个U-Net架构,其参数为 θ \boldsymbol{\theta} θ,用 i i i次迭代训练,并输入数据对 ( x t , t ) \left(\boldsymbol{x}_t, t\right) (xt,t),其中 t ∈ [ a , b ] t \in[a, b] t[a,b]。然后,我们采用两种不同的策略训练模型:一个具有108M网络参数的统一架构,用于所有区间,即 ( ϵ θ ) i [ 0 , 1 ] \left(\epsilon_\theta\right)_i^{[0,1]} (ϵθ)i[0,1];以及为每个区间提供不同网络参数的独立架构(例如, 47 M , 108 M , 169 M 47 \mathrm{M}, 108 \mathrm{M}, 169 \mathrm{M} 47M,108M,169M),例如, ( ϵ θ ) i [ 0 , t 1 ) \left(\boldsymbol{\epsilon}_{\boldsymbol{\theta}}\right)_i^{\left[0, t_1\right)} (ϵθ)i[0,t1)用于区间 [ 0 , t 1 ) \left[0, t_1\right) [0,t1),等等。值得注意的是,除了网络参数的差异外,我们对统一和独立方法都使用相同的网络架构(例如UNet)。我们通过在不同的训练迭代中评估图像生成质量来评估每个模型的训练进展。值得注意的是,由于有些模型仅在一个区间进行训练,我们需要为其他区间提供真实分数。在图2中,顶部显示了采样过程,底部显示了实验结果。

统一架构中的低效性。从图2中,我们观察到以下几点:

  • 在统一架构中同时出现了过参数化和欠拟合。在图2a中,我们观察到增加区间0中的参数数量可以提高图像生成质量(如较低的FID分数所示)。相反,图2b显示增加区间2中的参数数量对图像生成质量的影响微乎其微。这意味着使用统一架构将在区间0中导致欠拟合,而在区间2中导致过参数化。当前统一架构的参数冗余使其效率提升有很大的空间。为了优化计算资源的使用,我们应该为区间0分配更多参数,而为区间2分配更少参数。
  • 梯度差异性阻碍了统一架构的训练。文献[20]的定量结果显示,不同区间内的差异会导致梯度差异性。从我们的结果基于图2a和图2b也可以观察到这一点。对于使用相同参数数量(108M)的统一和独立架构,独立架构在相同的训练迭代中实现了显著更低的FID,这意味着区间之间的梯度差异可能在使用统一架构时阻碍训练。这里,训练独立和统一架构的唯一区别在于,统一架构的批梯度是基于所有时间步计算的,而独立架构的批梯度仅从特定区间计算。

现有独立架构中的低效性。尽管独立架构[19, 21, 22]更好地为每个区间分配计算资源,但它会遭遇过拟合。这可以基于在图2a中区间0中训练的独立架构(169M)和(108M)来说明,其中增加参数数量会导致过拟合。在图2b中,当我们比较所有独立架构时,这也发生在区间2中。相比之下,具有108M参数的统一网络在区间0和区间2中不太容易过拟合。这表明我们可以通过在不同区间共同训练共享权重来减少过拟合。

图2

图2. 在不同区间内,分离架构与统一架构在图像生成质量上的比较:(a) 对区间 [ 0 , t 1 ) \left[0, t_1\right) [0,t1)的分析;(b) 对区间 [ t 2 , 1 ] \left[t_2, 1\right] [t2,1]的分析。正如每个图顶部所示,在(a)和(b)中,我们仅在特定区间内训练分离架构以进行采样过程。在剩余的采样期间,我们使用经过良好训练的扩散模型 ( ϵ θ ) 4 × 1 0 5 [ 0 , 1 ] \left(\epsilon_\theta\right)_{4 \times 10^5}^{[0,1]} (ϵθ)4×105[0,1]来近似真实分数函数。如(a)图上方所示,例如对于区间1的分离架构,采样利用区间0的训练模型 ( ϵ θ ′ ) i [ 0 , t 1 ) \left(\epsilon_{\theta^{\prime}}\right)_i^{\left[0, t_1\right)} (ϵθ)i[0,t1)和区间1和2的经过良好训练的模型 ( ϵ θ ) 4 × 1 0 5 [ 0 , 1 ] \left(\epsilon_\theta\right)_{4 \times 10^5}^{[0,1]} (ϵθ)4×105[0,1]。值得注意的是,对于 ( ϵ θ ) i [ 0 , 1 ] \left(\epsilon_\theta\right)_i^{[0,1]} (ϵθ)i[0,1] ( ϵ θ ) 4 × 1 0 5 [ 0 , 1 ] \left(\epsilon_\theta\right)_{4 \times 10^5}^{[0,1]} (ϵθ)4×105[0,1],我们使用的模型参数均为108M。对于分离架构,括号中的数字表示模型 ( ϵ θ ′ ) i [ a , b ] \left(\epsilon_{\theta^{\prime}}\right)_i^{[a, b]} (ϵθ)i[a,b]的参数数量。例如,在(a)中的分离架构(169M),模型 ( θ θ ′ ) i [ 0 , t 1 ) \left(\boldsymbol{\theta}_{\boldsymbol{\theta}^{\prime}}\right)_i^{\left[0, t_1\right)} (θθ)i[0,t1)的参数为169M。底部图(a-b)说明了在不同训练迭代下每种架构生成的FID。

3.2. Tackling the Inefficiency via Multistage U-Net Architectures


在应用于所有时间步的统一架构中,通常面临双重挑战:在区间 [ 0 , t 1 ) \left[0, t_1\right) [0,t1)需要更多的参数 ( 169 M ) (169 \mathrm{M}) (169M),但在区间 [ t 2 , 1 ] \left[t_2, 1\right] [t2,1]需要更少的参数 ( 47 M ) (47 \mathrm{M}) (47M)。这个问题因不同时间步之间的梯度差异性而加剧,这可能阻碍有效的训练。或者,为不同区间使用单独的架构可能导致过拟合,并缺乏稳健的早停机制。为了解决这些挑战,我们在第4节提出的多阶段架构结合了共享参数以减少过拟合,并在每个区间使用特定参数以减轻梯度差异性的影响。这种为每个区间量身定制的方法确保了改进的适应性。此外,我们在第5.3节中进行了深入的消融研究,以展示我们的多阶段架构相较于现有模型的有效性。

4. Proposed Multistage Framework

在本节中,我们介绍了新的多阶段框架(如图 1 © 所示)。具体来说,我们首先在 4.1 节中介绍了多阶段 U-Net 架构设计,然后在 4.2 节中介绍了一种新的聚类方法,用于选择最佳间隔将整个时间步长 [0, 1] 划分为间隔,并在 4.3 节中讨论了所提架构的原理。

4.1. Proposed Multi-stage U-Net Architectures


如第3节所述,大多数现有的扩散模型要么在所有区间 [ 1 , 2 , 18 ] [1,2,18] [1,2,18]中采用统一架构以共享所有时间步的特征,要么为不同的时间步区间使用完全独立的架构 [ 21 , 22 , 28 ] [21,22,28] [21,22,28],其目标是利用不同区间内的良性特性。

为了利用先前研究中使用的统一和独立架构的优势,我们引入了一种多阶段U-Net架构,如图1©所示。具体来说,我们将整个时间步 [ 0 , 1 ] [0,1] [0,1]划分为几个区间,例如图1中的三个区间 [ 0 , t 1 ) , [ t 1 , t 2 ) , [ t 2 , 1 ] \left[0, t_1\right),\left[t_1, t_2\right),\left[t_2, 1\right] [0,t1),[t1,t2),[t2,1]。对于该架构,我们引入了:

  • 所有时间区间共享的一个编码器。对于每个时间步区间,我们实现了一个共享编码器架构(在图1©中以蓝色绘制),类似于原始U-Net框架中使用的架构[33]。与独立架构不同,共享编码器在所有时间步之间提供共享信息,防止模型过拟合(见第5.3节的讨论)。
  • 不同时间区间的独立解码器。受Mask Region-based Convolutional Neural Networks (MaskRCNN)方法[34]中引入的多头结构启发,我们建议使用多个不同的解码器(在图1©中为不同区间绘制的颜色),其中每个解码器都针对特定的时间步区间进行了定制。每个解码器的架构与[2]中使用的架构非常相似,并对嵌入维度进行了刻意调整以优化性能。

如我们所见,架构的主要区别在于解码器结构。直观地,我们为接近噪声的区间使用参数较少的解码器,因为学习任务更简单。对于接近图像的区间,我们使用参数较多的解码器。

4.2. Optimal Denoiser-based Timestep Clustering


接下来,我们讨论在实践中如何选择区间划分的时间点。为简单起见,我们集中讨论将时间 [ 0 , 1 ] [0,1] [0,1]划分为三个区间 [ 0 , t 1 ) , [ t 1 , t 2 ) , [ t 2 , 1 ] \left[0, t_1\right),\left[t_1, t_2\right),\left[t_2, 1\right] [0,t1),[t1,t2),[t2,1]的情况,并开发了一种时间步聚类方法来选择最佳的 t 1 t_1 t1 t 2 t_2 t2。当然,我们的方法可以推广到具有任意区间数量的多阶段网络。然而,在实践中,我们发现选择三个区间在效果和复杂性之间达到了良好的平衡;参见附录B.6中的消融研究。

为了划分时间区间,我们采用命题1中建立的最佳去噪器。
命题1。假设我们使用数据集 { y i ∈ R n } i = 1 N \left\{\boldsymbol{y}_i \in \mathbb{R}^n\right\}_{i=1}^N {yiRn}i=1N训练一个扩散模型去噪函数 ϵ θ ( x , t ) \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}, t) ϵθ(x,t),参数为 θ \boldsymbol{\theta} θ,通过以下公式
min ⁡ θ L ( ϵ θ ; t ) = E x 0 , x t , ϵ [ ∥ ϵ − ϵ θ ( x t , t ) ∥ 2 ] , \min _{\boldsymbol{\theta}} \mathcal{L}\left(\boldsymbol{\epsilon}_{\boldsymbol{\theta}} ; t\right)=\mathbb{E}_{\boldsymbol{x}_0, \boldsymbol{x}_t, \epsilon}\left[\left\|\boldsymbol{\epsilon}-\boldsymbol{\epsilon}_{\boldsymbol{\theta}}\left(\boldsymbol{x}_t, t\right)\right\|^2\right], θminL(ϵθ;t)=Ex0,xt,ϵ[ϵϵθ(xt,t)2],
其中 x 0 ∼ p data  ( x ) = 1 N ∑ i = 1 N δ ( x − y i ) , ϵ ∼ N ( 0 , I ) \boldsymbol{x}_0 \sim p_{\text {data }}(\boldsymbol{x})=\frac{1}{N} \sum_{i=1}^N \delta\left(\boldsymbol{x}-\boldsymbol{y}_i\right), \boldsymbol{\epsilon} \sim \mathcal{N}(0, \boldsymbol{I}) x0pdata (x)=N1i=1Nδ(xyi),ϵN(0,I), 且 x t ∼ p t ( x t ∣ x 0 ) = N ( x t ; s t x 0 , s t 2 σ t 2 I ) \boldsymbol{x}_t \sim p_t\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)=\mathcal{N}\left(\boldsymbol{x}_t ; s_t \boldsymbol{x}_0, s_t^2 \sigma_t^2 \boldsymbol{I}\right) xtpt(xtx0)=N(xt;stx0,st2σt2I),扰动参数为 s t s_t st σ t \sigma_t σt定义在方程(1)中。那么,定义为 t t t时的最佳去噪器 ϵ θ ∗ ( x ; t ) = arg ⁡ min ⁡ ϵ θ L ( ϵ θ ; t ) \boldsymbol{\epsilon}_{\boldsymbol{\theta}}^*(\boldsymbol{x} ; t)=\arg \min _{\boldsymbol{\epsilon}_{\boldsymbol{\theta}}} \mathcal{L}\left(\boldsymbol{\epsilon}_{\boldsymbol{\theta}} ; t\right) ϵθ(x;t)=argminϵθL(ϵθ;t)
ϵ θ ∗ ( x ; t ) = 1 s t σ t [ x − s t ∑ i = 1 N N ( x ; s t y i , s t 2 σ t 2 I ) y i ∑ i = 1 N N ( x ; s t y i , s t 2 σ t 2 I ) ] 。 \boldsymbol{\epsilon}_{\boldsymbol{\theta}}^*(\boldsymbol{x} ; t)=\frac{1}{s_t \sigma_t}\left[\boldsymbol{x}-s_t \frac{\sum_{i=1}^N \mathcal{N}\left(\boldsymbol{x} ; s_t \boldsymbol{y}_i, s_t^2 \sigma_t^2 \boldsymbol{I}\right) \boldsymbol{y}_i}{\sum_{i=1}^N \mathcal{N}\left(\boldsymbol{x} ; s_t \boldsymbol{y}_i, s_t^2 \sigma_t^2 \boldsymbol{I}\right)}\right]。 ϵθ(x;t)=stσt1[xsti=1NN(x;styi,st2σt2I)i=1NN(x;styi,st2σt2I)yi]

证明见附录A,该结果可以从Karras等人[18]的最新工作中推广,从特定核 p t ( x t ∣ x 0 ) = N ( x t ; x 0 , σ t 2 I ) p_t\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)=\mathcal{N}\left(\boldsymbol{x}_t ; \boldsymbol{x}_0, \sigma_t^2 \mathbf{I}\right) pt(xtx0)=N(xt;x0,σt2I)扩展到包含更广泛的噪声扰动核,给出 p t ( x t ∣ x 0 ) = N ( x t ; s t x 0 , s t 2 σ t 2 I ) p_t\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)=\mathcal{N}\left(\boldsymbol{x}_t ; s_t \boldsymbol{x}_0, s_t^2 \sigma_t^2 \mathbf{I}\right) pt(xtx0)=N(xt;stx0,st2σt2I)。为简洁起见,我们在命题1中将最佳去噪器 ϵ θ ∗ ( x , t ) \boldsymbol{\epsilon}_\theta^*(\boldsymbol{x}, t) ϵθ(x,t)简化为 ϵ t ∗ ( x ) \boldsymbol{\epsilon}_t^*(\boldsymbol{x}) ϵt(x)

为了获得最佳区间,我们的基本原则是在每个单独的时间区间内尽可能均匀化回归任务。为了实现这一目标,给定采样的 x 0 , ϵ x_0, \epsilon x0,ϵ,我们定义在任何给定时间步 t a , t b t_a, t_b ta,tb下最佳去噪器的函数距离为:
D ( ϵ t a ∗ , ϵ t b ∗ , x 0 , ϵ ) = 1 n ∑ i = 1 n 1 ( ∣ ϵ t a ∗ ( x t a ) − ϵ t b ∗ ( x t b ) ∣ i ≤ η ) , \mathcal{D}\left(\boldsymbol{\epsilon}_{t_a}^*, \boldsymbol{\epsilon}_{t_b}^*, \boldsymbol{x}_0, \boldsymbol{\epsilon}\right)=\frac{1}{n} \sum_{i=1}^n \mathbb{1}\left(\left|\boldsymbol{\epsilon}_{t_a}^*\left(\boldsymbol{x}_{t_a}\right)-\boldsymbol{\epsilon}_{t_b}^*\left(\boldsymbol{x}_{t_b}\right)\right|_i \leq \eta\right), D(ϵta,ϵtb,x0,ϵ)=n1i=1n1( ϵta(xta)ϵtb(xtb) iη),
其中 1 ( ⋅ ) \mathbb{1}(\cdot) 1()是指标函数, η \eta η是预设的阈值, x t a = s t a x 0 + s t a σ t a ϵ \boldsymbol{x}_{t_a}=s_{t_a} \boldsymbol{x}_0+s_{t_a} \sigma_{t_a} \boldsymbol{\epsilon} xta=stax0+staσtaϵ,和 x t b = s t b x 0 + s t b σ t b ϵ \boldsymbol{x}_{t_b}=s_{t_b} \boldsymbol{x}_0+s_{t_b} \sigma_{t_b} \epsilon xtb=stbx0+stbσtbϵ。因此,我们定义在时间步 t a t_a ta t b t_b tb下最佳去噪器的函数相似性为:
S ( ϵ t a ∗ , ϵ t b ∗ ) = E x 0 ∼ p data  E ϵ ∼ N ( 0 , I ) [ D ( ϵ t a ∗ , ϵ t b ∗ , x 0 , ϵ ) ] \mathcal{S}\left(\boldsymbol{\epsilon}_{t_a}^*, \boldsymbol{\epsilon}_{t_b}^*\right)=\mathbb{E}_{\boldsymbol{x}_0 \sim p_{\text {data }}} \mathbb{E}_{\boldsymbol{\epsilon} \sim \mathcal{N}(0, \mathbf{I})}\left[\mathcal{D}\left(\boldsymbol{\epsilon}_{t_a}^*, \boldsymbol{\epsilon}_{t_b}^*, \boldsymbol{x}_0, \boldsymbol{\epsilon}\right)\right] S(ϵta,ϵtb)=Ex0pdata EϵN(0,I)[D(ϵta,ϵtb,x0,ϵ)]
基于定义,我们设计以下优化问题以找到最大的 t 1 t_1 t1和最小的 t 2 t_2 t2为:
t 1 ← arg ⁡ max ⁡ τ { τ ∣ E t ∼ [ 0 , τ ) [ S ( ϵ t ∗ , ϵ 0 ∗ ) ] ≥ α } , t 2 ← arg ⁡ min ⁡ τ { τ ∣ E t ∼ [ τ , 1 ] [ S ( ϵ t ∗ , ϵ 1 ∗ ) ] ≥ α } , \begin{aligned} & t_1 \leftarrow \underset{\tau}{\arg \max }\left\{\tau \mid \mathbb{E}_{t \sim[0, \tau)}\left[\mathcal{S}\left(\boldsymbol{\epsilon}_t^*, \boldsymbol{\epsilon}_0^*\right)\right] \geq \alpha\right\}, \\ & t_2 \leftarrow \underset{\tau}{\arg \min }\left\{\tau \mid \mathbb{E}_{t \sim[\tau, 1]}\left[\mathcal{S}\left(\boldsymbol{\epsilon}_t^*, \boldsymbol{\epsilon}_1^*\right)\right] \geq \alpha\right\}, \end{aligned} t1τargmax{τEt[0,τ)[S(ϵt,ϵ0)]α},t2τargmin{τEt[τ,1][S(ϵt,ϵ1)]α},
使得 ϵ t ∗ \epsilon_t^* ϵt(分别为 ϵ t ∗ \epsilon_t^* ϵt)在 [ 0 , t 1 ) \left[0, t_1\right) [0,t1)(分别为 [ t 2 , 1 ] \left[t_2, 1\right] [t2,1])的平均函数相似性大于或等于预定义的阈值 α \alpha α。由于上述优化问题是不可解的,我们提出了算法1中概述的程序以获得近似解。特别是,该算法采样 K K K ( y k , ϵ k , t k ) , k ∈ { 1 , … , K } \left(\boldsymbol{y}_k, \boldsymbol{\epsilon}_k, t_k\right), k \in\{1, \ldots, K\} (yk,ϵk,tk),k{1,,K}以计算距离 D ( ϵ t k ∗ , ϵ 0 ∗ , y k , ϵ k ) \mathcal{D}\left(\boldsymbol{\epsilon}_{t_k}^*, \boldsymbol{\epsilon}_0^*, \boldsymbol{y}_k, \boldsymbol{\epsilon}_k\right) D(ϵtk,ϵ0,yk,ϵk) D ( ϵ t k ∗ , ϵ 1 ∗ , y k , ϵ k ) \mathcal{D}\left(\boldsymbol{\epsilon}_{t_k}^*, \boldsymbol{\epsilon}_1^*, \boldsymbol{y}_k, \boldsymbol{\epsilon}_k\right) D(ϵtk,ϵ1,yk,ϵk)(步骤6)。基于这些距离,我们在算法1的第8行和第9行中定义的优化问题中解决以获得 t 1 t_1 t1 t 2 t_2 t2

4.3. Rationales for the proposed architecture


最后,我们总结了基于第3节中的经验观察和先前工作的我们所提出的架构的基本原理。

共享编码器的原理。(i) 防止过拟合:如果我们将训练扩散模型的不同阶段视为多任务学习,[35] 表明多任务学习中跨任务共享参数可以减轻过拟合。(ii) 保持 h h h-空间的一致性。UNet 编码器的输出被称为 h h h-空间 [36],具有如均匀性、线性、鲁棒性和时间步长一致性的语义操控属性。因此,与单独的编码器相比,共享编码器可以在所有时间步中保持 h h h-空间更好的一致性。

网络参数设计的原理。首先,我们提供一些直觉来说明为什么 t = 0 t=0 t=0 时的学习任务比 t = 1 t=1 t=1 时更困难。假设扩散模型可以收敛到方程(4)中给出的最佳去噪器 ϵ θ θ ( x t , t ) \boldsymbol{\epsilon}_{\boldsymbol{\theta}}^{\boldsymbol{\theta}}\left(\boldsymbol{x}_t, t\right) ϵθθ(xt,t)。基于此,我们观察到:(i) 当 t → 0 t \rightarrow 0 t0 时,我们有 x t → 0 = x 0 , ϵ θ ∗ ∣ t → 0 = ϵ \boldsymbol{x}_{t \rightarrow 0}=\boldsymbol{x}_0,\left.\boldsymbol{\epsilon}_{\boldsymbol{\theta}}^*\right|_{t \rightarrow 0}=\boldsymbol{\epsilon} xt0=x0,ϵθt0=ϵ,因此 ϵ θ ∗ \boldsymbol{\epsilon}_{\boldsymbol{\theta}}^* ϵθ 是从训练数据分布 p data  ( x ) p_{\text {data }}(\boldsymbol{x}) pdata (x) 到高斯分布的复杂映射;(ii) 当 t → 1 t \rightarrow 1 t1 时,我们有 x t → 1 = ϵ \boldsymbol{x}_{t \rightarrow 1}=\boldsymbol{\epsilon} xt1=ϵ ϵ θ ∗ ∣ t → 1 = ϵ \left.\epsilon_{\boldsymbol{\theta}}^*\right|_{t \rightarrow 1}=\boldsymbol{\epsilon} ϵθt1=ϵ,因此 ϵ θ ∗ \epsilon_{\boldsymbol{\theta}}^* ϵθ 是一个恒等映射。两个极端情况揭示了对于 t → 1 t \rightarrow 1 t1(接近噪声)的恒等映射比 t → 0 t \rightarrow 0 t0 的更容易学习。其次,我们在不同阶段选择网络参数的原则与最近的工作 [37] 类似,该工作在 t → 0 t \rightarrow 0 t0 时采用高维子空间,并逐渐减少每个子空间的维数直到 t = 1 t=1 t=1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/51517.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

通过哨兵1号SAR数据获取桂林619洪水内涝情况

目录 1.SAR数据下载 2.SAR数据处理 1、下载轨道数据并进行轨道校正。 2、数据处理 3、转换SAR单位并创建彩色合成影像 3.查看彩色合成SAR数据 4.水体提取方法探讨 方法1&#xff1a;阈值提取法 方法2&#xff1a;深度学习提取水域 5.SAR与DEM数据获取 2024年6月19日&a…

MobaXterm 软件安装及使用

MobaXterm 软件安装及使用 1. 引言 MobaXterm是一款功能强大的终端软件&#xff0c;支持SSH、Telnet、RDP、VNC、FTP、SFTP、X11转发和串口等远程会话功能。它使得在Windows系统上进行Linux系统的远程管理和文件传输变得简单便捷。 2. MobaXterm 软件下载 下载链接&#xff…

蚓链数字化生态平台:构建城市智能商业,引领协同发展新潮流

​在当今数字化飞速发展的时代&#xff0c;城市商业的运行模式正在经历着数字化变革。蚓链数字化生态平台应运而生&#xff0c;以其强大的功能和创新的理念&#xff0c;成为构建城市智能商业枢纽中心的关键力量&#xff0c;推动着平台互通、业务贯通、管理协同的全新发展格局。…

五个优秀的免费 Ollama WebUI 客户端推荐

认识 Ollama 本地模型框架&#xff0c;并简单了解它的优势和不足&#xff0c;以及推荐了 5 款开源免费的 Ollama WebUI 客户端&#xff0c;以提高使用体验。 什么是 Ollama&#xff1f; Ollama 是一款强大的本地运行大型语言模型&#xff08;LLM&#xff09;的框架&#xff0c…

opencascade AIS_ManipulatorOwner AIS_MediaPlayer源码学习

前言 AIS_ManipulatorOwner是OpenCascade中的一个类&#xff0c;主要用于操纵对象的交互控制。AIS_ManipulatorOwner结合AIS_Manipulator类&#xff0c;允许用户通过可视化工具&#xff08;如旋转、平移、缩放等&#xff09;来操纵几何对象。 以下是AIS_ManipulatorOwner的基…

深度强化学习 ②(DRL)

参考视频&#xff1a;&#x1f4fa;王树森教授深度强化学习 前言&#xff1a; 最近在学习深度强化学习&#xff0c;学的一知半解&#x1f622;&#x1f622;&#x1f622;&#xff0c;这是我的笔记&#xff0c;欢迎和我一起学习交流~ 这篇博客目前还相对比较乱&#xff0c;后面…

angular入门基础教程(一)环境配置与新建项目

ng已经更新到v18了&#xff0c;我对他的印象还停留在v1,v2的版本&#xff0c;最近研究了下&#xff0c;与react和vue是越来越像了&#xff0c;所以准备正式上手了。 新官网地址:https://angular.cn/ 准备条件 nodejs > 18.0vscodeng版本18.x(最新的版本) {"name&qu…

【前端 17】使用Axios发送异步请求

Axios 简介与使用&#xff1a;简化 HTTP 请求 在现代 web 开发中&#xff0c;发送 HTTP 请求是一项常见且核心的任务。Axios 是一个基于 Promise 的 HTTP 客户端&#xff0c;适用于 node.js 和浏览器&#xff0c;它提供了一种简单的方法来发送各种 HTTP 请求。本文将介绍 Axio…

【虚拟化】KVM概念和架构

目录 一、什么是KVM&#xff1f; 二、KVM的功能 2.1 主要的功能 2.2 其它功能 三、KVM核心组件及作用 四、KVM与VMware的优势 五、KVM架构 六、qemu介绍 七、创建虚拟机流程 一、什么是KVM&#xff1f; Kernel-based Virtual Machine的简称&#xff0c;KVM 是基于虚拟…

ubuntu部署k8s/microk8s安装部署

资源 节点名称IP配置系统node310.2.20.174核8GUbuntu Server 22.04 LTS 64bitnode210.2.24.44核8GUbuntu Server 22.04 LTS 64bitnode110.2.20.134核8GUbuntu Server 22.04 LTS 64bitmaster10.2.24.104核8GUbuntu Server 22.04 LTS 64bit ps:所有命令尽量使用root账号操作 1…

实体店怎么做会员分析管理,告别“僵尸“会员?

在线上严重蚕食线下的当下&#xff0c;如果实体店不重视会员分析&#xff0c;那它将会错失更多的客户&#xff0c;甚至面临被淘汰的危险。 近年来&#xff0c;越来越多的实体店商家开始重视会员分析管理&#xff0c;但要做好会员分析管理并非易事&#xff0c;需要一整套的工具…

RK平台瑞发科NS6601 MIPI CSI VC虚拟通道支持不同分辨率

需求&#xff1a;两路不同分辨率的摄像头&#xff0c;通过des后输入给一路MIPI CSI。在capture的时候&#xff0c;可以分别支持不同分辨率的capture动作。 设备树 &i2c2 {status "okay";pinctrl-names "default";pinctrl-0 <&i2c2m4_xfer&g…

NLP笔记

文本处理和词嵌入 对于机器来说&#xff0c;不能理解一句话的意思&#xff0c;解决办法就是将一句话分割成多个词&#xff0c;用数字代表一个词&#xff0c;将一句话转化成数字的列表 这个词对应的字典又是怎么训练出来的呢&#xff0c;遍历这个词的列表&#xff0c;如果在字…

ConvGRU原理与开源代码

ConvGRU 1. 算法简介与应用场景2. 算法原理2.1 GRU基础2.2 ConvGRU原理2.2.1 ConvGRU的结构2.2.2 卷积操作的优点 2.3 GRU与ConvGRU的对比分析2.4 ConvGRU的应用 3. PyTorch代码 仅需要网络源码的可以直接跳到末尾即可 需要ConvLSTM的可以参考我的另外一篇博客&#xff1a;小白…

初识HTML文件,创建自己的第一个网页!

本文旨在初步介绍HTML&#xff08;超文本标记语言&#xff09;&#xff0c;帮助读者理解HTML中的相关术语及概念&#xff0c;并使读者在完成本文的阅读后可以快速上手编写一个属于自己的简易网页。 一、HTML介绍 HTML(全称HyperText Markup Language&#xff0c;超文本标记语言…

【C++】位图 + 布隆过滤器

目录 1. 位图1.1. 概念1.2. 实现1.3. 应用 2. 布隆过滤器2.1. 背景2.2. 概念2.3. 实现2.4. 优点2.5. 缺点 3. 海量数据面试题3.1. 哈希切割3.2. 位图应用3.3. 布隆过滤器3.4. 总结 1. 位图 1.1. 概念 位图是一种用于高效地存储和操作集合的数据结构。它的基本思想是使用一个二…

高并发内存池(四)Page Cache的框架及内存申请实现

目录 一、Page Cache的框架梳理 二、Page Cache的实现 2.1PageCache.h 2.2VirtualAlloc 2.3std::unordered_map _idSpanMap,> 2.4Page Cache.cpp 一、Page Cache的框架梳理 申请内存&#xff1a; 1. 当central cache向page cache申请内存时&#xff0c;page cache先检…

Intel 13/14代不稳定 微星率先发声:密切监视、8月中旬更新微码

不久前&#xff0c;Intel针对14/14代酷睿i9 K系列不稳定的问题发布了最新声明&#xff0c;确认问题源于微代码算法缺陷与电压过高&#xff0c;并承诺将在8月中旬完成新版BIOS的验证&#xff0c;随后发放。现在&#xff0c;微星在各家主板厂商中第一个站出来&#xff0c;表明了态…

Java 使用 POI 导出Excel,实现单元格输入内容提示功能

在使用Apache POI的库生成Excel导入模板的时候&#xff0c;有时候需要对单元格能够输入的内容进行一个提示&#xff0c;该如何实现这个特性呢&#xff1f;下面是一个示例代码&#xff0c;演示如何实现单元格输入内容提示功能。 代码 import org.apache.poi.ss.usermodel.*; im…

Frienda 4 件套幽灵狩猎猫球运动发光猫球 LED 运动激活猫球运动点亮猫狗互动玩具宠物发光迷你跑步健身球

来自 美国亚马逊&#xff1a;商品评论: Frienda 4 件套幽灵狩猎猫球运动发光猫球 LED 运动激活猫球运动点亮猫狗互动玩具宠物发光迷你跑步健身球玩具(亮色) (amazon.com) Kim 1.0 颗星&#xff0c;最多 5 颗星 Battery does not last/ cant replace 2024年5月29日 在美国审核…