所有代码 已上传至GitHub - duhanyue349/diffusion_model_learned_ddpm_main: 扩散模型基础框架源代码
目录结构如下
在train_cifar.py 中展示了扩散模型训练的所有代码
如果没有安装wandb 可以在create_argparser()设置 log_to_wandb=False
一、加载模型参数 args
这里用了一个create_argparser()函数创建命令行解析器
args = create_argparser().parse_args()
def create_argparser():device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")run_name = datetime.datetime.now().strftime("ddpm-%Y-%m-%d-%H-%M")defaults = dict(learning_rate=2e-4,batch_size=128,iterations=800000,log_to_wandb=True,log_rate=10,checkpoint_rate=10,log_dir="~/ddpm_logs",project_name='ddpm',run_name=run_name,model_checkpoint=None,optim_checkpoint=None,schedule_low=1e-4,schedule_high=0.02,device=device,)defaults.update(script_utils.diffusion_defaults())parser = argparse.ArgumentParser()script_utils.add_dict_to_argparser(parser, defaults)return parser
defaults是基础的一些参数,用defaults.update(script_utils.diffusion_defaults())可以将模型参数加载进来 使用了这个函数diffusion_defaults()其返回的是一个字典
def diffusion_defaults():defaults = dict(num_timesteps=1000,schedule="linear",loss_type="l2",use_labels=False,base_channels=128,channel_mults=(1, 2, 2, 2),num_res_blocks=2,time_emb_dim=128 * 4,norm="gn",dropout=0.1,activation="silu",attention_resolutions=(1,),schedule_low=1e-4,schedule_high=0.02,ema_decay=0.9999,ema_update_rate=1,)return defaults
随后 实例化命令行解析器 parser = argparse.ArgumentParser() script_utils.add_dict_to_argparser(parser, defaults)为解析器添加参数#这个函数是运用了字典存储命令函参数的形式,通过命令行参数的键值对来获取参数
def add_dict_to_argparser(parser, default_dict):"""https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/script_util.py"""for k, v in default_dict.items():v_type = type(v)if v is None:v_type = strelif isinstance(v, bool):v_type = str2boolparser.add_argument(f"--{k}", default=v, type=v_type)
基础的创建解析器的步骤为
#parser = argparse.ArgumentParser() 创建命令行解析器
#parser.add_argument() 添加命令行参数
#args = parser.parse_args() 对命令行参数进行解析
二、获得 diffusion 模型架构
diffusion = script_utils.get_diffusion_from_args(args).to(device)
这里调用了 get_diffusion_from_args 函数 加载模型,输入时刚刚创建的参数解析器
def get_diffusion_from_args(args):activations = {"relu": F.relu,"mish": F.mish,"silu": F.silu,}model = UNet(img_channels=3,base_channels=args.base_channels,channel_mults=args.channel_mults,time_emb_dim=args.time_emb_dim,norm=args.norm,dropout=args.dropout,activation=activations[args.activation],attention_resolutions=args.attention_resolutions,num_classes=None if not args.use_labels else 10,initial_pad=0,)if args.schedule == "cosine":betas = generate_cosine_schedule(args.num_timesteps)else:betas = generate_linear_schedule(args.num_timesteps,args.schedule_low * 1000 / args.num_timesteps,args.schedule_high * 1000 / args.num_timesteps,)diffusion = GaussianDiffusion(model, (32, 32), 3, 10,betas,ema_decay=args.ema_decay,ema_update_rate=args.ema_update_rate,ema_start=2000,loss_type=args.loss_type,)return diffusion
返回的是一个 GaussianDiffusion类 把model(UNet)、betas、loss_type等传给了这个类
这里beta 有两种定义方法 一个时cosine 一个是linear
betas = generate_cosine_schedule(args.num_timesteps)
betas = generate_linear_schedule(args.num_timesteps,args.schedule_low * 1000 / args.num_timesteps,args.schedule_high * 1000 / args.num_timesteps,)
GaussianDiffusionz这个类在diffusion .py 文件中
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as Ffrom functools import partial
from copy import deepcopyfrom .ema import EMA
from .utils import extractclass GaussianDiffusion(nn.Module):__doc__ = r"""Gaussian Diffusion model. Forwarding through the module returns diffusion reversal scalar loss tensor.Input:x: tensor of shape (N, img_channels, *img_size)y: tensor of shape (N)Output:scalar loss tensorArgs:model (nn.Module): model which estimates diffusion noiseimg_size (tuple): image size tuple (H, W)img_channels (int): number of image channelsbetas (np.ndarray): numpy array of diffusion betasloss_type (string): loss type, "l1" or "l2"ema_decay (float): model weights exponential moving average decayema_start (int): number of steps before EMAema_update_rate (int): number of steps before each EMA update"""def __init__(self,model,img_size,img_channels,num_classes,betas,loss_type="l2",ema_decay=0.9999,ema_start=5000,ema_update_rate=1,):super().__init__()self.model = modelself.ema_model = deepcopy(model)self.ema = EMA(ema_decay)self.ema_decay = ema_decayself.ema_start = ema_startself.ema_update_rate = ema_update_rateself.step = 0self.img_size = img_sizeself.img_channels = img_channelsself.num_classes = num_classesif loss_type not in ["l1", "l2"]:raise ValueError("__init__() got unknown loss type")self.loss_type = loss_typeself.num_timesteps = len(betas)alphas = 1.0 - betasalphas_cumprod = np.cumprod(alphas)to_torch = partial(torch.tensor, dtype=torch.float32)self.register_buffer("betas", to_torch(betas))self.register_buffer("alphas", to_torch(alphas))self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1 - alphas_cumprod)))self.register_buffer("reciprocal_sqrt_alphas", to_torch(np.sqrt(1 / alphas)))self.register_buffer("remove_noise_coeff", to_torch(betas / np.sqrt(1 - alphas_cumprod)))self.register_buffer("sigma", to_torch(np.sqrt(betas)))def update_ema(self):self.step += 1if self.step % self.ema_update_rate == 0:if self.step < self.ema_start:self.ema_model.load_state_dict(self.model.state_dict())else:self.ema.update_model_average(self.ema_model, self.model)@torch.no_grad()def remove_noise(self, x, t, y, use_ema=True):if use_ema:return ((x - extract(self.remove_noise_coeff, t, x.shape) * self.ema_model(x, t, y)) *extract(self.reciprocal_sqrt_alphas, t, x.shape))else:return ((x - extract(self.remove_noise_coeff, t, x.shape) * self.model(x, t, y)) *extract(self.reciprocal_sqrt_alphas, t, x.shape))@torch.no_grad()def sample(self, batch_size, device, y=None, use_ema=True):if y is not None and batch_size != len(y):raise ValueError("sample batch size different from length of given y")x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)for t in range(self.num_timesteps - 1, -1, -1):t_batch = torch.tensor([t], device=device).repeat(batch_size)x = self.remove_noise(x, t_batch, y, use_ema)if t > 0:x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)return x.cpu().detach()@torch.no_grad()def sample_diffusion_sequence(self, batch_size, device, y=None, use_ema=True):if y is not None and batch_size != len(y):raise ValueError("sample batch size different from length of given y")x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)diffusion_sequence = [x.cpu().detach()]for t in range(self.num_timesteps - 1, -1, -1):t_batch = torch.tensor([t], device=device).repeat(batch_size)x = self.remove_noise(x, t_batch, y, use_ema)if t > 0:x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)diffusion_sequence.append(x.cpu().detach())return diffusion_sequencedef perturb_x(self, x, t, noise):return (extract(self.sqrt_alphas_cumprod, t, x.shape) * x +extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * noise) def get_losses(self, x, t, y):noise = torch.randn_like(x)perturbed_x = self.perturb_x(x, t, noise)estimated_noise = self.model(perturbed_x, t, y)if self.loss_type == "l1":loss = F.l1_loss(estimated_noise, noise)elif self.loss_type == "l2":loss = F.mse_loss(estimated_noise, noise)return lossdef forward(self, x, y=None):b, c, h, w = x.shapedevice = x.deviceif h != self.img_size[0]:raise ValueError("image height does not match diffusion parameters")if w != self.img_size[0]:raise ValueError("image width does not match diffusion parameters")t = torch.randint(0, self.num_timesteps, (b,), device=device)return self.get_losses(x, t, y)def generate_cosine_schedule(T, s=0.008):def f(t, T):return (np.cos((t / T + s) / (1 + s) * np.pi / 2)) ** 2alphas = []f0 = f(0, T)for t in range(T + 1):alphas.append(f(t, T) / f0)betas = []for t in range(1, T + 1):betas.append(min(1 - alphas[t] / alphas[t - 1], 0.999))return np.array(betas)def generate_linear_schedule(T, low, high):return np.linspace(low, high, T)
三、获得优化器、数据集
optimizer = torch.optim.Adam(diffusion.parameters(), lr=args.learning_rate)
#省略了原代码中if args.model_checkpoint is not None:、if args.log_to_wandb:.....这些batch_size = args.batch_sizetrain_dataset = datasets.CIFAR10(root='./cifar_train',train=True,download=True,transform=script_utils.get_transform(),)test_dataset = datasets.CIFAR10(root='./cifar_test',train=False,download=True,transform=script_utils.get_transform(),)train_loader = script_utils.cycle(DataLoader(train_dataset,batch_size=batch_size,shuffle=True,drop_last=True,num_workers=2,))test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True, num_workers=2)
这里采用了一个cycle 的函数 循环加载数据 后面会和next 一起使用,x, y = next(train_loader)
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as Tclass RescaleChannels:def __call__(self, sample):return 2 * sample - 1'''return 2 * sample - 1 - 这一行代码是对输入 sample 进行线性变换的公式。
它将 sample 的每个像素值乘以 2,然后减去 1。这样的变换通常用于将像素值从 [0, 1] 范围映射到 [-1, 1] 范围。
为什么要做这样的转换?在神经网络训练中,将数据归一化到特定的范围可以带来以下好处:
数值稳定性:某些激活函数(如 tanh)在输入接近 [-1, 1] 范围时性能更好。
加速收敛:归一化的数据可以减少梯度消失或爆炸的问题,从而加快模型的训练速度。
标准化:确保不同来源或不同尺度的数据在模型中具有相似的影响。
所以,当你有一个 sample,比如说一个图像,其像素值范围是 [0, 1],
通过 RescaleChannels 类的实例调用,它会将像素值转换到 [-1, 1] 范围,这在很多情况下对于模型训练是有利的'''
def get_transform():return T.Compose([T.ToTensor(),RescaleChannels(),])
train_dataset = datasets.CIFAR10(root='./cifar_train',train=True,download=True,transform=get_transform(),
)
def cycle(dl):"""https://github.com/lucidrains/denoising-diffusion-pytorch/"""while True:for data in dl:yield data#这个 cycle 函数是一个无限循环的生成器,它的作用是让数据加载器(dl)的数据可以被无限次地迭代。这种设计通常在深度学习中用于数据增强或者当训练数据集较小而希望增加训练轮次时使用。
#当这个函数被调用时,它会不断地从 dl 中取出数据,一旦 dl 的数据被完全遍历,它会重新开始遍历,从而形成一个无限循环的数据流。
#这种设计允许你在训练模型时,即使数据集很小,也可以像拥有无限数据一样进行训练。
train_loader = cycle(DataLoader(train_dataset,batch_size=batch_size,shuffle=True,drop_last=True,num_workers=2,))
四、开始训练
acc_train_loss = 0for iteration in range(1, args.iterations + 1):diffusion.train()x, y = next(train_loader)x = x.to(device)y = y.to(device)if args.use_labels:loss = diffusion(x, y)else:loss = diffusion(x)acc_train_loss += loss.item()optimizer.zero_grad()loss.backward()optimizer.step()diffusion.update_ema()
五、测试部分
if iteration % args.log_rate == 0:test_loss = 0with torch.no_grad():diffusion.eval()for x, y in test_loader:x = x.to(device)y = y.to(device)if args.use_labels:loss = diffusion(x, y)else:loss = diffusion(x)test_loss += loss.item()if args.use_labels:samples = diffusion.sample(10, device, y=torch.arange(10, device=device))else:samples = diffusion.sample(10, device)#将张量(tensor)转换回可以显示或保存的图像格式samples = ((samples + 1) / 2).clip(0, 1).permute(0, 2, 3, 1).numpy()#将数据范围从[-1, 1](常见的归一化范围)转换到[0, 2] 将范围从[0, 2]缩放到[0, 1],这是常见的图像像素值范围。函数确保所有像素值都在0和1之间。这可以防止因为浮点数运算误差导致的像素值超出正常范围。,permute将其转换为(batch_size, height, width, channels),这通常是将数据从PyTorch的通道优先格式转换为更通用的格式,便于显示或保存图像。test_loss /= len(test_loader)acc_train_loss /= args.log_ratewandb.log({"test_loss": test_loss,"train_loss": acc_train_loss,"samples": [wandb.Image(sample) for sample in samples],})acc_train_loss = 0if iteration % args.checkpoint_rate == 0:model_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-model.pth"optim_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-optim.pth"# 获取目录路径log_dir1 = os.path.dirname(model_filename)log_dir2 = os.path.dirname(optim_filename)# 创建目录,如果它不存在os.makedirs(log_dir1, exist_ok=True)os.makedirs(log_dir2, exist_ok=True)# 使用完整的文件路径保存模型和优化器状态torch.save(diffusion.state_dict(), model_filename)torch.save(optimizer.state_dict(), optim_filename)
六、整个train_cifar 所有代码
import argparse
import datetime
import torch
import wandb
from torch.utils.data import DataLoader
from torchvision import datasets
from ddpm import script_utils
import osos.environ["WANDB_API_KEY"] = "b9171ddb0a1638d8cca0425e41c8a9d789281515"
os.environ["WANDB_MODE"] = "online"wandb.login(key="b9171ddb0a1638d8cca0425e41c8a9d789281515")
def main():args = create_argparser().parse_args()device = args.devicetry:diffusion = script_utils.get_diffusion_from_args(args).to(device)optimizer = torch.optim.Adam(diffusion.parameters(), lr=args.learning_rate)if args.model_checkpoint is not None:diffusion.load_state_dict(torch.load(args.model_checkpoint))if args.optim_checkpoint is not None:optimizer.load_state_dict(torch.load(args.optim_checkpoint))if args.log_to_wandb:if args.project_name is None:raise ValueError("args.log_to_wandb set to True but args.project_name is None")run = wandb.init(project=args.project_name,config=vars(args),name=args.run_name,)wandb.watch(diffusion)batch_size = args.batch_sizetrain_dataset = datasets.CIFAR10(root='./cifar_train',train=True,download=True,transform=script_utils.get_transform(),)test_dataset = datasets.CIFAR10(root='./cifar_test',train=False,download=True,transform=script_utils.get_transform(),)train_loader = script_utils.cycle(DataLoader(train_dataset,batch_size=batch_size,shuffle=True,drop_last=True,num_workers=2,))test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True, num_workers=2)acc_train_loss = 0for iteration in range(1, args.iterations + 1):diffusion.train()x, y = next(train_loader)x = x.to(device)y = y.to(device)if args.use_labels:loss = diffusion(x, y)else:loss = diffusion(x)acc_train_loss += loss.item()optimizer.zero_grad()loss.backward()optimizer.step()diffusion.update_ema()if iteration % args.log_rate == 0:test_loss = 0with torch.no_grad():diffusion.eval()for x, y in test_loader:x = x.to(device)y = y.to(device)if args.use_labels:loss = diffusion(x, y)else:loss = diffusion(x)test_loss += loss.item()if args.use_labels:samples = diffusion.sample(10, device, y=torch.arange(10, device=device))else:samples = diffusion.sample(10, device)#将张量(tensor)转换回可以显示或保存的图像格式samples = ((samples + 1) / 2).clip(0, 1).permute(0, 2, 3, 1).numpy()#将数据范围从[-1, 1](常见的归一化范围)转换到[0, 2] 将范围从[0, 2]缩放到[0, 1],这是常见的图像像素值范围。函数确保所有像素值都在0和1之间。这可以防止因为浮点数运算误差导致的像素值超出正常范围。,permute将其转换为(batch_size, height, width, channels),这通常是将数据从PyTorch的通道优先格式转换为更通用的格式,便于显示或保存图像。test_loss /= len(test_loader)acc_train_loss /= args.log_ratewandb.log({"test_loss": test_loss,"train_loss": acc_train_loss,"samples": [wandb.Image(sample) for sample in samples],})acc_train_loss = 0if iteration % args.checkpoint_rate == 0:model_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-model.pth"optim_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-optim.pth"# 获取目录路径log_dir1 = os.path.dirname(model_filename)log_dir2 = os.path.dirname(optim_filename)# 创建目录,如果它不存在os.makedirs(log_dir1, exist_ok=True)os.makedirs(log_dir2, exist_ok=True)# 使用完整的文件路径保存模型和优化器状态torch.save(diffusion.state_dict(), model_filename)torch.save(optimizer.state_dict(), optim_filename)if args.log_to_wandb:run.finish()except KeyboardInterrupt:if args.log_to_wandb:run.finish()print("Keyboard interrupt, run finished early")def create_argparser():device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")run_name = datetime.datetime.now().strftime("ddpm-%Y-%m-%d-%H-%M")defaults = dict(learning_rate=2e-4,batch_size=128,iterations=800000,log_to_wandb=True,log_rate=10,checkpoint_rate=10,log_dir="~/ddpm_logs",project_name='ddpm',run_name=run_name,model_checkpoint=None,optim_checkpoint=None,schedule_low=1e-4,schedule_high=0.02,device=device,)defaults.update(script_utils.diffusion_defaults())parser = argparse.ArgumentParser()script_utils.add_dict_to_argparser(parser, defaults)return parserif __name__ == "__main__":main()