符号表
符号 | 含义 |
---|---|
x ( i ) = z 0 ( i ) \boldsymbol{x}^{(i)}=\boldsymbol{z}_0^{(i)} x(i)=z0(i) | 第 i i i个训练数据,其为长度为 d d d的向量 |
z t ( i ) \boldsymbol{z}_t^{(i)} zt(i) | 第 i i i个训练数据在第 t t t时刻的加噪版本 |
ϵ t ( i ) \boldsymbol{\epsilon}_t^{(i)} ϵt(i) | 第 i i i个训练数据在第 t t t时刻所添加的高斯噪声 |
β t \beta_t βt | 噪声计划(noise schedule),范围为[0,1] |
α t \alpha_t αt | α t = ∏ s = 1 t ( 1 − β s ) \alpha_t=\prod_{s=1}^t (1-\beta_s) αt=∏s=1t(1−βs) |
N ( μ , σ 2 I ) N(\boldsymbol{\mu},\sigma^2\boldsymbol{I}) N(μ,σ2I) | 均值为 μ \boldsymbol{\mu} μ,标准差为 σ \sigma σ的高斯分布 |
q ( ⋅ ) q\left(·\right) q(⋅) | 正向过程的转移核 |
p ( ⋅ ) p(·) p(⋅) | 反向过程的转移核 |
p ( ⋅ ∣ θ ) p\left(·|\theta\right) p(⋅∣θ) | 受参数 θ \theta θ影响,用于拟合反向过程的真实概率密度函数 |
f t f_t ft | 反向过程中 t t t时刻对应的神经网络 |
g t g_t gt | 反向过程中 t t t时刻对应的神经网络 |
θ t \theta_t θt | 神经网络 f t f_t ft或 g t g_t gt的参数 |
θ 1 : T \theta_{1:T} θ1:T | θ 1 , θ 2 , ⋯ , θ T \theta_1,\theta_2,\cdots,\theta_T θ1,θ2,⋯,θT |
d z 1 : T d\boldsymbol{z}_{1:T} dz1:T | d z 1 d z 2 ⋯ d z T d\boldsymbol{z}_1d\boldsymbol{z}_2\cdots d\boldsymbol{z}_T dz1dz2⋯dzT |
注: | 如没有上标 ( i ) ^{(i)} (i),则表明在此语境下不特别指明对应某个样本 |
扩散模型的扩散过程(编码器)
扩散模型的编码器所做的工作如下:
设有原数据 x \boldsymbol{x} x,经过如下的逐步编码(添加噪声)过程可以得到一个符合标准高斯分布的噪声
z t = 1 − β t z t − 1 + β t ϵ t , t = 1 , 2 , ⋯ , T \begin{equation}\boldsymbol{z}_t=\sqrt{1-\beta_t}\boldsymbol{z}_{t-1}+\sqrt{\beta_t}\boldsymbol{\epsilon}_t, t=1,2,\cdots,T\end{equation} zt=1−βtzt−1+βtϵt,t=1,2,⋯,T
其中 z 0 = x \boldsymbol{z}_0=\boldsymbol{x} z0=x, ϵ 1 , ϵ 2 , ⋯ , ϵ t ∼ N ( 0 , I ) \boldsymbol{\epsilon}_1, \boldsymbol{\epsilon}_2,\cdots,\boldsymbol{\epsilon}_t\sim N(\boldsymbol{0}, \boldsymbol{I}) ϵ1,ϵ2,⋯,ϵt∼N(0,I), β 1 , β 2 , ⋯ , β t ∈ [ 0 , 1 ] \beta_1,\beta_2,\cdots,\beta_t\in [0,1] β1,β2,⋯,βt∈[0,1]为噪声计划(noise schedule),一般逐级递增。当 T → ∞ T\rightarrow \infty T→∞, z T \boldsymbol{z}_T zT将服从高斯分布,该推导在下面会涉及。
由于每一步的扩散结果 z t \boldsymbol{z}_t zt仅依赖于上一个扩散结果 z t − 1 \boldsymbol{z}_{t-1} zt−1,也即只要已知 z t − 1 \boldsymbol{z}_{t-1} zt−1(不需要再知道 z 1 , z 2 , ⋯ , z t − 2 \boldsymbol{z}_{1},\boldsymbol{z}_{2},\cdots,\boldsymbol{z}_{t-2} z1,z2,⋯,zt−2),再经过计算便可以得到 z t \boldsymbol{z}_{t} zt。该扩散特点符合马尔科夫链的性质,即每一时刻的状态仅依赖于上一时刻的状态,而与之前的状态无关。
现在用一个马尔科夫链表达该扩散过程。在 z t − 1 \boldsymbol{z}_{t-1} zt−1是已知的情况下, z t \boldsymbol{z}_t zt的均值
E [ z t ] = E [ 1 − β t z t − 1 ] + E [ β t ϵ t ] = 1 − β t z t − 1 + 0 = 1 − β t z t − 1 \begin{align}E[\boldsymbol{z}_t]&=E[\sqrt{1-\beta_t}\boldsymbol{z}_{t-1}]+E[\sqrt{\beta_t}\boldsymbol{\epsilon}_t]\\&=\sqrt{1-\beta_t}\boldsymbol{z}_{t-1}+\boldsymbol{0}\\&=\sqrt{1-\beta_t}\boldsymbol{z}_{t-1}\end{align} E[zt]=E[1−βtzt−1]+E[βtϵt]=1−βtzt−1+0=1−βtzt−1
z t \boldsymbol{z}_t zt的方差
C o v [ z t ] = C o v [ 1 − β t z t − 1 ] + C o v [ β t ϵ t ] = 0 + ( β t ) 2 I = β t I \begin{align}Cov[\boldsymbol{z}_t]&=Cov[\sqrt{1-\beta_t}\boldsymbol{z}_{t-1}]+Cov[\sqrt{\beta_t}\boldsymbol{\epsilon}_t]\\ &= \boldsymbol{0}+(\sqrt{\beta_t})^2\boldsymbol{I}\\ &= \beta_t\boldsymbol{I}\end{align} Cov[zt]=Cov[1−βtzt−1]+Cov[βtϵt]=0+(βt)2I=βtI
以上推导源自于:① z t − 1 \boldsymbol{z}_{t-1} zt−1是已知的,它不是分布,而是常量② ϵ t \boldsymbol{\epsilon}_t ϵt是标准的高斯分布③若 x ∼ N ( m x , Σ x ) \boldsymbol{x}\sim N(\boldsymbol{m}_{x},\boldsymbol{\Sigma}_{x}) x∼N(mx,Σx), y ∼ N ( m y , Σ y ) \boldsymbol{y}\sim N(\boldsymbol{m}_y,\boldsymbol{\Sigma}_y) y∼N(my,Σy),则 A x + B y + c ∼ N ( A m x + B m y + c , A Σ x A T + B Σ y B T ) \boldsymbol{Ax}+\boldsymbol{By}+\boldsymbol{c}\sim N(\boldsymbol{Am}_x+\boldsymbol{Bm}_y+\boldsymbol{c},\boldsymbol{A\Sigma_{x}A}^T+\boldsymbol{B\Sigma_{y}B}^T) Ax+By+c∼N(Amx+Bmy+c,AΣxAT+BΣyBT)
根据前面的分析,在已知 z t − 1 \boldsymbol{z}_{t-1} zt−1的情况下, z t \boldsymbol{z}_t zt的概率分布,即转移核的表达式如下:
q ( z t ∣ z t − 1 ) = N ( 1 − β t z t − 1 , β t I ) = 1 ( 2 π ) d 2 β t exp ( − ( z t − 1 − β t z t − 1 ) 2 2 β t ) \begin{equation}q(\boldsymbol{z}_t|\boldsymbol{z}_{t-1})=N(\sqrt{1-\beta_t}\boldsymbol{z}_{t-1},\beta_t\boldsymbol{I})=\frac{1}{(2\pi)^{\frac{d}{2}}\sqrt{\beta_t}}\exp{\left(-\frac{(\boldsymbol{z}_{t}-\sqrt{1-\beta_t}\boldsymbol{z}_{t-1})^2}{2\beta_t}\right)}\end{equation} q(zt∣zt−1)=N(1−βtzt−1,βtI)=(2π)2dβt1exp(−2βt(zt−1−βtzt−1)2)
该表达式使用了多元高斯分布的定义,即若随机变量 X = [ X 1 ⋯ X n ] T X=\begin{bmatrix}X_1\cdots X_n\end{bmatrix}^T X=[X1⋯Xn]T 服从均值为 μ ∈ R n \boldsymbol{\mu}\in\mathbb{R}^n μ∈Rn ,协方差为 Σ ∈ S + + n \boldsymbol{\Sigma}\in\mathbb{S}_{++}^n Σ∈S++n 的多元高斯分布,则其概率密度函数为:
1 ( 2 π ) n / 2 ∣ Σ ∣ 1 / 2 exp ( − 1 2 ( x − μ ) T Σ − 1 ( x − μ ) ) . \begin{aligned}\frac{1}{(2\pi)^{n/2}|\boldsymbol{\Sigma}|^{1/2}}\exp\left(-\frac{1}{2}(\boldsymbol{x}-\boldsymbol{\mu})^T\boldsymbol{\Sigma}^{-1}(\boldsymbol{x}-\boldsymbol{\mu})\right).\end{aligned} (2π)n/2∣Σ∣1/21exp(−21(x−μ)TΣ−1(x−μ)).
因此,在已知 x \boldsymbol{x} x的情况下,将通过 q ( z 1 ∣ x ) q(\boldsymbol{z}_1|\boldsymbol{x}) q(z1∣x)采样得到 z 1 \boldsymbol{z}_1 z1;则 z 1 \boldsymbol{z}_1 z1变为已知,再通过 q ( z 2 ∣ z 1 ) q(\boldsymbol{z}_2|\boldsymbol{z}_1) q(z2∣z1)采样得到 z 2 \boldsymbol{z}_2 z2,类似地递推,最后得到 x T \boldsymbol{x}_T xT。当 T T T非常大的时候,该过程十分耗时,但可以将 z t \boldsymbol{z}_t zt中的 z t − 1 \boldsymbol{z}_{t-1} zt−1逐层次替换为 x \boldsymbol{x} x的表达式,得到
z t = 1 − β t z t − 1 + β t ϵ t = 1 − β t ( 1 − β t − 1 z t − 2 + β t − 1 ϵ t − 1 ) + β t ϵ t = ( 1 − β t ) ( 1 − β t − 1 ) z t − 2 + 1 − β t − ( 1 − β t ) ( 1 − β t − 1 ) ϵ t − 1 + β t ϵ t \begin{aligned} \boldsymbol{z}_{t}& =\sqrt{1-\beta_t}\boldsymbol{z}_{t-1}+\sqrt{\beta_t}\boldsymbol{\epsilon}_{t} \\ &=\sqrt{1-\beta_t}\left(\sqrt{1-\beta_{t-1}}\boldsymbol{z}_{t-2}+\sqrt{\beta_{t-1}}\boldsymbol{\epsilon}_{t-1}\right)+\sqrt{\beta_t}\boldsymbol{\epsilon}_{t} \\ &=\sqrt{(1-\beta_t)(1-\beta_{t-1})}\boldsymbol{z}_{t-2}+\sqrt{1-\beta_t-(1-\beta_t)(1-\beta_{t-1})}\boldsymbol{\epsilon}_{t-1}+\sqrt{\beta_t}\boldsymbol{\epsilon}_{t} \end{aligned} zt=1−βtzt−1+βtϵt=1−βt(1−βt−1zt−2+βt−1ϵt−1)+βtϵt=(1−βt)(1−βt−1)zt−2+1−βt−(1−βt)(1−βt−1)ϵt−1+βtϵt
再根据高斯分布的混合公式,将 ϵ t − 1 \boldsymbol{\epsilon_{t-1}} ϵt−1和 ϵ t \boldsymbol{\epsilon}_{t} ϵt的项混合为 ϵ \boldsymbol{\epsilon} ϵ的分布,得到
z t = ( 1 − β t ) ( 1 − β t − 1 ) z t − 2 + ( 1 − β t ) − ( 1 − β t ) ( 1 − β t − 1 ) 2 + β t 2 ϵ = ( 1 − β t ) ( 1 − β t − 1 ) z t − 2 + 1 − β t − ( 1 − β t ) ( 1 − β t − 1 ) + β t ϵ = ( 1 − β t ) ( 1 − β t − 1 ) z t − 2 + 1 − ( 1 − β t ) ( 1 − β t − 1 ) ϵ = … = ∏ i = 1 t ( 1 − β i ) x + 1 − ∏ i = 1 t ( 1 − β i ) ϵ = α t x + 1 − α t ϵ , t = 1 , 2 , ⋯ , T \begin{align*} \boldsymbol{z}_{t}&=\sqrt{(1-\beta_t)(1-\beta_{t-1})}\boldsymbol{z}_{t-2}+\sqrt{\sqrt{(1-\beta_t)-(1-\beta_t)(1-\beta_{t-1})}^2+\sqrt{\beta_t}^2}\boldsymbol{\epsilon} \\ &=\sqrt{(1-\beta_t)(1-\beta_{t-1})}\boldsymbol{z}_{t-2}+\sqrt{1-\beta_t-(1-\beta_t)(1-\beta_{t-1})+\beta_t}\boldsymbol{\epsilon} \\ &=\sqrt{(1-\beta_t)(1-\beta_{t-1})}\boldsymbol{z}_{t-2}+\sqrt{1-(1-\beta_t)(1-\beta_{t-1})}\boldsymbol{\epsilon} \\ &=\ldots \\ &=\sqrt{\prod_{i=1}^t(1-\beta_i)}\boldsymbol{x}+\sqrt{1-\prod_{i=1}^t(1-\beta_i)}\boldsymbol{\epsilon} \\ &=\sqrt{\alpha_t}\boldsymbol{x}+\sqrt{1-\alpha_t}\boldsymbol{\epsilon},t=1,2,\cdots,T \end{align*} zt=(1−βt)(1−βt−1)zt−2+(1−βt)−(1−βt)(1−βt−1)2+βt2ϵ=(1−βt)(1−βt−1)zt−2+1−βt−(1−βt)(1−βt−1)+βtϵ=(1−βt)(1−βt−1)zt−2+1−(1−βt)(1−βt−1)ϵ=…=i=1∏t(1−βi)x+1−i=1∏t(1−βi)ϵ=αtx+1−αtϵ,t=1,2,⋯,T
为了区分不同时刻所对应的噪声,对 ϵ \boldsymbol{\epsilon} ϵ添加下标 t t t,可得
z t = α t x + 1 − α t ϵ t , t = 1 , 2 , ⋯ , T \begin{equation}\boldsymbol{z}_t=\sqrt{\alpha_t}\boldsymbol{x}+\sqrt{1-\alpha_t}\boldsymbol{\epsilon}_t, t=1,2,\cdots,T\end{equation} zt=αtx+1−αtϵt,t=1,2,⋯,T
其中, α t = ∏ s = 1 t ( 1 − β s ) \alpha_t=\prod_{s=1}^t (1-\beta_s) αt=∏s=1t(1−βs), ϵ ∼ N ( 0 , I ) \boldsymbol{\epsilon}\sim N(\boldsymbol{0}, \boldsymbol{I}) ϵ∼N(0,I)。
所以,一旦已知 x \boldsymbol{x} x,便可以得到 z t \boldsymbol{z}_t zt的分布,故:
q ( z t ∣ x ) = N ( α t x , ( 1 − α t ) I ) = 1 ( 2 π ) d 2 ( 1 − α t ) exp ( − ( z t − α t x ) 2 1 − α t ) \begin{equation}q(\boldsymbol{z}_t|\boldsymbol{x})=N(\sqrt{\alpha_t}\boldsymbol{\boldsymbol{x}},(1-\alpha_t)\boldsymbol{I})=\frac{1}{(2\pi)^{\frac{d}{2}}\sqrt{(1-\alpha_t)}}\exp{\left(-\frac{(\boldsymbol{z}_{t}-\sqrt{\alpha_t}\boldsymbol{x})^2}{1-\alpha_t}\right)}\end{equation} q(zt∣x)=N(αtx,(1−αt)I)=(2π)2d(1−αt)1exp(−1−αt(zt−αtx)2)
因此, z t \boldsymbol{z}_t zt可以通过先从标准的高斯分布中采样 ϵ \boldsymbol{\epsilon} ϵ,然后和 z 0 \boldsymbol{z}_0 z0进行混合得到。另外可以观察到,因为 β t \beta_t βt在 t t t很大的时候近似为 1 1 1,那么 α t \alpha_t αt在 t t t很大的时候近似等于0,此时 q ( z t ∣ x ) q(\boldsymbol{z}_t|\boldsymbol{x}) q(zt∣x)近似为一个标准的高斯分布。
扩散模型的去噪过程(解码器)
扩散模型的解码器是为了反转编码过程。如果知道逆向转移核 p ( z t − 1 ∣ z t ) p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t}) p(zt−1∣zt),那么就可以先从 p ( z T ) = N ( 0 , I ) p(\boldsymbol{z}_T)=N(\boldsymbol{0},\boldsymbol{I}) p(zT)=N(0,I)采样出 z T \boldsymbol{z}_T zT,再通过 p ( z T − 1 ∣ z T ) p(\boldsymbol{z}_{T-1}|\boldsymbol{z}_{T}) p(zT−1∣zT)采样出 z T − 1 \boldsymbol{z}_{T-1} zT−1,依次类推,直到采样出 z 0 \boldsymbol{z}_{0} z0,即 x \boldsymbol{x} x。
贝叶斯公式给出了根据 q ( z t ∣ z t − 1 ) q(\boldsymbol{z}_{t}|\boldsymbol{z}_{t-1}) q(zt∣zt−1)求出 p ( z t − 1 ∣ z t ) p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t}) p(zt−1∣zt)的方法,即
p ( z t − 1 ∣ z t ) = q ( z t ∣ z t − 1 ) q ( z t − 1 ) q ( z t ) \begin{equation}p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t})=\frac{q(\boldsymbol{z}_{t}|\boldsymbol{z}_{t-1})q(\boldsymbol{z}_{t-1})}{q(\boldsymbol{z}_t)}\end{equation} p(zt−1∣zt)=q(zt)q(zt∣zt−1)q(zt−1)
观察该式可知,由于 q ( z t − 1 ) / q ( z t ) q(\boldsymbol{z}_{t-1})/q(\boldsymbol{z}_{t}) q(zt−1)/q(zt)是未知的,所以求不出任何结果,而且实际上该逆向转移核不一定是高斯分布。
但是,如果给定额外条件 x \boldsymbol{x} x,由(15),可以得到
p ( z t − 1 ∣ z t , x ) = q ( z t ∣ z t − 1 , x ) q ( z t − 1 ∣ x ) q ( z t ∣ x ) \begin{equation}p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\boldsymbol{x})=\frac{q(\boldsymbol{z}_{t}|\boldsymbol{z}_{t-1},\boldsymbol{x})q(\boldsymbol{z}_{t-1}|\boldsymbol{x})}{q(\boldsymbol{z}_t|\boldsymbol{x})}\end{equation} p(zt−1∣zt,x)=q(zt∣x)q(zt∣zt−1,x)q(zt−1∣x)
根据马尔科夫链的性质 q ( z t ∣ z t − 1 , x ) = q ( z t ∣ z t − 1 ) q(\boldsymbol{z}_{t}|\boldsymbol{z}_{t-1},\boldsymbol{x})=q(\boldsymbol{z}_{t}|\boldsymbol{z}_{t-1}) q(zt∣zt−1,x)=q(zt∣zt−1),结合公式(8)和(10),经过很复杂的一段化简(省略过程)得到:
p ( z t − 1 ∣ z t , x ) = q ( z t ∣ z t − 1 ) q ( z t − 1 ∣ x ) q ( z t ∣ x ) ∝ q ( z t ∣ z t − 1 ) q ( z t − 1 ∣ x ) = N z t ( 1 − β t ⋅ z t − 1 , β t I ) N z t − 1 ( α t − 1 ⋅ x , ( 1 − α t − 1 ) I ) \begin{aligned} p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\boldsymbol{x})& =\quad\frac{q(\boldsymbol{z}_t|\boldsymbol{z}_{t-1})q(\boldsymbol{z}_{t-1}|\boldsymbol{x})}{q(\boldsymbol{z}_t|\boldsymbol{x})} \\ &\propto\quad q(\boldsymbol{z}_t|\boldsymbol{z}_{t-1})q(\boldsymbol{z}_{t-1}|\boldsymbol{x}) \\ &=\quad N_{\boldsymbol{z}_t}\left(\sqrt{1-\beta_t}\cdot\boldsymbol{z}_{t-1},\beta_t\boldsymbol{I}\right)N_{\boldsymbol{z}_{t-1}}\left(\sqrt{\alpha_{t-1}}\cdot\boldsymbol{x},(1-\alpha_{t-1})\boldsymbol{I}\right) \\\end{aligned} p(zt−1∣zt,x)=q(zt∣x)q(zt∣zt−1)q(zt−1∣x)∝q(zt∣zt−1)q(zt−1∣x)=Nzt(1−βt⋅zt−1,βtI)Nzt−1(αt−1⋅x,(1−αt−1)I)
根据高斯随机变量的变量替换定理,即
N v [ A w , B ] ∝ N w [ ( A T B − 1 A ) − 1 A T B − 1 v , ( A T B − 1 A ) − 1 ] N_{\boldsymbol{v}}\left[\boldsymbol{A}\boldsymbol{w},\boldsymbol{B}\right]\propto N_{\boldsymbol{w}}\left[\left(\boldsymbol{A}^T\boldsymbol{B}^{-1}\boldsymbol{A}\right)^{-1}\boldsymbol{A}^T\boldsymbol{B}^{-1}\boldsymbol{v},\left(\boldsymbol{A}^T\boldsymbol{B}^{-1}\boldsymbol{A}\right)^{-1}\right] Nv[Aw,B]∝Nw[(ATB−1A)−1ATB−1v,(ATB−1A)−1]
可得,
N z t ( 1 − β t ⋅ z t − 1 , β t I ) N z t − 1 ( α t − 1 ⋅ x , ( 1 − α t − 1 ) I ) ∝ N z t − 1 ( 1 1 − β t z t , β t 1 − β t I ) N z t − 1 ( α t − 1 ⋅ x , ( 1 − α t − 1 ) I ) \quad N_{\boldsymbol{z}_t}\left(\sqrt{1-\beta_t}\cdot\boldsymbol{z}_{t-1},\beta_t\boldsymbol{I}\right)N_{\boldsymbol{z}_{t-1}}\left(\sqrt{\alpha_{t-1}}\cdot\boldsymbol{x},(1-\alpha_{t-1})\boldsymbol{I}\right)\propto N_{\boldsymbol{z}_{t-1}}\left(\frac{1}{\sqrt{1-\beta_t}}\boldsymbol{z}_t,\frac{\beta_t}{1-\beta_t}\boldsymbol{I}\right)N_{\boldsymbol{z}_{t-1}}\left(\sqrt{\alpha_{t-1}}\cdot\boldsymbol{x},(1-\alpha_{t-1})\boldsymbol{I}\right) Nzt(1−βt⋅zt−1,βtI)Nzt−1(αt−1⋅x,(1−αt−1)I)∝Nzt−1(1−βt1zt,1−βtβtI)Nzt−1(αt−1⋅x,(1−αt−1)I)
再根据
N w [ a , A ] ⋅ N w [ b , B ] ∝ N w [ ( A − 1 + B − 1 ) − 1 ( A − 1 a + B − 1 b ) , ( A − 1 + B − 1 ) − 1 ] \begin{aligned}N_{\boldsymbol{w}}[\boldsymbol{a},\boldsymbol{A}]\cdot N_{\boldsymbol{w}}[\boldsymbol{b},\boldsymbol{B}]\propto N_{\boldsymbol{w}}&\left[\left(\boldsymbol{A}^{-1}+\boldsymbol{B}^{-1}\right)^{-1}(\boldsymbol{A}^{-1}\boldsymbol{a}+\boldsymbol{B}^{-1}\boldsymbol{b}),\left(\boldsymbol{A}^{-1}+\boldsymbol{B}^{-1}\right)^{-1}\right]\end{aligned} Nw[a,A]⋅Nw[b,B]∝Nw[(A−1+B−1)−1(A−1a+B−1b),(A−1+B−1)−1]
最终得到
p ( z t − 1 ∣ z t , x ) = N z t − 1 [ ( 1 − α t − 1 ) 1 − α t 1 − β t z t + α t − 1 β t 1 − α t x , β t ( 1 − α t − 1 ) 1 − α t I ] p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_t,\boldsymbol{x})=N_{\boldsymbol{z}_{t-1}}\left[\frac{(1-\alpha_{t-1})}{1-\alpha_t}\sqrt{1-\beta_t}\boldsymbol{z}_t+\frac{\sqrt{\alpha_{t-1}}\beta_t}{1-\alpha_t}\boldsymbol{x},\frac{\beta_t(1-\alpha_{t-1})}{1-\alpha_t}\boldsymbol{I}\right] p(zt−1∣zt,x)=Nzt−1[1−αt(1−αt−1)1−βtzt+1−αtαt−1βtx,1−αtβt(1−αt−1)I]
由此可知 p ( z t − 1 ∣ z t , x ) p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\boldsymbol{x}) p(zt−1∣zt,x)是一个高斯分布。
因此,尽管 p ( z t − 1 ∣ z t ) p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t}) p(zt−1∣zt)不是高斯分布,但给定条件 x \boldsymbol{x} x后得到的 p ( z t − 1 ∣ z t , x ) p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\boldsymbol{x}) p(zt−1∣zt,x)是高斯分布。另外,如果 p ( z t − 1 ∣ z t , x ) p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\boldsymbol{x}) p(zt−1∣zt,x)的均值和方差被确定,那么进一步可以写出从中采样的公式,得到 z t − 1 \boldsymbol{z}_{t-1} zt−1。因此,可以考虑用神经网络来近似 p ( z t − 1 ∣ z t , x ) p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\boldsymbol{x}) p(zt−1∣zt,x)(在后续训练目标的推导中可以看出网络的目标实际上是近似 p ( z t − 1 ∣ z t , x ) p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\boldsymbol{x}) p(zt−1∣zt,x)),记作 p ( z t − 1 ∣ z t , θ t ) p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\theta_t) p(zt−1∣zt,θt)。为了简化该分布,将其方差设为固定值,神经网络仅仅估计其均值。
p ( z t − 1 ∣ z t , θ t ) = N ( f t ( z t , θ t ) , σ t 2 I ) \begin{equation}p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\theta_t)=N(f_t(\boldsymbol{z}_t,\theta_t),\sigma_t^2\boldsymbol{I})\end{equation} p(zt−1∣zt,θt)=N(ft(zt,θt),σt2I)
其中 f t f_t ft为神经网络,其接受输入 z t \boldsymbol{z}_t zt并输出一个估计的均值, θ t \theta_t θt为该网络的参数, σ t \sigma_t σt为人为设定的标准差。
如果能训练出使得原数据 z 0 \boldsymbol{z}_0 z0总体出现概率最大的神经网络 f t ( z t , θ t ) f_t(\boldsymbol{z}_t,\theta_t) ft(zt,θt),进而得到 p ( z t − 1 ∣ z t , θ t ) p(\boldsymbol{z}_{t-1}|\boldsymbol{z}_{t},\theta_t) p(zt−1∣zt,θt),那么就可以先从 N ( 0 , I ) N(\boldsymbol{0},\boldsymbol{I}) N(0,I)采样出 z T \boldsymbol{z}_T zT,再通过 p ( z T − 1 ∣ z T , θ t ) p(\boldsymbol{z}_{T-1}|\boldsymbol{z}_{T},\theta_t) p(zT−1∣zT,θt)采样出 z T − 1 \boldsymbol{z}_{T-1} zT−1,依次类推,直到采样出 x \boldsymbol{x} x,即 z 0 \boldsymbol{z}_{0} z0。