无需向量量化的自回归图像生成

摘要

https://arxiv.org/pdf/2406.11838
传统观点认为,用于图像生成的自回归模型通常伴随着向量量化的标记。我们观察到,尽管离散值空间可以方便地表示分类分布,但它对于自回归建模来说并不是必需的。在这项工作中,我们提出使用扩散过程来建模每个标记的概率分布,这使得我们可以在连续值空间中应用自回归模型。我们定义了一个扩散损失函数来建模每个标记的概率,而不是使用分类交叉熵损失。这种方法消除了对离散值标记器的需求。我们在包括标准自回归模型和广义掩码自回归(MAR)变体在内的广泛案例中评估了其有效性。通过去除向量量化,我们的图像生成器在享受序列建模的速度优势的同时,取得了出色的结果。我们希望这项工作能够激发在其他连续值领域和应用中使用自回归生成的想法。

1、引言

自回归模型目前是自然语言处理中生成模型的公认解决方案[38,39,3]。这些模型基于之前的词作为输入来预测序列中的下一个词或标记。由于语言的离散性,这些模型的输入和输出都位于一个分类的、离散值空间中。这种普遍的方法导致了人们广泛认为自回归模型与离散表示法之间存在着固有的联系。

因此,关于将自回归模型推广到连续值域——尤其是图像生成——的研究,强烈地关注于数据的离散化[6, 13, 40]。一个常用的策略是在图像上训练一个离散值标记器,这涉及通过向量量化(VQ)获得的有限词汇表[51, 41]。然后,自回归模型在离散值标记空间上操作,类似于它们在语言处理中的对应物。

在这项工作中,我们旨在回答以下问题:自回归模型是否必须与向量量化表示相结合?我们注意到,自回归性质,即“基于先前的标记预测下一个标记”,与值是离散的还是连续的无关。所需的是建模每个标记的概率分布,这可以通过损失函数来衡量,并用于从中抽取样本。离散值表示可以通过分类分布来方便地建模,但从概念上讲,这并不是必需的。如果提供了替代模型来建模每个标记的概率分布,那么可以在不使用向量量化的情况下应用自回归模型。
在这里插入图片描述

基于这一观察,我们提出通过在连续值域上操作的扩散过程来建模每个标记的概率分布。我们的方法利用了扩散模型[45,24,33,10]的原理来表示任意概率分布。具体来说,我们的方法为每个标记自回归地预测一个向量 z z z,该向量作为去噪网络(例如,一个小型多层感知器MLP)的条件。去噪扩散过程使我们能够表示输出 x x x的潜在分布 p ( x ∣ z ) p(x \mid z) p(xz)(如图1所示)。这个小的去噪网络与自回归模型一起训练,以连续值标记作为输入和目标。从概念上讲,这个应用于每个标记的小预测头就像是一个损失函数,用于衡量 z z z的质量。我们将这个损失函数称为扩散损失(Diffusion Loss)。

我们的方法消除了对离散值标记器的需求。向量量化标记器难以训练,并且对梯度近似策略敏感[51, 41, 40, 27]。与连续值对应物相比,它们的重建质量通常较差[42]。我们的方法允许自回归模型享受更高质量、非量化标记器的优势。

为了拓宽范围,我们进一步将标准的自回归(AR)模型[13]和掩码生成模型[4,29]统一到一个广义的自回归框架中(图3)。从概念上讲,掩码生成模型以随机顺序同时预测多个输出标记,同时仍保持“基于已知标记预测下一个标记”的自回归性质。这导致了掩码自回归(MAR)模型,它可以无缝地与扩散损失(Diffusion Loss)一起使用。

我们通过实验展示了扩散损失在广泛情况下(包括AR和MAR模型)的有效性。它消除了对向量量化标记器的需求,并持续提高了生成质量。我们的损失函数可以灵活地与不同类型的标记器一起应用。此外,我们的方法享有序列模型快速速度的优势。我们的带有扩散损失的MAR模型可以在每秒不到0.3秒的速度下生成图像,同时在ImageNet 256 × 256 256 \times 256 256×256 上实现了低于2.0的FID分数。我们最好的模型可以达到接近1.55的FID分数。

我们方法的有效性揭示了图像生成中一个尚未充分探索的领域:通过自回归来建模标记之间的相互依赖关系,同时利用扩散来建模每个标记的分布。这与典型的潜在扩散模型[42,37]形成了对比,其中扩散过程建模了所有标记的联合分布。鉴于我们方法的有效性、速度和灵活性,我们希望扩散损失将推动自回归图像生成的发展,并在未来的研究中被推广到其他领域。

2、相关工作

序列模型用于图像生成。在自回归图像模型方面的开创性工作[17, 50, 49, 36, 7, 6]是在像素序列上进行的。自回归可以通过RNNs[50]、CNNs[49, 7]以及最近和最流行的Transformers[36, 6]来实现。受到语言模型的启发,另一系列工作[51, 41, 13, 40]将图像建模为离散值的标记。自回归[13, 40]和掩码生成模型[4, 29]可以在离散值标记空间上操作。但是,离散标记器难以训练,这最近引起了特别的关注[27, 54, 32]。

与我们工作相关的是,最近的GIVT工作[48]也关注序列模型中的连续值标记。GIVT和我们的工作都揭示了这一方向的重要性和潜力。在GIVT中,标记分布由高斯混合模型表示。它使用预定义数量的混合,这可能限制了它可以表示的分布类型。相比之下,我们的方法利用扩散过程建模任意分布的有效性。

扩散在表示学习中的应用。去噪扩散过程已被探索作为视觉自监督学习的标准。例如,DiffMAE[53]用去噪扩散解码器替换了原始MAE[21]中的L2损失;DARL[30]使用去噪扩散块解码器训练自回归模型。这些努力一直专注于表示学习,而不是图像生成。在他们的场景中,生成多样化的图像不是目标;这些方法还没有展示出从头开始生成新图像的能力。

扩散在策略学习中的应用。我们的工作在概念上与机器人学中的Diffusion Policy[8]相关。在那些场景中,执行动作的分布被制定为对机器人观测值的去噪过程,这些观测值可以是像素或潜在特征[8, 34]。在图像生成中,我们可以将生成一个标记视为要采取的“动作”。尽管存在这种概念上的联系,但在机器人学中,生成样本的多样性并不是像图像生成那样是一个核心考虑因素。

3、方法

简而言之,我们的图像生成方法是在标记化的潜在空间上操作的序列模型[6, 13, 40]。但与之前基于向量量化标记器(例如VQ-VAE的变体[51, 13])的方法不同,我们旨在使用连续值标记器(例如[42])。我们提出了扩散损失(Diffusion Loss),使序列模型与连续值标记兼容。

3.1、重新思考离散值标记

首先,我们重新审视自回归生成模型中离散值标记的角色。假设 x x x是在下一个位置要预测的真实标记。使用离散标记器, x x x可以表示为一个整数: 0 ≤ x < K 0 \leq x < K 0x<K,其中 K K K是词汇表的大小。自回归模型生成一个连续值的 D D D维向量 z ∈ R D z \in \mathbb{R}^{D} zRD,然后通过一个 K K K路分类器矩阵 W ∈ R K × D W \in \mathbb{R}^{K \times D} WRK×D进行投影。从概念上讲,这种公式化形式将分类概率分布建模为 p ( x ∣ z ) = softmax ⁡ ( W z ) p(x \mid z) = \operatorname{softmax}(W z) p(xz)=softmax(Wz)

在生成建模的上下文中,这个概率分布必须表现出两个基本属性。(i) 一个能够衡量估计分布和真实分布之间差异的损失函数。在分类分布的情况下,这可以通过交叉熵损失简单地实现。(ii) 一个在推理时可以从分布 x ∼ p ( x ∣ z ) x \sim p(x \mid z) xp(xz)中抽取样本的采样器。在分类分布的情况下,这通常通过从 p ( x ∣ z ) = softmax ⁡ ( W z / τ ) p(x \mid z) = \operatorname{softmax}(W z / \tau) p(xz)=softmax(Wz/τ)中抽取样本来实现,其中 τ \tau τ是一个控制样本多样性的温度参数。从分类分布中采样可以通过Gumbel-max方法[18]或逆变换采样来实现。

这种分析表明,离散值标记对于自回归模型来说并不是必要的。相反,建模分布的要求才是本质。离散值标记空间暗示了一个分类分布,其损失函数和采样器定义起来很简单。我们真正需要的是用于分布建模的损失函数及其对应的采样器。

3.2、扩散损失

去噪扩散模型[24]提供了一个有效的框架来建模任意分布。但与常见的用于表示所有像素或所有标记的联合分布的扩散模型用法不同,在我们的案例中,扩散模型是用于表示每个标记的分布。

考虑一个连续值向量 x ∈ R d x \in \mathbb{R}^{d} xRd,它表示要在下一个位置预测的真实标记。自回归模型在该位置生成一个向量 z ∈ R D z \in \mathbb{R}^{D} zRD。我们的目标是建模条件分布 p ( x ∣ z ) p(x \mid z) p(xz),即 x x x在给定 z z z时的概率分布。损失函数和采样器可以根据扩散模型[24, 33, 10]来定义,具体描述如下。

损失函数。根据[24, 33, 10],底层概率分布 p ( x ∣ z ) p(x \mid z) p(xz)的损失函数可以制定为一个去噪准则:

L ( z , x ) = E ε , t [ ∥ ε − ε θ ( x t ∣ t , z ) ∥ 2 ] \mathcal{L}(z, x) = \mathbb{E}_{\varepsilon, t}\left[\left\|\varepsilon - \varepsilon_{\theta}\left(x_{t} \mid t, z\right)\right\|^{2}\right] L(z,x)=Eε,t[εεθ(xtt,z)2]

在这里, ε ∈ R d \varepsilon \in \mathbb{R}^{d} εRd 是一个从 N ( 0 , I ) \mathcal{N}(\mathbf{0}, \mathbf{I}) N(0,I) 采样得到的噪声向量。噪声干扰后的向量 x t x_{t} xt x t = α ˉ t x + 1 − α ˉ t ε x_{t} = \sqrt{\bar{\alpha}_{t}} x + \sqrt{1 - \bar{\alpha}_{t}} \varepsilon xt=αˉt x+1αˉt ε,其中 α ˉ t \bar{\alpha}_{t} αˉt 定义了一个噪声时间表[24, 33]。 t t t 是噪声时间表的时间步。噪声估计器 ε θ \varepsilon_{\theta} εθ,由参数 θ \theta θ 参数化,是一个小型多层感知机(MLP)网络(参见第4节)。符号 ε θ ( x t ∣ t , z ) \varepsilon_{\theta}\left(x_{t} \mid t, z\right) εθ(xtt,z) 意味着这个网络以 x t x_{t} xt 作为输入,并且同时依赖于 t t t z z z。根据[46, 47],等式(1)在概念上类似于一种得分匹配:它与 p ( x ∣ z ) p(x \mid z) p(xz) 的得分函数相关的损失函数有关,即 ∇ log ⁡ x p ( x ∣ z ) \nabla \log_{x} p(x \mid z) logxp(xz)。扩散损失是一种参数化的损失函数,与对抗性损失[15]或感知损失[56]类似。

值得注意的是,条件向量 z z z 是由自回归网络产生的: z = f ( ⋅ ) z = f(\cdot) z=f(),我们稍后会讨论这一点。 z = f ( ⋅ ) z = f(\cdot) z=f() 的梯度是从等式(1)中的损失函数传播的。在概念上,等式(1)定义了一个用于训练网络 f ( ⋅ ) f(\cdot) f() 的损失函数。

我们注意到等式(1)中的期望 E ε , t [ ⋅ ] \mathbb{E}_{\varepsilon, t}[\cdot] Eε,t[] 是针对 t t t 的,对于任何给定的 z z z。由于我们的去噪网络较小,我们可以对任何给定的 z z z 多次采样 t t t。这有助于改善损失函数的利用,而无需重新计算 z z z。在训练过程中,我们为每个图像在每次迭代中采样 t t t 四次。

采样器。在推理时,需要从分布 p ( x ∣ z ) p(x \mid z) p(xz) 中抽取样本。采样是通过反向扩散过程[24]完成的: x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ε θ ( x t ∣ t , z ) ) + σ t δ x_{t-1} = \frac{1}{\sqrt{\alpha_{t}}}\left(x_{t} - \frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \varepsilon_{\theta}\left(x_{t} \mid t, z\right)\right) + \sigma_{t} \delta xt1=αt 1(xt1αˉt 1αtεθ(xtt,z))+σtδ。这里 δ \delta δ 是从高斯分布 N ( 0 , I ) \mathcal{N}(\mathbf{0}, \mathbf{I}) N(0,I) 中采样的,而 σ t \sigma_{t} σt 是在时间步 t t t 的噪声水平。从 x T ∼ N ( 0 , I ) x_{T} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) xTN(0,I) 开始,这个过程生成一个样本 x 0 x_{0} x0,使得 x 0 ∼ p ( x ∣ z ) x_{0} \sim p(x \mid z) x0p(xz)[24]。

当使用分类分布(第3.1节)时,自回归模型可以享受通过温度 τ \tau τ 控制样本多样性的好处。事实上,无论是语言还是图像方面的现有文献都表明,温度在自回归生成中起着关键作用。希望扩散采样器能提供一个温度的对应物。我们采用了[10]中提出的温度采样。在概念上,给定温度 τ \tau τ,人们可能希望从(重新归一化)概率 p ( x ∣ z ) 1 τ p(x \mid z)^{\frac{1}{\tau}} p(xz)τ1 中采样,其得分函数是 1 τ ∇ log ⁡ x p ( x ∣ z ) \frac{1}{\tau} \nabla \log _{x} p(x \mid z) τ1logxp(xz)。在实践中,[10]建议要么将 ε θ \varepsilon_{\theta} εθ 除以 τ \tau τ,要么将噪声乘以 τ \tau τ。我们采用了后者:在采样器中,我们将 σ t δ \sigma_{t} \delta σtδ 乘以 τ \tau τ。直观地讲, τ \tau τ 通过调整噪声方差来控制样本多样性。

3.3、自回归模型的扩散损失

接下来,我们描述用于图像生成的自回归模型及其扩散损失。给定一个标记序列 { x 1 , x 2 , … , x n } \left\{x^{1}, x^{2}, \ldots, x^{n}\right\} {x1,x2,,xn},其中上标 1 ≤ i ≤ n 1 \leq i \leq n 1in 指定了顺序,自回归模型[17,50,49,36,7,6]将生成问题表述为“下一个标记预测”:

p ( x 1 , … , x n ) = ∏ i = 1 n p ( x i ∣ x 1 , … , x i − 1 ) p\left(x^{1}, \ldots, x^{n}\right)=\prod_{i=1}^{n} p\left(x^{i} \mid x^{1}, \ldots, x^{i-1}\right) p(x1,,xn)=i=1np(xix1,,xi1)

使用一个网络来表示条件概率 p ( x i ∣ x 1 , … , x i − 1 ) p\left(x^{i} \mid x^{1}, \ldots, x^{i-1}\right) p(xix1,,xi1)。在我们的案例中, x i x^{i} xi 可以是连续值。我们可以将这个表述重写为两个部分。首先,我们通过一个网络(例如Transformer[52])对之前的标记进行操作来生成一个条件向量 z i z^{i} zi z i = f ( x 1 , … , x i − 1 ) z^{i}=f\left(x^{1}, \ldots, x^{i-1}\right) zi=f(x1,,xi1)。然后,我们通过 p ( x i ∣ z i ) p\left(x^{i} \mid z^{i}\right) p(xizi) 来建模下一个标记的概率。等式(1)中的扩散损失可以应用于 p ( x i ∣ z i ) p\left(x^{i} \mid z^{i}\right) p(xizi)。梯度将反向传播到 z i z^{i} zi 以更新 f ( ⋅ ) f(\cdot) f() 的参数。

3.4、统一自回归和掩码生成模型

我们展示了掩码生成模型,如MaskGIT[4]和MAGE[29],可以在自回归的广泛概念下被泛化,即下一个标记的预测。

双向注意力可以执行自回归。自回归的概念与网络架构是正交的:自回归可以通过RNNs[50]、CNNs[49, 7]和Transformers[38, 36, 6]来实现。当使用Transformers时,尽管自回归模型通常通过因果注意力来实现,但我们展示了它们也可以通过双向注意力来实现。参见图2。请注意,自回归的目标是根据先前的标记预测下一个标记;它并不限制先前的标记如何与下一个标记通信。

我们可以采用在Masked Autoencoder(MAE)[21]中使用的双向注意力实现。参见图2(b)。具体来说,我们首先对已知标记(带有位置嵌入[52])应用MAE风格的编码器 1 {}^{1} 1。然后,我们将编码后的序列与掩码标记(再次添加位置嵌入)进行连接,并使用MAE风格的解码器对该序列进行映射。掩码标记上的位置嵌入可以让解码器知道要预测哪些位置。与因果注意力不同,这里的损失仅计算在未知标记上[21]。

通过使用MAE风格的技巧,我们允许所有已知标记相互可见,同时也允许所有未知标记看到所有已知标记。这种全注意力引入了在标记之间比因果注意力更好的通信。在推理时,我们可以使用这种双向表述来生成标记(每步一个或多个),这是一种自回归的形式。作为折衷,我们不能使用因果注意力的键值(kv)缓存[44]来加速推理。但是,由于我们可以一起生成多个标记,我们可以减少生成步骤以加快推理速度。跨标记的全注意力可以显著提高质量,并提供更好的速度/准确度的权衡。

自回归模型在随机顺序中。为了与掩码生成模型[4, 29]相联系,我们考虑了随机顺序的自回归变体。模型会得到一个随机排列的序列。这个随机排列对于每个样本都是不同的。参见图3(b)。在这种情况下,需要预测的下一个标记的位置必须能够被模型访问。我们采用与MAE[21]类似的策略:我们在解码器层中添加位置嵌入(对应于未打乱的位置),这可以告诉模型要预测哪些位置。这种策略既适用于因果版本也适用于双向版本。

如图3(b)©所示,随机顺序的自回归表现得像一种特殊的掩码生成形式,其中一次生成一个标记。我们详细解释如下。

掩码自回归模型。在掩码生成建模[4, 29]中,模型基于已知/预测的标记预测一个随机的标记子集。这可以表述为通过随机顺序排列标记序列,然后基于先前的标记预测多个标记。参见图3©。从概念上讲,这是一个自回归过程,可以写为估计条件分布: p ( { x i , x i + 1 … , x j } ∣ x 1 , … , x i − 1 ) p\left(\left\{x^{i}, x^{i+1} \ldots, x^{j}\right\} \mid x^{1}, \ldots, x^{i-1}\right) p({xi,xi+1,xj}x1,,xi1),其中需要预测多个标记 { x i , x i + 1 … , x j } \left\{x^{i}, x^{i+1} \ldots, x^{j}\right\} {xi,xi+1,xj} i ≤ j i \leq j ij)。我们可以将这个自回归模型写为:

p ( x 1 , … , x n ) = p ( X 1 , … , X K ) = ∏ k K p ( X k ∣ X 1 , … , X k − 1 ) p\left(x^{1}, \ldots, x^{n}\right)=p\left(X^{1}, \ldots, X^{K}\right)=\prod_{k}^{K} p\left(X^{k} \mid X^{1}, \ldots, X^{k-1}\right) p(x1,,xn)=p(X1,,XK)=kKp(XkX1,,Xk1)

在这里, X k = { x i , x i + 1 … , x j } X^{k}=\left\{x^{i}, x^{i+1} \ldots, x^{j}\right\} Xk={xi,xi+1,xj}是在第 k k k步要预测的一组标记,其中 ∪ k X k = { x 1 , … , x n } \cup_{k} X^{k}=\left\{x^{1}, \ldots, x^{n}\right\} kXk={x1,,xn}。从这个意义上讲,这本质上是“下一组标记预测”,因此也是自回归的一种一般形式。我们将这种变体称为掩码自回归(MAR)模型。MAR是一种随机顺序的自回归模型,可以同时预测多个标记。

MAR在概念上与MAGE[29]相关。然而,MAR通过应用于每个标记的概率分布的温度 τ \tau τ来采样标记(这是像GPT这样的生成式语言模型中的标准做法)。相比之下,MAGE(遵循MaskGIT[4])应用一个温度来采样要预测的标记的位置:这不是一个完全随机的顺序,这会在训练时间和推理时间行为之间造成差异。

4、实现

本节描述了我们的实现。我们注意到,本文中介绍的概念是通用的,并不限于特定的实现。更详细的特定信息在附录B中。

4.1、扩散损失

扩散过程。我们的扩散过程遵循[33]。我们的噪声计划具有余弦形状,训练时有1000步;在推理时,它会重新采样为较少的步骤(默认情况下为100步)[33]。我们的去噪网络预测噪声向量 ε \varepsilon ε[24]。损失可以选择性地包括变分下界项 L v l b \mathcal{L}_{\mathrm{vlb}} Lvlb[33]。扩散损失自然支持无分类器指导(CFG)[23](详细见附录B)。

去噪MLP。我们使用一个小型的MLP,由几个残差块[20]组成,用于去噪。每个块依次应用LayerNorm(LN)[1]、一个线性层、SiLU激活函数[12]、和另一个线性层,并通过残差连接合并。默认情况下,我们使用3个块和1024个通道的宽度。去噪MLP的条件是AR/MAR模型产生的向量 z z z(见图1)。向量 z z z被添加到噪声计划时间步 t t t的时间嵌入中,该时间嵌入通过AdaLN[37]在LN层中作为MLP的条件。

4.2、自回归和掩码自回归图像生成

分词器。我们使用LDM[42]提供的公开可用的分词器。我们的实验将涉及他们的VQ-16和KL-16版本[42]。VQ-16是VQ-GAN[13],即带有GAN损失[15]和感知损失[56]的VQ-VAE[51];KL-16是通过Kullback-Leibler (KL)散度正则化的对应版本,没有向量量化。16表示分词器的步长。

Transformer。我们的架构遵循ViT[11]中的Transformer[52]实现。给定来自分词器的标记序列,我们添加位置嵌入[52]并附加类别标记[cls];然后我们用Transformer处理这个序列。默认情况下,我们的Transformer有32个块和1024的宽度,我们称之为Large大小或-\mathrm{L}(~4亿参数)。

自回归基线。因果注意力遵循GPT[38]的常见做法(图2(a))。输入序列通过一个标记(这里是[cls])进行移位。在注意力矩阵上应用三角形掩码[52]。在推理时,应用温度( τ \tau τ)采样。我们使用kv-cache[44]进行高效的推理。

掩码自回归模型。使用双向注意力(图2(b)),我们可以基于任意数量的已知标记预测任意数量的未知标记。在训练时,我们在[0.7,1.0]范围内随机采样一个掩码比率[21,4,29]:例如,0.7意味着70%的标记是未知的。由于采样序列可能非常短,我们总是在编码器序列的开头填充64个[cls]标记,这提高了我们编码的稳定性和容量。如图2所示,在解码器中引入了掩码标记[\mathrm{m}],并添加了位置嵌入。为了简单起见,与[21]不同,我们让编码器和解码器具有相同的大小:每个都有所有块的一半(例如,在MAR-L中为16个块)。

在推理时,MAR执行“下一组标记预测”。它使用余弦计划[4,29]逐步将掩码比率从1.0减少到0。默认情况下,我们在该计划中使用64步。应用温度( τ \tau τ)采样。与[4, 29]不同,MAR总是使用完全随机的顺序。

5、实验

我们在分辨率为 256 × 256 256 \times 256 256×256的ImageNet[9]上进行实验。我们评估FID[22]和IS[43],并提供精度和召回率作为参考,以遵循常见的实践[10]。我们遵循[10]提供的评估套件。

5.1、扩散损失的特性

扩散损失与交叉熵损失的比较。我们首先比较使用扩散损失的连续值标记与标准离散值标记的交叉熵损失(表1)。为了公平比较,我们从LDM代码库[42]下载了这两种分词器(“VQ-16”和“KL-16”)。这些分词器被广泛使用(例如,[13,42,37])。

如表1所示,在所有情况下,扩散损失都一致优于对应的交叉熵损失。具体来说,在MAR(例如默认设置)中,使用扩散损失可以将FID相对降低约50%~60%。这是因为连续值的KL-16比VQ-16具有更小的压缩损失(接下来在表2中讨论),而且扩散过程比分类过程更有效地建模分布。

在以下消融实验中,除非另有说明,我们遵循表1中的“默认”MAR设置。

扩散损失的灵活性。扩散损失的一个显著优势是它与各种分词器的灵活性。我们在表2中比较了几种公开可用的分词器。

即使给定VQ分词器,扩散损失也可以轻松使用。我们简单地将VQ层之前的连续值潜在作为标记。这种变体给出了7.82的FID(不使用CFG),与使用相同VQ分词器的交叉熵损失的8.79 FID(表1)相比表现良好。这表明扩散在建模分布方面的能力更强。

这种变体还使我们能够使用相同的损失来比较VQ-16和KL-16分词器。如表2所示,VQ-16的重建FID(rFID)远差于KL-16,这导致生成的FID也差得多(例如,表2中的7.82与3.50)。

有趣的是,扩散损失还允许我们使用步长不匹配的分词器。在表2中,我们研究了步长为8、输出序列长度为 32 × 32 32 \times 32 32×32的KL-8分词器。在不增加生成器序列长度的情况下,我们将 2 × 2 2 \times 2 2×2个标记组合成一个新标记。尽管存在不匹配,但我们仍然能够获得相当不错的结果,例如,KL-8给出了2.05的FID,与KL-16的1.98 FID相比。此外,这一特性允许我们研究其他分词器,例如Consistency Decoder[35],这是一个针对不同目标设计、具有不同架构/步长的非VQ分词器。

为了全面起见,我们还使用[42]的代码在ImageNet上训练了一个KL-16分词器,注意到[42]中原始的KL-16是在OpenImages[28]上训练的。比较结果列在表2的最后一行。在以下的探索中,我们使用这个分词器。

扩散损失中的去噪MLP。我们研究了表3中的去噪MLP。即使是一个非常小的MLP(例如,2M)也能带来有竞争力的结果。正如预期的那样,增加MLP的宽度有助于提高生成质量;我们已经探索了增加深度并得出了类似的观察结果。请注意,我们默认的MLP大小(1024宽度,21M)仅为MAR-L模型增加了约5%的额外参数。在推理时,扩散采样器有一个相当不错的成本,约占总体运行时间的10%。在我们的实现中,增加MLP的宽度具有可忽略的额外成本(表3),部分原因是主要的开销不在于计算而在于内存通信。

扩散损失的采样步数。我们的扩散过程遵循DDPM的常用做法[24,10]:我们使用1000步的噪声计划进行训练,但在推理时使用更少的步数。图4显示了在推理时使用100步扩散步骤足以达到较强的生成质量。

扩散损失的温度。在交叉熵损失的情况下,温度是至关重要的。扩散损失也提供了一个温度对应物来控制多样性和保真度。图5展示了在推理时扩散采样器中的温度 τ \tau τ(见第3.2节)的影响。在我们的模型中,温度 τ \tau τ起着重要作用,类似于基于交叉熵的对应物的观察结果(注意表1中的交叉熵结果是在其最佳温度下获得的)。

5.2、广义自回归模型的特性

从AR到MAR。表1还对比了AR/MAR的变种,接下来我们将讨论这些变种。首先,将AR中的栅格顺序替换为随机顺序带来了显著的增益,例如,将FID从19.23降低到13.07(无CFG)。接下来,将因果注意力替换为双向的对应物带来了另一个巨大的增益,例如,将FID从13.07降低到3.43(无CFG)。

随机顺序、双向的AR本质上是一种形式的MAR,它一次预测一个标记。在每一步中预测多个标记(’ >1 ')可以有效地减少自回归步骤的数量。在表1中,我们展示了具有64个步骤的MAR变种在生成质量上略有妥协。接下来将讨论更全面的权衡比较。

速度/准确度的权衡。跟随MaskGIT [4],我们的MAR享有一次性预测多个标记的灵活性。这通过在推理时控制自回归步骤的数量来实现。图6展示了速度/准确度的权衡。MAR与其AR对应物相比具有更好的权衡,注意到AR使用了高效的kv-cache。

使用扩散损失时,MAR与最近流行的Diffusion Transformer(DiT)[37]相比也显示出有利的权衡。作为一个潜在扩散模型,DiT通过扩散过程来建模所有标记之间的相互依赖关系。DiT的速度/准确度权衡主要由其扩散步骤控制。与我们在小型MLP上的扩散过程不同,DiT的扩散过程涉及整个Transformer架构。我们的方法更准确且更快。值得注意的是,我们的方法能够以每幅图像小于0.3秒的速度生成图像,同时保持FID小于2.0的强劲性能。

5.3、与先前系统的基准测试

我们在表4中与领先的系统进行了比较。我们探索了不同的模型大小(见附录B)并训练了800个周期。类似于自回归语言模型[3],我们观察到了令人鼓舞的扩展行为。进一步探索扩展可能会很有前景。关于指标,我们报告了不使用CFG时的FID为2.35,大幅超过了其他基于标记的方法。我们最佳的表现FID为1.55,与其他领先系统相比具有竞争力。图7展示了定性结果。

6、讨论与结论

Diffusion Loss在各种自回归模型上的有效性为新的机会指明了方向:通过自回归建模标记之间的相互依赖关系,同时结合扩散过程建模每个标记的分布。这与常见的扩散用法不同,后者建模所有标记的联合分布。我们在图像生成上的出色结果表明,自回归模型或其扩展是超越语言建模的有力工具。这些模型不需要被向量量化表示所限制。我们希望我们的工作能激励研究社区探索在其他领域中使用连续值表示的序列模型。

A、局限性与更广泛的影响

局限性。除了展示我们方法在图像生成方面的潜力外,本文也承认其局限性。

首先,我们的图像生成系统可能会产生带有明显瑕疵的图像(如图8所示)。这种局限性在现有方法中很常见,尤其是在使用受控的学术数据(如ImageNet)进行训练时。与在大量数据上训练的商业模型相比,基于ImageNet训练的研究驱动模型在视觉质量上仍有明显的差距。

其次,我们的图像生成系统依赖于现有的预训练标记器。我们系统的质量可能会受到这些标记器质量的限制。预训练更好的标记器超出了本文的范围。然而,我们希望我们的工作将为未来开发连续值标记器变得更加容易。

最后,我们注意到,由于计算资源有限,我们主要在ImageNet基准测试集上测试了我们的方法。为了评估我们的方法在更多样化和现实场景中的可扩展性和鲁棒性,还需要进一步的验证。

更广泛的影响。我们的主要目标是推动生成模型的基础研究,并相信这将对该领域产生积极影响。我们方法的直接应用之一是将其扩展到大型视觉生成模型,如文本到图像或文本到视频生成。我们的方法有可能显著降低这些大型模型的训练和推理成本。同时,我们的方法可能表明在许多应用中用Diffusion Loss替换传统损失函数的机会。从负面来看,我们的方法从训练数据集中学习统计信息,因此可能会反映数据中的偏差;图像生成系统可能会被滥用以生成虚假信息,这值得进一步考虑。

B、附加实现细节

分类器自由引导(CFG)。为了支持CFG [23],在训练时,对于10%的样本,类别条件被替换为一个虚拟类别标记[23]。在推理时,模型使用给定的类别标记和虚拟标记运行,产生两个输出 z c z_{c} zc z u z_{u} zu。然后,预测的噪声 ε \varepsilon ε被修改为[23]: ε = ε θ ( x t ∣ t , z u ) + ω ⋅ ( ε θ ( x t ∣ t , z c ) − ε θ ( x t ∣ t , z u ) ) \varepsilon=\varepsilon_{\theta}\left(x_{t} \mid t, z_{u}\right)+\omega \cdot\left(\varepsilon_{\theta}\left(x_{t} \mid t, z_{c}\right)-\varepsilon_{\theta}\left(x_{t} \mid t, z_{u}\right)\right) ε=εθ(xtt,zu)+ω(εθ(xtt,zc)εθ(xtt,zu)),其中 ω \omega ω是引导尺度。在推理时,我们遵循[5]的CFG计划。我们针对每个模型扫描最佳引导尺度和温度组合。

训练。默认情况下,模型使用AdamW优化器[31]训练400个周期。AdamW的权重衰减和动量分别为0.02和(0.9,0.95)。我们使用批处理大小为2048,学习率(lr)为 8 e − 4 8 \mathrm{e}-4 8e4。我们的带有Diffusion Loss的模型采用100个周期的线性lr预热[16],之后是恒定[37]的lr计划。交叉熵对应的模型则使用余弦lr计划,这对它们更有效。遵循[37, 25],我们使用0.9999的动量维持模型参数的指数移动平均值(EMA)。

表4中的实现细节。为了探索我们方法的扩展行为,我们研究了以下描述的三种模型大小。除了MAR-L,我们还探索了一个较小的模型(MAR-B)和一个较大的模型(MAR-H)。MAR-B、-L和-H分别具有24、32、40个Transformer块,以及768、1024和1280的宽度。在表4中特别地,去噪MLP分别具有6、8、12个块,以及1024、1280和1536的宽度。训练周期增加到800个周期。在推理时,我们运行256个自回归步骤以达到最佳结果。

Diffusion Loss的伪代码。见算法1。

Diffusion Loss 概念的伪代码。在这里,条件向量 z z z 是从 AR/MAR 模型中输出的。梯度将反向传播到 z z z。为了简化,这里我们省略了推理重调度、温度和变分下界损失项的代码[10],这些可以很容易地合并进来。

计算资源。我们的训练主要在 16 台服务器上完成,每台服务器配备了 8 个 V100 GPU。在这些 GPU 上训练一个 400 轮次的 MAR-L 模型大约需要 2.6 天。相比之下,在同一集群上训练相同轮次的 DiT-XL/2 和 LDM-4 模型分别需要 4.6 天和 9.5 天。

MAR 和 MAGE 的比较。MAR(无论使用何种损失)在概念上与 MAGE [29] 相关。除了实现上的差异(例如,架构特定性和超参数)之外,MAR 和 MAGE 在推理时的扫描顺序上存在主要的概念差异。在 MAGE 中,遵循 MaskGIT [4] 的方法,下一个要预测的令牌的位置是根据每个位置的样本置信度动态确定的,即每个步骤中更自信的位置更有可能被选中[4, 29]。相比之下,MAR 采用完全随机的顺序,并且其温度采样应用于每个令牌。表 5 在受控设置下比较了这种差异。第一行是我们的 MAR 实现,但使用了 MAGE 的即时排序策略,其结果与简单的随机排序相似。完全随机的排序可以使训练和推理过程在令牌顺序的分布上保持一致;它还允许我们以类似于自回归语言模型(例如,GPT [38,39,3])的方式对每个令牌进行温度采样。

D、额外比较

D.1、ImageNet 512 × 512 512 \times 512 512×512

与先前的工作一样,我们也报告了在 512 × 512 512 \times 512 512×512 分辨率下 ImageNet 的结果,并与领先的系统进行了比较(见表6)。为了简化,我们使用 KL-16 分词器,它在 512 × 512 512 \times 512 512×512 图像上给出 32 × 32 32 \times 32 32×32 的序列长度。其他设置遵循表4中描述的 MAR-L 配置。我们的方法在没有使用 CFG 的情况下达到了 FID 分数 2.74,使用 CFG 时达到了 1.73。我们的结果与先前系统的结果相竞争。由于资源有限,我们尚未在 ImageNet 512 × 512 512 \times 512 512×512 上训练更大的 MAR-H,预计会有更好的结果。

D.2、L2 损失与 Diffusion 损失

对于连续值令牌的一个简单基线是直接计算预测和目标令牌之间的均方误差(MSE,即 L2)损失。在栅格顺序的自回归模型的情况下,使用 L2 损失不引入随机性,因此无法生成多样化的样本。在使用 L2 损失的 MAR 模型中,唯一的随机性是序列顺序;对于任何给定的顺序,某个位置的预测是确定的。在我们的实验中,我们训练了一个使用 L2 损失的 MAR 模型,这导致 FID 分数灾难性地大于 100(>100)。

致谢

我们感谢 Congyue Deng 和 Xinlei Chen 的有益讨论。我们感谢 Google TPU 研究云(TRC)为我们提供 TPU 访问权限,以及 Google Cloud Platform 对 GPU 资源的支持。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/36988.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

战地战地风云最强的免费加速器 2024低延迟不卡顿加速器推荐

来喽来喽&#xff0c;steam夏季促销它又来喽&#xff0c;战地风云&#xff0c;第一人称射击游戏&#xff0c;而且这次迎来了史低&#xff0c;游戏背景设定为近未来&#xff08;公元2042年&#xff09;&#xff0c;会有动态的天气系统&#xff0c;以及改善后的破坏系统。该作为《…

开源模型应用落地-FastAPI-助力模型交互-WebSocket篇(三)

一、前言 使用 FastAPI 可以帮助我们更简单高效地部署 AI 交互业务。FastAPI 提供了快速构建 API 的能力,开发者可以轻松地定义模型需要的输入和输出格式,并编写好相应的业务逻辑。 FastAPI 的异步高性能架构,可以有效支持大量并发的预测请求,为用户提供流畅的交互体验。此外,F…

关于Mac mini 10G网口的问题

问题: 购入一个10G网口的Mac mini M2&#xff0c;将其和自己的2.5G交换机连接&#xff0c;使用共享屏幕进行远程操作的过程中出现了频率极高的卡顿&#xff0c;几乎是几秒钟卡一下&#xff0c;使用ping进行测试发现卡的时候就ping不通了。测试使用Mac mini的无线网和雷电转2.5G…

React Native 开发常见问题及注意事项

本文只是使用时积累的一些经验 开发环境 1、Android Studio 依赖项下载慢 如果发现依赖下载非常慢&#xff0c;动不动十几KB的 参考&#xff1a;加速 Android Studio 依赖项下载 也可以切换数据源 修改 android/build.gradle中的jcenter()和google() repositories {// goo…

人脑计算机技术与Neuroplatform:未来计算的革命性进展

引言 想象一下&#xff0c;你在某个清晨醒来&#xff0c;准备开始一天的工作&#xff0c;而实际上你的大脑正作为一台生物计算机的核心&#xff0c;处理着大量复杂的信息。这并非科幻电影的情节&#xff0c;而是人脑计算机技术即将带来的现实。本文将深入探讨FinalSpark公司的…

选择适合你的8款原型设计工具

随着互联网的飞速发展&#xff0c;设计行业逐渐成为近年来的热门职业。设计师们需要的掌握的技能也越来越多&#xff0c;例如海报设计、名片设计、产品设计、网页设计等。产品原型设计就是产品设计中非常重要的一个阶段&#xff0c;主要目的是帮助用户更容易了解产品设计的思路…

深度学习 —— 1.单一神经元

深度学习初级课程 1.单一神经元2.深度神经网络3.随机梯度下降法4.过拟合和欠拟合5.剪枝、批量标准化6.二分类 前言 本套课程仍为 kaggle 课程《Intro to Deep Learning》&#xff0c;仍按之前《机器学习》系列课程模式进行。前一系列《Keras入门教程》内容&#xff0c;与本系列…

【机器学习】Whisper:开源语音转文本(speech-to-text)大模型实战

目录 一、引言 二、Whisper 模型原理 2.1 模型架构 2.2 语音处理 2.3 文本处理 三、Whisper 模型实战 3.1 环境安装 3.2 模型下载 3.3 模型推理 3.4 完整代码 3.5 模型部署 四、总结 一、引言 上一篇对​​​​​​​ChatTTS文本转语音模型原理和实战进行了讲解&a…

【语义分割系列】基于cityscape的DDRNet算法

基于cityscape的DDRNet算法 前言 DDRNet是专门为实时语义分割设计的高效主干。该模型由两个深度分支组成,在这两个分支之间执行多次双边融合,并且还设计了一个新的上下文信息抽取器,名为深度聚合金字塔池模块(DAPPM),用于扩大有效的接受域,并基于低分辨率特征映射融合…

计算机网络——数据链路层(数据链路层概述及基本问题)

链路、数据链路和帧的概念 数据链路层在物理层提供服务的基础上向网络层提供服务&#xff0c;其主要作用是加强物理层传输原始比特流的功能&#xff0c;将物理层提供的可能出错的物理连接改造为逻辑上无差错的数据链路&#xff0c;使之对网络层表现为一条无差错的链路。 链路(…

Steam夏促史低游戏推荐 Steam夏促哪有游戏值得入手

steam夏季促销来袭&#xff0c;有这很多的游戏都进行打折出售&#xff0c;而且还有这很多的游戏都迎来了史低&#xff0c;简直是白送&#xff0c;很多玩家都想趁着这个时间入手自己喜欢的游戏&#xff0c;为了方便大家了解&#xff0c;下面我给大家带来steam夏季促销史低的游戏…

CO-DETR利用coco数据集训练和推理过程

CO-DETR利用coco数据集训练和推理过程&#xff0c;参考链接 Co-DETR训练自己的数据集 文章目录 前言训练过程推理过程总结 前言 环境&#xff1a;PyTorch 1.11.0 Python 3.8(ubuntu20.04) Cuda 11.3 先是在github上下载CO-DETR模型 !git clone https://github.com/Sense-X/Co…

陌陌笔试--并发打印文件内最有钱的老板的消费金额(算法)

题目&#xff1a; 算法中需要打印消费前十老板的消费金额&#xff0c;解决保留两位小数&#xff0c;并发是 JAVA 中的常考题&#xff0c; 我这里简单模拟下了数据&#xff0c;关键数据是用户id和消费金额。 解题思路&#xff1a; 1. 最简单的思路是单线程&#xff0c;偷懒…

狂神说Java之 rabbitmq高级分布式事务

分布式事务的完整架构图 案例场景分析 案例一&#xff1a;用RestTemplate演示&#xff08;不可靠生产&#xff0c;会出现问题&#xff09; 创建一个订单模块 创建一个OrderDataBaseService服务 创建一个order的service服务&#xff0c;调用saveOrder()方法 创建一个运单模块…

软件设计流程和开发流程及规范(Word)

2 过程总体描述 2.1 过程概述 2.2 过程流程图 3 过程元素描述 3.1 产品方案 3.2 产品设计 3.3 产品实现 获取方式&#xff1a;本文末个人名片直接获取。 软件资料清单列表部分文档清单&#xff1a;工作安排任务书&#xff0c;可行性分析报告&#xff0c;立项申请审批表&#x…

找不到vcomp140.dll怎么办,总结多种解决方法

​在日常使用电脑的过程中&#xff0c;我们可能会遇到一些错误提示&#xff0c;其中之一就是“vcomp140.dll丢失”。那么&#xff0c;vcomp140.dll是什么&#xff1f;它为什么会丢失&#xff1f;丢失后对电脑有什么影响&#xff1f;又该如何解决呢&#xff1f;本文将详细介绍vc…

根据肥胖类型选择减调方向收获窈窕身材

我们生活中胖子很多&#xff0c;从胖到瘦的人也不少&#xff0c;但瘦了后对自己身材满意的人却是不多的&#xff0c;很多人瘦了也只是减掉了身上的赘肉而已&#xff0c;大体的身形却是没有变化的&#xff0c;因此&#xff0c;并不感到满意。因为他们本身的形体是固定的&#xf…

SpringBoot-SpringBoot整合Swagger使用教程(图文介绍,一篇就够了)

前言 日常开发中&#xff0c;接口都是和开发文档相结合的。不论是和前端对接还是三方对接亦或者是接口留档&#xff0c;当我们开发完接口后&#xff0c;都需要去创建对应的接口文档。而修改接口后也要修改相对应的接口文档&#xff0c;但是这个真的很容易疏漏。而且相对于繁重的…

WEB攻防【6】——Python考点/CTF与CMS/SSTI模板注入/PYC反编译

#知识点 1、PYC文件反编译 2、python-web-SSTI 3、SSTI模板注入利用分析 SSTI 就是服务器端模板注入 &#xff08;Server-Side Template Injection&#xff09; 当前使用的一些框架&#xff0c;比如python的flask&#xff0c;php的tp&#xff0c;java的spring等一般都采用成…

存储管理(三):分区表

什么是分区表 假设存在表t&#xff1a; CREATETABLE t (ftimedatetime NOT NULL,c int(11) DEFAULT NULL,KEY (ftime) )ENGINEInnoDB DEFAULT CHARSETlatin1 PARTITION BY RANGE (YEAR(ftime)) (PARTITION p_2017 VALUES LESS THAN (2017) ENGINE InnoDB,PARTITION p_2018 VA…