Stable Diffusion 源码解析(1)

参考1:https://blog.csdn.net/Eric_1993/article/details/129393890
参考2:https://zhuanlan.zhihu.com/p/613337342

  • 1.StableDiffusion基本原理
    • 1.1 UNetModel、FrozenCLIP 模型
    • 1.2 DDPM、DDIM、PLMS算法
  • 2. Runwayml SD 源码
    • 2.1 Img2Img Pipeline
    • 2.2 DDIMSampler
    • 2.3 LatentDiffusion Model
    • 2.4 UNet Model
  • 3. Diffusers StableDiffusionPipeline源码

1.StableDiffusion基本原理

在这里插入图片描述
训练阶段 (查看图中 Part 1 和 Part 2),主要包含:

  1. 使用 AutoEncoderKL 自编码器将图像 Image 从 pixel space 映射到 latent space,学习图像的隐式表达,注意 AutoEncoderKL 编码器已提前训练好,参数是固定的。此时 Image 的大小将从 [B, C, H, W] 转换为 [B, Z, H/8, W/8],其中 Z 表示 latent space 下图像的 Channel 数。这一过程在 Stable Diffusion 代码中被称为 encode_first_stage
  2. 使用 FrozenCLIPEmbedder 文本编码器对 Prompt 提示词进行编码,生成大小为 [B, K, E] 的 embedding 表示(即 context),其中 K 表示文本最大编码长度 max length, E 表示 embedding 的大小。这一过程在 Stable Diffusion 代码中被称为 get_learned_conditioning
  3. 进行前向扩散过程(Diffusion Process),对图像的隐式表达进行不断加噪,该过程调用 UNetModel 完成;UNetModel 同时接收图像的隐式表达 latent image 以及文本 embedding context,在训练时以 context 作为 condition,使用 Attention 机制来更好的学习文本与图像的匹配关系;
  4. 扩散模型输出噪声 ϵ θ \epsilon_{\theta} ϵθ ,计算和真实噪声之间的误差作为 Loss,通过反向传播算法更新 UNetModel 模型的参数,注意这个过程中 AutoEncoderKL 和 FrozenCLIPEmbedder 中的参数不会被更新。

采样阶段(查看图中 Part 2 和 Part 3),也就是我们加载模型参数后,输入提示词就能产出图像的阶段。主要包含:

  1. 使用 FrozenCLIPEmbedder 文本编码器对 Prompt 提示词进行编码,生成大小为 [B, K, E] 的 embedding 表示(即 context);
  2. 随机产出大小为 [B, Z, H/8, W/8] 的噪声 Noise,利用训练好的 UNetModel 模型,按照 DDPM/DDIM/PLMS 等算法迭代 T 次,将噪声不断去除,恢复出图像的 latent 表示;
  3. 使用 AutoEncoderKL 对图像的 latent 表示(大小为 [B, Z, H/8, W/8])进行 decode(解码),最终恢复出 pixel space 的图像,图像大小为 [B, C, H, W]; 这一过程在 Stable Diffusion 中被称为 decode_first_stage

Stable Diffusion 的模块大致包括:

  • FrozenCLIPEmbedder
  • UNetModel
  • AutoEncoderKL & VQModelInterface (也是一种变分自动编码器,图上没画)
  • DDPM、DDIM、PLMS 算法

1.1 UNetModel、FrozenCLIP 模型

画了一下 Stable Diffusion 中使用的 UNetModel,就不分析代码了,看图很容易将代码写出来。Stable Diffusion 采用 UNetModel 这种 Encoder-Decoder 结构来实现扩散的过程,对噪声进行预估, 网络结构如下:

在这里插入图片描述
模型的输入包含三个部分:

  • 大小为 [B, C, H, W] 的图像 image; 注意不用在意表示大小时所用的符号,应将它们视作接口,比如 UNetModel 接收大小为 [B, Z, H/8, W/8] 的 noise latent image 作为输入时,这里的 C 就等于 Z, H 就等于 H/8, W 就等于 W/8
  • 大小为 [B,] 的 timesteps
  • 大小为 [B, K, E] 的文本 embedding 表示 context, 其中 K 表示最大编码长度,E 表示 embedding 大小

模型使用 DownSampleUpSample 来对样本进行下采样和上采样,此外出现最多的模块是 ResBlock 以及 SpatialTransformer,其中图中每一个 ResBlock 接收来自上一个模块的输入以及 timesteps 对应的 embedding timestep_emb (大小为 [B, 4*M]M 是可配置的参数);而图中每一个 SpatialTransformer接收来自上一个模块的输入以及 context (Prompt 文本的 embedding 表示),使用 Cross Attention,以 context 为 condition,学习 Prompt 和图像的匹配关系。但图上只在虚线框中显示了两个模块有多个输入,其他模块没有画出来)

可以看到,最后模型的输出大小为 [B, C, H, W], 和输入大小相同,也就是说 UNetModel 不改变输入输出的大小。

下面再分别看看 ResBlocktimestep_embeddingcontext 以及 SpatialTransformer 的实现。

ResBlock 的实现
ResBlock 网络结构图如下,它接受两个输入,图像 x 以及 timestep 对应的 embedding:
在这里插入图片描述
timestep_embedding 实现
timestep_embedding 的生成方式如下,用的是 Tranformer(Attention is All you Need)这篇 paper 中的方法:
在这里插入图片描述
Prompt 文本 embedding 的实现
context 的实现。Prompt 使用 CLIP 模型进行编码,我没有对 CLIP 模型详细学习,暂时也没有深入看的打算,后续有机会再补充;代码中使用预训练好的 CLIP 生成 context:
在这里插入图片描述
SpatialTransformer 的实现
最后再看下 SpatialTransformer 的实现,其模块比较多,在接收图像作为输入时,还使用 context 文本作为 condition 信息,二者使用 Cross Attention 进行建模。进一步展开 SpatialTransformer, 发现包含 BasicTransformerBlock ,它实际调用 Cross Attention 模块,而在 Cross Attention 模块中,图像信息作为 Query,文本信息作为 Key & Value,模型会关注图像和文本各部分内容的相关性:
在这里插入图片描述

我觉得可以用一种朴素的想法来理解这里 Cross Attention 的作用,比如训练时给定一张马吃草的图,以及文本提示词:“一匹白色的马在沙漠吃草”,在做 Attention 时,文本中的 “马” 这个关键词和图像中的动物(也是 “马”)的关联性更强,因为权重也更大,而 “一匹”、 “白色”、“沙漠”、 “草” 等权重更低;此时,当模型被训练的很好后,模型不仅将可以学习到图像和文本之间的匹配关系,通过 Attention 还可以学习到文本中的各个关键词想突出图像中哪些主体。

1.2 DDPM、DDIM、PLMS算法

扩散模型使用 DDPM、DDIM、PLMS 等采样Sample算法通过迭代去除噪声,从而生成图像的潜在空间(latent space)表示。

在图像生成前,模型会首先在Latent Space中生成一个完全随机的图像,然后噪声预测器会开始工作,从图像中减去预测的噪声。随着这个步骤的不断重复,最终我们得到了一个清晰的图像。Stable Diffusion在每个步骤中都会生成一张新的采样后的图像,整个去噪的过程,即为采样Sample,使用的采样手段,即为采样器Sampler或称为采样方法

DDPM

对原理进行朴素回顾

DDPM (Denoising Diffusion Probabilistic Models)算法之前在 扩散模型 (Diffusion Model) 简要介绍与源码分析 介绍过,推导有些复杂,这里就用朴素的大白话描述一下我觉得最重要的几个公式,然后分析代码实现,核心是理清楚推导的逻辑链。

首先扩散模型的整个思路是先在图像上不断的加噪,从而对图像进行破坏,然后再对破坏后的图像进行不断的去噪,最后恢复出原始图像。这个过程可以用如下公式描述:
在这里插入图片描述
现在的一个问题是如何求逆向阶段的分布,也就是如果给定了一张加噪的图像,我们如何才能求得它前一时刻没有被破坏的那么严重的图像。经过数学高手们的一顿推导,发现两个重要结论:1. 逆向过程也服从高斯分布;2. 在知晓初始干净图像的情况下,我们能通过贝叶斯公式将逆向过程转换成前向过程,从而算出逆向过程的分布; 在公式上体现如下:
在这里插入图片描述
算出逆向过程的分布后,我们就可以训练一个模型,去尽力拟合这个分布,那么模型预估出来的结果也应该服从高斯分布:
在这里插入图片描述
现在逆向过程的分布有了(可以理解为 label),模型的预估分布也有了,就差一个 Loss 函数,而经过数学高手的又一顿推导,发现 Loss 居然是计算两个分布的 KL 散度,而且还是两个高斯分布的 KL 散度!朴素的说,KL 散度可以用来描述两个分布之间的差距。不得不感慨,数学就是这么神奇,左推右推,最后能得到一个美妙的结果:
在这里插入图片描述
多元高斯分布的 KL 散度是有闭式解的,详见维基百科,具体公式如下:
在这里插入图片描述

最后得到训练过程和采样过程分别如下:
在这里插入图片描述
DDPM 代码分析
再次提醒,我对源码进行了抽象,以伪代码的形式呈现。详细列出每行代码完全没有必要,太多的细节会淹没真正重要的信息。另外注意两点:1. 在实现上,我保持类名、函数名和源码一致,这样就可以方便快速了解类或者函数的功能;2. 函数尽量按调用顺序进行组织;
在这里插入图片描述
不客气的说,非常简洁。PyTorch 中 forward() 函数是入口,输出噪声之间的 Loss;

  • 采样阶段:

按顺序阅读,核心在 p_sample 函数中,使用重参数技巧生成样本:

在这里插入图片描述
针对 DDPM 的改进
下面简单介绍 DDIM 和 PLMS算法,它们均是对 DDPM 算法的改进。DDPM 在采样阶段需要迭代很多次(比如 1000)才能得到一个比较好的效果,而 DDIM、PLMS 算法则尝试使用较少的迭代次数来加速采样过程。下图是 DDIM 论文中给出的实验结果分析:

在这里插入图片描述
其中第一行(绿线…)是 DDIM 的结果,最后一行是 DDPM 的实验结果,使用 FID 来评估生成图像的质量,该值越小,表示结果越好;S 为迭代次数,只看红框中的 CIFAR10 数据集上的效果,可以发现随着迭代次数的增加,FID 越小,生成图像质量越好;另外可以注意到 DDIM 迭代到第 50 次左右时,就几乎能达到 DDPM 迭代到 1000 次的效果 (4.67 vs. 3.17);

DDIM
DDIM 将图像的采样过程定义为非马尔科夫链:
在这里插入图片描述
并重新推导了图像的生成公式:
在这里插入图片描述
其中 σ t \sigma_t σt, 定义如下:
在这里插入图片描述

根据推导,如果系数 η = 1, 那么此时采样过程和 DDPM 相同;而当系数 η = 0 时,即为 DDIM 算法的采样过程,注意到此时均方差为 0,图像的生成过程是确定的。另外需要注意在 DDIM paper 的公式中, α t \alpha_t αt以及 β t \beta_t βt 等的含义和 DDPM 论文中不同,它们被重新定义了…

伪代码如下(DDIM 默认只迭代 50 步):
在这里插入图片描述

PLMS
PLMS是对DDIM的改进,论文中给出采样过程的公式如下:
在这里插入图片描述伪代码如下:
在这里插入图片描述

2. Runwayml SD 源码

参考runwayml的stable-diffusion-v1.5代码,使用SDEdit进行图像编辑推理:先将原始图像经过Inversion加噪作为Sample的起点,在text prompt(condition+uncondition)的引导下迭代Sample,采样噪声,迭代去噪得到编辑后的图像。

2.1 Img2Img Pipeline

Img2Img 图像编辑时,运行scripts/img2img.py,传入text promptinit image

python scripts/img2img.py --prompt "A fantasy landscape, trending on artstation" --init-img /home/pgao/yue/Stable_Diffusion/data/sketch-mountains-input.jpg --strength 0.8

总结超参数如下:

--prompt:要渲染的提示文本,默认为 "a painting of a virus monster playing guitar"--init-img:输入图像的路径。
--outdir:结果保存的目录,默认为 "outputs/img2img-samples"--skip_grid:是否跳过保存网格图像,仅保存单独的样本图像。在评估大量样本时很有帮助。
--skip_save:是否不保存单独的样本图像,用于加速测量。
--ddim_steps:ddim采样步骤的数量。
--plms:是否使用plms采样。
--fixed_code:如果启用,则在所有样本中使用相同的起始编码。
--ddim_eta:ddim采样的eta值(eta=0.0表示确定性采样)。
--n_iter:采样次数。
--C:潜变量通道数。
--f:下采样因子,通常为816--n_samples:对于每个给定的提示文本,要生成的样本数量,也称为批处理大小。
--n_rows:网格中的行数(默认为n_samples)。
--scale:无条件引导尺度。
--strength:噪声/去噪的强度。1.0表示Inversion完全破坏初始图像中的信息。
--from-file:如果指定,从该文件加载提示文本。
--config:构建模型的配置文件路径。
--ckpt:模型的检查点路径。
--seed:种子值(用于可重复采样)。
--precision:在此精度下评估,可选值为 "full""autocast",默认为 "autocast"

就从入口scripts/img2img.pymain开始阅读,跳过传入参数的parser部分:

  1. 设置随机种子seed:
	seed_everything(opt.seed)
  1. 加载Stable Diffusion模型:
	config = OmegaConf.load(f"{opt.config}")model = load_model_from_config(config, f"{opt.ckpt}")device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")model = model.to(device)

其中 opt.config= 'configs/stable-diffusion/v1-inference.yaml',指向一个预定义好的SD model的配置文件,opt.ckpt是预先下载好的模型权重。

然后看load_model_from_config函数,这一函数就定义在同一个文件(img2img.py文件)中,但是它调用了ldm.util中的两个方法instantiate_from_configget_obj_from_str,将ckpt权重加载到ldm.models.diffusion.ddpm.LatentDiffusion中。这里一起写出来:

def instantiate_from_config(config):# target = ldm.models.diffusion.ddpm.LatentDiffusionif not "target" in config:  if config == '__is_first_stage__':return Noneelif config == "__is_unconditional__":return Noneraise KeyError("Expected key `target` to instantiate.")# 等价于return LatentDiffusion(**config.get("params", dict()))return get_obj_from_str(config["target"])(**config.get("params", dict()))def get_obj_from_str(string, reload=False):# string = ldm.models.diffusion.ddpm.LatentDiffusionmodule, cls = string.rsplit(".", 1)# from ldm.models.diffusion.ddpm import LatentDiffusionif reload:module_imp = importlib.import_module(module)importlib.reload(module_imp)return getattr(importlib.import_module(module, package=None), cls)def load_model_from_config(config, ckpt, verbose=False):print(f"Loading model from {ckpt}")pl_sd = torch.load(ckpt, map_location="cpu")if "global_step" in pl_sd:print(f"Global Step: {pl_sd['global_step']}")sd = pl_sd["state_dict"]# 实例化 LatentDiffusion model model = instantiate_from_config(config.model)# 为 model 加载权重sdmissing, unexpected = model.load_state_dict(sd, strict=False)if len(missing) > 0 and verbose:print("missing keys:")print(missing)if len(unexpected) > 0 and verbose:print("unexpected keys:")print(unexpected)model.cuda()model.eval()return model

其中,missing是一个列表,包含在加载状态字典时模型中缺失的键(参数)。unexpected是一个列表,包含加载状态字典时模型中未预期到的额外键(参数)。理想情况下,两者都是空的。“verbose” 参数通常是一个布尔值或整数,用来控制程序在执行时是否输出详细信息,以及输出信息的程度。

实际上等效于,先实例化ldm.models.diffusion.ddpm.LatentDiffusion,再加载权重:

from ldm.models.diffusion.ddpm import LatentDiffusion
model = LatentDiffusion(**config.model.get("params", dict()))
model.load_state_dict(torch.load(ckpt, map_location="cpu")["state_dict"], strict=False)

原code使用importlib.import_module,来读取字典中的模块名称进行灵活的import。从方便理解代码运行和算法原理的视角来看,在实际使用LatentDiffusion时,上下两种写法是完全等效的。

# 初始化模型的全部逻辑:from ldm.models.diffusion.ddpm import LatentDiffusion
import torch
from omegaconf import OmegaConf# 读取config
config = OmegaConf.load(f"{opt.config}")# 初始化模型并传入config中的参数
model = LatentDiffusion(**config.model.get("params", dict()))
model.load_state_dict(torch.load(ckpt, map_location="cpu")["state_dict"], strict=False)device = torch.device("cuda")
model = model.to(device)
  1. 设置Sampler:

有了model之后是sampler的初始化 (基于命令行传入的 --plms,执行判断语句的第一条):sampler = PLMSSampler(model)

    if opt.plms:raise NotImplementedError("PLMS sampler not (yet) supported")sampler = PLMSSampler(model)else:sampler = DDIMSampler(model)# 设置Noise Schedulesampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False)

make_schedule()函数是PLMSSampler和DDIMSampler的函数,设置Noise Schedulealphabeta等参数(用于将self.model的参数注册为self)。先用lambda函数register_bufferself.model的参数提取copy一份,再用register_bufferself注册这些参数:

	def register_buffer(self, name, attr):if type(attr) == torch.Tensor:if attr.device != torch.device("cuda"):attr = attr.to(torch.device("cuda"))setattr(self, name, attr)def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)alphas_cumprod = self.model.alphas_cumprodassert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)self.register_buffer('betas', to_torch(self.model.betas))self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))# calculations for diffusion q(x_t | x_{t-1}) and othersself.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))# ddim sampling parametersddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),ddim_timesteps=self.ddim_timesteps,eta=ddim_eta,verbose=verbose)self.register_buffer('ddim_sigmas', ddim_sigmas)self.register_buffer('ddim_alphas', ddim_alphas)self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt((1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (1 - self.alphas_cumprod / self.alphas_cumprod_prev))self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
  1. 准备prompt:

紧接着,根据batch_size设置prompt数量。原代码提供了两种输入prompt的方法,分别是命令行输入和从文件读取,不关键。总之最后prompt进入了data这个变量

    batch_size = opt.n_samplesn_rows = opt.n_rows if opt.n_rows > 0 else batch_sizeif not opt.from_file:prompt = opt.promptassert prompt is not Nonedata = [batch_size * [prompt]]else:print(f"reading prompts from {opt.from_file}")with open(opt.from_file, "r") as f:data = f.read().splitlines()data = list(chunk(data, batch_size))sample_path = os.path.join(outpath, "samples")os.makedirs(sample_path, exist_ok=True)base_count = len(os.listdir(sample_path))grid_count = len(os.listdir(outpath)) - 1
  1. 准备编辑的image并转换为latent:

根据batch_size设置init_image数量,并且将images使用VAE encoder编码为latents

    assert os.path.isfile(opt.init_img)init_image = load_img(opt.init_img).to(device)init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image))  # move to latent space

load_img函数把image转换为tensor,同时完成一系列数据增强:

def load_img(path):image = Image.open(path).convert("RGB")w, h = image.sizeprint(f"loaded input image of size ({w}, {h}) from {path}")w, h = map(lambda x: x - x % 64, (w, h))  # resize to integer multiple of 32image = image.resize((w, h), resample=PIL.Image.LANCZOS)image = np.array(image).astype(np.float32) / 255.0image = image[None].transpose(0, 3, 1, 2)image = torch.from_numpy(image)return 2.*image - 1.

get_first_stage_encoding函数完成imagelatent

def get_first_stage_encoding(self, encoder_posterior):if isinstance(encoder_posterior, DiagonalGaussianDistribution):z = encoder_posterior.sample()elif isinstance(encoder_posterior, torch.Tensor):z = encoder_posteriorelse:raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")return self.scale_factor * z

到这里,我们有了

  • model - [LatentDiffusion]
  • sampler - [PLMSSampler]
  • text prompt
  • image latent

这样就可以开始编辑图片了(重新生成)。

  1. 图像编辑(再生成):

这里有两个重要的部分,一个是PLMSSampler的定义,一个是LatentDiffusion的定义。我们先将这两个模块视作黑箱,假定它们能完美的完成各自的任务,之后再详细看它们的代码。

在开始一张图像的time_step步的迭代sample之前,先计算Inversion的步骤t_enc ,设置推理的精度with autocast("cuda"),关闭梯度with torch.no_grad(),设置ema指数滑动平均with model.ema_scope()。然后开始迭代采样,trangetqdmrange

	assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]'t_enc = int(opt.strength * opt.ddim_steps)print(f"target t_enc is {t_enc} steps")precision_scope = autocast if opt.precision == "autocast" else nullcontextwith torch.no_grad():with precision_scope("cuda"):with model.ema_scope():all_samples = list()for n in trange(opt.n_iter, desc="Sampling"):

这里先简单回忆一下classifier-free guidance的方法:
在这里插入图片描述
因此除了prompt,也就是上式中c所对应的条件,还需要unconditional的Null text prompt

	# unconditional promptuc = Noneif opt.scale != 1.0:uc = model.get_learned_conditioning(batch_size * [""])# conditional promptif isinstance(prompts, tuple):prompts = list(prompts)c = model.get_learned_conditioning(prompts)

这里可以看到model中的一个方法get_learned_conditioning() : 输入text, 输出text的embedding。self.cond_stage_modelFrozenCLIPEmbedderself.cond_stage_model.encode()就是FrozenCLIPEmbedderforward(),完成分别送入CLIPTokenizerCLIPTextModel,返回编码后的prompt tensor

	def get_learned_conditioning(self, c):if self.cond_stage_forward is None:if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):c = self.cond_stage_model.encode(c)if isinstance(c, DiagonalGaussianDistribution):c = c.mode()else:c = self.cond_stage_model(c)else:assert hasattr(self.cond_stage_model, self.cond_stage_forward)c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)return c# FrozenCLIPEmbedderdef forward(self, text):batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,return_overflowing_tokens=False, padding="max_length", return_tensors="pt")tokens = batch_encoding["input_ids"].to(self.device)outputs = self.transformer(input_ids=tokens)z = outputs.last_hidden_statereturn z  # torch.Size([batch_size, 77, 768])

接着,为了编辑图像,先采用DDIM Inversion反演原始图像x0,对原始图像的latent进行随机加噪

# encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))

x t = α ˉ t x 0 + 1 − α ˉ t ϵ x_t=\sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon xt=αˉt x0+1αˉt ϵ
用到的stochastic_encode函数用于确定性加噪x0到xt,extract_into_tensor函数用于返回alpha序列中第t个值:

	def extract_into_tensor(a, t, x_shape):b, *_ = t.shapeout = a.gather(-1, t)return out.reshape(b, *((1,) * (len(x_shape) - 1)))@torch.no_grad()def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):# fast, but does not allow for exact reconstruction# t serves as an index to gather the correct alphasif use_original_steps:sqrt_alphas_cumprod = self.sqrt_alphas_cumprodsqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprodelse:sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphasif noise is None:noise = torch.randn_like(x0)return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)

之后就是图像生成,调用sampler实例的decode方法。

# decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,unconditional_conditioning=uc,)

传入:采样起点z_enc(原图Inversion的Noise Latent)、condition text embedding c、uncondition text embedding uc、unconditional_guidance_scale opt.scale、需要去噪的步数 t_enc(因为Inversion时只加噪的t_enc步,40步)。

迭代执行p_sample_ddim去噪,p_sample_ddim实现单步去噪,将 x t x_t xt去噪为 x t − 1 x_{t-1} xt1,最终得到编辑好的 x 0 x_0 x0

	@torch.no_grad()def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,use_original_steps=False):timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timestepstimesteps = timesteps[:t_start]time_range = np.flip(timesteps)total_steps = timesteps.shape[0]print(f"Running DDIM Sampling with {total_steps} timesteps")iterator = tqdm(time_range, desc='Decoding image', total=total_steps)x_dec = x_latentfor i, step in enumerate(iterator):index = total_steps - i - 1ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,unconditional_guidance_scale=unconditional_guidance_scale,unconditional_conditioning=unconditional_conditioning)return x_dec

p_sample_ddim实现timestep=t的单步去噪:拼接uc和c的text prompt embedding,self.model.apply_model是Unet预测uc和c引导下的noise,再进行classifier-free guidance的加权组合得到noise e_t,将 x t x_t xt去噪为 x t − 1 x_{t-1} xt1

	@torch.no_grad()def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,unconditional_guidance_scale=1., unconditional_conditioning=None):b, *_, device = *x.shape, x.deviceif unconditional_conditioning is None or unconditional_guidance_scale == 1.:e_t = self.model.apply_model(x, t, c)# classifier-free guidanceelse:  x_in = torch.cat([x] * 2)t_in = torch.cat([t] * 2)if isinstance(c, dict):assert isinstance(unconditional_conditioning, dict)c_in = dict()for k in c:if isinstance(c[k], list):c_in[k] = [torch.cat([unconditional_conditioning[k][i], c[k][i]])for i in range(len(c[k]))]else:c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])# do itelse:c_in = torch.cat([unconditional_conditioning, c])e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)if score_corrector is not None:assert self.model.parameterization == "eps"e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphasalphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prevsqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphassigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas# select parameters corresponding to the currently considered timestepa_t = torch.full((b, 1, 1, 1), alphas[index], device=device)a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)# current prediction for x_0pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()if quantize_denoised:pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)# direction pointing to x_tdir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_tnoise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperatureif noise_dropout > 0.:noise = torch.nn.functional.dropout(noise, p=noise_dropout)x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noisereturn x_prev, pred_x0

到这里为止,diffusion的任务已经结束了,x_samples_ddim 再经过decode_first_stage(VAE的decoder),就是最终的结果。

以上就是img2img.py文件的全部内容。这一部分绝大多数代码都是数据的读写和准备工作,核心逻辑部分比较少,还是比较好理解的。

接下来进入ddim文件去看sampler的代码实现。

2.2 DDIMSampler

这一模块的定义在ldm/models/diffusion/ddim.py中。

这一class包含以下方法:

class DDIMSampler(object):def __init__(self, model, schedule="linear", **kwargs):super().__init__()self.model = modelself.ddpm_num_timesteps = model.num_timestepsself.schedule = scheduledef register_buffer(self, name, attr):...setattr(self, name, attr)def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):...to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)self.register_buffer('hyper_params_name', to_torch(self.model.hyper_params_name))  # alpha, beta...@torch.no_grad()def sample(...):  # make_schedule_params and ddim_sampling...self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)samples, intermediates = self.ddim_sampling(...)return samples, intermediates@torch.no_grad()def ddim_sampling(...):  # x_t -> x_0...for i, step in enumerate(iterator):img, intermediates = self.p_sample_ddim(...)return img, intermediates@torch.no_grad()def p_sample_ddim(...):  # x_t -> x_t-1...return x_prev, pred_x0

其中register_buffer实际功能是:将attr(tensor)设置为name(module)的属性。也就是为模型注册更新变量的,这一简单的方法在make_schedule中被大量的调用,总之就是在生成的每一个step中计算并且更新diffusion过程的各个参数。

def register_buffer(self, name, attr):if type(attr) == torch.Tensor:if attr.device != torch.device("cuda"):attr = attr.to(torch.device("cuda"))setattr(self, name, attr)

接下来就是在上一篇中出现的sampling的入口函数sampler.sample(.....)

    def sample(self, S, batch_size, shape, c=None, eta=0., x_T=None, uc_scale=1., uc=None, **kwargs):self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)# samplingC, H, W = shapesize = (batch_size, C, H, W)samples, intermediates = self.ddim_sampling(c,size,x_T=x_T,uc_scale=uc_scale,uc=uc,)return samples, intermediates

这里先用了S,eta生成了参数schedule,将其他的参数[ batch_size, shape, conditioning, unconditional_guidance_scale, ]

接下来看ddim_sampling:返回采样结果x_0和中间结果intermediates

def ddim_sampling(self, cond, shape, x_T=None, uc_scale=1., uc=None,):device = self.model.betas.deviceb = shape[0]   # 实际就是batch_size# 如果x_T不存在,生成随机噪声作为第一步的图片if x_T is None:img = torch.randn(shape, device=device)else:img = x_Tif timesteps is None:timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timestepselif timesteps is not None and not ddim_use_original_steps:subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1timesteps = self.ddim_timesteps[:subset_end]intermediates = {'x_inter': [img], 'pred_x0': [img]}time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]print(f"Running DDIM Sampling with {total_steps} timesteps")iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)for i, step in enumerate(iterator):index = total_steps - i - 1ts = torch.full((b,), step, device=device, dtype=torch.long)img, pred_x0 = self.p_sample_ddim(img, cond, ts, index=index,uc_scale=uc_scale,uc=uc,t_next=ts_next)if index % log_every_t == 0 or index == total_steps - 1:intermediates['x_inter'].append(img)intermediates['pred_x0'].append(pred_x0)return img, intermediates

在这里搭建了关于逐步重建图片的反向去噪过程的for循环,在其中对于每一步的图片都应用了self.p_sample_ddim()这一方法,以下是其核心逻辑。

	def p_sample_ddim(self, x, c, t, index, uc_scale=1., uc=None, t_next=None):b, *_, device = *x.shape, x.device# classifier-free guidence text conditionx_in = torch.cat([x] * 2)t_in = torch.cat([t] * 2)c_in = torch.cat([unconditional_conditioning, c])# pred classifier-free guidence epsilon_te_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)# paramsalphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphasalphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prevsqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphassigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas# select parameters corresponding to the currently considered timestep ta_t = torch.full((b, 1, 1, 1), alphas[index], device=device)a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)# current prediction for x_0pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()if quantize_denoised:pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)# direction pointing to x_tdir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_tnoise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperatureif noise_dropout > 0.:noise = torch.nn.functional.dropout(noise, p=noise_dropout)x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noisereturn x_prev, pred_x0

终于用到了上一篇中提到的Unet model

e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)

显然,model对象中定义了apply_model这一函数,接收图像img,step t 和 text prompt 的embedding并返回对应的预测噪声e_t

之后再通过prediction for x_0、pointing to x_t,基于diffusion model的原理采样下一步的图片x_prev ,即x_t-1

回看一下上面的内容可以发现sampler在逻辑上并不复杂,但是中间(尤其是make_schedule这一方法中)大量的运算都是基于diffusion model的数学原理的。但是由于这些运算都是预先定义好的,从让代码跑起来/理解每一部分的代码在做什么这样的角度出发的话当作黑箱就可以了。

2.3 LatentDiffusion Model

在这一篇中我们对model的内部进行解读。model 的定义在ldm/models/diffusion/ddpm.py中。

这一文件长达1400+行,非常劝退。其中定义了DDPMLatentDiffusion以及DiffusionWrapper这3个类。

不过如果仅仅考虑inference阶段的代码逻辑,并且去除掉大量的条件判断的话,事实上可以简化到100行左右。

DDPM

import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from functools import partial
from tqdm import tqdm
from ldm.util import default, instantiate_from_config
from ldm.modules.diffusionmodules.util import make_beta_scheduleclass DDPM(pl.LightningModule):def __init__(self,unet_config,timesteps=1000,beta_schedule="linear",linear_start=1e-4,linear_end=2e-2,cosine_s=8e-3,**kwargs):super().__init__()self.model = DiffusionWrapper(unet_config)self.register_schedule(beta_schedule=beta_schedule, timesteps=timesteps,linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)def register_schedule(self, beta_schedule="linear", timesteps=1000,linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,cosine_s=cosine_s)alphas = 1. - betasalphas_cumprod = np.cumprod(alphas, axis=0)alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])timesteps, = betas.shapeself.num_timesteps = int(timesteps)to_torch = partial(torch.tensor, dtype=torch.float32)self.register_buffer('betas', to_torch(betas))self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))class LatentDiffusion(DDPM):def __init__(self, first_stage_config, cond_stage_config, scale_factor=1.0, *args, **kwargs):super().__init__(*args, **kwargs)self.scale_factor = scale_factorself.instantiate_first_stage(first_stage_config)self.instantiate_cond_stage(cond_stage_config)def register_schedule(self, beta_schedule="linear", timesteps=1000,linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):super().register_schedule(beta_schedule, timesteps, linear_start, linear_end, cosine_s)def instantiate_first_stage(self, config):model = instantiate_from_config(config)self.first_stage_model = model.eval()for param in self.first_stage_model.parameters():param.requires_grad = Falsedef instantiate_cond_stage(self, config):model = instantiate_from_config(config)self.cond_stage_model = model.eval()for param in self.cond_stage_model.parameters():param.requires_grad = Falsedef get_learned_conditioning(self, c): return self.cond_stage_model.encode(c)  @torch.no_grad()def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):z = 1. / self.scale_factor * zreturn self.first_stage_model.decode(z)def apply_model(self, x_noisy, t, cond, return_ids=False):cond = {'c_crossattn': [cond]}x_recon = self.model(x_noisy, t, **cond)return x_reconclass DiffusionWrapper(pl.LightningModule):def __init__(self, diff_model_config):super().__init__()self.diffusion_model = instantiate_from_config(diff_model_config)def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):cc = torch.cat(c_crossattn, 1)out = self.diffusion_model(x, t, context=cc)return out

这里LatentDiffusion继承自DDPM,而它们有:

	self.model = DiffusionWrapper(unet_config)...class DiffusionWrapper(pl.LightningModule):def __init__(self, diff_model_config):self.diffusion_model = instantiate_from_config(diff_model_config)

也就是,初始化LatentDiffusion的时候,同时会实例化一个DiffusionWrapper,并运行DiffusionWrapper的__init__中的instantiate_from_config,读取传入的config中的unet_config这一键值对应的Config字典,来初始化Unet。

如2.1提到的,一系列的模块都这样被初始化。【ctrl+F 搜索 instantiate_from_config 可以看到它被多次调用来进行模块的实例化】

回忆一下,在上一篇中提到,sampler中最后调用了model.apply这一方法来进行图像噪声的预测。sampler中的model实际上是LatentDiffusion的实例。

class LatentDiffusion(DDPM):def apply_model(self, x_noisy, t, cond, return_ids=False):cond = {'c_crossattn': [cond]}x_recon = self.model(x_noisy, t, **cond)return x_recon

可以看到 apply_model这一方法又调用了self.model(),而LatentDiffusion由于继承自DDPM,它的self.model实际上指向的是DiffusionWrapper,所以最终承担这一步运算任务的是DiffusionWrapper.diffusion_model

diffusion_model 这一变量是用instantiate_from_config根据unet_config这一字典实例化得来的。所以我们去看一下Config配置文件

model:base_learning_rate: 1.0e-04target: ldm.models.diffusion.ddpm.LatentDiffusionparams:...scheduler_config: # 10000 warmup stepstarget: ldm.lr_scheduler.LambdaLinearSchedulerparams:...unet_config:target: ldm.modules.diffusionmodules.openaimodel.UNetModelparams:...first_stage_config:target: ldm.models.autoencoder.AutoencoderKLparams:...cond_stage_config:target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

unet_config的target是UNetModel, 层层的嵌套终于快到末端了。

2.4 UNet Model

2.3 中,也就是说调用了UNetModel.forward(x_noisy, t, **cond)

class LatentDiffusion(DDPM):def apply_model(self, x_noisy, t, cond, return_ids=False):cond = {'c_crossattn': [cond]}x_recon = self.model(x_noisy, t, **cond)return x_recon

文件:ldm/modules/diffusionmodules/openaimodel.py

原代码太长了,还是和之前一样,按照原作者给的config删掉用不上的部分,原代码可以简化到不到200行。其中最重要的部分是对UnetModel的定义。

上采样和下采样

  • 上采样Upsample:先插值再卷积
class Upsample(nn.Module):def __init__(self, channels, out_channels=None, padding=1):super().__init__()self.conv = nn.Conv2d(channels, out_channels, 3, padding=padding)def forward(self, x):x = F.interpolate(x, scale_factor=2, mode="nearest")x = self.conv(x)return x
  • 下采样Downsample:卷积
class Downsample(nn.Module):def __init__(self, channels, out_channels=None,padding=1):super().__init__()self.op = nn.Conv2d(channels, out_channels, 3, stride=2, padding=padding)def forward(self, x):return self.op(x)
  • ResBlock:Resblock同时接受图像和 step embedding 两个输入,网络结构由in layer,out layer,emb layer以及skip connection四个部分组成。
    在这里插入图片描述
from abc import abstractmethodimport torch as th
import torch.nn as nn
import torch.nn.functional as Ffrom ldm.modules.diffusionmodules.util import normalizationclass TimestepBlock(nn.Module):@abstractmethoddef forward(self, x, emb):"""Apply the module to `x` given `emb` timestep embeddings."""class ResBlock(TimestepBlock):def __init__(self, channels, emb_channels, dropout, out_channels=None):super().__init__()self.channels = channelsself.dropout = dropoutself.out_channels = out_channels or channelsself.in_layers = nn.Sequential(normalization(channels),nn.SiLU(),nn.Conv2d(channels, self.out_channels, 3, padding=1))self.emb_layers = nn.Sequential(nn.SiLU(),nn.Linear(emb_channels,self.out_channels))self.out_layers = nn.Sequential(normalization(self.out_channels),nn.SiLU(),nn.Dropout(p=dropout),nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1))if self.out_channels == channels:self.skip_connection = nn.Identity()else:self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)def forward(self, x, emb):h = self.in_layers(x)emb_out = self.emb_layers(emb).type(h.dtype)while len(emb_out.shape) < len(h.shape):emb_out = emb_out[..., None]h = h + emb_outh = self.out_layers(h)
  • SpatialTransformer
    SpatialTransformer有以下必要的组件:
  1. CrossAttention
class CrossAttention(nn.Module):def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):super().__init__()inner_dim = dim_head * headscontext_dim = default(context_dim, query_dim)self.scale = dim_head ** -0.5self.heads = headsself.to_q = nn.Linear(query_dim, inner_dim, bias=False)self.to_k = nn.Linear(context_dim, inner_dim, bias=False)self.to_v = nn.Linear(context_dim, inner_dim, bias=False)self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim),nn.Dropout(dropout))def forward(self, x, context=None, mask=None):h = self.headsq = self.to_q(x)context = default(context, x)k = self.to_k(context)v = self.to_v(context)q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))sim = einsum('b i d, b j d -> b i j', q, k) * self.scale# attention, what we cannot get enough ofattn = sim.softmax(dim=-1)out = einsum('b i j, b j d -> b i d', attn, v)out = rearrange(out, '(b h) n d -> b n (h d)', h=h)return self.to_out(out)

在这里插入图片描述
2. BasicTransformerBlock

class FeedForward(nn.Module):def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):super().__init__()inner_dim = int(dim * mult)dim_out = default(dim_out, dim)project_in = nn.Sequential(nn.Linear(dim, inner_dim),nn.GELU()) if not glu else GEGLU(dim, inner_dim)self.net = nn.Sequential(project_in,nn.Dropout(dropout),nn.Linear(inner_dim, dim_out))def forward(self, x):return self.net(x)class BasicTransformerBlock(nn.Module):def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True):super().__init__()self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout)  # is a self-attentionself.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,heads=n_heads, dim_head=d_head, dropout=dropout)  # cross attentionself.norm1 = nn.LayerNorm(dim)self.norm2 = nn.LayerNorm(dim)self.norm3 = nn.LayerNorm(dim)def forward(self, x, context=None):x = self.attn1(self.norm1(x)) + xx = self.attn2(self.norm2(x), context=context) + xx = self.ff(self.norm3(x)) + xreturn x

在这里插入图片描述
注意这里的attn1是self-attention, attn2才是cross-attention

  1. 定义spatialTransformer
class SpatialTransformer(nn.Module):def __init__(self, in_channels, n_heads, d_head,depth=1, dropout=0., context_dim=None):super().__init__()self.in_channels = in_channelsinner_dim = n_heads * d_headself.norm = Normalize(in_channels)self.proj_in = nn.Conv2d(in_channels,inner_dim,kernel_size=1,stride=1,padding=0)self.transformer_blocks = nn.ModuleList([BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)for d in range(depth)])self.proj_out = nn.Conv2d(inner_dim,in_channels,kernel_size=1,stride=1,padding=0)def forward(self, x, context=None):b, c, h, w = x.shapex_in = xx = self.norm(x)x = self.proj_in(x)x = rearrange(x, 'b c h w -> b (h w) c')for block in self.transformer_blocks:x = block(x, context=context)x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)x = self.proj_out(x)return x + x_in

在block中是若干个上面定义的BasicTransformerBlock,数量由参数depth指定

  • UNetModel
    UNet基于config将上述模块组合起来。
    在这里插入图片描述
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):def forward(self, x, emb, context=None):for layer in self:if isinstance(layer, TimestepBlock):x = layer(x, emb)elif isinstance(layer, SpatialTransformer):x = layer(x, context)else:x = layer(x)return xclass UNetModel(nn.Module):def __init__(self,image_size,in_channels,model_channels,out_channels,num_res_blocks,attention_resolutions,dropout=0,channel_mult=(1, 2, 4, 8),num_heads=-1,transformer_depth=1,              # custom transformer supportcontext_dim=None):super().__init__()self.image_size = image_sizeself.in_channels = in_channelsself.model_channels = model_channelsself.out_channels = out_channelsself.num_res_blocks = num_res_blocksself.attention_resolutions = attention_resolutionsself.dropout = dropoutself.channel_mult = channel_multself.num_heads = num_headstime_embed_dim = model_channels * 4self.time_embed = nn.Sequential(nn.Linear(model_channels, time_embed_dim),nn.SiLU(),nn.Linear(time_embed_dim, time_embed_dim),)self.input_blocks = nn.ModuleList([TimestepEmbedSequential(nn.Conv2d(in_channels, model_channels, 3, padding=1))])self._feature_size = model_channelsinput_block_chans = [model_channels]ch = model_channelsds = 1for level, mult in enumerate(channel_mult):for _ in range(num_res_blocks):layers = [ResBlock(ch,time_embed_dim,dropout,out_channels=mult * model_channels)]ch = mult * model_channelsif ds in attention_resolutions:dim_head = ch // num_headslayers.append(SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim))self.input_blocks.append(TimestepEmbedSequential(*layers))self._feature_size += chinput_block_chans.append(ch)if level != len(channel_mult) - 1:out_ch = chself.input_blocks.append(TimestepEmbedSequential(Downsample(ch, out_channels=out_ch)))ch = out_chinput_block_chans.append(ch)ds *= 2self._feature_size += chdim_head = ch // num_headsself.middle_block = TimestepEmbedSequential(ResBlock(ch, time_embed_dim, dropout),SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim),ResBlock(ch, time_embed_dim, dropout))self._feature_size += chself.output_blocks = nn.ModuleList([])for level, mult in list(enumerate(channel_mult))[::-1]:for i in range(num_res_blocks + 1):ich = input_block_chans.pop()layers = [ResBlock(ch + ich, time_embed_dim, dropout, out_channels=model_channels * mult)]ch = model_channels * multif ds in attention_resolutions:dim_head = ch // num_headslayers.append(SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim))if level and i == num_res_blocks:out_ch = chlayers.append(Upsample(ch, out_channels=out_ch))ds //= 2self.output_blocks.append(TimestepEmbedSequential(*layers))self._feature_size += chself.out = nn.Sequential(normalization(ch),nn.SiLU(),nn.Conv2d(model_channels, out_channels, 3, padding=1))def forward(self, x, timesteps=None, context=None, y=None,**kwargs):hs = []t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)emb = self.time_embed(t_emb)h = x.type(th.float32)for module in self.input_blocks:h = module(h, emb, context)hs.append(h)h = self.middle_block(h, emb, context)for module in self.output_blocks:h = th.cat([h, hs.pop()], dim=1)h = module(h, emb, context)h = h.type(x.dtype)return self.out(h)

3. Diffusers StableDiffusionPipeline源码

本系列文章将深入diffusers的源码一步步进行解析,主要涉及:

  • Stable Diffusion 整体结构 与 模型导出
    https://zhuanlan.zhihu.com/p/603161500

  • Text encoder source code
    https://zhuanlan.zhihu.com/p/603168346

  • Unet source code
    https://zhuanlan.zhihu.com/p/603962505

  • Vae source code

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/221769.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

蚁群优化算法ACO

蚁群优化算法模拟了自然界中蚂蚁的觅食行为&#xff0c;信息素浓度的大小表征路径的远近&#xff0c;信息素浓度越高&#xff0c;表示对应的路径距离越短。同时&#xff0c;路径上的信息素浓度会随着时间的推进而逐渐衰减。 1.过程 &#xff08;1&#xff09;初始化参数 蚁群…

Nginx+Tomcat实现负载均衡和动静分离

目录 前瞻 动静分离和负载均衡原理 实现方法 实验&#xff08;七层代理&#xff09; 部署Nginx负载均衡服务器(192.168.75.50:80) 部署第一台Tomcat应用服务器&#xff08;192.168.75.60:8080&#xff09; 多实例部署第二台Tomcat应用服务器&#xff08;192.168.75.70:80…

8080端口被占用怎么解决,并结束释放8080端口

8080端口是被用于WWW代理服务的&#xff0c;可以实现网页浏览&#xff0c;经常在访问某个网站或使用代理服务器的时候&#xff0c;Win10 8080端口被占用解决方法吧。 1、按【 Win r】 2、运行窗口&#xff0c;输入【cmd】命令&#xff0c;按【确定或回车】&#xff0c;打开命…

基于SpringBoot 校园招聘系统设计与实现(源码+文档+可视化HTML+数据库)

摘 要 基于SpringBoot 校园招聘系统是一种基于Java技术的校园招聘和可视化展示的系统。该系统通过采集和整合各类招聘网站、社交媒体等渠道的数据&#xff0c;对招聘岗位进行深入分析&#xff0c;并将分析结果以直观、易懂的可视化形式呈现。系统能够自动从多个数据源获取招聘…

电子学会C/C++编程等级考试2023年03月(五级)真题解析

C/C++等级考试(1~8级)全部真题・点这里 第1题:拼点游戏 C和S两位同学一起玩拼点游戏。有一堆白色卡牌和一堆蓝色卡牌,每张卡牌上写了一个整数点数。C随机抽取n张白色卡牌,S随机抽取n张蓝色卡牌,他们进行n回合拼点,每次两人各出一张卡牌,点数大者获得三颗巧克力,小者获…

Flutter:web项目跨域问题解决

前后端解决系列 文章目录 一、Flutter web客户端解决本地环境调试跨域问题二、Flutter web客户端解决线上环境跨域问题 一、Flutter web客户端解决本地环境调试跨域问题 就一句命令【--web-browser-flag "--disable-web-security"】&#xff0c;用来屏蔽浏览器域名请…

大模型时代-让AI自己开发自己

一、前言 AI能自己开发自己或者开发和一个很像自己的东西吗&#xff1f;显然是可以的&#xff01;因为AI模型的算法&#xff0c;基本就是学习和递归 二、大模型的算法实现例子 本例子就是通过AI模型来写 大模型的实现通常涉及到深度学习框架和大量的计算资源。具体的算法代…

Java入门学习笔记二

一、抽象类 当编写一个类时&#xff0c;我们往往会为该类定义一些方法&#xff0c;这些方法是用来描述该类的行为方式&#xff0c;那么这些方法都有具体的方法体。 分析事物时&#xff0c;发现了共性内容&#xff0c;就出现向上抽取。会有这样一种特殊情况&#xff0c;就是功…

Python实现高效摸鱼,批量识别银行卡号并自动写入Excel表格

前言 每当有新员工入职&#xff0c;人事小姐姐都要收集大量的工资卡信息&#xff0c;并且生成Excel文档&#xff0c;看到小姐姐这么辛苦&#xff0c;我就忍不住要去帮她了… 于是我用1行代码就实现了自动识别银行卡信息并且自动生成Excel文件&#xff0c;小姐姐当场就亮眼汪汪…

智能冶钢厂环境监控与设备控制系统(边缘物联网网关)

目录 1、项目背景 2、项目功能介绍 3、模块框架 3.1 架构框图 3.2 架构介绍 4、系统组成与工作原理 4.1 数据采集 4.2 指令控制 4.3 其他模块 4.3.1 网页、qt视频流 4.3.2 qt搜索进程 5、成果呈现 6、问题解决 7、项目总结 1、项目背景 这个项目的背景是钢铁行业的…

tesseract-ocr安装使用

描述&#xff1a; 在centos上安装 tesseract 并在springboot项目中使用 步骤一&#xff1a;安装 确认使用的版本tesseract和test4j版本需要匹配&#xff0c;这里选择最新版 tesseract5.3.3 &#xff0c;test4j 5.9.0 版本匹配可查看 Releases nguyenq/tess4j GitHub 或…

使用 iperf 和 iftop 测试网络带宽

博主历时三年精心创作的《大数据平台架构与原型实现&#xff1a;数据中台建设实战》一书现已由知名IT图书品牌电子工业出版社博文视点出版发行&#xff0c;点击《重磅推荐&#xff1a;建大数据平台太难了&#xff01;给我发个工程原型吧&#xff01;》了解图书详情&#xff0c;…

CSS中神奇的filter属性

CSS是Web开发中不可或缺的一部分&#xff0c;它可以帮助开发者在页面上添加各种各样的样式和效果。其中一个比较神奇的CSS属性就是filter&#xff0c;它可以让我们实现各种有趣的图形处理效果。 一、filter属性的基础 filter属性是CSS中用于对元素进行图形效果处理的属性之一…

C语言—每日选择题—Day47

第一题 1. 以下逗号表达式的值为&#xff08;&#xff09; (x 4 * 5, x * 5), x 25 A&#xff1a;25 B&#xff1a;20 C&#xff1a;100 D&#xff1a;45 答案及解析 D 本题考查的就是逗号表达式&#xff0c;逗号表达式是依次计算每个表达式&#xff0c;但是只输出最后一个表…

【算法题】开源项目热度榜单(js)

解法 const lines ["4","8 6 2 8 6","camila 66 70 46 158 80","victoria 94 76 86 189 211","athony 29 17 83 21 48","emily 53 97 1 19 218", ]; const lines2 ["5","5 6 6 1 2","…

ArkTS入门

代码结构分析 struct Index{ } 「自定义组件&#xff1a;可复用的UI单元」 xxx 「装饰器&#xff1a;用来装饰类结构、方法、变量」 Entry 标记当前组件是入口组件&#xff08;该组件可被独立访问&#xff0c;通俗来讲&#xff1a;它自己就是一个页面&#xff09;Component 用…

tuxera2023破解版免费下载 NTFS for Mac读写工具(附序列号)

Tuxera ntfs 2023 破解安装包是一个mac读写ntfs磁盘工具允许您访问&#xff0c;它允许您访问NFTS 驱动器上的文件。 该应用程序提供访问访问Mac 设备中NFTS 格式文件的驱动力&#xff0c;因此您有权基于格式文件进行无困难的访问Windows 数据。 在发生电力灾难或断电时使用防损…

Signal EM的流程与分析

RedhawkTM 提供了一种在设计中分析Power EM和SignalEM的单一平台方法。Power EM通常作为Static IR和Dynamic IR分析的组成部分进行。Signal EM分析是单独进行分析的,检查设计中所有信号线和过孔的平均(单向或双向)、RMS和峰值电流密度【1】。 1 SignalEM 流程介绍 如图7…

ArrayList集合的两个实例应用,有趣的洗牌算法与杨辉三角

本节课的内容&#xff0c;就让我们来学习一下ArrayList集合的应用&#xff0c;ArrayList的本质就是一个顺序表&#xff0c;那下面一起来学习吧 目录 一、杨辉三角 1.题目详情及链接 2.剖析题目 3.思路及代码 二、洗牌算法 1.创造牌对象 2.创造一副牌 3.洗牌操作 4.发…

人工智能(pytorch)搭建模型22-基于pytorch搭建SimpleBaseline(人体关键点检测)模型,并详细介绍该网络模型与代码实现

大家好&#xff0c;我是微学AI&#xff0c;今天给大家介绍一下人工智能(pytorch)搭建模型22-基于pytorch搭建SimpleBaseline(人体关键点检测)模型&#xff0c;并详细介绍该网络模型与代码实现。本文将介绍关于SimpleBaseline模型的原理&#xff0c;以及利用pytorch框架搭建模型…