目录
- 一、diffusion创建
- 二、GaussianDiffusion定义
- 三、代码理解
- def __init__(self,model,img_size,img_channels,num_classes,betas, loss_type="l2", ema_decay=0.9999, ema_start=5000, ema_update_rate=1,):
- def remove_noise(self, x, t, y, use_ema=True):
- def sample(self, batch_size, device, y=None, use_ema=True):
- def perturb_x(self, x, t, noise):
- def get_losses(self, x, t, y):
- def forward(self, x, y=None):
- def generate_cosine_schedule(T, s=0.008)和def generate_linear_schedule(T, low, high):
一、diffusion创建
diffusion = GaussianDiffusion(model,args.img_size,args.img_channels,args.num_classes,betas,ema_decay=args.ema_decay,ema_update_rate=args.ema_update_rate,ema_start=2000,loss_type=args.loss_type,)
二、GaussianDiffusion定义
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as Ffrom functools import partial
from copy import deepcopyfrom ddpm.ema import EMA
from ddpm.utils import extractclass GaussianDiffusion(nn.Module):def __init__(self,model,img_size,img_channels,num_classes,betas,loss_type="l2",ema_decay=0.9999,ema_start=5000,ema_update_rate=1,):super().__init__()self.model = modelself.ema_model = deepcopy(model)self.ema = EMA(ema_decay)self.ema_decay = ema_decayself.ema_start = ema_startself.ema_update_rate = ema_update_rateself.step = 0self.img_size = img_sizeself.img_channels = img_channelsself.num_classes = num_classesif loss_type not in ["l1", "l2"]:raise ValueError("__init__() got unknown loss type")self.loss_type = loss_typeself.num_timesteps = len(betas)alphas = 1.0 - betasalphas_cumprod = np.cumprod(alphas)to_torch = partial(torch.tensor, dtype=torch.float32)self.register_buffer("betas", to_torch(betas))self.register_buffer("alphas", to_torch(alphas))self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1 - alphas_cumprod)))self.register_buffer("reciprocal_sqrt_alphas", to_torch(np.sqrt(1 / alphas)))self.register_buffer("remove_noise_coeff", to_torch(betas / np.sqrt(1 - alphas_cumprod)))self.register_buffer("sigma", to_torch(np.sqrt(betas)))def update_ema(self):self.step += 1if self.step % self.ema_update_rate == 0:if self.step < self.ema_start:self.ema_model.load_state_dict(self.model.state_dict())else:self.ema.update_model_average(self.ema_model, self.model)@torch.no_grad()def remove_noise(self, x, t, y, use_ema=True):if use_ema:return ((x - extract(self.remove_noise_coeff, t, x.shape) * self.ema_model(x, t, y)) *extract(self.reciprocal_sqrt_alphas, t, x.shape))else:return ((x - extract(self.remove_noise_coeff, t, x.shape) * self.model(x, t, y)) *extract(self.reciprocal_sqrt_alphas, t, x.shape))@torch.no_grad()def sample(self, batch_size, device, y=None, use_ema=True):if y is not None and batch_size != len(y):raise ValueError("sample batch size different from length of given y")x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)for t in range(self.num_timesteps - 1, -1, -1):t_batch = torch.tensor([t], device=device).repeat(batch_size)x = self.remove_noise(x, t_batch, y, use_ema)if t > 0:x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)return x.cpu().detach()@torch.no_grad()def sample_diffusion_sequence(self, batch_size, device, y=None, use_ema=True):if y is not None and batch_size != len(y):raise ValueError("sample batch size different from length of given y")x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)for t in range(self.num_timesteps - 1, -1, -1):t_batch = torch.tensor([t], device=device).repeat(batch_size)x = self.remove_noise(x, t_batch, y, use_ema)if t > 0:x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)yield x.cpu().detach()def perturb_x(self, x, t, noise):return (extract(self.sqrt_alphas_cumprod, t, x.shape) * x +extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * noise) def get_losses(self, x, t, y):noise = torch.randn_like(x)perturbed_x = self.perturb_x(x, t, noise)estimated_noise = self.model(perturbed_x, t, y)if self.loss_type == "l1":loss = F.l1_loss(estimated_noise, noise)elif self.loss_type == "l2":loss = F.mse_loss(estimated_noise, noise)return lossdef forward(self, x, y=None):b, c, h, w = x.shapedevice = x.deviceif h != self.img_size[0]:raise ValueError("image height does not match diffusion parameters")if w != self.img_size[0]:raise ValueError("image width does not match diffusion parameters")t = torch.randint(0, self.num_timesteps, (b,), device=device)return self.get_losses(x, t, y)def generate_cosine_schedule(T, s=0.008):def f(t, T):return (np.cos((t / T + s) / (1 + s) * np.pi / 2)) ** 2alphas = []f0 = f(0, T)for t in range(T + 1):alphas.append(f(t, T) / f0)betas = []for t in range(1, T + 1):betas.append(min(1 - alphas[t] / alphas[t - 1], 0.999))return np.array(betas)def generate_linear_schedule(T, low, high):return np.linspace(low, high, T)
三、代码理解
Input:
x: (N, img_channels, *img_size)
y: (N)
Output:
scalar loss tensor
Args:
model (nn.Module):估计高斯噪声的模型
img_size (tuple): (H, W)
img_channels (int): 图像通道数
betas (np.ndarray): diffusion betas 数组
loss_type (string): loss type, “l1” or “l2” 类型
ema_decay (float): model weights exponential moving average decay
ema_start (int): number of steps before EMA
ema_update_rate (int): number of steps before each EMA update
“”"
def init(self,model,img_size,img_channels,num_classes,betas, loss_type=“l2”, ema_decay=0.9999, ema_start=5000, ema_update_rate=1,):
np.cumprod返回数组沿指定轴的累计积。
a=[a1,a2,a3,a4,a5]
np.cumprod(a)=array([a1,a1a2,a1a2a3,a1a2a3a4,a1a2a3a4a5])。
def remove_noise(self, x, t, y, use_ema=True):
(x - extract(self.remove_noise_coeff, t, x.shape) * self.model(x, t, y)) *extract(self.reciprocal_sqrt_alphas, t, x.shape)
这个函数就是去除第t-1到第t步的噪声
在这个函数里面调用了extract函数。实现的功能:提取时间步t时对应的参数
def extract(a, t, x_shape):b, *_ = t.shapeout = a.gather(-1, t)return out.reshape(b, *((1,) * (len(x_shape) - 1)))
a: Tensor:(1000,)
t: Tensor:(128,)
x_shape: torch.Size([128, 1, 28, 28])
最终返回的是Tensor:(128,1,1,1)
模型定义在初始化函数中,模型调用定义在forward函数中。
def sample(self, batch_size, device, y=None, use_ema=True):
def sample(self, batch_size, device, y=None, use_ema=True):if y is not None and batch_size != len(y):raise ValueError("sample batch size different from length of given y")x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)for t in range(self.num_timesteps - 1, -1, -1):t_batch = torch.tensor([t], device=device).repeat(batch_size)x = self.remove_noise(x, t_batch, y, use_ema)if t > 0:x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)return x.cpu().detach()
def perturb_x(self, x, t, noise):
在图像中添加噪声
return (extract(self.sqrt_alphas_cumprod, t, x.shape) * x +extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * noise)
def get_losses(self, x, t, y):
计算添加噪声和估计噪声的损失
noise = torch.randn_like(x)perturbed_x = self.perturb_x(x, t, noise)estimated_noise = self.model(perturbed_x, t, y)if self.loss_type == "l1":loss = F.l1_loss(estimated_noise, noise)elif self.loss_type == "l2":loss = F.mse_loss(estimated_noise, noise)return loss
def forward(self, x, y=None):
前向函数很简单,随机b个t,然后计算对应的噪声损失。
def forward(self, x, y=None):b, c, h, w = x.shapedevice = x.deviceif h != self.img_size[0]:raise ValueError("image height does not match diffusion parameters")if w != self.img_size[0]:raise ValueError("image width does not match diffusion parameters")t = torch.randint(0, self.num_timesteps, (b,), device=device)return self.get_losses(x, t, y)
def generate_cosine_schedule(T, s=0.008)和def generate_linear_schedule(T, low, high):
这个函数就是两种不同的生成betas的方法。betas数组是从小到大排列的。