随机微分方程的分数扩散模型(Score-Based Generative Modeling through Stochastic Differential Equations)
基于分数的扩散模型,是估计数据分布梯度的方法,可以在不需要对抗训练的基础上,生成与GAN一样高质量的图片。来源于文章:Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. "Score-Based Generative Modeling through Stochastic Differential Equations." Internation Conference on Learning Representations, 2021
score-based diffusion是diffusion模型大火之后,又一个里程碑式的工作,将扩散模型和分数生成模型进行了统一。原始的扩散模型也有缺点,它的采样速度慢,通常需要数千个评估步骤才能抽取一个样本。而 score-based 的扩散模型可以在较短的时间内完成采样。
网络上有很多关于score-based diffusion原理介绍,应用案例等,还有文章解读,大家可以参考。但是,提供代码简介的很少,为此这里提供了score-based diffusion 模型的简单的可运行的代码示例。
1. 定义time-dependent score-based模型
导入相关模块
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npimport torch
import functools
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
import tqdm
1.1 将时间t嵌入的投影层
其实并没有投影层的说法,这里是为了描述将时间t (time step),随机初始化采样权重,然后使用[sin(2πωt);cos(2πωt)]生成相应的高斯随机特征向量的过程。注意,里面的参数是不可训练的。
class GaussianFourierProjection(nn.Module):"""Gaussian random features for encoding time steps.""" def __init__(self, embed_dim, scale=30.):super().__init__()# 在初始化期间随机采样权重。 这些权重是固定的 # 在优化期间并且不可训练self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)def forward(self, x):x_proj = x[:, None] * self.W[None, :] * 2 * np.pireturn torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
将时间t嵌入的投影层的出现,是因为score-based的扩散模型和正常的扩散模型的训练过程不一样。score-based的扩散模型在训练过程中,神经网络接受带有随机噪音的 x ,然后随机的时间信息 t 添加x中,然后利用x 和 t 作为输入,计算模型损失。
维度转换全连接层:
class Dense(nn.Module):"""A fully connected layer that reshapes outputs to feature maps."""def __init__(self, input_dim, output_dim):super().__init__()self.dense = nn.Linear(input_dim, output_dim)def forward(self, x):return self.dense(x)[..., None, None]
1.2 时间依赖基于分数的Unet模型
(time-dependent score-based model) 时间依赖,打分相关的Unet模型,froward函数中,输入除了x,还有时间t. 时间t经过GaussianFourierProjection嵌入后融合到模型中,然后输出marginal_prob_std正则化的结果。
class ScoreNet(nn.Module):"""初始化一个依赖时间的基于分数的Unet网络."""def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):""".Args:marginal_prob_std: 输入时间 t 并给出扰动核的标准差的函数 p_{0t}(x(t) | x(0)).channels: 各分辨率特征图的通道数.embed_dim: 高斯随机特征嵌入的维数,与1.1中GaussianFourierProjection相同."""super().__init__()# 时间t的高斯随机特征嵌入层self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),nn.Linear(embed_dim, embed_dim))# Encoding layers where the resolution decreasesself.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)self.dense1 = Dense(embed_dim, channels[0])self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)self.dense2 = Dense(embed_dim, channels[1])self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)self.dense3 = Dense(embed_dim, channels[2])self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)self.dense4 = Dense(embed_dim, channels[3])self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3]) # 分辨率增加的解码层self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)self.dense5 = Dense(embed_dim, channels[2])self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])self.tconv3 = nn.ConvTranspose2d(channels[2] + channels[2], channels[1], 3, stride=2, bias=False, output_padding=1) self.dense6 = Dense(embed_dim, channels[1])self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])self.tconv2 = nn.ConvTranspose2d(channels[1] + channels[1], channels[0], 3, stride=2, bias=False, output_padding=1) self.dense7 = Dense(embed_dim, channels[0])self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)# Swish 激活函数self.act = lambda x: x * torch.sigmoid(x)self.marginal_prob_std = marginal_prob_stddef forward(self, x, t): # 0 embed = self.act(self.embed(t)) # Encoding pathh1 = self.conv1(x) ## 合并来自 t 的信息h1 += self.dense1(embed)## 组标准化h1 = self.gnorm1(h1)h1 = self.act(h1)h2 = self.conv2(h1)h2 += self.dense2(embed)h2 = self.gnorm2(h2)h2 = self.act(h2)h3 = self.conv3(h2)h3 += self.dense3(embed)h3 = self.gnorm3(h3)h3 = self.act(h3)h4 = self.conv4(h3)h4 += self.dense4(embed)h4 = self.gnorm4(h4)h4 = self.act(h4)# Decoding pathh = self.tconv4(h4)## 从编码路径跳过连接h += self.dense5(embed)h = self.tgnorm4(h)h = self.act(h)h = self.tconv3(torch.cat([h, h3], dim=1))h += self.dense6(embed)h = self.tgnorm3(h)h = self.act(h)h = self.tconv2(torch.cat([h, h2], dim=1))h += self.dense7(embed)h = self.tgnorm2(h)h = self.act(h)h = self.tconv1(torch.cat([h, h1], dim=1))# Normalize output 正则化输出h = h / self.marginal_prob_std(t)[:, None, None, None]return h
2. 设置SDE
SDE用于将P_0扰动到P_T, 其中,包含两个重要函数:之前提到的marginal_prob_std和扩散系数diffusion_coeff marginal_prob_std,计算 p_{0t}(x(t) | x(0)) 的平均值和标准差; diffusion_coeff,计算SDE的扩散系数.
device = 'cuda' #@param ['cuda', 'cpu'] {'type':'string'}def marginal_prob_std(t, sigma):"""计算p_{0t}(x(t) | x(0))的平均值和标准差.Args: t: A vector of time steps.sigma: The $\sigma$ in our SDE. Returns:标准差.""" t = torch.tensor(t, device=device)return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))def diffusion_coeff(t, sigma):"""计算SDE的扩散系数.Args:t: A vector of time steps.sigma: The $\sigma$ in our SDE.Returns:扩散系数向量."""return torch.tensor(sigma**t, device=device)sigma = 25.0 #@param {'type':'number'}
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)