【AI理论学习】手把手利用PyTorch实现扩散模型DDPM

手把手利用PyTorch实现扩散模型DDPM

  • DDPM代码实现
    • 神经网络
      • 定义辅助函数
      • 位置嵌入
      • ResNet block
      • 注意力模块
      • 分组归一化
      • Conditional U-Net
    • 定义前向扩散过程
    • 定义PyTorch数据集+DataLoader
    • 采样
    • 训练模型
    • 采样
    • 后续阅读
  • 参考链接

上一篇博文已经手把手推导了扩散模型DDPM,本文利用PyTorch在Google Colab notebook中实现扩散模型。

DDPM代码实现

注意,扩散模型有几种观点。在这里,我们采用discrete-time(潜变量模型)的观点,但请务必查看其他观点。

神经网络

神经网络需要在特定的时间步接收噪声图像,并返回预测的噪声。需要注意的是,预测的噪声是一个与输入图像具有相同大小/分辨率的张量。因此从技术上讲,网络接收和输出具有相同形状的张量。在这种情况下,可以使用什么类型的神经网络呢?

在这里通常使用的方法与自编码器(Autoencoder)非常相似,你可能还记得它出现在典型的“intro to deep learning门”教程中。Autoencoders在encoder和decoder之间具有一个所谓的“bottleneck”层。编码器首先将图像编码为较小的隐藏表示,称为“bottleneck”,然后解码器将该隐表示解码回实际图像。这迫使网络在瓶颈层中仅保留最重要的信息。

在体系结构方面,DDPM 的作者采用了一个 U-Net 结构,该结构由(Ronneberger等人,2015)引入,当时在医学图像分割领域取得了最先进的结果。这个网络,像任何自编码器一样,由一个位于中间的瓶颈层组成,确保网络仅学习最重要的信息。重要的是,它在编码器和解码器之间引入了残差连接,大大改善了梯度流动(灵感来自于 He等人,2015年的 ResNet)。
U-net
如图所示,U-Net 模型首先对输入进行下采样(即在空间分辨率方面使输入变小),然后进行上采样

接下来,我们逐步实现这个网络。

!pip install -q -U einops datasets matplotlib tqdm

导入相关依赖库

import math
from inspect import isfunction
from functools import partial%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange, reduce
from einops.layers.torch import Rearrangeimport torch
from torch import nn, einsum
import torch.nn.functional as F

定义辅助函数

首先,定义一些在实现神经网络时将使用的辅助函数和类。重要的是,定义了一个残差模块,它将输入简单地添加到特定函数的输出中(换句话说,将残差连接添加到特定函数中)。

def exists(x):return x is not Nonedef default(val, d):if exists(val):return valreturn d() if isfunction(d) else ddef num_to_groups(num, divisor):groups = num // divisorremainder = num % divisorarr = [divisor] * groupsif remainder > 0:arr.append(remainder)return arrclass Residual(nn.Module):def __init__(self, fn):super().__init__()self.fn = fndef forward(self, x, *args, **kwargs):return self.fn(x, *args, **kwargs) + x

我们还为上采样和下采样操作定义了别名。

def Upsample(dim, dim_out=None):return nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),nn.Conv2d(dim, default(dim_out, dim), 3, padding=1),)def Downsample(dim, dim_out=None):# 不再有阶梯卷积或池return nn.Sequential(Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),nn.Conv2d(dim * 4, default(dim_out, dim), 1),)

位置嵌入

由于神经网络的参数在不同时间(噪声水平)之间共享,作者采用了受 Transformer(Vaswani et al., 2017)启发的正弦位置嵌入(sinusoidal position embeddings)来编码 t t t。这使得神经网络可以“know”它正在处理批次中的每个图像的特定时间步(噪声水平)。

SinusoidalPositionEmbeddings 模块接受形状为(batch_size,1)的张量作为输入(即批次中多个带噪声图像的噪声水平),并将其转换为形状为(batch_size,dim)的张量,其中 dim 是位置嵌入的维度。然后将其添加到每个残差块中,我们将在后面看到。

class SinusoidalPositionEmbeddings(nn.Module):def __init__(self, dim):super().__init__()self.dim = dimdef forward(self, time):device = time.devicehalf_dim = self.dim // 2embeddings = math.log(10000) / (half_dim - 1)embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)embeddings = time[:, None] * embeddings[None, :]embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)return embeddings

总之就是将 t t t 编码为embedding,和原本的输入一起进入网络,让网络“知道”当前的输入属于哪个step。

ResNet block

接下来,定义 U-Net 模型的核心构建块。DDPM 的作者使用了 Wide ResNet block(Zagoruyko et al., 2016),但 Phil Wang 将标准卷积层替换为“weight standardized”版本,这与分组归一化的结合效果更好(有关详细信息,请参见(Kolesnikov et al., 2019))。

class WeightStandardizedConv2d(nn.Conv2d):"""https://arxiv.org/abs/1903.10520weight standardization purportedly works synergistically with group normalization"""def forward(self, x):eps = 1e-5 if x.dtype == torch.float32 else 1e-3weight = self.weightmean = reduce(weight, "o ... -> o 1 1 1", "mean")var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False))normalized_weight = (weight - mean) * (var + eps).rsqrt()return F.conv2d(x,normalized_weight,self.bias,self.strides,self.padding,self.dilation,self.groups,)class Block(nn.Module):def __init__(self, dim, dim_out, groups=8):super().__init__()self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1)self.norm = nn.GroupNorm(groups, dim_out)self.act = nn.SiLU()def forward(self, x, scale_shift=None):x = self.proj(x)x = self.norm(x)if exists(scale_shift):scale, shift = scale_shiftx = x * (scale + 1) + shiftx = self.act(x)return xclass ResnetBlock(nn.Module):"""https://arxiv.org/abs/1512.03385"""def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):super().__init__()self.mlp = (nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2))if exists(time_emb_dim)else None)self.block1 = Block(dim, dim_out, groups=groups)self.block2 = Block(dim_out, dim_out, groups=groups)self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()def forward(self, x, time_emb = None):scale_shift = Noneif exists(self.mlp) and exists(time_emb):time_emb = self.mlp(time_emb)time_emb = rearrange(time_emb, "b c -> b c 1 1")scale_shift = time_emb.chunk(2, dim=1)h = self.block1(x, scale_shift=scale_shift)h = self.block2(h)return h + self.res_conv(x)

注意力模块

现在,定义注意力模块,这是 DDPM 的作者在卷积块之间添加的。注意力是著名的 Transformer 架构(Vaswani et al., 2017)的构建块,在人工智能的各个领域,从自然语言处理和视觉到蛋白质折叠都取得了巨大的成功。Phil Wang 使用了两种注意力的变体:一种是常规的多头自注意力(multi-head self-attention)(就像在 Transformer 中使用的那样),另一种是线性注意力变体(linear attention variant)(Shen et al., 2018),其时间和内存要求与序列长度线性缩放,而不是常规注意力的二次缩放。

关于注意力机制的详细解释,请参阅 Jay Allamar 的精彩博客文章。

class Attention(nn.Module):def __init__(self, dim, heads=4, dim_head=32):super().__init__()self.scale = dim_head ** -0.5self.heads = headshidden_dim = dim_head * headsself.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)self.to_out = nn.Conv2d(hidden_dim, dim, 1)def forward(self, x):b, c, h, w = x.shapeqkv = self.to_qkv(x).chunk(3, dim=1)q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv)q = q * self.scalesim = einsum("b h d i, b h d j -> b h i j", q, k)sim = sim - sim.amax(dim=-1, keepdim=True).detach()attn = sim.softmax(dim=-1)out = einsum("b h i j, b h d j -> b h i d", attn, v)out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)return self.to_out(out)class LinearAttention(nn.Module):def __init__(self, dim, heads=4, dim_head=32):super().__init__()self.scale = dim_head ** -0.5self.heads = headshidden_dim = dim_head * headsself.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))def forward(self, x):b, c, h, w = x.shapeqkv = self.to_qkv(x).chunk(3, dim=1)q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv)q = q.softmax(dim=2)k = k.softmax(dim=-1)q = q * self.scalecontext = torch.einsum("b h d n, b h e n -> b h d e", k, v)out = torch.einsum("b h d e, b h d n -> b h e n", context, q)out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)return self.to_out(out)

分组归一化

DDPM 的作者在 U-Net 的卷积/注意力层之间交错使用了分组归一化group normalization)(Wu et al., 2018)。在下面,定义了一个 PreNorm 类,该类将在注意力层之前应用分组归一化,正如我们将在后面看到的。值得注意的是,关于在 Transformer 中是在注意力之前还是之后应用归一化一直存在争议。

class PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.fn = fnself.norm = nn.GroupNorm(1, dim)def forward(self, x):x = self.norm(x)return self.fn(x)

Conditional U-Net

现在,我们已经定义了所有构建块(position embeddings,ResNet blocks,attention和group normalization),现在该定义整个神经网络了。回想一下网络 ϵ θ ( x t , t ) \boldsymbol{\epsilon}_\theta\left(\mathbf{x_t}, t\right) ϵθ(xt,t)的工作是获取一批有噪声的图像及其各自的噪声水平,并输出添加到输入的噪声。更正式地说:

  • 网络采集一批形状为(batch_size, num_channels, height, width)的噪声图像和一批形状为 (batch_size, 1)的噪声水平作为输入,并返回一个形状为 (batch_size, num_channels, height, width)的张量

网络构建如下:

  • 首先,在一批有噪声的图像上应用卷积层,并计算噪声水平的位置嵌入(position embeddings)
  • 然后,执行一系列的下采样阶段(downsampling stages)。每个下采样阶段由2个ResNet blocks+ groupnorm+ attention+residual connection+a downsample operation组成
  • 在网络的中间,再次应用ResNet block,与attention交错
  • 接下来,执行一系列上采样阶段(upsampling stages)。每个上采样阶段由2个ResNet blocks + groupnorm +attention + residual connection + an upsample operation组成
  • 最后,在一个卷积层后面应用一个ResNet block。

最终,神经网络就像乐高积木一样层层堆叠(但了解它们是如何工作的很重要)。

class Unet(nn.Module):def __init__(self, dim, init_dim=None, out_dim=None, dim_mults=(1, 2, 4, 8), channels=3, self_condition=False,resnet_block_groups=4):super().__init__()# determine dimensionsself.channels = channelsself.self_condition = self_conditioninput_channels = channels * (2 if self_condition else 1)init_dim = default(init_dim, dim)self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0)  #  changed to 1 and 0 from 7,3dims = [init_dim, *map(lambda m: dim * m, dim_mults)]in_out = list(zip(dims[:-1], dims[1:]))block_klass = partial(ResnetBlock, groups=resnet_block_groups)# time embeddingstime_dim = dim * 4self.time_mlp = nn.Sequential(SinusoidalPositionEmbeddings(dim),nn.Linear(dim, time_dim),nn.GELU(),nn.Linear(time_dim, time_dim),)# layersself.downs = nn.ModuleList([])self.ups = nn.ModuleList([])num_resolutions = len(in_out)for ind, (dim_in, dim_out) in enumerate(in_out):is_last = ind >= (num_resolutions - 1)self.downs.append(nn.ModuleList([block_klass(dim_in, dim_in, time_emb_dim=time_dim),block_klass(dim_in, dim_in, time_emb_dim=time_dim),Residual(PreNorm(dim_in, LinearAttention(dim_in))),Downsample(dim_in, dim_out)if not is_lastelse nn.Conv2d(dim_in, dim_out, 3, padding=1),]))mid_dim = dims[-1]self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)for ind, (dim, dim_out) in enumerate(reversed(in_out)):is_last = ind == (len(in_out) - 1)self.ups.append(nn.ModuleList([block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),Residual(PreNorm(dim_out, LinearAttention(dim_out))),Upsample(dim_out, dim_in)if not is_lastelse nn.Conv2d(dim_out, dim_in, 3, padding=1),]))self.out_dim = default(out_dim, channels)self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)self.final_conv = nn.Conv2d(dim, self.out_dim, 1)def forward(self, x, time, x_self_cond=None):if self.self_condition:x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))x = torch.cat((x_self_cond, x), dim=1)x = self.init_conv(x)r = x.clone()t = self.time_mlp(time)h = []for block1, block2, attn, downsample in self.downs:x = block1(x, t)h.append(x)x = block2(x, t)x = attn(x)h.append(x)x = downsample(x)x = self.mid_block1(x, t)x = self.mid_attn(x)x = self.mid_block2(x, t)for block1, block2, attn, upsample in self.ups:x = torch.cat((x, h.pop()), dim=1)x = block1(x, t)x = torch.cat((x, h.pop()), dim=1)x = block2(x, t)x = attn(x)x = upsample(x)x = torch.cat((x, r), dim=1)x = self.final_res_block(x, t)return self.final_conv(x)

定义前向扩散过程

forward diffusion process在 T T T个时间步内逐渐将噪声从真实分布添加到图像中,这是根据variance schedule发生的。最初的DDPM作者采用了linear schedule

我们将前向过程的方差设置为线性增加的常数 from β 1 = 1 0 − 4 \beta_1=10^{-4} β1=104 to β T = 0.02 \beta_T=0.02 βT=0.02.

然而,在(Nichol et al.,2021)中表明,使用cosine schedule可以获得更好的结果。下面,我们定义 T T T个时间步的不同的schedule(我们稍后会选择一个):

def cosine_beta_schedule(timesteps, s=0.008):"""cosine schedule as proposed in https://arxiv.org/abs/2102.09672"""steps = timesteps + 1x = torch.linspace(0, timesteps, steps)alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2alphas_cumprod = alphas_cumprod / alphas_cumprod[0]betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])return torch.clip(betas, 0.0001, 0.9999)def linear_beta_schedule(timesteps):beta_start = 0.0001beta_end = 0.02return torch.linspace(beta_start, beta_end, timesteps)def quadratic_beta_schedule(timesteps):beta_start = 0.0001beta_end = 0.02return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2def sigmoid_beta_schedule(timesteps):beta_start = 0.0001beta_end = 0.02betas = torch.linspace(-6, 6, timesteps)return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start

首先,使用T=300个时间步的linear schedule,并从 β t \beta_t βt中定义我们需要的变量,例如,方差的累积乘积KaTeX parse error: Undefined control sequence: \bat at position 1: \̲b̲a̲t̲{\alpha}_t。下面的每个变量都只是一维张量,存储从 t t t T T T的数值。注意,我们还定义了一个extract函数,它允许我们按照 t t t提取一个批次的索引。

timesteps = 300# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)# define alphas
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)def extract(a, t, x_shape):batch_size = t.shape[0]out = a.gather(-1, t.cpu())return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

我们将用猫图像说明如何在扩散过程的每个时间步中添加噪声:

from PIL import Image
import requestsurl = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw) # PIL image of shape HWC
image

cat
将噪声添加到Pytorch张量,而不是Pillow Images中。首先定义能够将PIL图像转换为Pytorch张量(可以在上面添加噪声)的图像转换,反之亦然。

这些转换非常简单:我们首先通过除以255(其结果能在[0,1]范围),然后确保它们在[-1,1]范围。DDPM 文中提到:

我们假设图像数据由在集合{ 0 ,1 、. . . , 255}中的整数组成,然后线性缩放到[−1, 1]。 这确保了神经网络逆向过程能够从标准正态先验 p ( x T ) p(x_T) p(xT)开始且一致缩放的输入上运行。"

from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resizeimage_size = 128
transform = Compose([Resize(image_size),CenterCrop(image_size),ToTensor(),  # turn into torch Tensor of shape CHW, divide by 255Lambda(lambda t: (t * 2) - 1),])x_start = transform(image).unsqueeze(0)
x_start.shape

输出结果:
torch.Size([1, 3, 128, 128])

另外,还定义了反向变换(reverse transform),它接收一个PyTorch张量,其中包含[-1,1],并将它们重新转换回PIL图像:

import numpy as npreverse_transform = Compose([Lambda(lambda t: (t + 1) / 2),Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWCLambda(lambda t: t * 255.),Lambda(lambda t: t.numpy().astype(np.uint8)),ToPILImage(),
])
reverse_transform(x_start.squeeze())

cat
现在,可以像论文中定义前向扩散过程:

# forward diffusion (using the nice property)
def q_sample(x_start, t, noise=None):if noise is None:noise = torch.randn_like(x_start)sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

在特定的时间步中进行测试:

def get_noisy_image(x_start, t):# add noisex_noisy = q_sample(x_start, t=t)# turn back into PIL imagenoisy_image = reverse_transform(x_noisy.squeeze())return noisy_image
# take time step
t = torch.tensor([40])get_noisy_image(x_start, t)

noise cat
可视化不同时间步的结果:

import matplotlib.pyplot as plt# use seed for reproducability
torch.manual_seed(0)# source: https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py
def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):if not isinstance(imgs[0], list):# Make a 2d grid even if there's just 1 rowimgs = [imgs]num_rows = len(imgs)num_cols = len(imgs[0]) + with_origfig, axs = plt.subplots(figsize=(200,200), nrows=num_rows, ncols=num_cols, squeeze=False)for row_idx, row in enumerate(imgs):row = [image] + row if with_orig else rowfor col_idx, img in enumerate(row):ax = axs[row_idx, col_idx]ax.imshow(np.asarray(img), **imshow_kwargs)ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])if with_orig:axs[0, 0].set(title='Original image')axs[0, 0].title.set_size(8)if row_title is not None:for row_idx in range(num_rows):axs[row_idx, 0].set(ylabel=row_title[row_idx])plt.tight_layout()
plot([get_noisy_image(x_start, torch.tensor([t])) for t in [0, 50, 100, 150, 199]])

不同时间步结果可视化
在给定模型的情况下定义损失函数:

def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):if noise is None:noise = torch.randn_like(x_start)x_noisy = q_sample(x_start=x_start, t=t, noise=noise)predicted_noise = denoise_model(x_noisy, t)if loss_type == 'l1':loss = F.l1_loss(noise, predicted_noise)elif loss_type == 'l2':loss = F.mse_loss(noise, predicted_noise)elif loss_type == "huber":loss = F.smooth_l1_loss(noise, predicted_noise)else:raise NotImplementedError()return loss

dnoise_model就是上面定义的U-Net。在真实噪声和预测噪声之间使用Huber损失

定义PyTorch数据集+DataLoader

这里定义一个常规的PyTorch数据集。该数据集仅由真实数据集(如Fashion、MNIST、CIFAR-10或ImageNet)的图像组成,线性缩放至 [ − 1 , 1 ] [-1,1] [1,1]
每个图像都被调整为相同的大小,同时是随机水平翻转的。从论文中:

我们在 CIFAR10 的训练过程中使用了随机水平翻转;我们尝试了有翻转和没有翻转的训练,发现翻转可以稍微提高样本质量。

在这里,使用Datasets库轻松地从hub加载 Fashion MNIST 数据集。该数据集由已经具有相同分辨率的图像组成,即 28x28。

from datasets import load_dataset# load dataset from the hub
dataset = load_dataset('fashion_mnist')
image_size = 28
channels = 1
batch_size = 128

接下来,定义一个函数,将在整个数据集上即时应用它。为此使用该with_transform功能。该函数只是应用了一些基本的图像预处理:随机水平翻转、重新缩放并最终使它们在[-1,1]范围。

from torchvision import transforms
from torch.utils.data import DataLoader# define image transformations(e.g. using torchvision)
transform = Compose([transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Lambda(lambda t: (t * 2) -1)
])# define function
def transforms(examples):examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]del examples["image"]return examplestransformed_dataset = dataset.with_transform(transforms).remove_columns("label")# create dataloader
dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True)
batch = next(iter(dataloader))
print(batch.keys())  # dict_keys(['pixel_values'])

采样

由于将在训练期间从模型中采样(来跟踪进度),因此定义了下面的代码。采样方法总结如下:
Sampling
从扩散模型生成新图像是通过逆扩散过程来实现的:从 T T T开始,从高斯分布中采样纯噪声,然后使用神经网络逐渐去噪(使用它学到的条件概率),直到在时间步 t = 0 t=0 t=0结束。如上所示,可以得到一个稍微降噪的图像 x t − 1 x_{t-1} xt1通过使用我们的噪声预测器插入均值的重新参数化。注意,方差是提前知道的。

理想情况下,最终会得到一张看起来像是来自真实数据分布的图像。下面的代码实现了这一点。

@torch.no_grad()
def p_sample(model, x, t, t_index):betas_t = extract(betas, t, x.shape)sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)# Equation 11 in the paper# Use our model (noise predictor) to predict the meanmodel_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)if t_index == 0:return model_meanelse:posterior_variance_t = extract(posterior_variance, t, x.shape)noise = torch.randn_like(x)# Algorithm 2 line 4:return model_mean + torch.sqrt(posterior_variance_t) * noise # Algorithm 2 (including returning all images)
@torch.no_grad()
def p_sample_loop(model, shape):device = next(model.parameters()).deviceb = shape[0]# start from pure noise (for each example in the batch)img = torch.randn(shape, device=device)imgs = []for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)imgs.append(img.cpu().numpy())return imgs@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

训练模型

接下来,以常规 PyTorch 方式训练模型。我们还定义了一些逻辑来定期保存生成的图像,使用上面定义的sample定义的方法。

from pathlib import Pathdef num_to_groups(num, divisor):groups = num // divisorremainder = num % divisorarr = [divisor] * groupsif remainder > 0:arr.append(remainder)return arrresults_folder = Path("./results")
results_folder.mkdir(exist_ok = True)
save_and_sample_every = 1000

下面,定义模型,并将其移动到GPU,还定义了一个标准优化器(Adam)。

from torch.optim import Adamdevice = "cuda" if torch.cuda.is_available() else "cpu"model = Unet(dim=image_size,channels=channels,dim_mults=(1, 2, 4,)
)
model.to(device)optimizer = Adam(model.parameters(), lr=1e-3)

开始训练:

from torchvision.utils import save_imageepochs = 6for epoch in range(epochs):for step, batch in enumerate(dataloader):optimizer.zero_grad()batch_size = batch["pixel_values"].shape[0]batch = batch["pixel_values"].to(device)# Algorithm 1 line 3: sample t uniformally for every example in the batcht = torch.randint(0, timesteps, (batch_size,), device=device).long()loss = p_losses(model, batch, t, loss_type="huber")if step % 100 == 0:print("Loss:", loss.item())loss.backward()optimizer.step()# save generated imagesif step != 0 and step % save_and_sample_every == 0:milestone = step // save_and_sample_everybatches = num_to_groups(4, batch_size)all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))all_images = torch.cat(all_images_list, dim=0)all_images = (all_images + 1) * 0.5save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = 6)

训练过程:

Loss: 0.5570111274719238
Loss: 0.06583500653505325
Loss: 0.06006840616464615
Loss: 0.051015421748161316
Loss: 0.0394190177321434
Loss: 0.04075610265135765
Loss: 0.039987701922655106
Loss: 0.03415030241012573
Loss: 0.030019590631127357
Loss: 0.036297883838415146
Loss: 0.037256866693496704
Loss: 0.03864285722374916
Loss: 0.03298967331647873
Loss: 0.03331328555941582
Loss: 0.027535393834114075
Loss: 0.03803558647632599
Loss: 0.03721949830651283
Loss: 0.03478413075208664
Loss: 0.03918925300240517
Loss: 0.03608154132962227
Loss: 0.027622627094388008
Loss: 0.02948344498872757
Loss: 0.029868196696043015
Loss: 0.03154699504375458
Loss: 0.029723389074206352
Loss: 0.039195798337459564
Loss: 0.032130151987075806
Loss: 0.031276602298021317
Loss: 0.03440115600824356
Loss: 0.030476151034235954

采样

要从模型中采样,可以使用上面定义的采样函数:

# sample 64 images
samples = sample(model, image_size=image_size, batch_size=64, channels=channels)# show a random one
random_index = 5
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")

采样结果
看起来模型能够生成一件漂亮的 T 恤!请记住,用来训练的数据集的分辨率非常低(28x28)。还可以创建去噪过程的 gif 图像:

import matplotlib.animation as animationrandom_index = 53fig = plt.figure()
ims = []
for i in range(timesteps):im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)ims.append([im])animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
animate.save('diffusion.gif')
plt.show()

gif图像

后续阅读

注意,DDPM 论文表明,扩散模型是(无)条件图像生成的一个promising的方向。从DDPM提出到现在已经(极大地)得到了改进,尤其是在文本条件图像生成方面。下面,列出了一些截至到2022年6月7日之前的重要的(但远非详尽的)后续工作:

  • Improved Denoising Diffusion Probabilistic Models (Nichol et al., 2021): 发现学习条件分布的方差(除均值外)有助于提高性能
  • Cascaded Diffusion Models for High Fidelity Image Generation (Ho et al., 2021): 引入了级联扩散,它包含多个扩散模型的pipeline,可生成分辨率不断提高的图像,用于高保真图像合成
  • Diffusion Models Beat GANs on Image Synthesis (Dhariwal et al., 2021): 表明扩散模型可以通过改进 U-Net 架构以及引入分类器引导,达到比SOTA生成式模型更好的效果。
  • Classifier-Free Diffusion Guidance (Ho et al., 2021): 表明不需要使用分类器来指导扩散模型,只需要使用单个神经网络联合训练条件扩散模型和无条件扩散模型
  • Hierarchical Text-Conditional Image Generation with CLIP Latents (DALL-E 2) (Ramesh et al., 2022): 使用先验将文字说明转换为 CLIP 图像嵌入,然后使用扩散模型将其解码为图像
  • Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding (ImageGen) (Saharia et al., 2022): 表明将大型预训练语言模型(例如 T5)与级联扩散相结合非常适用于文本到图像合成

参考链接

  1. The Annotated Diffusion Model
  2. 带你深入理解扩散模型DDPM
  3. 扩散模型全新课程:扩散模型从0到1实现!
  4. Denoising Diffusion Probabilitistic Models
  5. 《Diffusion Models Beat GANs on Image Synthesis》阅读笔记
  6. How Diffusion Models Work
  7. DDPM交叉熵损失函数推导
  8. DDPM(Denoising Diffusion Probabilistic Models)扩散模型简述
  9. What are Diffusion Models?
  10. 由浅入深了解Diffusion Model
  11. 什么是Diffusion模型?
  12. Probabilistic Diffusion Model概率扩散模型理论与完整PyTorch代码详细解读
  13. Denoising Diffusion Probabilistic Model, in Pytorch

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/33303.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【mysql】—— 表的约束

目录 序言 (一)空属性 (二)默认值 (三)列描述 (四)zerofill (五)主键 (六)自增长 (七)唯一键 &#…

Angular安全专辑 —— CSP防止XSS攻击

什么是 CSP(Content Security Policy) CSP(Content Security Policy)是一种Web安全策略,用于减轻和防止跨站脚本攻击(XSS)等安全漏洞。它通过允许网站管理员定义哪些资源可以加载到网页中&#…

面试热题(LRU缓存)

请你设计并实现一个满足 LRU (最近最少使用) 缓存 约束的数据结构。 实现 LRUCache 类: LRUCache(int capacity) 以 正整数 作为容量 capacity 初始化 LRU 缓存int get(int key) 如果关键字 key 存在于缓存中,则返回关键字的值,否则返回 -1 …

如何在页面中嵌入音频和视频?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 嵌入音频⭐ 嵌入视频⭐ 写在最后 ⭐ 专栏简介 前端入门之旅:探索Web开发的奇妙世界 记得点击上方或者右侧链接订阅本专栏哦 几何带你启航前端之旅 欢迎来到前端入门之旅!这个专栏是为那些对Web开发感兴趣、刚刚踏…

SpringSecurity 详解(通俗易懂)

SpringSecurity 详解 1、SpringSecurity讲解1.1、SpringSecurity完整流程1.2、认证流程 2、登录,退出,注册_分析说明2.1、登录2.2、校验2.3、退出2.4、注册2.5、SecurityContextHolder说明 3、代码实现3.1、引入依赖3.2、登录 退出 注册3.2.1、SpringSec…

k8部署安装

1 环境初始化 1.1 检查操作系统的版本 此方式下安装kubernetes集群要求Centos版本要在7.5或之上 [rootmaster ~]# cat /etc/redhat-release Centos Linux 7.5.1804 (Core)1.2主机名成解析 三台服务器的/etc/hosts文件 192.168.90.100 master 192.168.90.106 node1 192.168.…

18 | 基于DDD的微服务设计实例

为了更好地理解 DDD 的设计流程,这篇文章会用一个项目来带你了解 DDD 的战略设计和战术设计,走一遍从领域建模到微服务设计的全过程,一起掌握 DDD 的主要设计流程和关键点。 项目基本信息 项目的目标是实现在线请假和考勤管理。功能描述如下…

论文阅读---《Unsupervised ECG Analysis: A Review》

题目 无监督心电图分析一综述 摘要 电心图(ECG)是检测异常心脏状况的黄金标准技术。自动检测心电图异常有助于临床医生分析心脏监护仪每天产生的大量数据。由于用于训练监督式机器学习模型的带有心脏病专家标签的异常心电图样本数量有限,对…

K8S调度

K8S调度 一、List-Watch 机制 controller-manager、scheduler、kubelet 通过 List-Watch 机制监听 apiserver 发出的事件,apiserver 通过 List-Watch 机制监听 etcd 发出的事件1.scheduler 的调度策略 预选策略/预算策略:通过调度算法过滤掉不满足条件…

Ceph集群安装部署

Ceph集群安装部署 目录 Ceph集群安装部署 1、环境准备 1.1 环境简介1.2 配置hosts解析(所有节点)1.3 配置时间同步2、安装docker(所有节点)3、配置镜像 3.1 下载ceph镜像(所有节点执行)3.2 搭建制作本地仓库(ceph-01节点执行)3.3 配置私有仓库(所有节点执行)3.4 为 Docker 镜像…

opencv实战项目 手势识别-手势控制鼠标

手势识别系列文章目录 手势识别是一种人机交互技术,通过识别人的手势动作,从而实现对计算机、智能手机、智能电视等设备的操作和控制。 1. opencv实现手部追踪(定位手部关键点) 2.opencv实战项目 实现手势跟踪并返回位置信息&…

AAAI论文阅读

文章目录 Open-Vocabulary Multi-Label Classifcation via Multi-Modal Knowledge Transfer——知识蒸馏的范畴Med-EASi: Finely Annotated Dataset and Models for Controllable Simplifcation of Medical Texts——医学领域数据集构建“Nothing Abnormal”: Disambiguating M…

UE 5 GAS 在项目中处理AttributeSet相关

这一篇文章是个人的实战经验记录,如果对基础性的内容不了解的,可以看我前面一篇文章对基础的概念以及内容的讲解。 设置AttributeSet 使用GAS之前,首先需要设置参数集AS,这个是用于同步的一些参数,至于如何设置GAS&a…

Apollo让自动驾驶如此简单

前言: 最近被新能源的电价闹的不行,买了电车的直呼上当了、不香了。但电车吸引人不只是公里油耗低,还有良好的驾车使用感。比如辅助驾驶、甚至是自动驾驶。今天来介绍一个头部自动驾驶平台Apollo,Apollo是一个开源的、自动驾驶的软…

【论文阅读】基于深度学习的时序预测——FEDformer

系列文章链接 论文一:2020 Informer:长时序数据预测 论文二:2021 Autoformer:长序列数据预测 论文三:2022 FEDformer:长序列数据预测 论文四:2022 Non-Stationary Transformers:非平…

3.6 Spring MVC文件上传

1. 文件上传到本地 实现方式 Spring MVC使用commons-fileupload实现文件上传,注意事项如下: l HTTP请求方法是POST。 l HTTP请求头的Content-Type是multipart/form-data。 SpringMVC配置 配置commons-fileupload插件的文件上传解析器CommonsMultip…

SpringBoot复习:(34)@EnableWebMvc注解为什么让@WebMvcAutoconfiguration失效?

它导入了DelegatingWebMvcConfiguration 它会把容器中的类型为WebMvcConfigurer的bean注入到类型为WebMvcConfigurerComposite的成员变量configurers中。 可以看到它继承了WebMvcConfigurerSupport类 而WebMvcConfigureAutoConfiguration类定义如下 可以看到一个Conditional…

Flink源码之RPC

Flink是一个典型的Master/Slave分布式实时处理系统,分布式系统组件之间必然涉及通信,也即RPC,以下图展示Flink组件之间的关系: RPCGateWay 一般RPC框架可根据用户业务类生成客户端和服务器端通信底层代码,此时只需定…

基于机器学习进行降雨预测 -- 机器学习项目基础篇(13)

在本文中,我们将学习如何构建一个机器学习模型,该模型可以根据一些大气因素预测今天是否会有降雨。这个问题与使用机器学习的降雨预测有关,因为机器学习模型往往在以前已知的任务上表现得更好,而这些任务需要高技能的个人来完成。…

mysql正则表达式

mysql正则表达式 : REGEXP 规则基本和java正则规则一样,有些特殊字符有差异. 举例: -- 匹配以英文开头,4位英文字母8位数字/4位数字-4位数字 select * from tableName where code REGEXP ^[A-Z]{4}[0-9]{8}/[0-9]{3}-[0-9]{3}$ ;将code中的3位数字改成4位.高位补0. -- 数据…