摘要:
记录昇思MindSpore AI框架使用Pix2Pix模型生成图像、判断图像真实概率的原理和实际使用方法、步骤。包括环境准备、下载数据集、数据加载和处理、创建cGAN神经网络生成器和判别器、模型训练、模型推理等。
一、概念
1.Pix2Pix模型
条件生成对抗网络(cGAN, Condition Generative Adversarial Networks )
深度学习图像转换模型
功能
转换语义/标签到真实图片
转换灰度图到彩色图
转换航空图到地图
转换白天到黑夜
转换线稿图到实物图
cGAN应用于有监督的图像到图像翻译的经典之作
两个模型
生成器
判别器
不同数据来训练
2.基础原理
cGAN生成器
将输入图片作为指导信息
不断输入图像生成模拟图像
Pix2Pix判别器
判断从生成器输出图像的真实性
博弈平衡点
生成器输出图像与真实训练数据使得判别器刚好具有50%的概率判断正确。
定义符号:
x 观测图像数据
z 随机噪声数据
y=G(x,z) 生成器网络
观测图像x + 随机噪声y --> 模拟图片
x ϵ 训练数据
D(x,G(x,z)) 判别器网络
判定生成图像的真实概率
x ϵ 训练数据
G(x,z) ϵ 生成器。
cGAN损失函数:
logD(x,y) 判别器参数最大化
log(1-D(G(x,z))) 生成器参数最小化
cGAN的目标:

图像转换问题本质上就是像素到像素的映射问题
二、环境准备
%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore 
 
输出:
Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by:  
1.下载数据
数据集
外墙(facades)数据
mindspore.dataset方法读取
下载数据集
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/dataset_pix2pix.tar"
download(url, "./dataset", kind="tar", replace=True) 
输出:
Creating data folder...
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/dataset_pix2pix.tar (840.0 MB)file_sizes: 100%|█████████████████████████████| 881M/881M [00:04<00:00, 197MB/s]
Extracting tar file...
Successfully downloaded / unzipped to ./dataset
'./dataset' 
2.数据展示
调用Pix2PixDataset和create_train_dataset读取训练集
from mindspore import dataset as ds
import matplotlib.pyplot as plt
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator(output_numpy=True))
# 可视化部分训练数据
plt.figure(figsize=(10, 3), dpi=140)
for i, image in enumerate(data_iter['input_images'][:10], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow((image.transpose(1, 2, 0) + 1) / 2)
plt.show() 
输出:

三、创建网络
生成器G
U-Net结构
输入轮廓图x编码再解码成真实图片
判别器D
条件判别器PatchGAN
判断由轮廓图x生成图片G(x)的真伪性
损失函数。
1.生成器G结构
U-Net是德国Freiburg大学模式识别和图像处理组
全卷积结构
两个部分
左侧压缩
卷积
降采样
右侧扩张路径
卷积
上采样
每个网络块的输入由上一层上采样的特征和压缩路径部分的特征拼接而成
U-Net网络模型整体是一个U形的结构
和先降采样到低维度,再升采样到原始分辨率的编解码结构的网络相比
U-Net
skip-connection
对应feature maps
decode后的feature maps
按通道拼一起
保留不同分辨率下像素级的细节信息

定义UNet Skip Connection Block
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
class UNetSkipConnectionBlock(nn.Cell):def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False,submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):super(UNetSkipConnectionBlock, self).__init__()down_norm = nn.BatchNorm2d(inner_nc)up_norm = nn.BatchNorm2d(outer_nc)use_bias = Falseif norm_mode == 'instance':down_norm = nn.BatchNorm2d(inner_nc, affine=False)up_norm = nn.BatchNorm2d(outer_nc, affine=False)use_bias = Trueif in_planes is None:in_planes = outer_ncdown_conv = nn.Conv2d(in_planes, inner_nc, kernel_size=4,stride=2, padding=1, has_bias=use_bias, pad_mode='pad')down_relu = nn.LeakyReLU(alpha)up_relu = nn.ReLU()if outermost:up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,kernel_size=4, stride=2,padding=1, pad_mode='pad')down = [down_conv]up = [up_relu, up_conv, nn.Tanh()]model = down + [submodule] + upelif innermost:up_conv = nn.Conv2dTranspose(inner_nc, outer_nc,kernel_size=4, stride=2,padding=1, has_bias=use_bias, pad_mode='pad')down = [down_relu, down_conv]up = [up_relu, up_conv, up_norm]model = down + upelse:up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,kernel_size=4, stride=2,padding=1, has_bias=use_bias, pad_mode='pad')down = [down_relu, down_conv, down_norm]up = [up_relu, up_conv, up_norm]
model = down + [submodule] + upif dropout:model.append(nn.Dropout(p=0.5))self.model = nn.SequentialCell(model)self.skip_connections = not outermost
def construct(self, x):out = self.model(x)if self.skip_connections:out = ops.concat((out, x), axis=1)return out 
2.基于UNet的生成器
class UNetGenerator(nn.Cell):def __init__(self, in_planes, out_planes, ngf=64, n_layers=8, norm_mode='bn', dropout=False):super(UNetGenerator, self).__init__()unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None,norm_mode=norm_mode, innermost=True)for _ in range(n_layers - 5):unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block,norm_mode=norm_mode, dropout=dropout)unet_block = UNetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block,norm_mode=norm_mode)unet_block = UNetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block,norm_mode=norm_mode)unet_block = UNetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block,norm_mode=norm_mode)self.model = UNetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block,outermost=True, norm_mode=norm_mode)
def construct(self, x):return self.model(x) 
Pix2Pix训练和测试都使用dropout
生成多样性的结果
3.基于PatchGAN的判别器
PatchGAN结构
卷积
生成矩阵中的每个点代表原图的一小块区域(patch)
通过矩阵中的各个值判断原图中对应每个Patch的真假
import mindspore.nn as nn
class ConvNormRelu(nn.Cell):def __init__(self,in_planes,out_planes,kernel_size=4,stride=2,alpha=0.2,norm_mode='batch',pad_mode='CONSTANT',use_relu=True,padding=None):super(ConvNormRelu, self).__init__()norm = nn.BatchNorm2d(out_planes)if norm_mode == 'instance':norm = nn.BatchNorm2d(out_planes, affine=False)has_bias = (norm_mode == 'instance')if not padding:padding = (kernel_size - 1) // 2if pad_mode == 'CONSTANT':conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, padding=padding)layers = [conv, norm]else:paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))pad = nn.Pad(paddings=paddings, mode=pad_mode)conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)layers = [pad, conv, norm]if use_relu:relu = nn.ReLU()if alpha > 0:relu = nn.LeakyReLU(alpha)layers.append(relu)self.features = nn.SequentialCell(layers)
def construct(self, x):output = self.features(x)return output
class Discriminator(nn.Cell):def __init__(self, in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):super(Discriminator, self).__init__()kernel_size = 4layers = [nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),nn.LeakyReLU(alpha)]nf_mult = ndffor i in range(1, n_layers):nf_mult_prev = nf_multnf_mult = min(2 ** i, 8) * ndflayers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))nf_mult_prev = nf_multnf_mult = min(2 ** n_layers, 8) * ndflayers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))self.features = nn.SequentialCell(layers)
def construct(self, x, y):x_y = ops.concat((x, y), axis=1)output = self.features(x_y)return output 
4.Pix2Pix的生成器和判别器初始化
实例化Pix2Pix生成器和判别器。
import mindspore.nn as nn
from mindspore.common import initializer as init
g_in_planes = 3
g_out_planes = 3
g_ngf = 64
g_layers = 8
d_in_planes = 6
d_ndf = 64
d_layers = 3
alpha = 0.2
init_gain = 0.02
init_type = 'normal'
net_generator = UNetGenerator(in_planes=g_in_planes, out_planes=g_out_planes,ngf=g_ngf, n_layers=g_layers)
for _, cell in net_generator.cells_and_names():if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):if init_type == 'normal':cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))elif init_type == 'xavier':cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))elif init_type == 'constant':cell.weight.set_data(init.initializer(0.001, cell.weight.shape))else:raise NotImplementedError('initialization method [%s] is not implemented' % init_type)elif isinstance(cell, nn.BatchNorm2d):cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
net_discriminator = Discriminator(in_planes=d_in_planes, ndf=d_ndf,alpha=alpha, n_layers=d_layers)
for _, cell in net_discriminator.cells_and_names():if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):if init_type == 'normal':cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))elif init_type == 'xavier':cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))elif init_type == 'constant':cell.weight.set_data(init.initializer(0.001, cell.weight.shape))else:raise NotImplementedError('initialization method [%s] is not implemented' % init_type)elif isinstance(cell, nn.BatchNorm2d):cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
class Pix2Pix(nn.Cell):"""Pix2Pix模型网络"""def __init__(self, discriminator, generator):super(Pix2Pix, self).__init__(auto_prefix=True)self.net_discriminator = discriminatorself.net_generator = generator
def construct(self, reala):fakeb = self.net_generator(reala)return fakeb 
四、训练
训练判别器
提高判别图像真伪的概率
训练生成器
产生更好的模拟图像
分别获取训练损失
每个周期结束时统计
import numpy as np
import os
import datetime
from mindspore import value_and_grad, Tensor
epoch_num = 3
ckpt_dir = "results/ckpt"
dataset_size = 400
val_pic_size = 256
lr = 0.0002
n_epochs = 100
n_epochs_decay = 100
def get_lr():lrs = [lr] * dataset_size * n_epochslr_epoch = 0for epoch in range(n_epochs_decay):lr_epoch = lr * (n_epochs_decay - epoch) / n_epochs_decaylrs += [lr_epoch] * dataset_sizelrs += [lr_epoch] * dataset_size * (epoch_num - n_epochs_decay - n_epochs)return Tensor(np.array(lrs).astype(np.float32))
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True, num_parallel_workers=1)
steps_per_epoch = dataset.get_dataset_size()
loss_f = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()
def forword_dis(reala, realb):lambda_dis = 0.5fakeb = net_generator(reala)pred0 = net_discriminator(reala, fakeb)pred1 = net_discriminator(reala, realb)loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))loss_dis = loss_d * lambda_disreturn loss_dis
def forword_gan(reala, realb):lambda_gan = 0.5lambda_l1 = 100fakeb = net_generator(reala)pred0 = net_discriminator(reala, fakeb)loss_1 = loss_f(pred0, ops.ones_like(pred0))loss_2 = l1_loss(fakeb, realb)loss_gan = loss_1 * lambda_gan + loss_2 * lambda_l1return loss_gan
d_opt = nn.Adam(net_discriminator.trainable_params(), learning_rate=get_lr(),beta1=0.5, beta2=0.999, loss_scale=1)
g_opt = nn.Adam(net_generator.trainable_params(), learning_rate=get_lr(),beta1=0.5, beta2=0.999, loss_scale=1)
grad_d = value_and_grad(forword_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forword_gan, None, net_generator.trainable_params())
def train_step(reala, realb):loss_dis, d_grads = grad_d(reala, realb)loss_gan, g_grads = grad_g(reala, realb)d_opt(d_grads)g_opt(g_grads)return loss_dis, loss_gan
if not os.path.isdir(ckpt_dir):os.makedirs(ckpt_dir)
g_losses = []
d_losses = []
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs=epoch_num)
for epoch in range(epoch_num):for i, data in enumerate(data_loader):start_time = datetime.datetime.now()input_image = Tensor(data["input_images"])target_image = Tensor(data["target_images"])dis_loss, gen_loss = train_step(input_image, target_image)end_time = datetime.datetime.now()delta = (end_time - start_time).microsecondsif i % 2 == 0:print("ms per step:{:.2f}  epoch:{}/{}  step:{}/{}  Dloss:{:.4f}  Gloss:{:.4f} ".format((delta / 1000), (epoch + 1), (epoch_num), i, steps_per_epoch, float(dis_loss), float(gen_loss)))d_losses.append(dis_loss.asnumpy())g_losses.append(gen_loss.asnumpy())if (epoch + 1) == epoch_num:mindspore.save_checkpoint(net_generator, ckpt_dir + "Generator.ckpt") 
输出:
ms per step:500.71  epoch:1/3  step:0/25  Dloss:0.6924  Gloss:38.2835 
ms per step:112.14  epoch:1/3  step:2/25  Dloss:0.6490  Gloss:33.7575 
ms per step:105.57  epoch:1/3  step:4/25  Dloss:0.5474  Gloss:35.7007 
ms per step:104.50  epoch:1/3  step:6/25  Dloss:0.6045  Gloss:38.9824 
ms per step:105.54  epoch:1/3  step:8/25  Dloss:0.2939  Gloss:37.5004 
ms per step:109.78  epoch:1/3  step:10/25  Dloss:0.2635  Gloss:37.8297 
ms per step:109.31  epoch:1/3  step:12/25  Dloss:0.4991  Gloss:36.2161 
ms per step:109.15  epoch:1/3  step:14/25  Dloss:0.2570  Gloss:36.8445 
ms per step:109.72  epoch:1/3  step:16/25  Dloss:0.2443  Gloss:37.3726 
ms per step:108.75  epoch:1/3  step:18/25  Dloss:0.3285  Gloss:36.3953 
ms per step:105.53  epoch:1/3  step:20/25  Dloss:0.4726  Gloss:37.0197 
ms per step:106.82  epoch:1/3  step:22/25  Dloss:0.2093  Gloss:39.1963 
ms per step:106.30  epoch:1/3  step:24/25  Dloss:0.2402  Gloss:38.0046 
ms per step:103.31  epoch:2/3  step:0/25  Dloss:0.3002  Gloss:31.6780 
ms per step:105.05  epoch:2/3  step:2/25  Dloss:0.3453  Gloss:34.5222 
ms per step:101.52  epoch:2/3  step:4/25  Dloss:0.1365  Gloss:36.2112 
ms per step:106.16  epoch:2/3  step:6/25  Dloss:0.2867  Gloss:36.5928 
ms per step:102.05  epoch:2/3  step:8/25  Dloss:0.2066  Gloss:35.2446 
ms per step:106.42  epoch:2/3  step:10/25  Dloss:0.7759  Gloss:39.4841 
ms per step:106.11  epoch:2/3  step:12/25  Dloss:0.4025  Gloss:33.4352 
ms per step:102.23  epoch:2/3  step:14/25  Dloss:0.3659  Gloss:31.1093 
ms per step:106.73  epoch:2/3  step:16/25  Dloss:0.2157  Gloss:38.9941 
ms per step:108.85  epoch:2/3  step:18/25  Dloss:0.3607  Gloss:37.1816 
ms per step:105.55  epoch:2/3  step:20/25  Dloss:0.1683  Gloss:30.8941 
ms per step:105.53  epoch:2/3  step:22/25  Dloss:1.3473  Gloss:35.3493 
ms per step:105.66  epoch:2/3  step:24/25  Dloss:0.5771  Gloss:36.7098 
ms per step:105.42  epoch:3/3  step:0/25  Dloss:0.6560  Gloss:36.6383 
ms per step:106.17  epoch:3/3  step:2/25  Dloss:0.3694  Gloss:36.5227 
ms per step:105.67  epoch:3/3  step:4/25  Dloss:0.3402  Gloss:34.6686 
ms per step:106.52  epoch:3/3  step:6/25  Dloss:0.3173  Gloss:29.8994 
ms per step:105.17  epoch:3/3  step:8/25  Dloss:0.2101  Gloss:34.6112 
ms per step:104.83  epoch:3/3  step:10/25  Dloss:0.5902  Gloss:36.7074 
ms per step:105.24  epoch:3/3  step:12/25  Dloss:0.4989  Gloss:35.7287 
ms per step:106.48  epoch:3/3  step:14/25  Dloss:0.4831  Gloss:31.4974 
ms per step:100.69  epoch:3/3  step:16/25  Dloss:0.2834  Gloss:39.2598 
ms per step:107.46  epoch:3/3  step:18/25  Dloss:0.2820  Gloss:36.3580 
ms per step:101.32  epoch:3/3  step:20/25  Dloss:0.2407  Gloss:36.4354 
ms per step:101.03  epoch:3/3  step:22/25  Dloss:0.4778  Gloss:35.9755 
ms per step:102.05  epoch:3/3  step:24/25  Dloss:0.6571  Gloss:35.1790 
| 
五、推理
获取训练完成后的ckpt文件
导入ckpt中的权重参数到模型
load_checkpoint
load_param_into_net
获取数据
推理
演示效果图
from mindspore import load_checkpoint, load_param_into_net
param_g = load_checkpoint(ckpt_dir + "Generator.ckpt")
load_param_into_net(net_generator, param_g)
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator())
predict_show = net_generator(data_iter["input_images"])
plt.figure(figsize=(10, 3), dpi=140)
for i in range(10):plt.subplot(2, 10, i + 1)plt.imshow((data_iter["input_images"][i].asnumpy().transpose(1, 2, 0) + 1) / 2)plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0.02)plt.subplot(2, 10, i + 11)plt.imshow((predict_show[i].asnumpy().transpose(1, 2, 0) + 1) / 2)plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show() 
输出:

各数据集分别推理的效果如下
