1.1 Introduction

训练过程也就是正向扩散过程(Forward Diffusion Process),即为训练集中每个epoch中的每张照片进行加噪,根据所有加噪照片计算一个概率分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0)(续上一篇关于DDPM的博客),至于为什么要计算这个分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0),简要来说是此分布作为了反向扩散过程 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt) 的 ground truth 从而进行MSE,相当于对反向扩散过程进行了一个引导。

1.2 Loss Function

1.2.1 Maximum Likelihood Estimation (MLE)


P ( x ∣ θ ) P(x|\theta) P(xθ) 在已知参数 θ \theta θ的情况下,得到结果 x x x的概率


L ( θ ∣ x ) \mathcal{L}(\theta|x) L(θx) 在已知结果 x x x的情况下,得到参数 θ \theta θ的概率


概率和似然在数值上相等, P ( x ∣ θ ) P(x|\theta) P(xθ)= L ( θ ∣ x ) \mathcal{L}(\theta|x) L(θx),但意义不同,得知参数 θ \theta θ和结果 x x x的顺序不同。
L ( θ ∣ x ) \mathcal{L}(\theta|x) L(θx)是关于 θ \theta θ的函数, P ( x ∣ θ ) P(x|\theta) P(xθ)是关于 x x x的函数,两者从不同角度描述了同一件事情

似然函数(Likelihood Function)

The likelihood function helps us find the best parameters for our distribution.
L ( θ ∣ x 1 , x 2 , ⋯ , x n ) = f ( x 1 , x 2 , ⋯ , x n ∣ θ ) = ∏ i = 1 n f ( x i ∣ θ ) \mathcal{L}(\theta|x_1,x_2,\cdots,x_n)=f(x_1,x_2,\cdots,x_n|\theta)=\prod_{i=1}^{n}f(x_i|\theta) L(θx1,x2,,xn)=f(x1,x2,,xnθ)=i=1nf(xiθ)
where θ \theta θ is the parameter to maximize
x 1 , x 2 , ⋯ , x n x_1,x_2,\cdots,x_n x1,x2,,xn are observations for n n n random variables from a distribution
f f f is the joint density function of our distribution with the parameter θ \theta θ
For example, in the case of a normal distribution, we could have θ = ( μ , σ ) \theta=(\mu,\sigma) θ=(μ,σ)
L ( θ ∣ x 1 , x 2 , ⋯ , x n ) \mathcal{L}(\theta|x_1,x_2,\cdots,x_n) L(θx1,x2,,xn) 不是概率密度函数,这意味着在特定区间上进行积分不会产生该区间上的“概率”。相反,它讨论的是具有特定参数值 θ \theta θ的分布适合我们的数据的可能性
the variance tells about how much the blue intensities in the image vary or deviate from the average blue intensity (0.8).

极大似然估计 (Maximum Likelihood Estimation)

最大似然估计(简称 MLE)是估计分布参数的过程,该过程最大化观测数据属于该分布的可能性。 简而言之,当我们执行 MLE 时,我们试图找到最适合我们数据的分布。分布参数的结果值称为最大似然估计。

1.2.2 Image and Probability Distribution

RGB图片各通道的值范围为:[0, 255]
我们将各通道的通过( R / 255 , G / 255 , B / 255 R/255,G/255,B/255 R/255,G/255,B/255)归一化到范围:[0, 1]
图片单个通道的概率分布(1D Gaussian)
图片两个通道的概率分布(2D Gaussian)

μ = [ μ x 1 , μ x 2 ] = [ μ b l u e , μ g r e e n ] \bf{\mu}=[\mu_{x_1},\mu_{x_2}]=[\mu_{blue},\mu_{green}] μ=[μx1,μx2]=[μblue,μgreen]

Σ = [ σ x 1 2 σ x 1 , x 2 σ x 2 , x 1 σ x 2 2 ] = [ σ b l u e 2 σ b l u e , g r e e n σ g r e e n , b l u e σ g r e e n 2 ] \Sigma=\begin{bmatrix} \sigma_{x_1}^2 & \sigma_{x_1,x_2}\\ \sigma_{x_2,x_1} & \sigma_{x_2}^2 \end{bmatrix}=\begin{bmatrix} \sigma_{blue}^2 & \sigma_{blue,green}\\ \sigma_{green,blue} & \sigma_{green}^2 \end{bmatrix} Σ=[σx12σx2,x1σx1,x2σx22]=[σblue2σgreen,blueσblue,greenσgreen2]

图片三个通道的概率分布(3D Gaussian)

μ = [ μ x , μ y , μ z ] = [ μ r e d , μ g r e e n , μ b l u e ] \bf{\mu}=[\mu_{x},\mu_{y},\mu_{z}]=[\mu_{red},\mu_{green},\mu_{blue}] μ=[μx,μy,μz]=[μred,μgreen,μblue]

Σ = [ σ x 2 σ x y σ x z σ y x σ y 2 σ y z σ z x σ z σ z 2 ] \Sigma=\begin{bmatrix} \sigma_{x}^2 & \sigma_{xy} & \sigma_{xz}\\ \sigma_{yx} & \sigma_{y}^2 & \sigma_{yz}\\ \sigma_{zx} & \sigma_{z} & \sigma_{z}^2\\ \end{bmatrix} Σ= σx2σyxσzxσxyσy2σzσxzσyzσz2
在Stable Diffusion训练过程中我们要给clear image加噪声,则我们需要在三维标准正态分布中进行随机采样,这样采样得到的tensor shape与图片tensor的shape一致
ϵ ∼ N ( 0 , I ) \epsilon \sim N(0,I) ϵN(0,I)

1.2.3 Maximize ELBO (Maximize Evidence Lower Bound)


i i i张样本图片的概率分布 p θ ( x i ) p_{\theta}(x^i) pθ(xi),将数据集中 m m m张照片的分布相乘得到联合概率分布,求该联合分布的极大似然,最终得到一个最优的参数 θ = ( μ , σ ) \theta=(\mu,\sigma) θ=(μ,σ)

目前Stable Diffusion的Unet有三种预测方案:
(1)Unet 直接预测 x 0 x_0 x0,但是效果不好
(2)Unet 预测要去掉的噪声分布(本次训练使用这种方案)

(3)Unet 预测分数

1.3 Training (from DDPM thesis)

batch size, iteration, and epoch



import os.path
import torch
import torch.nn as nn
import torch.optim as optim
from ddpm import DDPMSampler
from diffusion import UNET, Diffusion
import logging
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from pipeline import get_time_embedding
from create_dataset import train_loader
import logging'''
Algorithm Training
2: x_0 ~ q(x_0) 
# sample a batch from a epoch
# for epoch for batch for every image tensor
3: t ~ Uniform({1...T})
# sample randomly a t for every image tensor
# t: num_inference_step
# T: num_training_step
t = diffusion.sample_timesteps(images.shape[0]).to(device)
4: epsilon ~ N(0,I) 
# 3d standard normal distribution
# noise tensor shape that sample from this distribution,which is same as image tensor shape
noisy_image_tensor = add_noise(t)
5: Take gradient descent step on 
# nabla_{theta} L2(|| epsilon - epsilon_{theta}(noisy image tensor,t,y)||)
6: until converged
1.Data Preprocessing
(1) Loading and Transforming Data: Data is loaded from the dataset and transformed to a suitable format for training. 
Common transformations include resizing, normalization, and converting to tensors.
(2) Creating Data Loaders: Data loaders are created to efficiently load the data in batches, shuffle the training data, 
and manage parallel processing.
2.Model Initialization
(1) Define the UNet Model: The UNet architecture is defined, which typically consists of an encoder-decoder structure 
with skip connections. The encoder captures context while the decoder enables precise localization.
(2) Move Model to Device: The model is moved to the appropriate device (CPU or GPU) to leverage hardware acceleration.
3.Loss Function and Optimizer
(1) Loss Function: The loss function measures the difference between the predicted output and the true output. 
(2) Optimizer: The optimizer updates the model parameters to minimize the loss. Common optimizers include Adam,SGD,etc.
4.Training Loop
(1) Set Model to Training Mode: The model is set to training mode using model.train().
(2) Iterate Over Data: For each epoch, iterate over batches of data.Forward Pass: Pass input data through the model to get predictions.A random time step t will be selected for each training sample (image)Apply the Gaussian noise (corresponding to t) to each imageConvert the time steps to embeddings (vector)Compute Loss: Calculate the loss using the predictions and ground truth.Backward Pass: Perform backpropagation to compute gradients.Update Parameters: Use the optimizer to update model parameters based on the gradients.
(3) Monitor Training: Track and print training loss to monitor progress.
After each epoch, validate the model using a separate validation set to ensure the model is not overfitting and 
to monitor its generalization performance.
6.Checkpoint Saving
Save Model Checkpoint: Save the model's state, optimizer state, and any relevant training information after each epoch 
to allow for resuming training if needed.
'''# A PyTorch random number generator.
generator = torch.Generator(device='cuda')
# Sets the seed for generating random numbers. Returns a torch. Generator object.
# Initialize the DDPMSampler with the random generator
ddpm_sampler = DDPMSampler(generator)diffusion = Diffusion()def timesteps_to_time_emb(timesteps):time_embeddings = []for i, timestep in enumerate(timesteps):# (1,320)time_emb_320 = get_time_embedding(timestep).to('cuda')embedding ='cuda')time_embedding = embedding(time_emb_320).squeeze(0)  # Ensure shape is (1280)# (1,1280)time_embeddings.append(time_embedding)return torch.stack(time_embeddings)  # Final shape should be (batch_size, 1280)print('Start training now !')def train(args):device = args.device  # Get the device to run the training onmodel = UNET().to(device)   # Initialize the model and move it to the devicemodel.train()optimizer = optim.AdamW(model.parameters(),  # set up the optimizer with AdamWmse = nn.MSELoss()  # Mean Squared Error loss functionlogger = SummaryWriter(os.path.join("runs", args.run_name))len_train = len(train_loader)print('Start into the loop !')for epoch in range(args.epochs)"Starting epoch {epoch}:")  # log the start of the epochprogress_bar = tqdm(train_loader)  # progress bar for the dataloaderoptimizer.zero_grad()  # Explicitly zero the gradient buffersaccumulation_steps = 4# Load all data into a batchfor batch_idx, (images, captions) in enumerate(progress_bar):images =  # move images to the device# The dataloaer will add a batch size dimension to the tensor, but I've already added batch size to the VAE# and CLIP input, so we're going to remove a batch size and just keep the batch size of the dataloaderimages = torch.squeeze(images, dim=1)captions =  # move caption to the devicetext_embeddings = torch.squeeze(captions, dim=1) # squeeze batch_sizetimesteps = ddpm_sampler.sample_timesteps(images.shape[0]).to(device)  # Sample random timestepsnoisy_latent_images, noises = ddpm_sampler.add_noise(images, timesteps)  # Add noise to the imagestime_embeddings = timesteps_to_time_emb(timesteps)# x_t (batch_size, channel, Height/8, Width/8) (bs,4,256/8,256/8)# caption (batch_size, seq_len, dim) (bs, 77, 768)# t (batch_size, channel) (batch_size, 1280)# (bs,320,H/8,W/8)with torch.no_grad():last_decoder_noise = model(noisy_latent_images, text_embeddings, time_embeddings)# (bs,4,H/8,W/8)final_output = = final_output(last_decoder_noise).to(device)loss = mse(noises, predicted_noise)  # Compute the lossloss.backward()  # Backpropagate the lossif (batch_idx + 1) % accumulation_steps == 0:  # Wait for several backward passesoptimizer.step()  # Now we can do an optimizer stepoptimizer.zero_grad()  # Reset gradients to zeroprogress_bar.set_postfix(MSE=loss.item())  # Update the progress bar with the loss# log the loss to TensorBoardlogger.add_scalar("MSE", loss.item(), global_step=epoch * len_train + batch_idx)# Save the model checkpointos.makedirs(os.path.join("models", args.run_name), exist_ok=True), os.path.join("models", args.run_name, f"stable_diffusion.ckpt")),os.path.join("models", args.run_name, f""))  # Save the optimizer statedef launch():import argparse  # Import the argparse module for command-line argument parsingparser = argparse.ArgumentParser()  # Create an argument parserargs = parser.parse_args()  # Parse the command-line arguments# Set the default values for the argumentsargs.run_name = " Condition_Unet"  # Name for the run, used for logging and saving modelsargs.epochs = 40      # Number of epochs to train the modelargs.batch_size = 10  # Batch size for the dataloaderargs.image_size = 256  # Size of the imagesargs.device = "cuda"  # Device to run the training on ('cuda' for GPU or 'cpu') = 3e-4  # Learning rate for the optimizertrain(args)  # Call the train function with the parsed argumentsif __name__ == '__main__':launch()  # Call the launch function if this script is run as the main program

2.CUDA out of memory

2.1 Reasons

2.1.1 Large Batch Size

Using a batch size that is too large can quickly exhaust GPU memory, especially with large models or high-resolution images.

2.1.2 High Model Complexity

Complex models with many layers and parameters consume more memory. This includes architectures with large fully connected layers, extensive use of skip connections, or multi-headed attention mechanisms.

2.1.3 Large Input Data

High-resolution images or large input tensors consume more memory.

2.1.4 Insufficient Memory Management

Not clearing intermediate variables or not using memory-efficient operations can lead to memory leaks or inefficient memory usage.

2.1.5 Gradients and Optimizer States

Storing gradients and optimizer states, especially for adaptive optimizers like Adam or RMSprop, can be memory-intensive.

2.1.6 Memory Fragmentation

Fragmentation occurs when memory is allocated and deallocated in such a way that it becomes difficult to find contiguous blocks of memory, leading to inefficient memory use.

2.2 Solutions

2.2.1 Reduce Batch Size

Decreasing the batch size is the simplest and most effective way to reduce memory usage.

args.batch_size = 5  # Example: reduce the batch size

2.2.2 Use Mixed Precision Training

Mixed precision training can reduce memory usage by using 16-bit floats instead of 32-bit floats for certain operations.

以下为gpt修改的关于笔者训练stable diffusion时的代码

from torch.cuda.amp import GradScaler, autocastscaler = GradScaler()def train(args):device = args.devicemodel = UNET().to(device)model.train()optimizer = optim.AdamW(model.parameters(), = nn.MSELoss()logger = SummaryWriter(os.path.join("runs", args.run_name))len_train = len(train_loader)for epoch in range(args.epochs)"Starting epoch {epoch}:")progress_bar = tqdm(train_loader)optimizer.zero_grad()accumulation_steps = 4for batch_idx, (images, captions) in enumerate(progress_bar):images = = torch.squeeze(images, dim=1)captions = = torch.squeeze(captions, dim=1)timesteps = ddpm_sampler.sample_timesteps(images.shape[0]).to(device)noisy_latent_images, noises = ddpm_sampler.add_noise(images, timesteps)time_embeddings = timesteps_to_time_emb(timesteps)with autocast():last_decoder_noise = model(noisy_latent_images, text_embeddings, time_embeddings)final_output = = final_output(last_decoder_noise).to(device)loss = mse(noises, predicted_noise)scaler.scale(loss).backward()if (batch_idx + 1) % accumulation_steps == 0:scaler.step(optimizer)scaler.update()optimizer.zero_grad()progress_bar.set_postfix(MSE=loss.item())logger.add_scalar("MSE", loss.item(), global_step=epoch * len_train + batch_idx)torch.cuda.empty_cache()os.makedirs(os.path.join("models", args.run_name), exist_ok=True), os.path.join("models", args.run_name, f"stable_diffusion.ckpt")), os.path.join("models", args.run_name, f""))

2.2.3 Gradient Accumulation

Accumulate gradients over multiple iterations before updating model parameters. This effectively simulates a larger batch size without increasing memory usage.

Accumulating gradients over multiple iterations refers to a technique where you perform forward and backward passes on smaller batches of data and accumulate the gradients over several iterations before updating the model parameters. This approach allows you to simulate a larger batch size without increasing memory usage, which is especially useful when you have limited GPU memory.This method effectively increases the batch size without increasing memory usage, as you don’t need to hold all the data in memory at once.

standard training loop.jpg

gradient accumulation.jpg

Key Points
1.Batch Size vs. Mini-Batch Size:
(1) The original batch size is split into smaller mini-batches to fit into GPU memory.
(2) accumulation_steps * mini_batch_size = effective_batch_size.

2.Loss Scaling:
(1) The loss is divided by accumulation_steps to ensure that the gradient magnitudes remain consistent with what they would be if you processed the entire batch at once.

3.Optimizer Step and Gradient Zeroing:
(1) The optimizer step is performed, and gradients are zeroed only after accumulating gradients over several mini-batches.

from torch.cuda.amp import GradScaler, autocast# Assuming you have defined your model, optimizer, loss function, and data loader
model = UNET().to(device)
optimizer = optim.AdamW(model.parameters(),
scaler = GradScaler()
mse = nn.MSELoss()
accumulation_steps = 4  # Number of mini-batches to accumulate gradients overfor epoch in range(args.epochs):model.train()optimizer.zero_grad()for batch_idx, (images, captions) in enumerate(train_loader):images = = = torch.squeeze(captions, dim=1)timesteps = ddpm_sampler.sample_timesteps(images.shape[0]).to(device)noisy_latent_images, noises = ddpm_sampler.add_noise(images, timesteps)time_embeddings = timesteps_to_time_emb(timesteps)with autocast():last_decoder_noise = model(noisy_latent_images, text_embeddings, time_embeddings)final_output = = final_output(last_decoder_noise).to(device)loss = mse(noises, predicted_noise) / accumulation_stepsscaler.scale(loss).backward()# Accumulate gradients but do not update the weights yetif (batch_idx + 1) % accumulation_steps == 0:scaler.step(optimizer)scaler.update()optimizer.zero_grad()# Optional: Save model checkpoint after each, f"model_epoch_{epoch}.pth")

2.2.4 Clear Cache

Manually clear the GPU cache to free up unused memory.

from torch.cuda.amp import GradScaler, autocastdef train(args):device = args.devicemodel = UNET().to(device)model.train()optimizer = optim.AdamW(model.parameters(), = GradScaler()mse = nn.MSELoss()logger = SummaryWriter(os.path.join("runs", args.run_name))len_train = len(train_loader)for epoch in range(args.epochs)"Starting epoch {epoch}:")progress_bar = tqdm(train_loader)optimizer.zero_grad()accumulation_steps = 4for batch_idx, (images, captions) in enumerate(progress_bar):images = = torch.squeeze(images, dim=1)captions = = torch.squeeze(captions, dim=1)timesteps = ddpm_sampler.sample_timesteps(images.shape[0]).to(device)noisy_latent_images, noises = ddpm_sampler.add_noise(images, timesteps)time_embeddings = timesteps_to_time_emb(timesteps)with autocast():last_decoder_noise = model(noisy_latent_images, text_embeddings, time_embeddings)final_output = = final_output(last_decoder_noise).to(device)loss = mse(noises, predicted_noise) / accumulation_stepsscaler.scale(loss).backward()if (batch_idx + 1) % accumulation_steps == 0:scaler.step(optimizer)scaler.update()optimizer.zero_grad()# Clear cache to free up memorytorch.cuda.empty_cache()progress_bar.set_postfix(MSE=loss.item())logger.add_scalar("MSE", loss.item(), global_step=epoch * len_train + batch_idx)# Save model checkpoint after each epochos.makedirs(os.path.join("models", args.run_name), exist_ok=True), os.path.join("models", args.run_name, f"stable_diffusion.ckpt")), os.path.join("models", args.run_name, f""))





