1--前言
以论文《High-Resolution Image Synthesis with Latent Diffusion Models》 开源的项目为例,剖析Stable Diffusion经典组成部分,巩固学习加深印象。
2--DDIM
一个可以debug的小demo:SD_DDIM
以文生图为例,剖析SD中DDIM的核心组成模块。 本质上SD的DDIM遵循论文DENOISING DIFFUSION IMPLICIT MODELS的核心公式。
3--核心模块剖析
见SD_DDIM
4--完整代码
import torch
import pytorch_lightning as plimport numpy as np
from tqdm import tqdm
from functools import partial# From https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/util.py
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose = True):# select alphas for computing the variance schedulealphas = alphacums[ddim_timesteps] # 由于alphacums来自DDPM,所以本质上还是调用了DDPM的alphas_cumprod,即[0.9983, 0.9804, ..., 0.0058]alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) # 构成alphas_prev的方法是保留前49个alphas,同时在最前面添加DDPM的alphas_cumprod[0], 即[0.9991]# according the the formula provided in https://arxiv.org/abs/2010.02502 论文中的公式16sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))if verbose:print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')print(f'For the chosen value of eta, which is {eta}, 'f'this results in the following sigma_t schedule for ddim sampler {sigmas}')return sigmas, alphas, alphas_prev# From https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/util.py
# 获取 ddim 的timesteps
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose = True):if ddim_discr_method == 'uniform':c = num_ddpm_timesteps // num_ddim_timesteps # 1000 // 50 = 20ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) # 间隔c取样elif ddim_discr_method == 'quad':ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)else:raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')# assert ddim_timesteps.shape[0] == num_ddim_timesteps# add one to get the final alpha values right (the ones from first scale to data during sampling)steps_out = ddim_timesteps + 1 # 每个数值加1if verbose:print(f'Selected timesteps for ddim sampler: {steps_out}')return steps_out # [1, 21, 41, ..., 981]# From https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/util.py
def noise_like(shape, device, repeat = False):repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))noise = lambda: torch.randn(shape, device=device)return repeat_noise() if repeat else noise()# From https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/util.py
def make_beta_schedule(schedule, n_timestep, linear_start = 1e-4, linear_end = 2e-2, cosine_s = 8e-3):if schedule == "linear":betas = (torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype = torch.float64) ** 2)elif schedule == "cosine":timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s)alphas = timesteps / (1 + cosine_s) * np.pi / 2alphas = torch.cos(alphas).pow(2)alphas = alphas / alphas[0]betas = 1 - alphas[1:] / alphas[:-1]betas = np.clip(betas, a_min=0, a_max=0.999)elif schedule == "sqrt_linear":betas = torch.linspace(linear_start, linear_end, n_timestep, dtype = torch.float64)elif schedule == "sqrt":betas = torch.linspace(linear_start, linear_end, n_timestep, dtype = torch.float64) ** 0.5else:raise ValueError(f"schedule '{schedule}' unknown.")return betas.numpy()# origin from https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/ddpm.py, modified by ljf
class DDPM(pl.LightningModule):def __init__(self, given_betas = None, beta_schedule = "linear", timesteps = 1000, linear_start = 0.00085, linear_end = 0.012, cosine_s = 8e-3):super().__init__()self.v_posterior = 0.0self.parameterization = "eps"self.register_schedule(given_betas = given_betas, beta_schedule = beta_schedule, timesteps = timesteps,linear_start = linear_start, linear_end = linear_end, cosine_s = cosine_s)def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,cosine_s=cosine_s) # 计算 betas [0.00085, 0.0008547, ..., 0.012] # total 1000alphas = 1. - betas # 根据betas计算alphas [0.99915, 0.9991453, ..., 0.988] # total 1000alphas_cumprod = np.cumprod(alphas, axis=0) # 计算alphas_cumprod [0.99915, 0.99915*0.9991453, ..., ..*0.988] # 与本身及前面的数进行相乘alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) # 计算alphas_cumprod_prev [1, 0.99915, 0.99915*0.9991453, ...] # 保留前999位timesteps, = betas.shapeself.num_timesteps = int(timesteps)self.linear_start = linear_startself.linear_end = linear_endassert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'to_torch = partial(torch.tensor, dtype=torch.float32)self.register_buffer('betas', to_torch(betas))self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))# calculations for diffusion q(x_t | x_{t-1}) and othersself.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))# calculations for posterior q(x_{t-1} | x_t, x_0)posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + self.v_posterior * betas# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)self.register_buffer('posterior_variance', to_torch(posterior_variance))# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chainself.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))self.register_buffer('posterior_mean_coef1', to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))self.register_buffer('posterior_mean_coef2', to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))if self.parameterization == "eps":lvlb_weights = self.betas ** 2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))elif self.parameterization == "x0":lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))else:raise NotImplementedError("mu not supported")# TODO how to choose this termlvlb_weights[0] = lvlb_weights[1]self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)assert not torch.isnan(self.lvlb_weights).all()# 模拟 UNet 预测def apply_model(self, x_noisy, t, cond, return_ids=False):return torch.rand(x_noisy.shape) # 随机返回一个latent 预测# Origin from https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/ddim.py, modified by ljf
class DDIMSampler(object):def __init__(self, model, schedule = "linear", **kwargs):super().__init__()self.model = model # DDPM的modelself.ddpm_num_timesteps = model.num_timestepsself.schedule = scheduledef register_buffer(self, name, attr):if type(attr) == torch.Tensor:if attr.device != torch.device("cuda"):attr = attr.to(torch.device("cuda"))setattr(self, name, attr)def make_schedule(self, ddim_num_steps, ddim_discretize = "uniform", ddim_eta = 0., verbose = True):# 获取ddim的timesteps [1, 21, 41, ..., 981]self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method = ddim_discretize, num_ddim_timesteps = ddim_num_steps,num_ddpm_timesteps = self.ddpm_num_timesteps, verbose = verbose)alphas_cumprod = self.model.alphas_cumprod # 使用ddpm的alphas_cumprodassert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) # lambda表达式,对每一个输入实现相同的操作self.register_buffer('betas', to_torch(self.model.betas)) # 使用ddpm的betasself.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) # 使用ddpm的alphas_cumprodself.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) #使用ddpm的alphas_cumprod_prev# calculations for diffusion q(x_t | x_{t-1}) and othersself.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))# ddim sampling parametersddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums = alphas_cumprod.cpu(),ddim_timesteps = self.ddim_timesteps,eta = ddim_eta,verbose = verbose)self.register_buffer('ddim_sigmas', ddim_sigmas)self.register_buffer('ddim_alphas', ddim_alphas)self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt((1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (1 - self.alphas_cumprod / self.alphas_cumprod_prev))self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)@torch.no_grad()def sample(self, S, batch_size, shape, conditioning = None, callback = None,img_callback = None, quantize_x0 = False, eta = 0., mask = None, x0 = None,temperature = 1., noise_dropout = 0., score_corrector = None, corrector_kwargs = None,verbose = True, x_T = None, log_every_t = 100, unconditional_guidance_scale = 1.,unconditional_conditioning = None):self.make_schedule(ddim_num_steps = S, ddim_eta = eta, verbose = verbose) # 注册各个参数# samplingC, H, W = shape # [4, 64, 64]size = (batch_size, C, H, W) # [3, 4, 64, 64]print(f'Data shape for DDIM sampling is {size}, eta {eta}')samples, intermediates = self.ddim_sampling(conditioning, size,callback = callback,img_callback = img_callback,quantize_denoised = quantize_x0,mask = mask, x0 = x0,ddim_use_original_steps = False,noise_dropout = noise_dropout,temperature = temperature,score_corrector = score_corrector,corrector_kwargs = corrector_kwargs,x_T = x_T,log_every_t = log_every_t,unconditional_guidance_scale = unconditional_guidance_scale,unconditional_conditioning = unconditional_conditioning,)return samples, intermediates@torch.no_grad()def ddim_sampling(self, cond, shape,x_T = None, ddim_use_original_steps = False,callback = None, timesteps = None, quantize_denoised = False,mask = None, x0 = None, img_callback = None, log_every_t = 100,temperature = 1., noise_dropout = 0., score_corrector = None, corrector_kwargs = None,unconditional_guidance_scale = 1., unconditional_conditioning = None):device = self.model.betas.deviceb = shape[0] # batchsizeif x_T is None:img = torch.randn(shape, device=device)else:img = x_Ttimesteps = self.ddim_timestepsintermediates = {'x_inter': [img], 'pred_x0': [img]}time_range = np.flip(timesteps) total_steps = timesteps.shape[0] # 50print(f"Running DDIM Sampling with {total_steps} timesteps")iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)for i, step in enumerate(iterator): # 981, 961, ..., 1index = total_steps - i - 1 ts = torch.full((b,), step, device=device, dtype=torch.long) # [981, 981, 981], [961, 961, 961], ...outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,quantize_denoised=quantize_denoised, temperature=temperature,noise_dropout=noise_dropout, score_corrector=score_corrector,corrector_kwargs=corrector_kwargs,unconditional_guidance_scale=unconditional_guidance_scale,unconditional_conditioning=unconditional_conditioning)img, pred_x0 = outs # 更新imgif callback: callback(i)if img_callback: img_callback(pred_x0, i)if index % log_every_t == 0 or index == total_steps - 1:intermediates['x_inter'].append(img)intermediates['pred_x0'].append(pred_x0)return img, intermediates@torch.no_grad()def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,unconditional_guidance_scale=1., unconditional_conditioning=None):b, *_, device = *x.shape, x.deviceif unconditional_conditioning is None or unconditional_guidance_scale == 1.:e_t = self.model.apply_model(x, t, c)else:x_in = torch.cat([x] * 2) # [3, 4, 64, 64] -> [6, 4, 64, 64]t_in = torch.cat([t] * 2) # [3] -> [6]c_in = torch.cat([unconditional_conditioning, c]) # [3, 77, 768] -> [6, 77, 768]e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) # using Unete_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) # free guidance# 使用ddpm的参数或者make_ddim_sampling_parameters()函数生成的参数,这里默认使用了后者alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphasalphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prevsqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphassigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas# select parameters corresponding to the currently considered timestepa_t = torch.full((b, 1, 1, 1), alphas[index], device = device)a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device = device)sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device = device)sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device = device)# current prediction for x_0pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() # 论文https://arxiv.org/pdf/2010.02502中公式(12)的第一项# direction pointing to x_tdir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t # 论文https://arxiv.org/pdf/2010.02502中公式(12)的第二项 noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature # 论文https://arxiv.org/pdf/2010.02502中公式(12)的第三项 # 由于输入的eta为0,因此sigma_t为0,因此本式的结果为0if noise_dropout > 0.:noise = torch.nn.functional.dropout(noise, p=noise_dropout)x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise # 构成论文https://arxiv.org/pdf/2010.02502中的公式(12),即根据x_t得到x_(t-1)return x_prev, pred_x0 if __name__ == "__main__":model = DDPM() # 初始化DDPM modelsampler = DDIMSampler(model)# 模拟FrozenCLIPEmbedder的输出batchsize = 3c = torch.rand(batchsize, 77, 768) # 模拟有prompt时的embeddinguc = torch.rand(batchsize, 77, 768) # 模拟无prompt时的embedding# 使用ddim进行去噪shape = [4, 64, 64]scale = 7.5 # unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))ddim_eta = 0.0 # ddim eta (eta=0.0 corresponds to deterministic samplingsamples_ddim, _ = sampler.sample(S = 50, # 采样50步conditioning = c, # 条件embeddingbatch_size = batchsize,shape = shape,verbose = False,unconditional_guidance_scale = scale,unconditional_conditioning = uc, # 无条件embeddingeta = ddim_eta,x_T = None)assert samples_ddim.shape[0] == batchsizeassert list(samples_ddim[0].shape) == shapeprint("samples_ddim.shape: ", samples_ddim.shape)assert samples_ddim.shape[0] == batchsizeassert list(samples_ddim.shape[1:]) == shapeprint("All Done!")