论文名称:High-Resolution Image Synthesis with Latent Diffusion Models
发表时间:CVPR2022
作者及组织:Robin Rombach, Andreas Blattmann, Dominik Lorenz,Patrick Esser和 Bjorn Ommer, 来自Ludwig Maximilian University of Munich & IWR, Heidelberg University, Germany。
开源地址:https://github.com/CompVis/latent-diffusion
前言
本文就是VQGAN和DDPM的结合。在图像的2D特征向量上做加噪去噪,从而降低DDPM在全像素空间上生成造成计算量大的问题。而且在隐变量上训练DDPM在一定程度上并不会损失生成的细节。
1、方法
以VQGAN为例,第一个阶段是感知压缩阶段,旨在去掉无关的像素细节;第二个阶段是语义压缩阶段,让自回归模型来预测图像的语义。而本文就是找到两个压缩阶段之间的一个trade-off。
1.1.感知压缩阶段
该阶段用T-UNet来提取图像特征向量: z = E ( x ) z = E(x) z=E(x) ,其中 z z z 并不像VQGAN中一样是经过codebook后的特征向量,而是未经过codebook,因为作者认为此时 z z z 天然具有一定归纳偏置 ,有利于后续生成。而压缩的比例用变量 $f $ 进行表示(比如f=2就表示下采样2倍,f=1就是原始像素空间)。
解码器为 x ^ = D ( z ) \hat x = D(z) x^=D(z) 。
为了防止隐空间的特征向量有高方差,加了两个正则化,KL-reg和VQ-reg,分别对应VAE和VQGAN中的两种损失函数。
1.2.LDM
DM损失函数为:
L D M = E x , ϵ N ( 0 , 1 ) , t [ ∣ ∣ ϵ − ϵ θ ( x t , t ) ∣ ∣ 2 2 ] \begin{equation} L_{DM}= E_{x,\epsilon~N(0,1),t} [||\epsilon-\epsilon_\theta(x_t,t)||_2^2] \tag{1} \end{equation} LDM=Ex,ϵ N(0,1),t[∣∣ϵ−ϵθ(xt,t)∣∣22](1)
LDM的损失函数就是将采样样本x变成了隐空间 z = E ( x ) z=E(x) z=E(x) :
L D M = E E ( x ) , ϵ N ( 0 , 1 ) , t [ ∣ ∣ ϵ − ϵ θ ( x t , t ) ∣ ∣ 2 2 ] \begin{equation} L_{DM}= E_{E(x),\epsilon~N(0,1),t} [||\epsilon-\epsilon_\theta(x_t,t)||_2^2] \tag{2} \end{equation} LDM=EE(x),ϵ N(0,1),t[∣∣ϵ−ϵθ(xt,t)∣∣22](2)
而如果加一些条件(文本,layout,mask…)则损失函数为:
L L D M = E E ( x ) , y , ϵ N ( 0 , 1 ) , t [ ∣ ∣ ϵ − ϵ θ ( x t , t , τ θ ( y ) ) ∣ ∣ 2 2 ] \begin{equation} L_{LDM}= E_{E(x),y,\epsilon~N(0,1),t} [||\epsilon-\epsilon_\theta(x_t,t, \tau_\theta(y))||_2^2] \tag{3} \end{equation} LLDM=EE(x),y,ϵ N(0,1),t[∣∣ϵ−ϵθ(xt,t,τθ(y))∣∣22](3)
其中条件注入用了CrossAttn。
2、实验
2.1. class conditional
数据集:ImageNet和Celeb-A数据集。
下图表示LDM-4/8收敛速度快,且生成图像的保真度高。
下图表示相同采样步数,LDM-8吞吐量高且生成图像逼真。
外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传
2.2. ConditionalLDM
Text2img训了一个1.45B的模型在LAION-400M。下图说明 class free guide 的trick非常有用,但训练资源加倍。
外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传
2.3. rescale
在AE和DM训练中,为了防止隐空间尺度任意变换,对 z z z 做了一下正则化,如下图所示,若不做正则化,生成图像细节不足。
外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传
思考:
LDM还能完成好多其余工作:比如text2img,img inpaint, mask2img, super等。是后续生成模型的基本组件。