0、项目视频详解
视频教程见B站https://www.bilibili.com/video/BV1e8411a7mz
1、diffusion模型理论(推导出损失函数)
1.1、背景
随着人工智能在图像生成,文本生成以及多模态生成等领域的技术不断累积,如:生成对抗网络(GAN)、变微分自动编码器(VAE)、normalizing flow models、自回归模型(AR)、energy-based models以及近年来大火的扩散模型(Diffusion Model)。
扩散模型的成功并非横空出世一般,突然出现在人们的视野中。其实早在2015年就已有人提出相类似的想法,最终在2020年提出了我们所熟知的“denoising diffusion probabilistic models”。DDPM
近期的novelai的生成技术同样是基于扩散模型,以下可以看到其强大的生成效果。可在此处跳转进行玩耍。
本项目可以达到的效果如下。输入向日葵,cfg=7的结果。可以看到,效果已经比较不错了。
1.2、模型训练与采样的算法流程
先放个图,1.3和1.4进行具体的流程与公式推导。我们要做的就是要推导出训练过程中的损失函数。
1.3、前向噪声扩散公式推导
diffusion模型的前向过程是向原始图片中逐步的添加高斯噪声,直至最后的图像趋于高斯分布。由于噪声占比会越来越大,所以添加噪声的强度也会越来越大。如下图所示:
-
每一时刻的图像都由前一时刻的图像添加噪声得到
-
最后的图像会变成纯噪声
-
每一时刻的添加的噪声强度均不同,目前有线性调度器,余弦调度器等
-
这一过程构建了我们训练所用到的标签,后面会看到
下面的推导过程展示了,我们如何从初始图像直接得到第t时刻的图像。
这个公式为下面的推导打上一个铺垫,下面一节就是关键的损失函数推导了。
1.4、优化目标,损失函数推导
上面的正向扩散并不难,下面我们推导反向扩散过程。即由Xt到Xt-1。
2、非条件生成(随机生成图片)
使用stanford汽车图片为例,没有类别。
2.1、训练过程解析
我们使用前向过程采样得到标签,训练时使用Unet网络结构,同时在模型的输入中嵌入时间步的编码。这类似于transformer模型中的位置编码,让模型更容易训练。 如下图所示:
2.2、数据解压
解压我们的数据集。只需要首次运行该项目时解压即可!
In [13]
import os
if not os.path.exists("work/cars"):!mkdir work/cars
!unzip -oq data/data173302/stanford_cars.zip -d work/cars
In [14]
# 删除多余文件
!rm -rf work/cars/cars_test
!rm -rf work/cars/devkit
!rm -rf work/cars/car_devkit.tgz
!rm -rf work/cars/cars_train.tgz
!rm -rf work/cars/cars_test.tgz
!rm -rf work/cars/cars_test_annos_withlabels.mat
2.3、数据展示
查看我们的汽车图片。
In [1]
import paddle
import paddle.vision
import matplotlib.pyplot as plt
from PIL import Image
%matplotlib inline# 定义展示图片函数
def show_images(imgs_paths=[],cols=4):num_samples = len(imgs_paths)plt.figure(figsize=(15,15))i = 0for img_path in imgs_paths:img = Image.open(img_path)plt.subplot(int(num_samples/cols + 1), cols, i + 1)plt.imshow(img)i += 1imgs_paths = ["work/cars/cars_train/05930.jpg", "work/cars/cars_train/06816.jpg", "work/cars/cars_train/02885.jpg", "work/cars/cars_train/07471.jpg","work/cars/cars_train/06600.jpg", "work/cars/cars_train/06020.jpg", "work/cars/cars_train/04818.jpg", "work/cars/cars_train/06088.jpg"
]
show_images(imgs_paths)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import MutableMapping /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import Iterable, Mapping /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import Sized /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingif isinstance(obj, collections.Iterator): /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingreturn list(data) if isinstance(data, collections.MappingView) else data
<Figure size 1500x1500 with 8 Axes>
2.4、构建数据集
我们使用paddle.vision里的数据集接口即可。
In [2]
import os
import paddle
import paddle.nn as nn
import paddle.vision as V
from PIL import Image
from matplotlib import pyplot as plt
from paddle.io import DataLoader# 这里我们不需要用到图像标签,可以直接用paddle.vision里面提供的数据集接口
def get_data(args):transforms = V.transforms.Compose([V.transforms.Resize(80), # args.image_size + 1/4 *args.image_sizeV.transforms.RandomResizedCrop(args.image_size, scale=(0.8, 1.0)),V.transforms.ToTensor(),V.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])dataset = V.datasets.ImageFolder(args.dataset_path, transform=transforms)dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)return dataloader
2.5、训练流程
训练中我们可以修改ARGS类的参数进行超参数定义。基本上,只要知道我们的损失函数是两张图片之间的均方误差,代码部分会变得比较简单。对比GAN而言,diffusion的参数更加容易调整,也更容易训练。
In [3]
"""ddpm"""import os
import paddle
import paddle.nn as nn
from matplotlib import pyplot as plt
%matplotlib inline
from tqdm import tqdm
from paddle import optimizer
# from utils import *
from modules import UNet # 模型
import logging
import numpy as nplogging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")class Diffusion:def __init__(self, noise_steps=500, beta_start=1e-4, beta_end=0.02, img_size=64, device="cuda"):self.noise_steps = noise_stepsself.beta_start = beta_startself.beta_end = beta_endself.img_size = img_sizeself.device = deviceself.beta = self.prepare_noise_schedule()self.alpha = 1. - self.betaself.alpha_hat = paddle.cumprod(self.alpha, dim=0)def prepare_noise_schedule(self):return paddle.linspace(self.beta_start, self.beta_end, self.noise_steps)def noise_images(self, x, t):sqrt_alpha_hat = paddle.sqrt(self.alpha_hat[t])[:, None, None, None]sqrt_one_minus_alpha_hat = paddle.sqrt(1 - self.alpha_hat[t])[:, None, None, None]Ɛ = paddle.randn(shape=x.shape)return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛdef sample_timesteps(self, n):return paddle.randint(low=1, high=self.noise_steps, shape=(n,))def sample(self, model, n):logging.info(f"Sampling {n} new images....")model.eval()with paddle.no_grad():x = paddle.randn((n, 3, self.img_size, self.img_size))for i in tqdm(reversed(range(1, self.noise_steps)), position=0):t = paddle.to_tensor([i] * x.shape[0]).astype("int64")# print(x.shape, t.shape)# print(f"完成第{i}步")predicted_noise = model(x, t)alpha = self.alpha[t][:, None, None, None]alpha_hat = self.alpha_hat[t][:, None, None, None]beta = self.beta[t][:, None, None, None]if i > 1:noise = paddle.randn(shape=x.shape)else:noise = paddle.zeros_like(x)x = 1 / paddle.sqrt(alpha) * (x - ((1 - alpha) / (paddle.sqrt(1 - alpha_hat))) * predicted_noise) + paddle.sqrt(beta) * noisemodel.train()x = (x.clip(-1, 1) + 1) / 2x = (x * 255)return xdef train(args):# setup_logging(args.run_name)device = args.devicedataloader = get_data(args)image = next(iter(dataloader))[0]model = UNet()opt = optimizer.Adam(learning_rate=args.lr, parameters=model.parameters())mse = nn.MSELoss()diffusion = Diffusion(img_size=args.image_size, device=device)# logger = SummaryWriter(os.path.join("runs", args.run_name))l = len(dataloader)for epoch in range(args.epochs):logging.info(f"Starting epoch {epoch}:")pbar = tqdm(dataloader)for i, images in enumerate(pbar):# print(images)t = diffusion.sample_timesteps(images[0].shape[0])x_t, noise = diffusion.noise_images(images[0], t)predicted_noise = model(x_t, t)loss = mse(noise, predicted_noise) # 损失函数opt.clear_grad()loss.backward()opt.step()pbar.set_postfix(MSE=loss.item())# print(("MSE", loss.item(), "global_step", epoch * l + i))# logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)if epoch % 20 == 0:paddle.save(model.state_dict(), f"car_models/ddpm_uncond{epoch}.pdparams")sampled_images = diffusion.sample(model, n=8)for i in range(8):img = sampled_images[i].transpose([1, 2, 0])img = np.array(img).astype("uint8")plt.subplot(2,4,i+1)plt.imshow(img)plt.show()def launch():import argparse# 参数设置class ARGS:def __init__(self):self.run_name = "DDPM_Uncondtional"self.epochs = 150self.batch_size = 24self.image_size = 64self.dataset_path = r"/home/aistudio/work/cars"self.device = "cuda"self.lr = 1.5e-4args = ARGS()train(args)if __name__ == '__main__':launch()pass
W1024 11:03:25.091079 573 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2 W1024 11:03:25.094197 573 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2. 11:03:25 - INFO: Starting epoch 0: 100%|██████████| 340/340 [02:13<00:00, 3.70it/s, MSE=0.15] 11:05:39 - INFO: Sampling 8 new images.... 499it [00:20, 23.93it/s]
<Figure size 640x480 with 8 Axes>
11:06:00 - INFO: Starting epoch 1: 100%|██████████| 340/340 [02:13<00:00, 3.11it/s, MSE=0.0725] 11:08:14 - INFO: Starting epoch 2: 100%|██████████| 340/340 [02:12<00:00, 3.37it/s, MSE=0.0777] 11:10:26 - INFO: Starting epoch 3: 100%|██████████| 340/340 [02:12<00:00, 3.44it/s, MSE=0.0814] 11:12:38 - INFO: Starting epoch 4: 100%|██████████| 340/340 [02:12<00:00, 3.30it/s, MSE=0.0579] 11:14:51 - INFO: Starting epoch 5: 100%|██████████| 340/340 [02:13<00:00, 3.40it/s, MSE=0.107] 11:17:05 - INFO: Starting epoch 6: 100%|██████████| 340/340 [02:14<00:00, 3.49it/s, MSE=0.0742] 11:19:19 - INFO: Starting epoch 7: 100%|██████████| 340/340 [02:14<00:00, 3.20it/s, MSE=0.0422] 11:21:34 - INFO: Starting epoch 8: 100%|██████████| 340/340 [02:13<00:00, 3.26it/s, MSE=0.0527] 11:23:47 - INFO: Starting epoch 9: 100%|██████████| 340/340 [02:13<00:00, 3.45it/s, MSE=0.064] 11:26:01 - INFO: Starting epoch 10: 100%|██████████| 340/340 [02:15<00:00, 2.91it/s, MSE=0.043] 11:28:17 - INFO: Starting epoch 11: 100%|██████████| 340/340 [02:14<00:00, 2.60it/s, MSE=0.0712] 11:30:31 - INFO: Starting epoch 12: 100%|██████████| 340/340 [02:13<00:00, 3.23it/s, MSE=0.0674] 11:32:44 - INFO: Starting epoch 13: 100%|██████████| 340/340 [02:14<00:00, 3.00it/s, MSE=0.0464] 11:34:59 - INFO: Starting epoch 14: 100%|██████████| 340/340 [02:14<00:00, 2.93it/s, MSE=0.0349] 11:37:13 - INFO: Starting epoch 15: 100%|██████████| 340/340 [02:13<00:00, 3.58it/s, MSE=0.0279] 11:39:26 - INFO: Starting epoch 16: 100%|██████████| 340/340 [02:14<00:00, 2.62it/s, MSE=0.0436] 11:41:40 - INFO: Starting epoch 17: 100%|██████████| 340/340 [02:15<00:00, 3.06it/s, MSE=0.0278] 11:43:55 - INFO: Starting epoch 18: 100%|██████████| 340/340 [02:13<00:00, 3.03it/s, MSE=0.0318] 11:46:09 - INFO: Starting epoch 19: 100%|██████████| 340/340 [02:13<00:00, 3.01it/s, MSE=0.0743] 11:48:22 - INFO: Starting epoch 20: 100%|██████████| 340/340 [02:12<00:00, 3.26it/s, MSE=0.0721] 11:50:36 - INFO: Sampling 8 new images.... 499it [00:20, 24.05it/s]
<Figure size 640x480 with 8 Axes>
11:50:57 - INFO: Starting epoch 21: 100%|██████████| 340/340 [02:13<00:00, 3.32it/s, MSE=0.0275] 11:53:10 - INFO: Starting epoch 22: 100%|██████████| 340/340 [02:13<00:00, 3.23it/s, MSE=0.028] 11:55:24 - INFO: Starting epoch 23: 100%|██████████| 340/340 [02:13<00:00, 2.89it/s, MSE=0.0155] 11:57:37 - INFO: Starting epoch 24: 100%|██████████| 340/340 [02:13<00:00, 3.17it/s, MSE=0.0386] 11:59:51 - INFO: Starting epoch 25: 100%|██████████| 340/340 [02:13<00:00, 3.16it/s, MSE=0.0189] 12:02:04 - INFO: Starting epoch 26: 100%|██████████| 340/340 [02:13<00:00, 3.23it/s, MSE=0.0285] 12:04:18 - INFO: Starting epoch 27: 100%|██████████| 340/340 [02:13<00:00, 3.47it/s, MSE=0.0593] 12:06:31 - INFO: Starting epoch 28: 100%|██████████| 340/340 [02:14<00:00, 2.98it/s, MSE=0.0151] 12:08:45 - INFO: Starting epoch 29: 100%|██████████| 340/340 [02:12<00:00, 3.40it/s, MSE=0.0552] 12:10:57 - INFO: Starting epoch 30: 100%|██████████| 340/340 [02:14<00:00, 3.53it/s, MSE=0.0335] 12:13:12 - INFO: Starting epoch 31: 100%|██████████| 340/340 [02:13<00:00, 3.01it/s, MSE=0.00773] 12:15:25 - INFO: Starting epoch 32: 100%|██████████| 340/340 [02:13<00:00, 3.03it/s, MSE=0.0907] 12:17:39 - INFO: Starting epoch 33: 100%|██████████| 340/340 [02:15<00:00, 3.65it/s, MSE=0.0412] 12:19:54 - INFO: Starting epoch 34: 100%|██████████| 340/340 [02:13<00:00, 3.55it/s, MSE=0.0359] 12:22:08 - INFO: Starting epoch 35: 100%|██████████| 340/340 [02:13<00:00, 3.30it/s, MSE=0.0563] 12:24:21 - INFO: Starting epoch 36: 100%|██████████| 340/340 [02:13<00:00, 3.34it/s, MSE=0.0299] 12:26:35 - INFO: Starting epoch 37: 100%|██████████| 340/340 [02:13<00:00, 3.24it/s, MSE=0.0315] 12:28:49 - INFO: Starting epoch 38: 100%|██████████| 340/340 [02:13<00:00, 3.08it/s, MSE=0.0455] 12:31:02 - INFO: Starting epoch 39: 100%|██████████| 340/340 [02:12<00:00, 3.23it/s, MSE=0.024] 12:33:15 - INFO: Starting epoch 40: 100%|██████████| 340/340 [02:13<00:00, 3.32it/s, MSE=0.0416] 12:35:29 - INFO: Sampling 8 new images.... 499it [00:20, 23.89it/s]
<Figure size 640x480 with 8 Axes>
12:35:50 - INFO: Starting epoch 41: 100%|██████████| 340/340 [02:13<00:00, 3.18it/s, MSE=0.0134] 12:38:03 - INFO: Starting epoch 42: 100%|██████████| 340/340 [02:12<00:00, 3.77it/s, MSE=0.0948] 12:40:16 - INFO: Starting epoch 43: 100%|██████████| 340/340 [02:13<00:00, 3.16it/s, MSE=0.0208] 12:42:30 - INFO: Starting epoch 44: 100%|██████████| 340/340 [02:13<00:00, 3.29it/s, MSE=0.0421] 12:44:44 - INFO: Starting epoch 45: 100%|██████████| 340/340 [02:13<00:00, 2.88it/s, MSE=0.0296] 12:46:57 - INFO: Starting epoch 46: 100%|██████████| 340/340 [02:12<00:00, 3.00it/s, MSE=0.0398] 12:49:10 - INFO: Starting epoch 47: 100%|██████████| 340/340 [02:13<00:00, 3.06it/s, MSE=0.0269] 12:51:24 - INFO: Starting epoch 48: 100%|██████████| 340/340 [02:12<00:00, 3.34it/s, MSE=0.0635] 12:53:37 - INFO: Starting epoch 49: 100%|██████████| 340/340 [02:12<00:00, 3.58it/s, MSE=0.0687] 12:55:49 - INFO: Starting epoch 50: 100%|██████████| 340/340 [02:12<00:00, 3.08it/s, MSE=0.0253] 12:58:01 - INFO: Starting epoch 51: 100%|██████████| 340/340 [02:12<00:00, 3.33it/s, MSE=0.0219] 01:00:14 - INFO: Starting epoch 52: 100%|██████████| 340/340 [02:12<00:00, 3.13it/s, MSE=0.0422] 01:02:27 - INFO: Starting epoch 53: 100%|██████████| 340/340 [02:12<00:00, 3.26it/s, MSE=0.0187] 01:04:39 - INFO: Starting epoch 54: 100%|██████████| 340/340 [02:14<00:00, 3.39it/s, MSE=0.0453] 01:06:54 - INFO: Starting epoch 55: 100%|██████████| 340/340 [02:14<00:00, 3.45it/s, MSE=0.101] 01:09:08 - INFO: Starting epoch 56: 100%|██████████| 340/340 [02:15<00:00, 3.22it/s, MSE=0.016] 01:11:23 - INFO: Starting epoch 57: 100%|██████████| 340/340 [02:14<00:00, 3.21it/s, MSE=0.0173] 01:13:38 - INFO: Starting epoch 58: 100%|██████████| 340/340 [02:13<00:00, 2.65it/s, MSE=0.0127] 01:15:52 - INFO: Starting epoch 59: 100%|██████████| 340/340 [02:14<00:00, 3.56it/s, MSE=0.112] 01:18:06 - INFO: Starting epoch 60: 100%|██████████| 340/340 [02:14<00:00, 3.01it/s, MSE=0.0155] 01:20:21 - INFO: Sampling 8 new images.... 499it [00:21, 23.74it/s]
<Figure size 640x480 with 8 Axes>
01:20:42 - INFO: Starting epoch 61: 100%|██████████| 340/340 [02:15<00:00, 3.17it/s, MSE=0.0143] 01:22:58 - INFO: Starting epoch 62: 100%|██████████| 340/340 [02:15<00:00, 3.26it/s, MSE=0.0731] 01:25:14 - INFO: Starting epoch 63: 100%|██████████| 340/340 [02:14<00:00, 3.38it/s, MSE=0.0484] 01:27:28 - INFO: Starting epoch 64: 100%|██████████| 340/340 [02:16<00:00, 3.30it/s, MSE=0.0154] 01:29:45 - INFO: Starting epoch 65: 100%|██████████| 340/340 [02:15<00:00, 3.31it/s, MSE=0.0224] 01:32:00 - INFO: Starting epoch 66: 100%|██████████| 340/340 [02:15<00:00, 3.14it/s, MSE=0.0265] 01:34:16 - INFO: Starting epoch 67: 100%|██████████| 340/340 [02:14<00:00, 3.10it/s, MSE=0.0326] 01:36:30 - INFO: Starting epoch 68: 100%|██████████| 340/340 [02:14<00:00, 3.35it/s, MSE=0.0656] 01:38:44 - INFO: Starting epoch 69: 100%|██████████| 340/340 [02:14<00:00, 3.20it/s, MSE=0.0591] 01:40:58 - INFO: Starting epoch 70: 100%|██████████| 340/340 [02:13<00:00, 3.34it/s, MSE=0.0196] 01:43:12 - INFO: Starting epoch 71: 100%|██████████| 340/340 [02:15<00:00, 2.64it/s, MSE=0.021] 01:45:28 - INFO: Starting epoch 72: 100%|██████████| 340/340 [02:14<00:00, 2.85it/s, MSE=0.0166] 01:47:42 - INFO: Starting epoch 73: 100%|██████████| 340/340 [02:15<00:00, 3.31it/s, MSE=0.0408] 01:49:57 - INFO: Starting epoch 74: 100%|██████████| 340/340 [02:14<00:00, 3.06it/s, MSE=0.0705] 01:52:12 - INFO: Starting epoch 75: 100%|██████████| 340/340 [02:14<00:00, 3.06it/s, MSE=0.0326] 01:54:26 - INFO: Starting epoch 76: 100%|██████████| 340/340 [02:13<00:00, 3.55it/s, MSE=0.016] 01:56:39 - INFO: Starting epoch 77: 100%|██████████| 340/340 [02:13<00:00, 2.98it/s, MSE=0.0122] 01:58:53 - INFO: Starting epoch 78: 100%|██████████| 340/340 [02:13<00:00, 3.57it/s, MSE=0.0304] 02:01:06 - INFO: Starting epoch 79: 100%|██████████| 340/340 [02:14<00:00, 3.17it/s, MSE=0.0186] 02:03:21 - INFO: Starting epoch 80: 100%|██████████| 340/340 [02:14<00:00, 3.37it/s, MSE=0.0248] 02:05:35 - INFO: Sampling 8 new images.... 499it [00:21, 22.82it/s]
<Figure size 640x480 with 8 Axes>
02:05:57 - INFO: Starting epoch 81: 100%|██████████| 340/340 [02:13<00:00, 2.93it/s, MSE=0.0321] 02:08:11 - INFO: Starting epoch 82: 100%|██████████| 340/340 [02:15<00:00, 2.76it/s, MSE=0.0274] 02:10:26 - INFO: Starting epoch 83: 100%|██████████| 340/340 [02:16<00:00, 3.49it/s, MSE=0.0069] 02:12:42 - INFO: Starting epoch 84: 100%|██████████| 340/340 [02:13<00:00, 3.05it/s, MSE=0.0847] 02:14:56 - INFO: Starting epoch 85: 100%|██████████| 340/340 [02:13<00:00, 3.23it/s, MSE=0.0237] 02:17:09 - INFO: Starting epoch 86: 100%|██████████| 340/340 [02:13<00:00, 2.71it/s, MSE=0.0124] 02:19:23 - INFO: Starting epoch 87: 100%|██████████| 340/340 [02:14<00:00, 3.69it/s, MSE=0.0537] 02:21:37 - INFO: Starting epoch 88: 100%|██████████| 340/340 [02:13<00:00, 3.13it/s, MSE=0.0463] 02:23:51 - INFO: Starting epoch 89: 100%|██████████| 340/340 [02:13<00:00, 2.85it/s, MSE=0.0137] 02:26:04 - INFO: Starting epoch 90: 100%|██████████| 340/340 [02:12<00:00, 3.05it/s, MSE=0.0198] 02:28:17 - INFO: Starting epoch 91: 100%|██████████| 340/340 [02:12<00:00, 3.31it/s, MSE=0.0205] 02:30:30 - INFO: Starting epoch 92: 100%|██████████| 340/340 [02:12<00:00, 2.79it/s, MSE=0.0146] 02:32:43 - INFO: Starting epoch 93: 100%|██████████| 340/340 [02:12<00:00, 2.94it/s, MSE=0.00888] 02:34:56 - INFO: Starting epoch 94: 100%|██████████| 340/340 [02:12<00:00, 3.20it/s, MSE=0.0572] 02:37:08 - INFO: Starting epoch 95: 100%|██████████| 340/340 [02:13<00:00, 3.11it/s, MSE=0.021] 02:39:22 - INFO: Starting epoch 96: 100%|██████████| 340/340 [02:13<00:00, 3.24it/s, MSE=0.0392] 02:41:35 - INFO: Starting epoch 97: 100%|██████████| 340/340 [02:12<00:00, 2.66it/s, MSE=0.0166] 02:43:48 - INFO: Starting epoch 98: 100%|██████████| 340/340 [02:14<00:00, 2.51it/s, MSE=0.0591] 02:46:03 - INFO: Starting epoch 99: 100%|██████████| 340/340 [02:16<00:00, 3.14it/s, MSE=0.0283] 02:48:19 - INFO: Starting epoch 100: 100%|██████████| 340/340 [02:13<00:00, 3.19it/s, MSE=0.0276] 02:50:33 - INFO: Sampling 8 new images.... 499it [00:21, 23.23it/s]
<Figure size 640x480 with 8 Axes>
02:50:55 - INFO: Starting epoch 101: 100%|██████████| 340/340 [02:14<00:00, 3.48it/s, MSE=0.0293] 02:53:10 - INFO: Starting epoch 102: 100%|██████████| 340/340 [02:16<00:00, 3.12it/s, MSE=0.0518] 02:55:27 - INFO: Starting epoch 103: 100%|██████████| 340/340 [02:14<00:00, 3.46it/s, MSE=0.0133] 02:57:42 - INFO: Starting epoch 104: 100%|██████████| 340/340 [02:15<00:00, 3.32it/s, MSE=0.0207] 02:59:58 - INFO: Starting epoch 105: 100%|██████████| 340/340 [02:14<00:00, 3.26it/s, MSE=0.00727] 03:02:12 - INFO: Starting epoch 106: 100%|██████████| 340/340 [02:15<00:00, 3.81it/s, MSE=0.0319] 03:04:28 - INFO: Starting epoch 107: 100%|██████████| 340/340 [02:15<00:00, 3.11it/s, MSE=0.0348] 03:06:44 - INFO: Starting epoch 108: 100%|██████████| 340/340 [02:15<00:00, 3.34it/s, MSE=0.0245] 03:08:59 - INFO: Starting epoch 109: 100%|██████████| 340/340 [02:15<00:00, 3.24it/s, MSE=0.0139] 03:11:14 - INFO: Starting epoch 110: 100%|██████████| 340/340 [02:15<00:00, 3.23it/s, MSE=0.0311] 03:13:29 - INFO: Starting epoch 111: 100%|██████████| 340/340 [02:15<00:00, 3.53it/s, MSE=0.0234] 03:15:45 - INFO: Starting epoch 112: 100%|██████████| 340/340 [02:16<00:00, 3.13it/s, MSE=0.0158] 03:18:01 - INFO: Starting epoch 113: 100%|██████████| 340/340 [02:15<00:00, 3.44it/s, MSE=0.0315] 03:20:17 - INFO: Starting epoch 114: 100%|██████████| 340/340 [02:13<00:00, 3.16it/s, MSE=0.0187] 03:22:30 - INFO: Starting epoch 115: 100%|██████████| 340/340 [02:13<00:00, 3.23it/s, MSE=0.0228] 03:24:43 - INFO: Starting epoch 116: 100%|██████████| 340/340 [02:14<00:00, 3.04it/s, MSE=0.0607] 03:26:57 - INFO: Starting epoch 117: 100%|██████████| 340/340 [02:13<00:00, 3.34it/s, MSE=0.0217] 03:29:10 - INFO: Starting epoch 118: 100%|██████████| 340/340 [02:13<00:00, 3.28it/s, MSE=0.0131] 03:31:24 - INFO: Starting epoch 119: 100%|██████████| 340/340 [02:15<00:00, 3.54it/s, MSE=0.0618] 03:33:39 - INFO: Starting epoch 120: 100%|██████████| 340/340 [02:15<00:00, 3.08it/s, MSE=0.0388] 03:35:55 - INFO: Sampling 8 new images.... 499it [00:21, 23.36it/s]
<Figure size 640x480 with 8 Axes>
03:36:16 - INFO: Starting epoch 121: 100%|██████████| 340/340 [02:19<00:00, 3.14it/s, MSE=0.0142] 03:38:36 - INFO: Starting epoch 122: 100%|██████████| 340/340 [02:19<00:00, 2.97it/s, MSE=0.0112] 03:40:56 - INFO: Starting epoch 123: 100%|██████████| 340/340 [02:19<00:00, 2.84it/s, MSE=0.0243] 03:43:15 - INFO: Starting epoch 124: 100%|██████████| 340/340 [02:19<00:00, 3.11it/s, MSE=0.0312] 03:45:35 - INFO: Starting epoch 125: 100%|██████████| 340/340 [02:19<00:00, 3.26it/s, MSE=0.0513] 03:47:54 - INFO: Starting epoch 126: 100%|██████████| 340/340 [02:18<00:00, 3.10it/s, MSE=0.0254] 03:50:13 - INFO: Starting epoch 127: 100%|██████████| 340/340 [02:17<00:00, 3.18it/s, MSE=0.00965] 03:52:30 - INFO: Starting epoch 128: 100%|██████████| 340/340 [02:17<00:00, 3.35it/s, MSE=0.0183] 03:54:47 - INFO: Starting epoch 129: 100%|██████████| 340/340 [02:17<00:00, 3.36it/s, MSE=0.0158] 03:57:05 - INFO: Starting epoch 130: 100%|██████████| 340/340 [02:18<00:00, 3.29it/s, MSE=0.0326] 03:59:24 - INFO: Starting epoch 131: 100%|██████████| 340/340 [02:17<00:00, 3.18it/s, MSE=0.0224] 04:01:42 - INFO: Starting epoch 132: 100%|██████████| 340/340 [02:16<00:00, 3.11it/s, MSE=0.0367] 04:03:58 - INFO: Starting epoch 133: 100%|██████████| 340/340 [02:18<00:00, 2.95it/s, MSE=0.0231] 04:06:16 - INFO: Starting epoch 134: 100%|██████████| 340/340 [02:19<00:00, 3.34it/s, MSE=0.0195] 04:08:35 - INFO: Starting epoch 135: 100%|██████████| 340/340 [02:18<00:00, 3.30it/s, MSE=0.00914] 04:10:54 - INFO: Starting epoch 136: 100%|██████████| 340/340 [02:19<00:00, 2.76it/s, MSE=0.0355] 04:13:13 - INFO: Starting epoch 137: 100%|██████████| 340/340 [02:19<00:00, 3.14it/s, MSE=0.0365] 04:15:33 - INFO: Starting epoch 138: 100%|██████████| 340/340 [02:20<00:00, 3.38it/s, MSE=0.0182] 04:17:53 - INFO: Starting epoch 139: 100%|██████████| 340/340 [02:18<00:00, 3.19it/s, MSE=0.057] 04:20:11 - INFO: Starting epoch 140: 100%|██████████| 340/340 [02:16<00:00, 3.27it/s, MSE=0.0156] 04:22:28 - INFO: Sampling 8 new images.... 499it [00:21, 22.81it/s]
<Figure size 640x480 with 8 Axes>
04:22:51 - INFO: Starting epoch 141: 100%|██████████| 340/340 [02:17<00:00, 3.11it/s, MSE=0.0256] 04:25:09 - INFO: Starting epoch 142: 100%|██████████| 340/340 [02:16<00:00, 2.82it/s, MSE=0.0271] 04:27:26 - INFO: Starting epoch 143: 100%|██████████| 340/340 [02:16<00:00, 3.35it/s, MSE=0.041] 04:29:42 - INFO: Starting epoch 144: 100%|██████████| 340/340 [02:16<00:00, 3.04it/s, MSE=0.0126] 04:31:59 - INFO: Starting epoch 145: 100%|██████████| 340/340 [02:16<00:00, 3.38it/s, MSE=0.0186] 04:34:16 - INFO: Starting epoch 146: 100%|██████████| 340/340 [02:19<00:00, 3.21it/s, MSE=0.0195] 04:36:36 - INFO: Starting epoch 147: 100%|██████████| 340/340 [02:19<00:00, 2.58it/s, MSE=0.00809] 04:38:55 - INFO: Starting epoch 148: 100%|██████████| 340/340 [02:20<00:00, 3.04it/s, MSE=0.0113] 04:41:15 - INFO: Starting epoch 149: 100%|██████████| 340/340 [02:19<00:00, 3.17it/s, MSE=0.013]
2.6、使用训练好的模型进行采样
我们可以加载训练时觉得不错的模型进行采样生成。这个项目仅作为演示,生成汽车可能并不具备特别的价值。但是最新的novelai已经可以生成超高水平的二次元绘画,所以通过这个项目帮助我们理解diffusion模型的底层原理,可以让未来接触更多改进版的diffusion模型更加轻松。
In [6]
import paddlemodel = UNet()
model.set_state_dict(paddle.load("car_models/ddpm_uncond140.pdparams")) # 加载模型文件
diffusion = Diffusion(img_size=64, device="cuda")sampled_images = diffusion.sample(model, n=8)# 采样图片
for i in range(8):img = sampled_images[i].transpose([1, 2, 0])img = np.array(img).astype("uint8")plt.subplot(2, 4,i+1)plt.imshow(img)
plt.show()
05:37:15 - INFO: Sampling 8 new images.... 499it [00:22, 22.61it/s]
<Figure size 640x480 with 8 Axes>
3、条件生成(通过标签指导图片生成)
3.1、训练过程解析
同非条件生成一样,我们使用前向过程采样得到标签,训练时使用Unet网络结构,同时在模型的输入中嵌入时间步的编码。这类似于transformer模型中的位置编码,让模型更容易训练。 这里我们额外添加类别的标签编码,也作为模型的输入。其中cfg表示条件生成与非条件生成之间的比值,cfg越大,生成的图像中条件生成的比例就越大(生成图像=(1-alpha)* 条件生成+(alpha)* 非条件生成),其中alpha与cfg相关。
- ——cfg, classifier free guidance(标签引导)
另一方面,下面这个训练使用了上一代模型与当前模型参数的指数平均,削减因为离群点对模型参数更新的影响,从而实现更稳定的梯度更新。
- ——ema, exponential moving average(指数移动平均)
运行下面代码前先重启内核!清空显存占用。
3.2、解压数据集
我们使用花朵数据集,包含5种种类,这样后面我们在采样时就可以指定其中一种种类进行生成。
In [1]
# 解压花朵数据集
import os
if not os.path.exists("work/flowers"):!mkdir work/flowers
!unzip -oq data/data173680/flowers.zip -d work/flowers
In [2]
# 加载数据集
"""由于条件生成需要同时提供图片标签,因此我们这里自定义数据集"""# 1、将图片数据写入txt文件。flowers本来是分类数据集,这里我们把他的训练集和验证集都提取出来,当作我们生成模型的训练集。
import os
train_sunflower = os.listdir("work/flowers/pic/train/sunflower") # 0——向日葵
valid_sunflower = os.listdir("work/flowers/pic/validation/sunflower") # 0——向日葵
train_rose = os.listdir("work/flowers/pic/train/rose") # 1——玫瑰
valid_rose = os.listdir("work/flowers/pic/validation/rose") # 1——玫瑰
train_tulip = os.listdir("work/flowers/pic/train/tulip") # 2——郁金香
valid_tulip = os.listdir("work/flowers/pic/validation/tulip") # 2——郁金香
train_dandelion = os.listdir("work/flowers/pic/train/dandelion") # 3——蒲公英
valid_dandelion = os.listdir("work/flowers/pic/validation/dandelion") # 3——蒲公英
train_daisy = os.listdir("work/flowers/pic/train/daisy") # 4——雏菊
valid_daisy = os.listdir("work/flowers/pic/validation/daisy") # 4——雏菊with open("flowers_data.txt", 'w') as f:for image in train_sunflower:f.write("work/flowers/pic/train/sunflower/" + image + ";" + "0" + "\n")for image in valid_sunflower:f.write("work/flowers/pic/validation/sunflower/" + image + ";" + "0" + "\n")for image in train_rose:f.write("work/flowers/pic/train/rose/" + image + ";" + "1" + "\n")for image in valid_rose:f.write("work/flowers/pic/validation/rose/" + image + ";" + "1" + "\n")for image in train_tulip:f.write("work/flowers/pic/train/tulip/" + image + ";" + "2" + "\n")for image in valid_tulip:f.write("work/flowers/pic/validation/tulip/" + image + ";" + "2" + "\n")for image in train_dandelion:f.write("work/flowers/pic/train/dandelion/" + image + ";" + "3" + "\n")for image in valid_dandelion:f.write("work/flowers/pic/validation/dandelion/" + image + ";" + "3" + "\n")for image in train_daisy:f.write("work/flowers/pic/train/daisy/" + image + ";" + "4" + "\n")for image in valid_daisy:f.write("work/flowers/pic/validation/daisy/" + image + ";" + "4" + "\n")
3.3、构建数据集
因为这里我们的数据迭代器需要同时返回图片及标签。所以我们使用基础api构建我们的数据集。
In [3]
# 2、构建数据集
# 数据变化,返回图片与标签
import paddle.vision as V
from PIL import Image
from paddle.io import Dataset, DataLoader
from tqdm import tqdm# 数据变换
transforms = V.transforms.Compose([V.transforms.Resize(80), # args.image_size + 1/4 *args.image_sizeV.transforms.RandomResizedCrop(64, scale=(0.8, 1.0)),V.transforms.ToTensor(),V.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])class TrainDataFlowers(Dataset):def __init__(self, txt_path="flowers_data.txt"):with open(txt_path, "r") as f:data = f.readlines()self.image_paths = data[:-1] # 最后一行是空行,舍弃def __getitem__(self, index):image_path, label = self.image_paths[index].split(";")image = Image.open(image_path)image = transforms(image)label = int(label)return image, labeldef __len__(self):return len(self.image_paths)dataset = TrainDataFlowers()
dataloader = DataLoader(dataset, batch_size=24, shuffle=True)if __name__ == "__main__": # 测试数据集是否可用pbar = tqdm(dataloader)for i, (images, labels) in enumerate(pbar):passprint("ok")
0%| | 0/181 [00:00<?, ?it/s]W1023 15:49:27.184664 3398 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2 W1023 15:49:27.188580 3398 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2. 100%|██████████| 181/181 [00:15<00:00, 11.37it/s]
ok
3.4、训练流程
训练中我们可以修改ARGS类的参数进行超参数定义。基本上,只要知道我们的损失函数是两张图片之间的均方误差,代码部分会变得比较简单。对比GAN而言,diffusion的参数更加容易调整,也更容易训练。
In [4]
import os
import paddle
import copy
import paddle.nn as nn
from matplotlib import pyplot as plt
%matplotlib inline
from tqdm import tqdm
from paddle import optimizer
from modules import UNet_conditional, EMA
import logging
import numpy as np
logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")class Diffusion:def __init__(self, noise_steps=500, beta_start=1e-4, beta_end=0.02, img_size=256, device="cuda"):self.noise_steps = noise_stepsself.beta_start = beta_startself.beta_end = beta_endself.beta = self.prepare_noise_schedule()self.alpha = 1. - self.betaself.alpha_hat = paddle.cumprod(self.alpha, dim=0)self.img_size = img_sizeself.device = devicedef prepare_noise_schedule(self):return paddle.linspace(self.beta_start, self.beta_end, self.noise_steps)def noise_images(self, x, t):sqrt_alpha_hat = paddle.sqrt(self.alpha_hat[t])[:, None, None, None]sqrt_one_minus_alpha_hat = paddle.sqrt(1 - self.alpha_hat[t])[:, None, None, None]Ɛ = paddle.randn(shape=x.shape)return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛdef sample_timesteps(self, n):return paddle.randint(low=1, high=self.noise_steps, shape=(n,))def sample(self, model, n, labels, cfg_scale=3):logging.info(f"Sampling {n} new images....")model.eval()with paddle.no_grad():x = paddle.randn((n, 3, self.img_size, self.img_size))for i in tqdm(reversed(range(1, self.noise_steps)), position=0):t = paddle.to_tensor([i] * x.shape[0]).astype("int64")predicted_noise = model(x, t, labels)if cfg_scale > 0:uncond_predicted_noise = model(x, t, None)cfg_scale = paddle.to_tensor(cfg_scale).astype("float32")predicted_noise = paddle.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)alpha = self.alpha[t][:, None, None, None]alpha_hat = self.alpha_hat[t][:, None, None, None]beta = self.beta[t][:, None, None, None]if i > 1:noise = paddle.randn(shape=x.shape)else:noise = paddle.zeros_like(x)x = 1 / paddle.sqrt(alpha) * (x - ((1 - alpha) / (paddle.sqrt(1 - alpha_hat))) * predicted_noise) + paddle.sqrt(beta) * noisemodel.train()x = (x.clip(-1, 1) + 1) / 2x = (x * 255)return xdef train(args):# setup_logging(args.run_name)device = args.devicedataloader = args.dataloadermodel = UNet_conditional(num_classes=args.num_classes)opt = optimizer.Adam(learning_rate=args.lr, parameters=model.parameters())mse = nn.MSELoss()diffusion = Diffusion(img_size=args.image_size, device=device)l = len(dataloader)ema = EMA(0.995)ema_model = copy.deepcopy(model)ema_model.eval()# print("ema_model", ema_model)for epoch in range(args.epochs):logging.info(f"Starting epoch {epoch}:")pbar = tqdm(dataloader)for i, (images, labels) in enumerate(pbar):t = diffusion.sample_timesteps(images.shape[0])x_t, noise = diffusion.noise_images(images, t)if np.random.random() < 0.1:labels = Nonepredicted_noise = model(x_t, t, labels)loss = mse(noise, predicted_noise) # 损失函数opt.clear_grad()loss.backward()opt.step()ema.step_ema(ema_model, model)pbar.set_postfix(MSE=loss.item())# logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)if epoch % 30 == 0: # 保存模型,可视化训练结果。paddle.save(model.state_dict(), f"models/ddpm_cond{epoch}.pdparams")labels = paddle.arange(5).astype("int64")# 一共采样10张图片# 从左到右依次为-->向日葵,玫瑰,郁金香,蒲公英,雏菊sampled_images1 = diffusion.sample(model, n=len(labels), labels=labels)sampled_images2 = diffusion.sample(model, n=len(labels), labels=labels)# ema_sampled_images = diffusion.sample(ema_model, n=len(labels), labels=labels)for i in range(5):img = sampled_images1[i].transpose([1, 2, 0])img = np.array(img).astype("uint8")plt.subplot(2,5,i+1)plt.imshow(img)for i in range(5):img = sampled_images2[i].transpose([1, 2, 0])img = np.array(img).astype("uint8")plt.subplot(2,5,i+1+5)plt.imshow(img)plt.show()def launch():import argparse# 参数设置class ARGS:def __init__(self):self.run_name = "DDPM_Uncondtional"self.epochs = 300self.batch_size = 48self.image_size = 64self.device = "cuda"self.lr = 1.5e-4self.num_classes = 5self.dataloader = dataloaderargs = ARGS()train(args)if __name__ == '__main__':# 训练launch()pass
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import MutableMapping /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import Iterable, Mapping /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import Sized 03:56:58 - INFO: Starting epoch 0: 100%|██████████| 181/181 [01:04<00:00, 3.76it/s, MSE=0.172] 03:58:03 - INFO: Sampling 5 new images.... 499it [00:44, 11.13it/s] 03:58:48 - INFO: Sampling 5 new images.... 499it [00:44, 11.28it/s] /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingif isinstance(obj, collections.Iterator): /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingreturn list(data) if isinstance(data, collections.MappingView) else data
<Figure size 640x480 with 10 Axes>
03:59:33 - INFO: Starting epoch 1: 100%|██████████| 181/181 [01:02<00:00, 3.81it/s, MSE=0.104] 04:00:36 - INFO: Starting epoch 2: 100%|██████████| 181/181 [01:02<00:00, 3.78it/s, MSE=0.103] 04:01:38 - INFO: Starting epoch 3: 100%|██████████| 181/181 [01:02<00:00, 3.75it/s, MSE=0.0912] 04:02:41 - INFO: Starting epoch 4: 100%|██████████| 181/181 [01:02<00:00, 3.80it/s, MSE=0.0649] 04:03:43 - INFO: Starting epoch 5: 100%|██████████| 181/181 [00:59<00:00, 4.66it/s, MSE=0.0631] 04:04:43 - INFO: Starting epoch 6: 100%|██████████| 181/181 [00:55<00:00, 4.72it/s, MSE=0.179] 04:05:38 - INFO: Starting epoch 7: 100%|██████████| 181/181 [00:55<00:00, 4.59it/s, MSE=0.0908] 04:06:33 - INFO: Starting epoch 8: 100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.158] 04:07:29 - INFO: Starting epoch 9: 100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.171] 04:08:24 - INFO: Starting epoch 10: 100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0362] 04:09:20 - INFO: Starting epoch 11: 100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0444] 04:10:16 - INFO: Starting epoch 12: 100%|██████████| 181/181 [00:55<00:00, 4.63it/s, MSE=0.0393] 04:11:12 - INFO: Starting epoch 13: 100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.064] 04:12:07 - INFO: Starting epoch 14: 100%|██████████| 181/181 [00:55<00:00, 4.67it/s, MSE=0.035] 04:13:03 - INFO: Starting epoch 15: 100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.063] 04:13:58 - INFO: Starting epoch 16: 100%|██████████| 181/181 [00:55<00:00, 4.62it/s, MSE=0.0157] 04:14:54 - INFO: Starting epoch 17: 100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.0159] 04:15:49 - INFO: Starting epoch 18: 100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0212] 04:16:45 - INFO: Starting epoch 19: 100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0252] 04:17:40 - INFO: Starting epoch 20: 100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.0192] 04:18:35 - INFO: Starting epoch 21: 100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0361] 04:19:31 - INFO: Starting epoch 22: 100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.0177] 04:20:26 - INFO: Starting epoch 23: 100%|██████████| 181/181 [00:55<00:00, 4.54it/s, MSE=0.0527] 04:21:22 - INFO: Starting epoch 24: 100%|██████████| 181/181 [00:55<00:00, 4.59it/s, MSE=0.0458] 04:22:17 - INFO: Starting epoch 25: 100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0539] 04:23:13 - INFO: Starting epoch 26: 100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.205] 04:24:09 - INFO: Starting epoch 27: 100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0463] 04:25:04 - INFO: Starting epoch 28: 100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.152] 04:26:00 - INFO: Starting epoch 29: 100%|██████████| 181/181 [00:55<00:00, 4.59it/s, MSE=0.284] 04:26:55 - INFO: Starting epoch 30: 100%|██████████| 181/181 [00:55<00:00, 4.71it/s, MSE=0.0896] 04:27:51 - INFO: Sampling 5 new images.... 499it [00:44, 11.27it/s] 04:28:36 - INFO: Sampling 5 new images.... 499it [00:45, 11.01it/s]
<Figure size 640x480 with 10 Axes>
04:29:21 - INFO: Starting epoch 31: 100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.299] 04:30:17 - INFO: Starting epoch 32: 100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.0226] 04:31:12 - INFO: Starting epoch 33: 100%|██████████| 181/181 [00:55<00:00, 4.70it/s, MSE=0.00727] 04:32:08 - INFO: Starting epoch 34: 100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.132] 04:33:03 - INFO: Starting epoch 35: 100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0498] 04:33:59 - INFO: Starting epoch 36: 100%|██████████| 181/181 [00:55<00:00, 4.62it/s, MSE=0.0107] 04:34:55 - INFO: Starting epoch 37: 100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.0116] 04:35:50 - INFO: Starting epoch 38: 100%|██████████| 181/181 [00:55<00:00, 4.56it/s, MSE=0.044] 04:36:46 - INFO: Starting epoch 39: 100%|██████████| 181/181 [00:55<00:00, 4.71it/s, MSE=0.167] 04:37:41 - INFO: Starting epoch 40: 100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0359] 04:38:37 - INFO: Starting epoch 41: 100%|██████████| 181/181 [00:55<00:00, 4.67it/s, MSE=0.0064] 04:39:33 - INFO: Starting epoch 42: 100%|██████████| 181/181 [00:55<00:00, 4.71it/s, MSE=0.0107] 04:40:28 - INFO: Starting epoch 43: 100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0216] 04:41:24 - INFO: Starting epoch 44: 100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0361] 04:42:20 - INFO: Starting epoch 45: 100%|██████████| 181/181 [00:55<00:00, 4.59it/s, MSE=0.0368] 04:43:15 - INFO: Starting epoch 46: 100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0283] 04:44:10 - INFO: Starting epoch 47: 100%|██████████| 181/181 [00:55<00:00, 4.63it/s, MSE=0.0352] 04:45:06 - INFO: Starting epoch 48: 100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0499] 04:46:01 - INFO: Starting epoch 49: 100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.0359] 04:46:56 - INFO: Starting epoch 50: 100%|██████████| 181/181 [00:55<00:00, 4.62it/s, MSE=0.0555] 04:47:52 - INFO: Starting epoch 51: 100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.14] 04:48:47 - INFO: Starting epoch 52: 100%|██████████| 181/181 [00:54<00:00, 4.65it/s, MSE=0.0136] 04:49:42 - INFO: Starting epoch 53: 100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0242] 04:50:38 - INFO: Starting epoch 54: 100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0252] 04:51:33 - INFO: Starting epoch 55: 100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.0274] 04:52:29 - INFO: Starting epoch 56: 100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.0727] 04:53:24 - INFO: Starting epoch 57: 100%|██████████| 181/181 [00:55<00:00, 4.72it/s, MSE=0.023] 04:54:20 - INFO: Starting epoch 58: 100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0457] 04:55:15 - INFO: Starting epoch 59: 100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0123] 04:56:10 - INFO: Starting epoch 60: 100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0125] 04:57:07 - INFO: Sampling 5 new images.... 499it [00:45, 10.91it/s] 04:57:52 - INFO: Sampling 5 new images.... 499it [00:45, 10.90it/s]
<Figure size 640x480 with 10 Axes>
04:58:39 - INFO: Starting epoch 61: 100%|██████████| 181/181 [00:56<00:00, 4.53it/s, MSE=0.00765] 04:59:35 - INFO: Starting epoch 62: 100%|██████████| 181/181 [00:55<00:00, 4.63it/s, MSE=0.00355] 05:00:30 - INFO: Starting epoch 63: 100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0256] 05:01:26 - INFO: Starting epoch 64: 100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0413] 05:02:22 - INFO: Starting epoch 65: 100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0146] 05:03:17 - INFO: Starting epoch 66: 100%|██████████| 181/181 [00:56<00:00, 4.57it/s, MSE=0.00737] 05:04:13 - INFO: Starting epoch 67: 100%|██████████| 181/181 [00:56<00:00, 4.63it/s, MSE=0.00363] 05:05:09 - INFO: Starting epoch 68: 100%|██████████| 181/181 [00:56<00:00, 4.58it/s, MSE=0.121] 05:06:06 - INFO: Starting epoch 69: 100%|██████████| 181/181 [00:56<00:00, 4.53it/s, MSE=0.0124] 05:07:02 - INFO: Starting epoch 70: 100%|██████████| 181/181 [00:56<00:00, 4.53it/s, MSE=0.0235] 05:07:59 - INFO: Starting epoch 71: 100%|██████████| 181/181 [00:55<00:00, 4.52it/s, MSE=0.084] 05:08:55 - INFO: Starting epoch 72: 100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.022] 05:09:50 - INFO: Starting epoch 73: 100%|██████████| 181/181 [00:56<00:00, 4.61it/s, MSE=0.00922] 05:10:47 - INFO: Starting epoch 74: 100%|██████████| 181/181 [00:56<00:00, 4.27it/s, MSE=0.0059] 05:11:43 - INFO: Starting epoch 75: 100%|██████████| 181/181 [00:56<00:00, 4.60it/s, MSE=0.00901] 05:12:40 - INFO: Starting epoch 76: 100%|██████████| 181/181 [00:56<00:00, 4.60it/s, MSE=0.0261] 05:13:36 - INFO: Starting epoch 77: 100%|██████████| 181/181 [00:55<00:00, 4.72it/s, MSE=0.0317] 05:14:32 - INFO: Starting epoch 78: 100%|██████████| 181/181 [00:55<00:00, 4.70it/s, MSE=0.0379] 05:15:27 - INFO: Starting epoch 79: 100%|██████████| 181/181 [00:54<00:00, 4.62it/s, MSE=0.0126] 05:16:22 - INFO: Starting epoch 80: 100%|██████████| 181/181 [00:55<00:00, 4.57it/s, MSE=0.0129] 05:17:17 - INFO: Starting epoch 81: 100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0174] 05:18:13 - INFO: Starting epoch 82: 100%|██████████| 181/181 [00:57<00:00, 4.59it/s, MSE=0.00267] 05:19:11 - INFO: Starting epoch 83: 100%|██████████| 181/181 [00:57<00:00, 4.61it/s, MSE=0.00863] 05:20:08 - INFO: Starting epoch 84: 100%|██████████| 181/181 [00:55<00:00, 4.59it/s, MSE=0.0928] 05:21:04 - INFO: Starting epoch 85: 100%|██████████| 181/181 [00:55<00:00, 4.71it/s, MSE=0.0151] 05:21:59 - INFO: Starting epoch 86: 100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0231] 05:22:55 - INFO: Starting epoch 87: 100%|██████████| 181/181 [00:55<00:00, 4.50it/s, MSE=0.0442] 05:23:51 - INFO: Starting epoch 88: 100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.00999] 05:24:47 - INFO: Starting epoch 89: 100%|██████████| 181/181 [00:55<00:00, 4.57it/s, MSE=0.00467] 05:25:42 - INFO: Starting epoch 90: 100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.0219] 05:26:38 - INFO: Sampling 5 new images.... 499it [00:44, 11.12it/s] 05:27:23 - INFO: Sampling 5 new images.... 499it [00:45, 11.06it/s]
<Figure size 640x480 with 10 Axes>
05:28:09 - INFO: Starting epoch 91: 100%|██████████| 181/181 [00:55<00:00, 4.70it/s, MSE=0.00285] 05:29:05 - INFO: Starting epoch 92: 100%|██████████| 181/181 [00:55<00:00, 4.67it/s, MSE=0.112] 05:30:00 - INFO: Starting epoch 93: 100%|██████████| 181/181 [00:55<00:00, 4.67it/s, MSE=0.0108] 05:30:56 - INFO: Starting epoch 94: 100%|██████████| 181/181 [00:55<00:00, 4.62it/s, MSE=0.0281] 05:31:51 - INFO: Starting epoch 95: 100%|██████████| 181/181 [00:55<00:00, 4.70it/s, MSE=0.0355] 05:32:47 - INFO: Starting epoch 96: 100%|██████████| 181/181 [00:55<00:00, 4.62it/s, MSE=0.133] 05:33:42 - INFO: Starting epoch 97: 100%|██████████| 181/181 [00:56<00:00, 4.56it/s, MSE=0.0138] 05:34:39 - INFO: Starting epoch 98: 100%|██████████| 181/181 [00:56<00:00, 4.66it/s, MSE=0.00963] 05:35:35 - INFO: Starting epoch 99: 100%|██████████| 181/181 [00:56<00:00, 4.59it/s, MSE=0.0298] 05:36:31 - INFO: Starting epoch 100: 100%|██████████| 181/181 [00:56<00:00, 4.65it/s, MSE=0.00709] 05:37:27 - INFO: Starting epoch 101: 100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.0737] 05:38:23 - INFO: Starting epoch 102: 100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0105] 05:39:18 - INFO: Starting epoch 103: 100%|██████████| 181/181 [00:56<00:00, 4.56it/s, MSE=0.00631] 05:40:14 - INFO: Starting epoch 104: 100%|██████████| 181/181 [00:55<00:00, 4.55it/s, MSE=0.00662] 05:41:10 - INFO: Starting epoch 105: 100%|██████████| 181/181 [00:55<00:00, 4.67it/s, MSE=0.262] 05:42:05 - INFO: Starting epoch 106: 100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.0206] 05:43:00 - INFO: Starting epoch 107: 100%|██████████| 181/181 [00:55<00:00, 4.62it/s, MSE=0.00979] 05:43:56 - INFO: Starting epoch 108: 100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.0121] 05:44:52 - INFO: Starting epoch 109: 100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.00493] 05:45:48 - INFO: Starting epoch 110: 100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0158] 05:46:43 - INFO: Starting epoch 111: 100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.00567] 05:47:39 - INFO: Starting epoch 112: 100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.00994] 05:48:34 - INFO: Starting epoch 113: 100%|██████████| 181/181 [00:56<00:00, 4.57it/s, MSE=0.00712] 05:49:31 - INFO: Starting epoch 114: 100%|██████████| 181/181 [00:56<00:00, 4.61it/s, MSE=0.0414] 05:50:27 - INFO: Starting epoch 115: 100%|██████████| 181/181 [00:56<00:00, 4.61it/s, MSE=0.00445] 05:51:23 - INFO: Starting epoch 116: 100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0967] 05:52:19 - INFO: Starting epoch 117: 100%|██████████| 181/181 [00:55<00:00, 4.55it/s, MSE=0.0384] 05:53:15 - INFO: Starting epoch 118: 100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0122] 05:54:10 - INFO: Starting epoch 119: 100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0342] 05:55:06 - INFO: Starting epoch 120: 100%|██████████| 181/181 [00:56<00:00, 4.63it/s, MSE=0.0257] 05:56:02 - INFO: Sampling 5 new images.... 499it [00:44, 11.29it/s] 05:56:46 - INFO: Sampling 5 new images.... 499it [00:45, 11.03it/s]
<Figure size 640x480 with 10 Axes>
05:57:32 - INFO: Starting epoch 121: 100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.00285] 05:58:28 - INFO: Starting epoch 122: 100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0274] 05:59:24 - INFO: Starting epoch 123: 100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.0629] 06:00:19 - INFO: Starting epoch 124: 100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0203] 06:01:15 - INFO: Starting epoch 125: 100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.0619] 06:02:10 - INFO: Starting epoch 126: 100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.0456] 06:03:05 - INFO: Starting epoch 127: 100%|██████████| 181/181 [00:55<00:00, 4.54it/s, MSE=0.0157] 06:04:01 - INFO: Starting epoch 128: 100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0136] 06:04:57 - INFO: Starting epoch 129: 100%|██████████| 181/181 [00:55<00:00, 4.63it/s, MSE=0.115] 06:05:52 - INFO: Starting epoch 130: 100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.0519] 06:06:48 - INFO: Starting epoch 131: 100%|██████████| 181/181 [00:56<00:00, 4.61it/s, MSE=0.0179] 06:07:44 - INFO: Starting epoch 132: 100%|██████████| 181/181 [00:56<00:00, 4.72it/s, MSE=0.0211] 06:08:40 - INFO: Starting epoch 133: 100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.0172] 06:09:36 - INFO: Starting epoch 134: 100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.134] 06:10:31 - INFO: Starting epoch 135: 100%|██████████| 181/181 [00:55<00:00, 4.56it/s, MSE=0.201] 06:11:26 - INFO: Starting epoch 136: 100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.0325] 06:12:22 - INFO: Starting epoch 137: 100%|██████████| 181/181 [00:55<00:00, 4.70it/s, MSE=0.0203] 06:13:17 - INFO: Starting epoch 138: 100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.00265] 06:14:13 - INFO: Starting epoch 139: 100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.00424] 06:15:08 - INFO: Starting epoch 140: 100%|██████████| 181/181 [00:55<00:00, 4.73it/s, MSE=0.00383] 06:16:03 - INFO: Starting epoch 141: 100%|██████████| 181/181 [00:54<00:00, 4.66it/s, MSE=0.0153] 06:16:58 - INFO: Starting epoch 142: 100%|██████████| 181/181 [00:54<00:00, 4.68it/s, MSE=0.0284] 06:17:53 - INFO: Starting epoch 143: 100%|██████████| 181/181 [00:55<00:00, 4.55it/s, MSE=0.00366] 06:18:48 - INFO: Starting epoch 144: 100%|██████████| 181/181 [00:56<00:00, 4.52it/s, MSE=0.0912] 06:19:44 - INFO: Starting epoch 145: 100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.00704] 06:20:40 - INFO: Starting epoch 146: 100%|██████████| 181/181 [00:55<00:00, 4.72it/s, MSE=0.0042] 06:21:36 - INFO: Starting epoch 147: 100%|██████████| 181/181 [00:54<00:00, 4.63it/s, MSE=0.011] 06:22:31 - INFO: Starting epoch 148: 100%|██████████| 181/181 [00:54<00:00, 4.66it/s, MSE=0.0379] 06:23:26 - INFO: Starting epoch 149: 100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.11] 06:24:21 - INFO: Starting epoch 150: 100%|██████████| 181/181 [00:55<00:00, 4.72it/s, MSE=0.00667] 06:25:17 - INFO: Sampling 5 new images.... 499it [00:43, 11.47it/s] 06:26:01 - INFO: Sampling 5 new images.... 499it [00:42, 11.65it/s]
<Figure size 640x480 with 10 Axes>
06:26:44 - INFO: Starting epoch 151: 100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.00426] 06:27:39 - INFO: Starting epoch 152: 100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.0859] 06:28:34 - INFO: Starting epoch 153: 100%|██████████| 181/181 [00:55<00:00, 4.72it/s, MSE=0.0238] 06:29:29 - INFO: Starting epoch 154: 100%|██████████| 181/181 [00:54<00:00, 4.69it/s, MSE=0.0261] 06:30:24 - INFO: Starting epoch 155: 100%|██████████| 181/181 [00:54<00:00, 4.73it/s, MSE=0.049] 06:31:19 - INFO: Starting epoch 156: 100%|██████████| 181/181 [00:54<00:00, 4.68it/s, MSE=0.00625] 06:32:14 - INFO: Starting epoch 157: 100%|██████████| 181/181 [00:55<00:00, 4.71it/s, MSE=0.0107] 06:33:09 - INFO: Starting epoch 158: 100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.13] 06:34:04 - INFO: Starting epoch 159: 100%|██████████| 181/181 [00:55<00:00, 4.67it/s, MSE=0.0495] 06:34:59 - INFO: Starting epoch 160: 100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0112] 06:35:54 - INFO: Starting epoch 161: 100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.00525] 06:36:49 - INFO: Starting epoch 162: 100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.00437] 06:37:44 - INFO: Starting epoch 163: 100%|██████████| 181/181 [00:54<00:00, 4.72it/s, MSE=0.00408] 06:38:39 - INFO: Starting epoch 164: 100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.0177] 06:39:35 - INFO: Starting epoch 165: 100%|██████████| 181/181 [00:55<00:00, 4.71it/s, MSE=0.00417] 06:40:30 - INFO: Starting epoch 166: 100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0786] 06:41:25 - INFO: Starting epoch 167: 100%|██████████| 181/181 [00:55<00:00, 4.67it/s, MSE=0.0205] 06:42:20 - INFO: Starting epoch 168: 100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.0952] 06:43:15 - INFO: Starting epoch 169: 100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0118] 06:44:10 - INFO: Starting epoch 170: 100%|██████████| 181/181 [00:55<00:00, 4.62it/s, MSE=0.253] 06:45:06 - INFO: Starting epoch 171: 100%|██████████| 181/181 [00:55<00:00, 4.71it/s, MSE=0.00373] 06:46:01 - INFO: Starting epoch 172: 100%|██████████| 181/181 [00:54<00:00, 4.67it/s, MSE=0.00618] 06:46:56 - INFO: Starting epoch 173: 100%|██████████| 181/181 [00:54<00:00, 4.67it/s, MSE=0.0083] 06:47:50 - INFO: Starting epoch 174: 100%|██████████| 181/181 [00:54<00:00, 4.67it/s, MSE=0.0171] 06:48:45 - INFO: Starting epoch 175: 100%|██████████| 181/181 [00:54<00:00, 4.60it/s, MSE=0.0216] 06:49:40 - INFO: Starting epoch 176: 100%|██████████| 181/181 [00:55<00:00, 4.71it/s, MSE=0.00168] 06:50:35 - INFO: Starting epoch 177: 100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.0166] 06:51:30 - INFO: Starting epoch 178: 100%|██████████| 181/181 [00:54<00:00, 4.72it/s, MSE=0.00656] 06:52:25 - INFO: Starting epoch 179: 100%|██████████| 181/181 [00:54<00:00, 4.68it/s, MSE=0.114] 06:53:20 - INFO: Starting epoch 180: 100%|██████████| 181/181 [00:54<00:00, 4.70it/s, MSE=0.00226] 06:54:15 - INFO: Sampling 5 new images.... 499it [00:42, 11.66it/s] 06:54:58 - INFO: Sampling 5 new images.... 499it [00:43, 11.56it/s]
<Figure size 640x480 with 10 Axes>
06:55:42 - INFO: Starting epoch 181: 100%|██████████| 181/181 [00:54<00:00, 4.72it/s, MSE=0.0484] 06:56:37 - INFO: Starting epoch 182: 100%|██████████| 181/181 [00:54<00:00, 4.69it/s, MSE=0.0196] 06:57:32 - INFO: Starting epoch 183: 100%|██████████| 181/181 [00:54<00:00, 4.68it/s, MSE=0.00695] 06:58:27 - INFO: Starting epoch 184: 100%|██████████| 181/181 [00:54<00:00, 4.69it/s, MSE=0.0515] 06:59:21 - INFO: Starting epoch 185: 100%|██████████| 181/181 [00:54<00:00, 4.67it/s, MSE=0.00296] 07:00:16 - INFO: Starting epoch 186: 100%|██████████| 181/181 [00:54<00:00, 4.76it/s, MSE=0.0878] 07:01:11 - INFO: Starting epoch 187: 100%|██████████| 181/181 [00:54<00:00, 4.73it/s, MSE=0.0574] 07:02:06 - INFO: Starting epoch 188: 100%|██████████| 181/181 [00:54<00:00, 4.75it/s, MSE=0.00468] 07:03:00 - INFO: Starting epoch 189: 100%|██████████| 181/181 [00:54<00:00, 4.61it/s, MSE=0.0289] 07:03:55 - INFO: Starting epoch 190: 100%|██████████| 181/181 [00:54<00:00, 4.73it/s, MSE=0.0167] 07:04:50 - INFO: Starting epoch 191: 100%|██████████| 181/181 [00:54<00:00, 4.65it/s, MSE=0.0505] 07:05:45 - INFO: Starting epoch 192: 100%|██████████| 181/181 [00:54<00:00, 4.78it/s, MSE=0.00374] 07:06:39 - INFO: Starting epoch 193: 100%|██████████| 181/181 [00:54<00:00, 4.67it/s, MSE=0.0176] 07:07:34 - INFO: Starting epoch 194: 100%|██████████| 181/181 [00:54<00:00, 4.76it/s, MSE=0.0161] 07:08:29 - INFO: Starting epoch 195: 100%|██████████| 181/181 [00:54<00:00, 4.63it/s, MSE=0.0161] 07:09:24 - INFO: Starting epoch 196: 100%|██████████| 181/181 [00:54<00:00, 4.66it/s, MSE=0.0358] 07:10:18 - INFO: Starting epoch 197: 100%|██████████| 181/181 [00:54<00:00, 4.67it/s, MSE=0.0694] 07:11:13 - INFO: Starting epoch 198: 100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.107] 07:12:09 - INFO: Starting epoch 199: 100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.0383] 07:13:04 - INFO: Starting epoch 200: 100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0169] 07:13:59 - INFO: Starting epoch 201: 100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0855] 07:14:54 - INFO: Starting epoch 202: 100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.00749] 07:15:49 - INFO: Starting epoch 203: 100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.00324] 07:16:45 - INFO: Starting epoch 204: 100%|██████████| 181/181 [00:54<00:00, 4.74it/s, MSE=0.0965] 07:17:40 - INFO: Starting epoch 205: 100%|██████████| 181/181 [00:54<00:00, 4.70it/s, MSE=0.0277] 07:18:34 - INFO: Starting epoch 206: 100%|██████████| 181/181 [00:54<00:00, 4.71it/s, MSE=0.0146] 07:19:29 - INFO: Starting epoch 207: 100%|██████████| 181/181 [00:54<00:00, 4.69it/s, MSE=0.00659] 07:20:24 - INFO: Starting epoch 208: 100%|██████████| 181/181 [00:54<00:00, 4.73it/s, MSE=0.0176] 07:21:19 - INFO: Starting epoch 209: 100%|██████████| 181/181 [00:54<00:00, 4.65it/s, MSE=0.12] 07:22:14 - INFO: Starting epoch 210: 100%|██████████| 181/181 [00:54<00:00, 4.76it/s, MSE=0.0688] 07:23:10 - INFO: Sampling 5 new images.... 499it [00:43, 11.17it/s] 07:23:53 - INFO: Sampling 5 new images.... 499it [00:43, 11.55it/s]
<Figure size 640x480 with 10 Axes>
07:24:37 - INFO: Starting epoch 211: 100%|██████████| 181/181 [00:54<00:00, 4.72it/s, MSE=0.00553] 07:25:31 - INFO: Starting epoch 212: 100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.0851] 07:26:26 - INFO: Starting epoch 213: 100%|██████████| 181/181 [00:54<00:00, 4.69it/s, MSE=0.0147] 07:27:21 - INFO: Starting epoch 214: 100%|██████████| 181/181 [00:55<00:00, 4.75it/s, MSE=0.0669] 07:28:16 - INFO: Starting epoch 215: 100%|██████████| 181/181 [00:54<00:00, 4.69it/s, MSE=0.00531] 07:29:11 - INFO: Starting epoch 216: 100%|██████████| 181/181 [00:54<00:00, 4.63it/s, MSE=0.0315] 07:30:06 - INFO: Starting epoch 217: 100%|██████████| 181/181 [00:54<00:00, 4.76it/s, MSE=0.147] 07:31:01 - INFO: Starting epoch 218: 100%|██████████| 181/181 [00:54<00:00, 4.63it/s, MSE=0.0547] 07:31:56 - INFO: Starting epoch 219: 100%|██████████| 181/181 [00:54<00:00, 4.74it/s, MSE=0.036] 07:32:50 - INFO: Starting epoch 220: 100%|██████████| 181/181 [00:54<00:00, 4.67it/s, MSE=0.00479] 07:33:45 - INFO: Starting epoch 221: 100%|██████████| 181/181 [00:55<00:00, 4.71it/s, MSE=0.0225] 07:34:40 - INFO: Starting epoch 222: 100%|██████████| 181/181 [00:54<00:00, 4.67it/s, MSE=0.0192] 07:35:35 - INFO: Starting epoch 223: 100%|██████████| 181/181 [00:54<00:00, 4.69it/s, MSE=0.00701] 07:36:30 - INFO: Starting epoch 224: 100%|██████████| 181/181 [00:54<00:00, 4.56it/s, MSE=0.036] 07:37:25 - INFO: Starting epoch 225: 100%|██████████| 181/181 [00:54<00:00, 4.79it/s, MSE=0.0908] 07:38:19 - INFO: Starting epoch 226: 100%|██████████| 181/181 [00:55<00:00, 4.72it/s, MSE=0.00345] 07:39:14 - INFO: Starting epoch 227: 100%|██████████| 181/181 [00:54<00:00, 4.73it/s, MSE=0.0657] 07:40:09 - INFO: Starting epoch 228: 100%|██████████| 181/181 [00:54<00:00, 4.69it/s, MSE=0.0841] 07:41:04 - INFO: Starting epoch 229: 100%|██████████| 181/181 [00:54<00:00, 4.72it/s, MSE=0.00407] 07:41:59 - INFO: Starting epoch 230: 100%|██████████| 181/181 [00:54<00:00, 4.66it/s, MSE=0.0249] 07:42:54 - INFO: Starting epoch 231: 100%|██████████| 181/181 [00:54<00:00, 4.71it/s, MSE=0.0563] 07:43:48 - INFO: Starting epoch 232: 100%|██████████| 181/181 [00:54<00:00, 4.67it/s, MSE=0.052] 07:44:43 - INFO: Starting epoch 233: 100%|██████████| 181/181 [00:54<00:00, 4.74it/s, MSE=0.0698] 07:45:38 - INFO: Starting epoch 234: 100%|██████████| 181/181 [00:54<00:00, 4.65it/s, MSE=0.0553] 07:46:33 - INFO: Starting epoch 235: 100%|██████████| 181/181 [00:54<00:00, 4.72it/s, MSE=0.00646] 07:47:27 - INFO: Starting epoch 236: 100%|██████████| 181/181 [00:54<00:00, 4.68it/s, MSE=0.00258] 07:48:22 - INFO: Starting epoch 237: 100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0236] 07:49:17 - INFO: Starting epoch 238: 100%|██████████| 181/181 [00:54<00:00, 4.68it/s, MSE=0.0339] 07:50:12 - INFO: Starting epoch 239: 100%|██████████| 181/181 [00:54<00:00, 4.65it/s, MSE=0.00555] 07:51:07 - INFO: Starting epoch 240: 100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0571] 07:52:03 - INFO: Sampling 5 new images.... 499it [00:42, 11.79it/s] 07:52:45 - INFO: Sampling 5 new images.... 499it [00:42, 11.62it/s]
<Figure size 640x480 with 10 Axes>
07:53:28 - INFO: Starting epoch 241: 100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.00577] 07:54:24 - INFO: Starting epoch 242: 100%|██████████| 181/181 [00:54<00:00, 4.69it/s, MSE=0.0155] 07:55:18 - INFO: Starting epoch 243: 100%|██████████| 181/181 [00:54<00:00, 4.64it/s, MSE=0.0322] 07:56:13 - INFO: Starting epoch 244: 100%|██████████| 181/181 [00:55<00:00, 4.76it/s, MSE=0.00787] 07:57:08 - INFO: Starting epoch 245: 100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.116] 07:58:03 - INFO: Starting epoch 246: 100%|██████████| 181/181 [00:54<00:00, 4.70it/s, MSE=0.0187] 07:58:58 - INFO: Starting epoch 247: 100%|██████████| 181/181 [00:54<00:00, 4.70it/s, MSE=0.059] 07:59:52 - INFO: Starting epoch 248: 100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0248] 08:00:48 - INFO: Starting epoch 249: 100%|██████████| 181/181 [00:54<00:00, 4.60it/s, MSE=0.0254] 08:01:42 - INFO: Starting epoch 250: 100%|██████████| 181/181 [00:55<00:00, 4.70it/s, MSE=0.133] 08:02:38 - INFO: Starting epoch 251: 100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0752] 08:03:33 - INFO: Starting epoch 252: 100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.00802] 08:04:28 - INFO: Starting epoch 253: 100%|██████████| 181/181 [00:54<00:00, 4.66it/s, MSE=0.254] 08:05:23 - INFO: Starting epoch 254: 100%|██████████| 181/181 [00:54<00:00, 4.60it/s, MSE=0.0261] 08:06:18 - INFO: Starting epoch 255: 100%|██████████| 181/181 [00:54<00:00, 4.62it/s, MSE=0.0514] 08:07:13 - INFO: Starting epoch 256: 100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.00751] 08:08:08 - INFO: Starting epoch 257: 100%|██████████| 181/181 [00:55<00:00, 4.70it/s, MSE=0.0209] 08:09:03 - INFO: Starting epoch 258: 100%|██████████| 181/181 [00:54<00:00, 4.63it/s, MSE=0.0484] 08:09:58 - INFO: Starting epoch 259: 100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.0255] 08:10:53 - INFO: Starting epoch 260: 100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.00507] 08:11:49 - INFO: Starting epoch 261: 100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0218] 08:12:45 - INFO: Starting epoch 262: 100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0203] 08:13:40 - INFO: Starting epoch 263: 100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.036] 08:14:36 - INFO: Starting epoch 264: 100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.0266] 08:15:32 - INFO: Starting epoch 265: 100%|██████████| 181/181 [00:55<00:00, 4.62it/s, MSE=0.0145] 08:16:27 - INFO: Starting epoch 266: 100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.00483] 08:17:23 - INFO: Starting epoch 267: 100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0604] 08:18:18 - INFO: Starting epoch 268: 100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.0466] 08:19:14 - INFO: Starting epoch 269: 100%|██████████| 181/181 [00:56<00:00, 4.59it/s, MSE=0.00358] 08:20:10 - INFO: Starting epoch 270: 100%|██████████| 181/181 [00:55<00:00, 4.62it/s, MSE=0.0104] 08:21:06 - INFO: Sampling 5 new images.... 499it [00:45, 11.06it/s] 08:21:51 - INFO: Sampling 5 new images.... 499it [00:44, 11.33it/s]
<Figure size 640x480 with 10 Axes>
08:22:36 - INFO: Starting epoch 271: 100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0111] 08:23:31 - INFO: Starting epoch 272: 100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.0474] 08:24:26 - INFO: Starting epoch 273: 100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.106] 08:25:22 - INFO: Starting epoch 274: 100%|██████████| 181/181 [00:55<00:00, 4.63it/s, MSE=0.00758] 08:26:18 - INFO: Starting epoch 275: 100%|██████████| 181/181 [00:55<00:00, 4.62it/s, MSE=0.00715] 08:27:13 - INFO: Starting epoch 276: 100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0412] 08:28:08 - INFO: Starting epoch 277: 100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.00941] 08:29:04 - INFO: Starting epoch 278: 100%|██████████| 181/181 [00:55<00:00, 4.67it/s, MSE=0.0251] 08:29:59 - INFO: Starting epoch 279: 100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0385] 08:30:55 - INFO: Starting epoch 280: 100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.0308] 08:31:50 - INFO: Starting epoch 281: 100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.108] 08:32:45 - INFO: Starting epoch 282: 100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.049] 08:33:41 - INFO: Starting epoch 283: 100%|██████████| 181/181 [00:55<00:00, 4.70it/s, MSE=0.0046] 08:34:37 - INFO: Starting epoch 284: 100%|██████████| 181/181 [00:55<00:00, 4.56it/s, MSE=0.00371] 08:35:32 - INFO: Starting epoch 285: 100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.00421] 08:36:27 - INFO: Starting epoch 286: 100%|██████████| 181/181 [00:54<00:00, 4.72it/s, MSE=0.00429] 08:37:22 - INFO: Starting epoch 287: 100%|██████████| 181/181 [00:55<00:00, 4.67it/s, MSE=0.0258] 08:38:17 - INFO: Starting epoch 288: 100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.0109] 08:39:13 - INFO: Starting epoch 289: 100%|██████████| 181/181 [00:55<00:00, 4.63it/s, MSE=0.152] 08:40:08 - INFO: Starting epoch 290: 100%|██████████| 181/181 [00:55<00:00, 4.56it/s, MSE=0.0362] 08:41:04 - INFO: Starting epoch 291: 100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0161] 08:41:59 - INFO: Starting epoch 292: 100%|██████████| 181/181 [00:56<00:00, 4.52it/s, MSE=0.167] 08:42:56 - INFO: Starting epoch 293: 100%|██████████| 181/181 [00:56<00:00, 4.46it/s, MSE=0.00373] 08:43:52 - INFO: Starting epoch 294: 100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0298] 08:44:48 - INFO: Starting epoch 295: 100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.0306] 08:45:43 - INFO: Starting epoch 296: 100%|██████████| 181/181 [00:56<00:00, 4.52it/s, MSE=0.0111] 08:46:40 - INFO: Starting epoch 297: 100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0232] 08:47:36 - INFO: Starting epoch 298: 100%|██████████| 181/181 [00:56<00:00, 4.03it/s, MSE=0.0217] 08:48:32 - INFO: Starting epoch 299: 100%|██████████| 181/181 [00:56<00:00, 4.63it/s, MSE=0.05]
3.5、使用训练好的模型来采样各种花朵
In [12]
import paddle
model = UNet_conditional(num_classes=5)
model.set_state_dict(paddle.load("models/ddpm_cond270.pdparams")) # 加载模型文件
diffusion = Diffusion(img_size=64, device="cuda")# 向日葵,玫瑰,郁金香,蒲公英,雏菊分别对应标签0,1,2,3,4
labels = paddle.to_tensor([0, 0, 0, 0, 0]).astype("int64")
# 标签引导强度
cfg_scale = 7
sampled_images = diffusion.sample(model, n=len(labels), labels=labels, cfg_scale=cfg_scale)
for i in range(5):img = sampled_images[i].transpose([1, 2, 0])img = np.array(img).astype("uint8")plt.subplot(1,5,i+1)plt.imshow(img)
plt.show()
09:01:17 - INFO: Sampling 5 new images.... 499it [00:43, 11.60it/s]
<Figure size 640x480 with 5 Axes>
4、总结
-
推理出了diffusion模型的损失函数,从最小化对数似然,到优化变分下界,简化变分下界,得到最后目标,预测噪声。
-
提供了两版代码,其中条件生成与时下最火的text2image原理类似,只是text2image不仅仅使用单一类别作为编码。参考novelai。
-
作为新一代生成模型,diffusion训练的过程可谓是十分的稳定,调参也比GAN相对简单不少!
-
想要更好结果,我们只需要加大T,加大epoch即可。