参考:
[1] https://github.com/xiaohu2015/nngen/blob/main/models/diffusion_models/ddpm_cifar10.ipynb
[2] https://www.bilibili.com/video/BV1we4y1H7gG/?spm_id_from=333.337.search-card.all.click&vd_source=9e9b4b6471a6e98c3e756ce7f41eb134
TOC
- 1 UNet部分
- 1.1 SelfAttention
- 1.2 DoubleConv
- 1.3 Down
- 1.4 Up
- 1.5 UNet模型
- 2 Diffusion部分以及回顾
- 2.1 beta_schedule
- 2.2 初始化
- 2.3 提取数组中的对应timestep的值
- 2.4 从 x 0 x_0 x0提取 x t x_t xt
- 2.5 真实后验分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt−1∣xt,x0)的均值和方差
- 2.6 估计重建 x 0 x_0 x0
- 2.7 计算 p θ p_\theta pθ的均值和方差
- 2.8 采样
- 3 训练部分
- 4 其他实验
- 4.1 加噪过程
- 4.2 多GPU分布式代码
- 4.3 去噪过程
- 忘记了十万八千次的知识点(随笔记)
1 UNet部分
1.1 SelfAttention
1)自注意力模块可以调用pytorch的 nn.MultiheadAttention(channels, head, batch_first)
,避免重复造轮子;
2)执行顺序为:
- 将输入由(B,C,H,W) -> (B,C,H*W) -> (B,H*W,C)
- 通过LayerNorm模块,得到x_ln
- 将x_ln作为三个qkv参数传入到多头注意力模块,得到attention_value
- 将attention_value和原始输入x进行残差连接
- (可加可不加)再通过前馈神经网络
- 将attention_value变回(B,C,H,W)
class SelfAttention(nn.Module):def __init__(self,channels):super().__init__()self.channels = channelsself.mha = nn.MultiheadAttention(channels, 4, batch_first=True)self.ln = nn.LayerNorm([channels])self.ff = nn.Sequential(nn.LayerNorm([channels]),nn.Linear(channels,channels),nn.GELU(),nn.Linear(channels,channels))def forward(self,x):B,C,H,W = x.shapex = x.reshape(-1,self.channels,H*W).swapaxes(1,2)x_ln = self.ln(x)attention_value = self.mha(x_ln)attention_value = attention_value + xattention_value = self.ff(attention_value)+ attention_valuereturn attention_value.swapaxes(1,2).view(-1,self.channels,H,W)
测试:
# here testing MHAmha = SelfAttention(32)x = torch.rand(3,32,64,64)out = mha(x)print(x.shape)# torch.Size([3, 32, 64, 64])
1.2 DoubleConv
相当于UNet中的double conv,只不过这里把一些模块换了,并且新增了residual结构。
class DoubleConv(nn.Module):def __init__(self,in_c,out_c,mid_c=None,residual=False):super().__init__()self.residual = residualif mid_c is None:mid_c = out_cself.double_conv = nn.Sequential(nn.Conv2d(in_c,mid_c,kernel_size=3,padding=1),nn.GroupNorm(1,mid_c),nn.GELU(),nn.Conv2d(mid_c,out_c,kernel_size=3,padding=1),nn.GroupNorm(1,mid_c))if in_c != out_c:self.shortcut = nn.Conv2d(in_c,out_c,kernel_size=1)else:self.shortcut = nn.Identity()def forward(self,x):if self.residual:return F.gelu(self.shortcut(x)+self.double_conv(x))else:return F.gelu(self.double_conv(x))
1.3 Down
down模块其实就是一个maxpooling层,再接两个double_conv层,其中double_conv的维度变化为in_c -> out_c -> out_c;
其次,还有一个timestep_embedding层,既一个激活函数+一个线性层,目的是为了让timestep的维度(B, emb_dim) 和要相加的数据一致(B, out_c, h,w)
class Down(nn.Module):def __init__(self,in_c,out_c,emb_dim=256):self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2) # kernel_size=2, stride default equal to kDoubleConv(in_c,out_c,residual=True),DoubleConv(out_c,out_c))self.emb_layer = nn.Sequential(nn.SiLU(),nn.Linear(emb_dim,out_c))def forward(self,x,t):x = self.maxpool_conv(x)emb = self.emb_layer(t)[:,:,None,None].repeat(1,1,x.shape[-2],x.shape[-1])# 扩维后,在最后两维重复h和w次,此时和x的尺寸相同return x+emb
1.4 Up
Up模块先进行双线性插值上采样,然后在channel维度进行拼接,之后在进行两次double conv。
同样要有timestep_embeding
class Up(nn.Module):def __init__(self,in_c,out_c,emb_dim=256):self.up = nn.UpSample(scale_factor=2,mode='bilinear', align_corner=True)self.conv = nn.Sequential(nn.Conv2d(in_c,in_c,residual=True),nn.Conv2d(in_c,out_c))self.emb_layer = nn.Sequential(nn.SiLU(),nn.Linear(emb_dim,out_c))def forward(self,x,skip_x, t):x = self.up(x)x = torch.cat([x,skip_x],dim=1)x = self.conv(x)emb = self.emb_layer(t)[:,:,None,None].repeat(1,1,x.shape[-2],x.shape[-1])return x + emb
1.5 UNet模型
根据常规的UNet模型拼接起来,在每次下采样和上采样之后加上self-attention层
class UNet(nn.Module):def __init__(self,in_c, out_c, time_dim=256, device='cuda'):super().__init__()self.device = deviceself.time_dim = time_dimself.inc = DoubleConv(c_in, 64)self.down1 = Down(64, 128)self.sa1 = SelfAttention(128)self.down2 = Down(128, 256)self.sa2 = SelfAttention(256)self.down3 = Down(256, 512)self.sa3 = SelfAttention(512)self.bot1 = DoubleConv(512, 512)self.bot2 = DoubleConv(512, 512)self.bot3 = DoubleConv(512, 256)self.up1 = Up(512, 128)self.sa4 = SelfAttention(128)self.up2 = Up(256, 64)self.sa5 = SelfAttention(64)self.up3 = Up(128, 64)self.sa6 = SelfAttention(64)self.outc = nn.Conv2d(64, c_out, kernel_size=1)def pos_encoding(self,t,channels):freq = 1.0/(10000**torch.arange(0,channels,2,device=self.device).float()/channels)args = t[:,None].float()*freq[None]embedding = torch.cat([torch.sin(args), torch.cos(args)],dim=-1)if channels % 2 != 0:embedding = torch.cat([embedding,torch.zeros_like(embedding[:,:1])],dim=-1)return embeddigdef forward(self,x,t):t = self.pos_encoding(t,self.time_dim)x1 = self.inc(x)x2 = self.down1(x1, t)x2 = self.sa1(x2)x3 = self.down2(x2, t)x3 = self.sa2(x3)x4 = self.down3(x3, t)x4 = self.sa3(x4)x4 = self.bot1(x4)x4 = self.bot2(x4)x4 = self.bot3(x4)x = self.up1(x4, x3, t)x = self.sa4(x)x = self.up2(x, x2, t)x = self.sa5(x)x = self.up3(x, x1, t)x = self.sa6(x)output = self.outc(x)return output
关于positional_encoding的讲解:
freq = 1.0/(10000**torch.arange(0,channels,2,device=self.device).float()/channels)
这里是从1到 1 1000 0 256 / 256 \frac{1}{10000^{256}/256} 10000256/2561args = t[:,None].float()*freq[None]
t为1-d向量,所以这里先进行扩维,freq是一个[channels//2]维,最终args为[3,128]embedding = torch.cat([torch.sin(args), torch.cos(args)],dim=-1)
最后一维进行拼接embedding = torch.cat([embedding,torch.zeros_like(embedding[:,:1])],dim=-1)
这里是防止维度为奇数的情况,若为奇数则在最后一维补0.- 使用方法:
在每一层residual block之后,使用emb_layer对timestep_embedding进行维度变换,之后加到数据上即可。
2 Diffusion部分以及回顾
2.1 beta_schedule
linear_beta_schedule
def linear_beta_schedule(self):scale = 1000/self.noise_stepsbeta_start = self.beta_start*scalebeta_end = self.beta_end*scalereturn torch.linspace(beta_start, beta_end, self.noise_steps)
cosine_beta_schedule
公式为: f ( t ) = c o s ( t / T + s 1 + s × π 2 ) 2 α t = f ( t ) / f ( 0 ) β t = 1 − α t α t − 1 f(t)=cos(\frac{t/T+s}{1+s}\times\frac{\pi}{2})^2\\\alpha_t=f(t)/f(0)\\\beta_t=1-\frac{\alpha_t}{\alpha_{t-1}} f(t)=cos(1+st/T+s×2π)2αt=f(t)/f(0)βt=1−αt−1αt
def cosine_beta_schedule(self,s=0.008):"""as proposed in Improved ddpm paper;"""steps = self.noise_steps + 1x = torch.linspace(0, self.noise_steps, steps, dtype=torch.float64) # 从0到self.noise_stepsalphas_cumprod = torch.cos(((x / self.noise_steps) + s) / (1 + s) * math.pi * 0.5) ** 2alphas_cumprod = alphas_cumprod / alphas_cumprod[0]betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) # alpha_cumprod包含了noise_steps+1个值,则alpha_t是第一个到最后一个;alpha_{t-1}是第0个到倒数第二个(第0个为0)return torch.clip(betas, 0, 0.999) # 不大于0.999
2.2 初始化
回顾一下DDPM所有的公式:
-
前向过程
x t = α t x t − 1 + 1 − α t ϵ x t = α ˉ t x 0 + 1 − α ˉ t ϵ x_t = \sqrt\alpha_tx_{t-1}+\sqrt{1-\alpha_t}\epsilon\\ x_t = \sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon xt=αtxt−1+1−αtϵxt=αˉtx0+1−αˉtϵ -
后验分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt−1∣xt,x0)的均值和方差
μ q = α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) x ^ θ 1 − α ˉ t = 1 α t x t − 1 − α t 1 − α ˉ t α t ϵ ^ ( x t , t ) \mu_q =\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})x_t+\sqrt{\bar\alpha_{t-1}}(1-\alpha_t)\hat x_\theta}{1-\bar\alpha_t}= \frac{1}{\sqrt{\alpha_t}}x_t-\frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}\sqrt{\alpha_t}}\hat\epsilon(x_t,t) μq=1−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)x^θ=αt1xt−1−αˉtαt1−αtϵ^(xt,t)
Σ = ( 1 − α t ) ( 1 − α ˉ t − 1 ) 1 − α ˉ t I = β t ( 1 − α ˉ t − 1 ) 1 − α ˉ t I \Sigma=\frac{(1-\alpha_t)(1-\bar\alpha_{t-1})}{1-\bar\alpha_t}I=\frac{\beta_t(1-\bar\alpha_{t-1})}{1-\bar\alpha_t}I Σ=1−αˉt(1−αt)(1−αˉt−1)I=1−αˉtβt(1−αˉt−1)I -
每一次采样得到的估计 x ^ 0 \hat x_0 x^0和 x t − 1 x_{t-1} xt−1
x ^ 0 = 1 α ˉ t ( x t − 1 − α ˉ t ϵ ) x t − 1 = μ ~ + σ t z \hat x_0 = \frac{1}{\sqrt{\bar\alpha_t}}(x_t-\sqrt{1-\bar\alpha_t}\epsilon)\\ x_{t-1} = \tilde\mu+\sigma_t z x^0=αˉt1(xt−1−αˉtϵ)xt−1=μ~+σtz
class Diffusion:def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, beta_schedule='linear',device="cuda"):self.noise_steps = noise_stepsself.beta_start = beta_startself.beta_end = beta_endself.img_size = img_sizeself.device = deviceif beta_schedule == 'linear':self.beta = self.linear_beta_schedule().to(device)elif beta_schedule == 'cosine':self.beta = self.cosine_beta_schedule().to(device)else:raise ValueError(f'Unknown beta schedule {beta_schedule}')# all parametersself.alpha = 1. - self.beta self.alpha_hat = torch.cumprod(self.alpha, dim=0) self.alpha_hat_prev = F.pad(self.alpha_hat[:-1],(1,0),value=1.)self.sqrt_alpha_hat = torch.sqrt(self.alpha_hat)self.sqrt_one_minus_alpha_hat = torch.sqrt(1.-self.alpha_hat)self.sqrt_recip_alpha_hat = torch.sqrt(1./self.alpha_hat) # 用于估计x_0,估计x_0后用于计算p(x_{t-1}|x_t) 均值self.sqrt_recip_minus_alpha_hat = torch.sqrt(1./self.alpha_hat-1) self.posterior_variance = (self.beta*(1.-self.alpha_hat_prev)/(1.-self.alpha_hat)) # 用于计算p(x_{t-1}|x_t)的方差self.posterior_mean_coef1 = (self.beta * torch.sqrt(self.alpha_hat_prev) / (1.0 - self.alphas_hat)) # 用于计算p(x_{t-1}|x_t)的均值self.posterior_mean_coef2 = ((1.0 - self.alphas_hat_prev)* torch.sqrt(self.alphas)/ (1.0 - self.alphas_hat))
2.3 提取数组中的对应timestep的值
def _extract(self,arr,t,x_shape):# 根据timestep t从arr中提取对应元素并变形为x_shapebs = x_shape[0]out = arr.to(t.device).gather(0,t).float()out = out.reshape(bs,*((1,)*(len(x_shape)-1))) # reshape为(bs,1,1,1)return out
2.4 从 x 0 x_0 x0提取 x t x_t xt
根据公式 x t = α ˉ t x 0 + 1 − α ˉ t ϵ x_t = \sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon xt=αˉtx0+1−αˉtϵ。
模型训练首先要根据随机采样的t和 x 0 x_0 x0来得到加噪后的 x t x_t xt以及 n o i s e noise noise,所以返回两个值。
def q_sample(self, x, t, noise=None):# q(x_t|x_0)if noise is None:Ɛ = torch.randn_like(x)sqrt_alpha_hat = self._extract(self.sqrt_alpha_hat,t,x.shape)sqrt_one_minus_alpha_hat = self._extract(self.sqrt_one_minus_alpha_hat,t,x.shape)return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ
2.5 真实后验分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt−1∣xt,x0)的均值和方差
参考上面的公式。
实际上我们将 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt−1∣xt)的分布(也就是模型拟合分布)也设为类似形式,然后将模型估计出来的重建 x 0 x_0 x0连同 x t x_t xt丢进这个函数来,预测出 p p p的均值和方差。重建 x 0 x_0 x0怎么来呢?模拟预测噪声,然后根据 x t x_t xt和预测噪声得来。
def q_posterior_mean_variance(self,x,x_t,t):# calculate mean and variance of q(x_{t-1}|x_t,x_0), we send parameters x0 and x_t into this function# in fact we use this function to predict p(x_{t-1}|x_t)'s mean and variance by sending x_t, \hat x_0, tposterior_mean = (self._extract(self.posterior_mean_coef1,t,x.shape) * x + self._extract(self.posterior_mean_coef2,t,x.shape) * x_t)posterior_variance = (self.posterior_variance,t,x.shape)return posterior_mean, posterior_variance
2.6 估计重建 x 0 x_0 x0
参考上面公式,根据 x t x_t xt和预测出的噪声pred_noise来估计,相当于 x t − ϵ p r e d x_t - \epsilon_{pred} xt−ϵpred
def estimate_x0_from_noise(self,x_t,t,noise):# \hat x_0return (self._extract(self.sqrt_recip_alpha_hat,t,x_t.shape)*x_t + self._extract(self.sqrt_recip_minus_alpha_hat,t,x_t.shape)*noise)
2.7 计算 p θ p_\theta pθ的均值和方差
首先通过 x t , t x_t,t xt,t预测噪声,然后估计出重建x0,将值裁剪到(-1,1),然后去估计均值和方差
def p_mean_variance(self,model,x_t,t,clip_denoised=True):pred_noise = model(x_t,t)x_recon = self.estimate_x0_from_noise(x_t,t,pred_noise)if clip_denoised:x_recon = torch.clamp(x_recon,min=-1.,max=1.)p_mean,p_var = self.q_posterior_mean_variance(x_recon,x_t,t)return p_mean,p_var
2.8 采样
采样就是 x t − 1 = μ + σ t z x_{t-1}= \mu+\sigma_t z xt−1=μ+σtz,这个 σ t \sigma_t σt是固定的,z是随机采样的,并且当t=0的时候,也就是最后一步不加噪声。 loop函数采样从noise_step到0。
def p_sample(self, model, x_t, t, clip_denoised=True):logging.info(f"Sampling {n} new images....")model.eval()with torch.no_grad():p_mean,p_var = self.p_mean_variance(model,x_t,t,clip_denoised=clip_denoised)noise = torch.randn_like(x_t)nonzero_mask = ((t!=0).float().view(-1,*([1]*len(x_t.shape)-1))) # 当t!=0时为1,否则为0pred_img = p_mean + nonzero_mask*(torch.sqrt(p_var))*noisereturn pred_imgdef p_sample_loop(self,model,shape):model.eval()with torch.no_grad():bs = shape[0]device = next(model.parameters()).to(device)img = torch.randn(shape,device=device)imgs = []for i in tqdm(reversed(range(0,self.noise_steps)),desc='sampling loop time step',total=self.noise_steps):img = self.p_sample(model,img,torch.full((bs,),i,device=device,dtype=torch.long)) # 从T到0imgs.append(img)return imgs@torch.no_grad()def sample(self,model,img_size,bs=8,channels=3):return self.p_sample_loop(model,(bs,channels,img_size,img_size))
3 训练部分
def train(args):setup_logging(args.run_name)device = args.devicedataloader = get_data(args)model = UNet().to(device)optimizer = optim.AdamW(model.parameters(), lr=args.lr)mse = nn.MSELoss()diffusion = Diffusion(img_size=args.image_size, device=device)logger = SummaryWriter(os.path.join("runs", args.run_name))l = len(dataloader)for epoch in range(args.epochs):logging.info(f"Starting epoch {epoch}:")pbar = tqdm(dataloader)for i, (images, _) in enumerate(pbar):images = images.to(device)t = diffusion.sample_timesteps(images.shape[0]).to(device)x_t, noise = diffusion.q_sample(images, t)predicted_noise = model(x_t, t)loss = mse(noise, predicted_noise)optimizer.zero_grad()loss.backward()optimizer.step()pbar.set_postfix(MSE=loss.item())logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)sampled_images = diffusion.sample(model, n=images.shape[0])save_images(sampled_images, os.path.join("results", args.run_name, f"{epoch}.jpg"))torch.save(model.state_dict(), os.path.join("models", args.run_name, f"ckpt.pt"))def launch():import argparseparser = argparse.ArgumentParser()args = parser.parse_args()args.run_name = "DDPM_Uncondtional"args.epochs = 500args.batch_size = 12args.image_size = 64args.dataset_path = r"C:\Users\dome\datasets\landscape_img_folder"args.device = "cuda"args.lr = 3e-4train(args)if __name__ == '__main__':launch()
4 其他实验
4.1 加噪过程
import ddpm
from PIL import Image
from torchvision import transforms
import torch
import matplotlib.pyplot as plt
import numpy as npimage = Image.open("giraffe.jpg")image_size = 128
transform = transforms.Compose([transforms.Resize(image_size),transforms.CenterCrop(image_size),transforms.PILToTensor(),transforms.ConvertImageDtype(torch.float),transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])x_start = transform(image).unsqueeze(0)diffusion_linear = ddpm.Diffusion(noise_steps=500)
diffusion_cosine = ddpm.Diffusion(noise_steps=500,beta_schedule='cosine')plt.figure(figsize=(16, 8))
for idx, t in enumerate([0, 50, 100, 200, 499]): x_noisy,_ = diffusion_linear.q_sample(x_start, t=torch.tensor([t])) # 使用q_sample去生成x_tx_noisy2,_ = diffusion_cosine.q_sample(x_start,t=torch.tensor([t])) # [1,3,128,128]noisy_image = (x_noisy.squeeze().permute(1, 2, 0) + 1) * 127.5 # 我们的x_t被裁剪到(-1,1),所以+1后乘以127.5noisy_img2 = (x_noisy2.squeeze().permute(1,2,0)+1)*127.5 # # [128,128,3] -> (0,2) noisy_image = noisy_image.numpy().astype(np.uint8)noisy_img2 = noisy_img2.numpy().astype(np.uint8)plt.subplot(2, 5, 1 + idx)plt.imshow(noisy_image)plt.axis("off")plt.title(f"t={t}")plt.subplot(2, 5, 6+idx)plt.imshow(noisy_img2)plt.axis('off')
plt.figtext(0.5, 0.95, 'Linear Beta Schedule', ha='center', fontsize=16) # 在第一行上方添加大标题
plt.figtext(0.5, 0.48, 'Cosine Beta Schedule', ha='center', fontsize=16) # 在第二行上方添加大标题
plt.savefig('temp_img/add_noise_process.png')
4.2 多GPU分布式代码
4.3 去噪过程
忘记了十万八千次的知识点(随笔记)
view
和reshape
的区别
https://zhuanlan.zhihu.com/p/593664378