第四章 Diffusers 实战
安装Difffusers 库
pip install -qq -U diffusers datasets transformers accelerate ftfy pyarrow
扩散模型调度器
from diffusers import DDPMScheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
定义扩散模型
from diffusers import UNet2DModeldef model():model = UNet2DModel(sample_size = 240,in_channels = 4,out_channels = 4,layers_per_block = 2,block_out_channels = (64,128,128,256),down_block_types=("DownBlock2D","DownBlock2D","AttnDownBlock2D","AttnDownBlock2D",),up_block_types=("AttnUpBlock2D","AttnUpBlock2D","UpBlock2D","UpBlock2D",))return model
创建扩散模型训练循环
import torch.utils.data.dataset
import torchvision
from dataset import dataset_brats_2D
from torchvision import transforms
from diffusers import DDPMScheduler
import model
from torch.utils.data import DataLoader
import torch.nn.functional as F
import os
import timeif __name__ == "__main__":device = torch.device('cuda')dataset = #自定义datasettrain_dl = DataLoader(dataset, 128, False, num_workers=1)timesteps = torch.linspace(0, 1000, 2).long().to(device)model = model.model().to(device)model = torch.nn.DataParallel(model, device_ids=[0])noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2")optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)losses = []loss_flag = 10e+10for epoch in range(100):for step, batch in enumerate(train_dl):clean_images = batch.to(device)noise = torch.randn(clean_images.shape).to(device)batch_size = clean_images.shape[0]timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (batch_size,),device=device).long()noisy_images = noise_scheduler.add_noise(clean_images,noise,timesteps)noisy_pred = model(noisy_images,timesteps,return_dict=False)[0]loss = F.mse_loss(noisy_pred, noise)loss.backward(loss)losses.append(loss.item())optimizer.step()optimizer.zero_grad()if (epoch +1) % 5 == 0:loss_last_epoch = sum(losses[-len(train_dl) :]) / len(train_dl)print(f"Epoch:{epoch + 1}, loss:{loss_last_epoch}")state = {'net': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}if loss_flag < loss:torch.save(state,"best.pth")loss_flag = loss
图像的生成
import time
import torchvision.utils
from diffusers import DDPMPipeline,DDPMScheduler
import cv2
import torch
import torchvision
from PIL import Image
import model
import numpy as np
import time
def show_images(x):x = x * 0.5 + 0.5grid = torchvision.utils.make_grid(x)grid_im = grid.detach().cpu().permute(1,2,0).clip(0,1) *255grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))return grid_imif __name__ == "__main__":device = torch.device('cuda')sample = torch.randn(1, 4, 240, 240).to(device)model = model.model().to(device)ckpt = torch.load(r"")#自己的checkpointmodel.load_state_dict({k.replace('module.', ''): v for k, v in ckpt['net'].items() if k.startswith('module.')})noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2")flag = 0for k in range(10000):# start = time.time()sample = torch.randn(1, 4, 240, 240).to(device)for i,t in enumerate(noise_scheduler.timesteps):print(t)with torch.no_grad():residual = model(sample,t).samplesample = noise_scheduler.step(residual, t, sample).prev_sampletime_flag = time.time()print(sample.shape)image = show_images(sample[0][0])image.save(str(time_flag) +'_0'+'.png')