Diffusion Model Patching via Mixture-of-Prompts
公和众和号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群)
目录
0 摘要
1 简介
2 相关工作
3 扩散模型修补(DMP)
3.1 架构
3.2 训练
4 实验
4.2 设计选择
4.3 分析
5 结论
0 摘要
我们提出了扩散模型修补(Diffusion Model Patching,DMP),这是一种简单的方法,可以在参数增加可忽略不计的情况下提升已达到收敛状态的预训练扩散模型的性能。DMP 在模型的输入空间中插入一小组可学习的提示,同时保持原始模型冻结。DMP 的有效性不仅仅在于参数的增加,而是源于其动态门控机制,该机制在生成过程的每一步(即反向去噪步骤)选择并组合一部分可学习的提示。我们称这种策略为 “提示混合”,使模型能够在每一步利用每个提示的独特专长,以最少但专业化的参数 “修补”(patching)模型的功能。独特的是,DMP 通过在原始训练数据集上进一步训练来增强模型,即使在通常不期望显著改进的模型收敛情况下也是如此。实验表明,DMP 在FFHQ 256×256 数据集上,将 DiT-L/2 的收敛 FID 显著提升了 10.38%,仅增加了 1.43% 的参数和 50K 次额外训练迭代。
项目页面:https://sangminwoo.github.io/DMP/
1 简介
扩散模型的特点在于其多步骤去噪过程,该过程逐步将随机噪声精细化为结构化的输出,如图像。每一步都旨在对带噪输入进行去噪,逐步将完全随机的噪声转化为有意义的图像。尽管所有去噪步骤的目标都是生成高质量图像,但每一步都有不同的特性,有助于塑造最终输出【12, 42】。扩散模型学习的视觉概念因输入的噪声比率而异【6】。在较高噪声水平(时间步 t 接近 T)时,图像高度损坏,因此内容无法识别,模型专注于恢复全局结构和颜色。随着噪声水平降低,图像损坏程度减少(时间步 t 接近 0),恢复图像的任务变得更加直接,扩散模型学习恢复细粒度的细节。最近的研究【2, 6, 12, 42】表明,考虑阶段特异性(stage-specificity)是有益的,因为它更好地与生成过程不同阶段的细微需求对齐。然而,许多现有的扩散模型并未明确考虑这一方面。
我们的目标是通过引入阶段特异性能力来增强已经收敛的预训练扩散模型。我们提出扩散模型修补(DMP),一种为预训练扩散模型配备增强工具的方法,使它们能够以更大的精细度和精确度导航生成过程。DMP 的概述如图 1 所示。DMP 由两个主要组件组成:
- 一小组可学习的提示【29】,每个都针对生成过程的特定阶段进行优化。这些提示附加到模型的输入空间,作为某些去噪步骤(或噪声水平)的 “专家”。这种设计使得模型可以针对每个阶段的特定行为进行定向,无需重新训练整个模型,而是专注于仅调整输入空间的小参数。
- 一个动态门控机制,根据输入图像的噪声水平自适应地组合 “专家” 提示(或提示混合)。这种动态地利用提示使得模型具有灵活性,能够在生成的不同阶段利用提示知识集的不同方面。通过利用嵌入在每个提示中的专业知识,模型可以在整个多步骤生成过程中,适应阶段特定的需求。
通过结合这些组件,我们继续使用最初预训练的原始数据集训练收敛的扩散模型。鉴于模型已经收敛,一般认为传统的微调不会带来显著的改进,甚至可能导致过拟合。然而,DMP 为模型提供了对每个去噪步骤的细致理解,从而即使在相同的数据分布上训练,也能提升性能。如图 2 所示,DMP 在 FFHQ【25】 256×256 数据集上,通过仅 50K 次迭代,将 DiT-L/2【43】的性能提升了 10.38%。
2 相关工作
具有阶段特异性(stage-specificity)的扩散模型。扩散模型的最新进展【17, 54, 55】拓宽了它们在各种数据模态中的应用,包括图像【47, 50】、音频【27】、文本【30】和 3D【62】,展示了在众多生成任务中的显著多功能性。近期的努力集中在提高去噪阶段的特异性,在架构和优化方面均取得了显著进展。
在架构方面,
- eDiff-I【2】和 ERNIE-ViLG 2.0【10】引入了使用多个专家去噪器的概念,每个去噪器针对特定的噪声水平进行优化,从而增强了模型的整体能力。
- 此外,RAPHAEL【66】采用了时空专家,其中空间专家专注于特定图像区域,时间专家专注于特定的去噪阶段。
- DTR【42】通过为每个去噪步骤分配不同的通道组合,优化了扩散模型架构。
从优化角度看,
- Choi 等人【6】提出了一种策略,为构建内容的阶段分配更高的权重,而为清理剩余噪声的阶段分配较低的权重。
- Hang 等人【14】通过将扩散训练框架化为多任务学习问题【4】来加速收敛,在每个时间步基于任务难度调整损失权重。
- Go 等人【12】基于信噪比(SNR)聚类相似阶段,缓解了多个去噪阶段的学习冲突。
以往的研究旨在提高去噪阶段的特异性,通常假设从头开始训练或使用多个专家网络,这可能需要大量资源和显著的参数存储。我们的方法在不修改原始模型参数的情况下,实现了阶段特异性,仅使用单个预训练扩散模型。
扩散模型中的参数高效微调(PEFT)。PEFT 提供了一种通过微调少量(额外)参数来增强模型的方法,避免了重新训练整个模型的需求,大大降低了计算和存储成本。鉴于扩散模型的复杂性和高参数密度【43, 48】,这一点特别有吸引力。从头开始直接训练扩散模型往往不切实际,这引发了对 PEFT 方法的日益关注【63】。这一领域的最新进展大致可以分为三个流派:
- T2I-Adapter【39】、SCEdit【23】和 ControlNet【69】利用适配器【5, 19, 20】或旁路微调(side-tuning)【57, 68】在特定层修改神经网络的行为,从而重写激活。
- Prompt2prompt【15】和 Textual Inversion【11】使用类似于提示微调【22, 29, 31, 37, 71】的概念,修改输入或文本表示,以影响后续处理,而不改变功能本身。
- CustomDiffusion【28】、SVDiff【13】和 DiffFit【64】专注于部分参数微调【32, 65, 67】,微调神经网络的特定参数,如偏置项。
这些方法在个性化或定制编辑的扩散模型微调中取得了成功,通常使用与原始预训练数据集不同的样本【13, 15, 28, 49】。它们还有效地支持了用于可控图像生成的多样条件输入【23, 39, 69】。相比之下,我们的工作旨在通过原始训练数据集优化预训练扩散模型的性能。我们的方法在保持参数高效的同时,目标是域内增强。
3 扩散模型修补(DMP)
我们提出了 DMP,一种简单而有效的方法,旨在通过使预训练的、已收敛的扩散模型能够利用特定于不同去噪阶段的知识来增强其性能。DMP包含两个关键组件:可学习提示池(prompts pool)和动态门控机制。首先,将少量可学习参数(称为提示)附加到扩散模型的输入空间。其次,动态门控机制根据输入图像的噪声水平选择最佳的提示集(或提示混合)。在这些组件基础上,我们使用相同的预训练数据集进一步训练模型,以学习提示,同时保持主干参数冻结。我们选择 DiT【43】作为基础模型。DMP 的整体框架如图 3 所示。
提示微调。提示微调的核心思想是找到一小组参数,当与输入结合时,有效地 “微调” 预训练模型的输出以达到期望的结果。传统的微调旨在通过修改预训练模型 fθ,在给定输入 x 的情况下,最小化真实值 y 和预测值 ^y之间的差距:
其中 ^y 是优化后的预测值, f′_θ 是修改后的模型,红色和蓝色分别表示可学习和冻结的参数。这个过程通常计算量大且资源密集,因为它需要存储和更新完整的模型参数。相比之下,提示微调通过直接修改输入 x 来增强输出 ^y:
以往的工作【22, 29, 31, 61, 71】通常定义
其中 [·; ·] 表示连接。然而,我们采取了一种不同的方法,直接将提示添加到输入中,旨在更明确地影响输入本身,因此
提示通过梯度下降进行优化,类似于传统的微调,但不改变模型的参数。
动机。在多步去噪过程中,每个阶段的难度和目标因噪声水平而异【6, 12, 14, 42】。提示微调【22, 29, 31】假设如果预训练模型已经具备足够的知识,精心构建的提示可以从冻结的模型中提取特定下游任务的知识。同样,我们假设预训练的扩散模型已经掌握了所有去噪阶段的通用知识,通过为每个阶段学习不同的提示混合,可以用阶段特定的知识修补模型。
3.1 架构
我们采用 DiT【43】作为基础架构,这是一种基于 Transformer【59】的 DDPM【17】,在潜在空间中操作【48】。我们使用来自 Stable Diffusion【48】的预训练 VAE【26】将输入图像处理成形状为 H×W×D 的潜在代码(latent code)(对于 256×256×3 的图像,潜在代码为 32×32×4)。然后我们对潜在代码进行加噪。有噪潜在代码被分成 N 个固定大小的补丁(patch),每个补丁的形状为 K×K×D。这些补丁线性嵌入,并应用标准位置编码【59】,得到 x^(0) ∈ R^(N×D)。这些补丁随后由一系列 L 个 DiT 块处理。这些块使用时间步嵌入 t 进行训练,并可选择使用类别或文本嵌入。在最后一个 DiT 块之后,有噪潜在补丁经过最终的层归一化,并线性解码为 K×K×2D 的张量(D 用于噪声预测,另一个 D 用于对角协方差预测)。最后,解码后的 token 被重新排列以匹配原始形状 H×W×D。
可学习提示。我们的目标是通过在输入空间中调整小参数来有效地增强模型的去噪阶段特定知识。为此,我们从预训练的 DiT 模型开始,在每个 DiT 块的输入空间中插入 N 个维度为 D 的可学习连续嵌入,即提示。可学习提示集表示为:
这里, p^(i) 表示第 i 个 DiT 块的提示,L 是模型中的 DiT 块总数。不像之前的方法【22, 29, 31, 61】,提示通常被预先置于输入序列之前,我们直接将它们添加到输入中。这种方法的优点是不增加序列长度,从而保持几乎与之前相同的计算速度。此外,在生成过程中,每个添加到输入补丁中的提示为每个时间步的特定空间部分提供直接信号,有助于专门的去噪步骤。这样设计使模型能够专注于输入的不同方面,辅助特定的去噪步骤。第 i 个 DiT 块的输出计算如下:
在进一步训练过程中,只更新提示,而 DiT 块的参数保持不变。
动态门控。在公式(4)中,整个训练过程中使用相同的提示,因此它们将学习与去噪阶段无关的知识。为了用阶段特定的知识修补模型,我们引入了动态门控。这种机制根据输入图像的噪声水平以不同的比例混合提示,具体来说,我们使用时间步嵌入 t 来表示生成过程中的每一步噪声水平。对于给定的 t,门控网络 G 创建提示混合,从而重新定义公式(4)为:
其中 σ 是 softmax 函数,⊙ 表示元素乘法。在实际应用中,G 被实现为一个简单的线性层。它还将 DiT 块深度 i 作为输入,根据深度产生不同的结果。这种动态门控机制有效地使用少量提示处理不同噪声水平,赋予模型在生成过程不同阶段使用不同提示知识集的灵活性。
3.2 训练
零初始化。我们实验证明,随机初始化提示可能在训练初期破坏原始信息流,导致不稳定和发散。为了确保预训练扩散模型的稳定的进一步训练,我们从零初始化提示。采用之前选择的提示添加策略,零初始化有助于防止有害噪声影响神经网络层的深层特征,并在训练开始时保留原始知识。在第一步训练中,由于提示参数初始化为零,公式(5)中的 p^(i−1) 项评估为零,因此:
随着训练的进行,模型可以逐渐结合提示中的额外信号。
提示平衡损失。我们采用 Shazeer 等人【53】的两个软约束来平衡提示混合的激活。
- 重要性平衡:在多专家设置【21, 24】中,Eigen 等人【9】指出,一旦专家被选择,它们往往会持续被选择。在我们的设置中,重要性平衡损失防止门控网络 G 过度偏向权重较大的少数提示,鼓励所有提示具有相似的整体重要性。
- 负载平衡:尽管具有相似的整体重要性,提示仍可能以不平衡的权重激活。例如,一个提示可能在少数去噪步骤中被分配较大权重,而另一个提示可能在许多步骤中被分配较小权重。负载平衡损失确保提示在所有去噪步骤(或噪声水平)中几乎均匀地激活。
在我们的 DMP 方法的训练阶段,我们采用提示平衡损失来防止门仅选择少数特定的提示,这是一种已知的模式崩溃问题。提示平衡损失受到了混合专家中使用的平衡损失的启发。当将第 i 个 DiT 块层中的第 n 个提示门定义为 g^i_n 时,重要性损失和负载平衡损失的表达式如下:
其中,L 是 DiT 块层的总数,N 是提示的总数,ϵ = 1e−5 以防止除以零,而 I 是指示函数,计算满足条件(g^i_n > 0)的元素数量。重要性损失使用平方变异系数,这使得方差值对均值具有鲁棒性。通过这些提示平衡损失,我们对门中的提示选择进行正则化,确保未选择的提示只有少量,如图 5 所示。
提示效率。表 1 展示了不同版本 DiT 架构【43】在有无 DMP 情况下的模型参数,从 DiT-B/2 到DiT-XL/2(其中 “2” 表示补丁大小 K)。假设固定分辨率为 256×256,使用 DMP 会使 DiT-B/2 的参数增加 1.96%。对于最大的模型 DiT-XL/2,使用 DMP 会使参数增加 1.26%,达到 683.5M。随着模型规模的增大,DMP 参数占总模型参数的比例减少,相对于整个模型,仅需微调少量参数。
4 实验
4.2 设计选择
提示深度。为了研究在不同数量的块中插入提示的影响,我们使用包含 12 个 DiT 块的 DiT-B/2 模型进行消融研究。我们评估了在不同深度(仅在第一个块、最多一半的块以及所有块)应用混合提示的性能差异,如图 4 所示。无论深度如何,性能相比基线(FID=6.27,无提示)始终有所提升。我们的研究发现,提示深度与性能呈正相关关系,在更多块中使用混合提示能取得更好的结果。图 5 中展示了每个块选择的提示。
门控架构。我们的 DMP 框架引入了动态门控机制,选择针对每个去噪步骤(或噪声水平)专门设计的混合提示。我们在表 4a 中比较了两种门控架构的影响:线性门控和注意力门控。
- 线性门控使用单个线性层,以时间步嵌入为输入,生成每个可学习提示的软加权掩码。
- 注意力门控使用注意力层【59】,将可学习提示作为查询,将时间步嵌入作为键和值,生成直接添加到潜在向量的加权提示。
- 比较两种门控架构后发现,简单的线性门控实现了更好的 FID(5.87),相比之下,更复杂的注意力门控(6.41)反而比基线(6.27)更差。因此,我们采用线性门控作为默认设置。
门控类型。在表 4b 中,我们评估了创建混合提示的两种设计选择:硬选择和软选择。
- 硬选择基于门控函数输出概率选择 top-k 提示,以权重 1 直接使用这些提示,这种方法明确区分了每个时间步的提示组合。我们设定 k=192(占总提示的 75%)。
- 软选择使用所有提示,但为每个提示分配不同的权重。
- 软选择进一步提高了性能,而硬选择的性能实际上比预训练模型更差。因此,我们将软选择设为默认设置。
提示选择。默认情况下,DMP 在扩散模型的每个块的输入空间插入可学习提示。在此背景下有两种选择:
- 统一:方程(5)中的门控函数 G 仅接收时间步嵌入 t 作为输入,并对每个块的提示应用输出权重,因此所有块的提示选择一致。
- 不同:G 不仅以 t,还以当前块的深度 i 作为输入,为每个块生成不同的权重。
- 如表 4c 所示,使用不同的提示选择能提升性能。因此,我们将时间步嵌入和当前块深度信息结合到门控函数中,使得在不同块深度使用不同的提示组合成为默认设置。
提示位置。以前的提示调优方法【22, 29, 31, 61, 71】通常将可学习提示预先添加到图像标记(token)中。我们的方法是直接将提示逐元素添加到图像标记中,从而保持输入序列长度。表 4d 比较了将提示插入输入空间的不同选择及其对性能的影响。我们比较了两种方法:prepend 和 add。对于 “add”,我们使用 256 个提示以匹配图像标记的数量;对于 “prepend”,我们使用 50 个提示。尽管理想情况下,“prepend” 也应使用 256 个标记以公平比较,但我们经验上发现,prepend 256 个提示标记会导致扩散模型发散,即使两种方法都同样初始化为零。因此,我们将 “prepend” 限制为 50 个标记。
提示平衡。提示平衡损失作为对门控函数的软约束,帮助减轻在生成混合提示时的提示选择偏差。我们研究了两种平衡损失的效果:重要性平衡损失和负载平衡损失。如表 4e 所示,同时使用两种损失使扩散模型达到最佳性能。单独使用时,每种损失对性能略有提升,但结合使用时,它们表现出互补效果,整体性能提升。这表明,平衡在不同时间步上激活提示的权重和激活提示的数量对于创建有效的混合提示至关重要。因此,我们采用重要性平衡损失和负载平衡损失来进行提示平衡。
4.3 分析
提示激活。门控函数在动态地从可学习提示集合中构建混合提示中起着关键作用,其基于输入中存在的噪声水平。如图 5 所示,使用颜色进行视觉突出显示的激活情况。随着去噪过程的进行,不同时间步中提示的选择呈现显著变化。在噪声水平较高的时间步中,门控函数倾向于利用更广泛的提示组合。相反,在噪声较低的时间步中,随着噪声的减少,提示变得更加专业化,更专注于需要更密切关注的输入特定特征。这种提示的战略部署使得模型能够在每个去噪步骤中形成专业化的 “专家”。这种灵活性使得模型能够根据当前的噪声水平调整其方法,有效地利用不同的提示来提高其整体性能。使用针对不同噪声水平定制的提示混合物确保了图像生成过程既灵活又精确,满足了每一步输入噪声特性所要求的具体需求。
门控条件。在类条件图像生成和文本到图像生成任务中,条件引导在决定生成图像结果方面起着至关重要的作用。为了研究条件引导对选择提示混合物的影响,我们评估了两种情况的性能:一种情况下,门控函数 G 仅接收时间步嵌入 t,另一种情况下,它接收 t 和类别或文本条件嵌入 c。在后一种情况下,我们修改了门控函数在方程(5)中的输入,从 t 变为 t+c。我们在第 4.2 节中的分析表明,在 ImageNet 数据集上,两种方法的性能相当,而在 COCO 数据集上,使用 t + c 相较于仅使用 t 会产生更好的性能。这表明,结合条件引导有助于确定如何在每个去噪步骤中组合提示,以生成与文本条件相匹配的图像。因此,我们在文本到图像任务中将 t + c 作为门控函数的默认输入。
定性分析。图 6 呈现了三种方法之间的视觉比较:基线 DiT 模型、提示调整 [29] 和我们的 DMP。这些方法是在 FFHQ [25] 和 COCO [33] 数据集上进行无条件、文本到图像生成任务的评估。总的来说,DMP 展示出更具逼真和自然的图像,并且具有更少的伪影。我们在附录 E 中提供了额外的定性结果。
5 结论
在这项研究中,我们介绍了 Diffusion Model Patching(DMP),这是一种轻量级且计算效率高的方法,用于增强已经收敛的预训练扩散模型。大多数预训练扩散模型在其模型设计中没有明确地整合去噪阶段特异性。鉴于去噪阶段的特征略有不同,我们的目标是修改模型以了解这些差异。关键思想是将时间步特定的可学习提示集成到每个时间步的模型行为中。为了实现这一目标,DMP 利用动态门控,根据去噪过程中的当前时间步(或噪声水平)自适应地组合提示。这使得我们的方法能够有效地扩展到成千上万个去噪步骤。我们的结果表明,虽然在原始数据集上对模型进行微调不会带来进一步的改进,但 DMP 独特地提高了性能。当应用于 DiT-XL 骨干时,DMP 在 FFHQ 256×256 上仅增加了 1.43% 的额外参数,在 50K 次迭代中产生了显著的 FID 增益达到 10.38%。值得注意的是,DMP 的有效性跨越不同的模型大小和各种图像生成任务。
局限性。我们的 DMP 方法采用提示添加(prompt-adding)策略以确保稳定训练并保持采样速度。然而,由于输入块的数量是固定的,提示数量的灵活性受到限制。在保持稳定训练的同时,将我们的 DMP 扩展到使用 prepend 方法是一个有趣的未来方向。