8. BBDM: Image-to-Image Translation with Brownian Bridge Diffusion Models
本文提出一种基于布朗桥(Brownian Bridge)的扩散模型用于图像到图像的转换。图像到图像转换的目标是将源域 A A A中的图像 I A I_A IA,映射到目标域 B B B中得到图像 I B I_B IB。在一般的扩散模型中(如DDPM),是从目标域 B B B中采集样本作为起点 x 0 x_0 x0对其进行扩散,得到纯噪声 x T x_T xT;然后,再从纯噪声中采样进行反向去噪,生成目标图像 x 0 {x}_0 x0。为了实现图像到图像的转换,一般会将参考图像作为条件 y y y,引入到生成过程中,噪声估计网络 ϵ θ \epsilon_{\theta} ϵθ同时根据前一步的结果 x t x_t xt,时刻 t t t和条件 y y y来估计噪声,进而得到新的去噪结果 x t − 1 x_{t-1} xt−1,如下图A所示。
不同于一般的扩散模型,其扩散过程只依赖于起始点 x 0 x_0 x0,布朗桥扩散过程同时依赖起点 x 0 x_0 x0和终点 x T x_T xT,其数学表达如下 p ( x t ∣ x 0 , x T ) = N ( ( 1 − t T ) x 0 + t T x T , t ( T − t ) T I ) (8-1) p\left(\boldsymbol{x}_{t} \mid \boldsymbol{x}_{0}, \boldsymbol{x}_{T}\right)=\mathcal{N}\left(\left(1-\frac{t}{T}\right) \boldsymbol{x}_{0}+\frac{t}{T} \boldsymbol{x}_{T}, \frac{t(T-t)}{T} \boldsymbol{I}\right)\tag{8-1} p(xt∣x0,xT)=N((1−Tt)x0+TtxT,Tt(T−t)I)(8-1)基于此,作者将条件 y y y取代纯噪声作为终点 x T x_T xT,然后从条件 y y y开始进行反向去噪得到目标图像 x 0 {x}_0 x0。值得注意的是,在生成过程中,条件 y y y只作为起点,而不作为噪声估计网络 ϵ θ \epsilon_{\theta} ϵθ的条件,如上图B所示。
为了提升学习的效率和泛化能力,作者在浅层空间中完成扩散和重建过程,而不是在图像空间中,作者先利用VQGAN的编码器将图像 I A I_A IA映射到潜在空间中 L A L_A LA,经过扩散和重建后得到目标域的潜在特征 L A → B L_{A\rightarrow B} LA→B,最后再利用VQGAN的解码器恢复得到图像 I A → B I_{A\rightarrow B} IA→B。
这篇文章我读着很迷惑,从源域转换到目标域,那么根据上图的表示源域应该是真实图片,目标域是漫画图像,那么所谓的条件也就是参考图像 y y y应该是来自于源域啊。为什么文章中又说从目标域 B B B中采样得到 y y y呢?而且前文一直在讲,把 y y y作为前向扩散过程的终点和反向去噪过程的起点,那为什么上图灰色区域中前向扩散的终点是目标域的图像呢?不知道是我自己的理解问题,还是作者本身的写作有误。下文会按照我自己的理解来写,可能会与原文有一点点微弱的出入。
分别从源域 A A A和目标域 B B B中采集成对的样本 ( y , x ) (y,x) (y,x),经过VQGAN的编码器处理后得到对应的特征向量 y , x \boldsymbol{y,x} y,x,则布朗桥前向扩散过程可写为 q B B ( x t ∣ x 0 , y ) = N ( x t ; ( 1 − m t ) x 0 + m t y , δ t I ) (8-2) q_{B B}\left(\boldsymbol{x}_{t} \mid \boldsymbol{x}_{0}, \boldsymbol{y}\right)=\mathcal{N}\left(\boldsymbol{x}_{t} ;\left(1-m_{t}\right) \boldsymbol{x}_{0}+m_{t} \boldsymbol{y}, \delta_{t} \boldsymbol{I}\right)\tag{8-2} qBB(xt∣x0,y)=N(xt;(1−mt)x0+mty,δtI)(8-2)其中 x 0 = x , m t = t T \boldsymbol{x}_{0}=\boldsymbol{x}, \quad m_{t}=\frac{t}{T} x0=x,mt=Tt T T T表示扩散过程的总步数,方差 δ t \delta_t δt定义为 δ t = 2 s ( m t − m t 2 ) (8-3) \delta_{t}=2 s\left(m_{t}-m_{t}^{2}\right)\tag{8-3} δt=2s(mt−mt2)(8-3)其中 s s s作为一个放缩系数,用于控制采样的多样性,默认值为1。这样的设置,保证了当 t = 0 t=0 t=0和 t = T t=T t=T时, δ t \delta_t δt都为0,而 x t x_t xt分别为 x 0 x_0 x0和 y y y,满足了前文所述的扩散的起点和终点。扩散过程中单步的转移公式如下 q B B ( x t ∣ x t − 1 , y ) = N ( x t ; 1 − m t 1 − m t − 1 x t − 1 + ( m t − 1 − m t 1 − m t − 1 m t − 1 ) y , δ t ∣ t − 1 I ) (8-4) q_{B B}\left(\boldsymbol{x}_{t} \mid \boldsymbol{x}_{t-1}, \boldsymbol{y}\right)=\mathcal{N}\left(\boldsymbol{x}_{t} ; \frac{1-m_{t}}{1-m_{t-1}} \boldsymbol{x}_{t-1}+\left(m_{t}-\frac{1-m_{t}}{1-m_{t-1}} m_{t-1}\right) \boldsymbol{y}, \delta_{t \mid t-1} \boldsymbol{I}\right) \tag{8-4} qBB(xt∣xt−1,y)=N(xt;1−mt−11−mtxt−1+(mt−1−mt−11−mtmt−1)y,δt∣t−1I)(8-4)其中 δ t ∣ t − 1 = δ t − δ t − 1 ( 1 − m t ) 2 ( 1 − m t − 1 ) 2 (8-5) \delta_{t \mid t-1}=\delta_{t}-\delta_{t-1} \frac{\left(1-m_{t}\right)^{2}}{\left(1-m_{t-1}\right)^{2}}\tag{8-5} δt∣t−1=δt−δt−1(1−mt−1)2(1−mt)2(8-5)
经过前向扩散过程,我们将目标域的图像 x 0 x_0 x0映射到源域中的 x T = y x_T=y xT=y,在接下来的反向去噪过程中,我们将从 y y y出发逐步去噪生成一个新的目标域图像 x 0 {x}_0 x0,单步的去噪过程如下 p θ ( x t − 1 ∣ x t , y ) = N ( x t − 1 ; μ θ ( x t , t ) , δ ~ t I ) (8-6) p_{\theta}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t}, \boldsymbol{y}\right)=\mathcal{N}\left(\boldsymbol{x}_{t-1} ; \boldsymbol{\mu}_{\theta}\left(\boldsymbol{x}_{t}, t\right), \tilde{\delta}_{t} \boldsymbol{I}\right)\tag{8-6} pθ(xt−1∣xt,y)=N(xt−1;μθ(xt,t),δ~tI)(8-6)其中均值 μ θ ( x t , t ) \boldsymbol{\mu}_{\theta}\left(\boldsymbol{x}_{t}, t\right) μθ(xt,t)是由一个神经网络根据 x t , t \boldsymbol{x}_{t}, t xt,t估计得到的,而方差 δ ~ t \tilde{\delta}_{t} δ~t则是一个无需学习的仅与 t t t有关的变量。那么下面的任务就是如何训练一个网络来估计均值 μ θ ( x t , t ) \boldsymbol{\mu}_{\theta}\left(\boldsymbol{x}_{t}, t\right) μθ(xt,t)了。与DDPM类似,作者也是给出一个了可变分下界的目标函数 E L B O = − E q ( D K L ( q B B ( x T ∣ x 0 , y ) ∥ p ( x T ∣ y ) ) + ∑ t = 2 T D K L ( q B B ( x t − 1 ∣ x t , x 0 , y ) ∥ p θ ( x t − 1 ∣ x t , y ) ) − log p θ ( x 0 ∣ x 1 , y ) ) (8-7) \begin{aligned} E L B O & =-\mathbb{E}_{q}\left(D_{K L}\left(q_{B B}\left(\boldsymbol{x}_{T} \mid \boldsymbol{x}_{0}, \boldsymbol{y}\right) \| p\left(\boldsymbol{x}_{T} \mid \boldsymbol{y}\right)\right)\right. \\ & +\sum_{t=2}^{T} D_{K L}\left(q_{B B}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t}, \boldsymbol{x}_{0}, \boldsymbol{y}\right) \| p_{\theta}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t}, \boldsymbol{y}\right)\right) \\ & \left.-\log p_{\theta}\left(\boldsymbol{x}_{0} \mid \boldsymbol{x}_{1}, \boldsymbol{y}\right)\right) \end{aligned}\tag{8-7} ELBO=−Eq(DKL(qBB(xT∣x0,y)∥p(xT∣y))+t=2∑TDKL(qBB(xt−1∣xt,x0,y)∥pθ(xt−1∣xt,y))−logpθ(x0∣x1,y))(8-7)其中第一项为常数,可以忽略。重点看第二项, q B B ( x t − 1 ∣ x t , x 0 , y ) q_{B B}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t}, \boldsymbol{x}_{0}, \boldsymbol{y}\right) qBB(xt−1∣xt,x0,y)根据贝叶斯理论可得 q B B ( x t − 1 ∣ x t , x 0 , y ) = q B B ( x t ∣ x t − 1 , y ) q B B ( x t − 1 ∣ x 0 , y ) q B B ( x t ∣ x 0 , y ) = N ( x t − 1 ; μ ~ t ( x t , x 0 , y ) , δ ~ t I ) (8-8) \begin{aligned} q_{B B}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t}, \boldsymbol{x}_{0}, \boldsymbol{y}\right) & =\frac{q_{B B}\left(\boldsymbol{x}_{t} \mid \boldsymbol{x}_{t-1}, \boldsymbol{y}\right) q_{B B}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{0}, \boldsymbol{y}\right)}{q_{B B}\left(\boldsymbol{x}_{t} \mid \boldsymbol{x}_{0}, \boldsymbol{y}\right)} \\& =\mathcal{N}\left(\boldsymbol{x}_{t-1} ; \tilde{\boldsymbol{\mu}}_{t}\left(\boldsymbol{x}_{t}, \boldsymbol{x}_{0}, \boldsymbol{y}\right), \tilde{\delta}_{t} \boldsymbol{I}\right) \end{aligned}\tag{8-8} qBB(xt−1∣xt,x0,y)=qBB(xt∣x0,y)qBB(xt∣xt−1,y)qBB(xt−1∣x0,y)=N(xt−1;μ~t(xt,x0,y),δ~tI)(8-8)其中均值 μ ~ t ( x t , x 0 , y ) \tilde{\boldsymbol{\mu}}_{t}\left(\boldsymbol{x}_{t}, \boldsymbol{x}_{0}, \boldsymbol{y}\right) μ~t(xt,x0,y)为 μ ~ t ( x t , x 0 , y ) = δ t − 1 δ t 1 − m t 1 − m t − 1 x t + ( 1 − m t − 1 ) δ t ∣ t − 1 δ t x 0 + ( m t − 1 − m t 1 − m t 1 − m t − 1 δ t − 1 δ t ) y (8-9) \begin{aligned} \tilde{\boldsymbol{\mu}}_{t}\left(\boldsymbol{x}_{t}, \boldsymbol{x}_{0}, \boldsymbol{y}\right) & =\frac{\delta_{t-1}}{\delta_{t}} \frac{1-m_{t}}{1-m_{t-1}} \boldsymbol{x}_{t} \\ & +\left(1-m_{t-1}\right) \frac{\delta_{t \mid t-1}}{\delta_{t}} \boldsymbol{x}_{0} \\ & +\left(m_{t-1}-m_{t} \frac{1-m_{t}}{1-m_{t-1}} \frac{\delta_{t-1}}{\delta_{t}}\right) \boldsymbol{y} \end{aligned}\tag{8-9} μ~t(xt,x0,y)=δtδt−11−mt−11−mtxt+(1−mt−1)δtδt∣t−1x0+(mt−1−mt1−mt−11−mtδtδt−1)y(8-9)方差 δ ~ t \tilde{\delta}_{t} δ~t为 δ ~ t = δ t ∣ t − 1 ⋅ δ t − 1 δ t (8-10) \tilde{\delta}_{t}=\frac{\delta_{t \mid t-1} \cdot \delta_{t-1}}{\delta_{t}}\tag{8-10} δ~t=δtδt∣t−1⋅δt−1(8-10)由于在推理过程中 x 0 x_0 x0是未知的,因此可以根据公式8-2由当前的 x t x_t xt反向估计一个 x ^ 0 \hat{x}_0 x^0,将其带入公式8-9中可得 δ ~ t = δ t ∣ t − 1 ⋅ δ t − 1 δ t μ ~ t ( x t , y ) = c x t x t + c y t y + c ϵ t ( m t ( y − x 0 ) + δ t ϵ ) (8-11) \tilde{\delta}_{t}=\frac{\delta_{t \mid t-1} \cdot \delta_{t-1}}{\delta_{t}}\tilde{\boldsymbol{\mu}}_{t}\left(\boldsymbol{x}_{t}, \boldsymbol{y}\right)=c_{x t} \boldsymbol{x}_{t}+c_{y t} \boldsymbol{y}+c_{\epsilon t}\left(m_{t}\left(\boldsymbol{y}-\boldsymbol{x}_{0}\right)+\sqrt{\delta_{t}} \boldsymbol{\epsilon}\right)\tag{8-11} δ~t=δtδt∣t−1⋅δt−1μ~t(xt,y)=cxtxt+cyty+cϵt(mt(y−x0)+δtϵ)(8-11)其中 c x t = δ t − 1 δ t 1 − m t 1 − m t − 1 + δ t ∣ t − 1 δ t ( 1 − m t − 1 ) c y t = m t − 1 − m t 1 − m t 1 − m t − 1 δ t − 1 δ t c ϵ t = ( 1 − m t − 1 ) δ t ∣ t − 1 δ t (8-12) \begin{array}{l} c_{x t}=\frac{\delta_{t-1}}{\delta_{t}} \frac{1-m_{t}}{1-m_{t-1}}+\frac{\delta_{t \mid t-1}}{\delta_{t}}\left(1-m_{t-1}\right) \\ c_{y t}=m_{t-1}-m_{t} \frac{1-m_{t}}{1-m_{t-1}} \frac{\delta_{t-1}}{\delta_{t}} \\ c_{\epsilon t}=\left(1-m_{t-1}\right) \frac{\delta_{t \mid t-1}}{\delta_{t}} \end{array}\tag{8-12} cxt=δtδt−11−mt−11−mt+δtδt∣t−1(1−mt−1)cyt=mt−1−mt1−mt−11−mtδtδt−1cϵt=(1−mt−1)δtδt∣t−1(8-12)与DDPM中一样,作者不直接预测均值 μ ~ t \tilde{\mu}_t μ~t,而是对其中的噪声 ϵ \epsilon ϵ进行预测。 p θ ( x t − 1 ∣ x t , y ) p_{\theta}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t}, \boldsymbol{y}\right) pθ(xt−1∣xt,y)中的均值项 μ θ ( x t , t ) \boldsymbol{\mu}_{\theta}\left(\boldsymbol{x}_{t}, t\right) μθ(xt,t)可以重写为 x t , y \boldsymbol{x}_{t},\boldsymbol{y} xt,y和估计噪声 ϵ θ \epsilon_{\theta} ϵθ的线性组合 μ θ ( x t , y , t ) = c x t x t + c y t y + c ϵ t ϵ θ ( x t , t ) (8-13) \boldsymbol{\mu}_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t}, \boldsymbol{y}, t\right)=c_{x t} \boldsymbol{x}_{t}+c_{y t} \boldsymbol{y}+c_{\epsilon t} \boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{t}, t\right)\tag{8-13} μθ(xt,y,t)=cxtxt+cyty+cϵtϵθ(xt,t)(8-13)则目标函数 E L B O ELBO ELBO可以简化为 E x 0 , y , ϵ [ c ϵ t ∥ m t ( y − x 0 ) + δ t ϵ − ϵ θ ( x t , t ) ∥ 2 ] (8-14) \mathbb{E}_{\boldsymbol{x}_{0}, \boldsymbol{y}, \boldsymbol{\epsilon}}\left[c_{\epsilon t}\left\|m_{t}\left(\boldsymbol{y}-\boldsymbol{x}_{0}\right)+\sqrt{\delta_{t}} \boldsymbol{\epsilon}-\boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{t}, t\right)\right\|^{2}\right]\tag{8-14} Ex0,y,ϵ[cϵt mt(y−x0)+δtϵ−ϵθ(xt,t) 2](8-14)
完整的训练流程如下
经过训练得到噪声估计网络 ϵ θ ( x t , t ) \boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{t}, t\right) ϵθ(xt,t),就可以从源域中任意采样一个条件输入 y \boldsymbol{y} y作为生成的起点 x T \boldsymbol{x}_T xT,经过反向去噪得到生成结果 x 0 x_0 x0,如下所示
上述的采样过程也可以利用DDIM提出的加速技巧进行加速。整体上而言,BBDM就是将原本扩散过程从图像到噪声的变换,改成了从目标图像到源图像的变换。然后,在反向去噪时,只需给定一个源图像就能据此生成对应目标域中的样本。虽然不用像其他条件扩散模型那样,将条件引入模型中用于训练,但在BBDM的训练过程需要成对的样本,这限制了BBDM在许多情景中的应用。