又学习了一种方法,类别引导diffusion模型,使用mnist数据集,记录一下它的用法吧。
Diffusion实战篇:
【Diffusion实战】训练一个diffusion模型生成S曲线(Pytorch代码详解)
【Diffusion实战】训练一个diffusion模型生成蝴蝶图像(Pytorch代码详解)
【Diffusion实战】引导一个diffusion模型根据文字生成图像(Pytorch代码详解)
Diffusion综述篇:
【Diffusion综述】医学图像分析中的扩散模型(一)
【Diffusion综述】医学图像分析中的扩散模型(二)
1、数据集装载
使用mnist数据集来训练类别引导diffusion模型,因为其比较简单清晰:
import torch
import torchvision
from torchvision import transforms
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
from PIL import Image
import numpy as npdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=False, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)# 查看MNIST数据集样本
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')
plt.axis('off')
plt.show()
看一看我们朴素的样本:
2、创建条件扩散模型
创建了一个名为ClassConditionedUnet的条件扩散模型,定义了一个可学习的嵌入层,用以将数字类别映射到特征向量上,将类别嵌入与原始输入拼接之后,送入常规的UNet网络即可。
知识传送:【python函数】torch.nn.Embedding函数用法图解
class ClassConditionedUnet(nn.Module):def __init__(self, num_classes=10, class_emb_size=4):super().__init__()# 嵌入层将数字类别映射到特征向量上self.class_emb = nn.Embedding(num_classes, class_emb_size)# 一个常规的UNet网络self.model = UNet2DModel(sample_size=28, # 图像尺寸in_channels=1 + class_emb_size, # 增加一个通道, 用于条件生成out_channels=1, # 输出通道layers_per_block=2, # 残差连接层数目block_out_channels=(32, 64, 64), down_block_types=( "DownBlock2D", # a regular ResNet downsampling block"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention"AttnDownBlock2D",), up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention"UpBlock2D", # a regular ResNet upsampling block),)def forward(self, x, t, class_labels):bs, ch, w, h = x.shape # [8, 1, 28, 28] # 类别条件以额外通道的形式输入class_cond = self.class_emb(class_labels) # [8, 4]class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h) # [8, 4, 28, 28]# 拼接原始输入与类别条件映射net_input = torch.cat((x, class_cond), 1) # (8, 5, 28, 28)# 模型预测return self.model(net_input, t).sample # (8, 1, 28, 28)noisy_xb = torch.randn(8, 1, 28, 28).to(device)
timesteps = torch.linspace(0, 999, 8).long().to(device)
y = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1]).to(device)
model = ClassConditionedUnet().to(device)
with torch.no_grad():model_prediction = model(noisy_xb, timesteps, y)
model_prediction.shape # 验证输出与输出尺寸相同
3、模型训练
训练过程就跟之前的一样啦~
# 创建调度器
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True)n_epochs = 10
net = ClassConditionedUnet().to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3) losses = []
for epoch in range(n_epochs):for x, y in tqdm(train_dataloader):# 获取数据并添加噪声x = x.to(device) * 2 - 1 # 归一化到[-1, 1]y = y.to(device)noise = torch.randn_like(x)timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device)# 前向加噪noisy_x = noise_scheduler.add_noise(x, noise, timesteps)# 获得模型预测结果pred = net(noisy_x, timesteps, y) # 此处传入了类别标签# 损失计算loss = loss_fn(pred, noise) # 损失回传, 参数更新opt.zero_grad()loss.backward()opt.step()# 损失保存losses.append(loss.item())# 输出损失avg_loss = sum(losses[-100:])/100print(f'Finished epoch {epoch}. Average of the last 100 loss values: {avg_loss:05f}')# 查看损失曲线
plt.figure(dpi=300)
plt.plot(losses)
plt.show()
输出损失曲线为:
4、模型推理
进行采样循环,用类别标签引导图像生成:
x = torch.randn(80, 1, 28, 28).to(device) # 随机噪声
y = torch.tensor([[i]*8 for i in range(10)]).flatten().to(device) # 类别标签# 采样循环
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):# 模型预测结果with torch.no_grad():residual = net(x, t, y)# 根据预测噪声和时间步更新图像x = noise_scheduler.step(residual, t, x).prev_sample# 结果可视化
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=8)[0], 'Greys')
ax.axis('off')
类别引导效果如下,效果还是挺好的哩:
5、代码汇总
import torch
import torchvision
from torchvision import transforms
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
from PIL import Image
import numpy as npdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')# -----------------------------------------------------------------------------
# 1、数据集装载
dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=False, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)# 查看MNIST数据集样本
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')
plt.axis('off')
plt.show()
# -----------------------------------------------------------------------------# -----------------------------------------------------------------------------
# 2、创建条件扩散模型
class ClassConditionedUnet(nn.Module):def __init__(self, num_classes=10, class_emb_size=4):super().__init__()# 嵌入层将数字类别映射到特征向量上self.class_emb = nn.Embedding(num_classes, class_emb_size)# 一个常规的UNet网络self.model = UNet2DModel(sample_size=28, # 图像尺寸in_channels=1 + class_emb_size, # 增加一个通道, 用于条件生成out_channels=1, # 输出通道layers_per_block=2, # 残差连接层数目block_out_channels=(32, 64, 64), down_block_types=( "DownBlock2D", # a regular ResNet downsampling block"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention"AttnDownBlock2D",), up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention"UpBlock2D", # a regular ResNet upsampling block),)def forward(self, x, t, class_labels):bs, ch, w, h = x.shape # [8, 1, 28, 28] # 类别条件以额外通道的形式输入class_cond = self.class_emb(class_labels) # [8, 4]class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h) # [8, 4, 28, 28]# 拼接原始输入与类别条件映射net_input = torch.cat((x, class_cond), 1) # (8, 5, 28, 28)# 模型预测return self.model(net_input, t).sample # (8, 1, 28, 28)noisy_xb = torch.randn(8, 1, 28, 28).to(device)
timesteps = torch.linspace(0, 999, 8).long().to(device)
y = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1]).to(device)
model = ClassConditionedUnet().to(device)
with torch.no_grad():model_prediction = model(noisy_xb, timesteps, y)
model_prediction.shape # 验证输出与输出尺寸相同
# -----------------------------------------------------------------------------# -----------------------------------------------------------------------------
# 3、模型训练
# 创建调度器
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True)n_epochs = 10
net = ClassConditionedUnet().to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3) losses = []
for epoch in range(n_epochs):for x, y in tqdm(train_dataloader):# 获取数据并添加噪声x = x.to(device) * 2 - 1 # 归一化到[-1, 1]y = y.to(device)noise = torch.randn_like(x)timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device)# 前向加噪noisy_x = noise_scheduler.add_noise(x, noise, timesteps)# 获得模型预测结果pred = net(noisy_x, timesteps, y) # 此处传入了类别标签# 损失计算loss = loss_fn(pred, noise) # 损失回传, 参数更新opt.zero_grad()loss.backward()opt.step()# 损失保存losses.append(loss.item())# 输出损失avg_loss = sum(losses[-100:])/100print(f'Finished epoch {epoch}. Average of the last 100 loss values: {avg_loss:05f}')# 查看损失曲线
plt.figure(dpi=300)
plt.plot(losses)
plt.show()
# -----------------------------------------------------------------------------# -----------------------------------------------------------------------------
# 4、模型推理
x = torch.randn(80, 1, 28, 28).to(device) # 随机噪声
y = torch.tensor([[i]*8 for i in range(10)]).flatten().to(device) # 类别标签# 采样循环
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):# 模型预测结果with torch.no_grad():residual = net(x, t, y)# 根据预测噪声和时间步更新图像x = noise_scheduler.step(residual, t, x).prev_sample# 结果可视化
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=8)[0], 'Greys')
ax.axis('off')
# -----------------------------------------------------------------------------
diffusion的修炼境界又提升了一级~