相关知识
图像翻译
图像翻译Image translation是一种计算机视觉任务,旨在将一种图像转换为另一种图像。典型的任务包括:图像到图像的转换如白天到黑夜,风格迁移,图像修复。
CGAN
CGAN在GAN的基础上引入了条件信息,使生成器不仅根据随机噪声生成数据,还可以通过给定的条件信息生成相应的数据。具体的,CGAN的生成器除了接受随机噪声向量之外,还要接受一个条件向量,这个条件向量可以是任何附加信息如类别标签、文本描述等。判别器的任务除了要区分是否是真实的数据,还要验证数据与条件信息的一致性。
- x x x:代表观测图像的数据。
- z z z:代表随机噪声的数据。
- y = G ( x , z ) y=G(x,z) y=G(x,z):生成器网络,给出由观测图像 x x x与随机噪声 z z z生成的“假”图片,其中 x x x来自于训练数据而非生成器。
- D ( x , G ( x , z ) ) D(x,G(x,z)) D(x,G(x,z)):判别器网络,给出图像判定为真实图像的概率,其中 x x x来自于训练数据, G ( x , z ) G(x,z) G(x,z)来自于生成器。
cGAN的目标可以表示为:
L c G A N ( G , D ) = E ( x , y ) [ l o g ( D ( x , y ) ) ] + E ( x , z ) [ l o g ( 1 − D ( x , G ( x , z ) ) ) ] L_{cGAN}(G,D)=E_{(x,y)}[log(D(x,y))]+E_{(x,z)}[log(1-D(x,G(x,z)))] LcGAN(G,D)=E(x,y)[log(D(x,y))]+E(x,z)[log(1−D(x,G(x,z)))]
该公式是cGAN的损失函数,D
想要尽最大努力去正确分类真实图像与“假”图像,也就是使参数 l o g D ( x , y ) log D(x,y) logD(x,y)最大化;而G
则尽最大努力用生成的“假”图像 y y y欺骗D
,避免被识破,也就是使参数 l o g ( 1 − D ( G ( x , z ) ) ) log(1−D(G(x,z))) log(1−D(G(x,z)))最小化。cGAN的目标可简化为:
a r g min G max D L c G A N ( G , D ) arg\min_{G}\max_{D}L_{cGAN}(G,D) argGminDmaxLcGAN(G,D)
Pix2Pix
Pix2Pix是基于CGAN实现的一种深度学习图像转换模型。它可以实现语义/标签到真实图片、灰度图到彩色图、航空图到地图、白天到黑天、线稿图到实物图的转换。图像转换问题本质上就是像素到像素的映射问题。
实现
网络构建
生成器G用到的是U-Net结构,输入的轮廓图 x x x编码再解码成真是图片,判别器D用到的是作者自己提出来的条件判别器PatchGAN,判别器D的作用是在轮廓图 x x x的条件下,对于生成的图片 G ( x ) G(x) G(x)判断为假,对于真实判断为真。
生成器结构
U-Net是全卷积结构,分成两个部分:左侧是由卷积和降采样操作组成的压缩路径,右侧是由卷积和上采样组成的扩张路径。扩张的每个网络块的输入由上一层上采样的特征和压缩路径部分的特征拼接而成。网络模型整体是一个U形的结构。和常见的先降采样到低维度,再升采样到原始分辨率的编解码结构的网络相比,U-Net的区别是加入skip-connection 跳跃连接,对应的feature maps和decode之后的同样大小的feature maps按通道拼一起,用来保留不同分辨率下像素级的细节信息。
以下先定义跳跃连接块,详细内容见注释。
class UNetSkipConnectionBlock(nn.Cell):def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False,submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):
# 参数:外层的输出通道数、内层的输出通道数、输入的通道数、 是否dropout、U-net中嵌套的子模块、是否是最外侧模块、是否是最内侧模块、leakyReLU激活函数的负斜率、归一化模式 super(UNetSkipConnectionBlock, self).__init__()down_norm = nn.BatchNorm2d(inner_nc)up_norm = nn.BatchNorm2d(outer_nc)use_bias = Falseif norm_mode == 'instance':down_norm = nn.BatchNorm2d(inner_nc, affine=False)up_norm = nn.BatchNorm2d(outer_nc, affine=False)use_bias = Trueif in_planes is None:in_planes = outer_nc
# 定义上下采样卷积层down_conv = nn.Conv2d(in_planes, inner_nc, kernel_size=4,stride=2, padding=1, has_bias=use_bias, pad_mode='pad')down_relu = nn.LeakyReLU(alpha)up_relu = nn.ReLU()
# 最外层的模块定义
# 包含下采样层、上采样卷积层、tanh激活函数
# 中间包括子模块if outermost:up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,kernel_size=4, stride=2,padding=1, pad_mode='pad')down = [down_conv]up = [up_relu, up_conv, nn.Tanh()]model = down + [submodule] + up
# 最内层
# 包含下采样卷积层、上采样层
# 无子模块elif innermost:up_conv = nn.Conv2dTranspose(inner_nc, outer_nc,kernel_size=4, stride=2,padding=1, has_bias=use_bias, pad_mode='pad')down = [down_relu, down_conv]up = [up_relu, up_conv, up_norm]model = down + up
# 中间层
# 包含下采样层、上采样层、归一化层
# 包含子模块和drop层else:up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,kernel_size=4, stride=2,padding=1, has_bias=use_bias, pad_mode='pad')down = [down_relu, down_conv, down_norm]up = [up_relu, up_conv, up_norm]model = down + [submodule] + upif dropout:model.append(nn.Dropout(p=0.5))self.model = nn.SequentialCell(model)self.skip_connections = not outermost
# 层与层之间的连接构建def construct(self, x):out = self.model(x)if self.skip_connections:out = ops.concat((out, x), axis=1)return out
以下进行生成器定义,这里的生成器只使用了条件信息,因此不能生成多样的结果。此时使用dropout可以提升结果的多样性。
class UNetGenerator(nn.Cell):def __init__(self, in_planes, out_planes, ngf=64, n_layers=8, norm_mode='bn', dropout=False):super(UNetGenerator, self).__init__()
# 构建最内层的模块unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None,norm_mode=norm_mode, innermost=True)
# 构建多个中间层模块for _ in range(n_layers - 5):unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block,norm_mode=norm_mode, dropout=dropout)
# 逐渐构建向外层的模块
# 每个模块的输入和输出通道数逐步减少
# 最终到达最外层unet_block = UNetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block,norm_mode=norm_mode)unet_block = UNetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block,norm_mode=norm_mode)unet_block = UNetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block,norm_mode=norm_mode)
# 构建最外层的模块self.model = UNetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block,outermost=True, norm_mode=norm_mode)def construct(self, x):return self.model(x)
判别器
判别器使用PatchGAN结构,生成的矩阵中的每一点代表原图的一小块区域,也就是一个patch。通过矩阵中的每个值来判断原图中patch的真假。
定义组合模块
import mindspore.nn as nnclass ConvNormRelu(nn.Cell):def __init__(self,in_planes,out_planes,kernel_size=4,stride=2,alpha=0.2,norm_mode='batch',pad_mode='CONSTANT',use_relu=True,padding=None):super(ConvNormRelu, self).__init__()
# 以下逐渐创建卷积层、归一化层和激活层norm = nn.BatchNorm2d(out_planes)if norm_mode == 'instance':norm = nn.BatchNorm2d(out_planes, affine=False)has_bias = (norm_mode == 'instance')if not padding:padding = (kernel_size - 1) // 2if pad_mode == 'CONSTANT':conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, padding=padding)layers = [conv, norm]else:paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))pad = nn.Pad(paddings=paddings, mode=pad_mode)conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)layers = [pad, conv, norm]if use_relu:relu = nn.ReLU()if alpha > 0:relu = nn.LeakyReLU(alpha)layers.append(relu)self.features = nn.SequentialCell(layers)def construct(self, x):output = self.features(x)return output
定义判别器
class Discriminator(nn.Cell):def __init__(self, in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):super(Discriminator, self).__init__()kernel_size = 4
# 初始层:卷积+leakyRelulayers = [nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),nn.LeakyReLU(alpha)]nf_mult = ndf
# 中间层for i in range(1, n_layers):nf_mult_prev = nf_multnf_mult = min(2 ** i, 8) * ndflayers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
# 最后一层nf_mult_prev = nf_multnf_mult = min(2 ** n_layers, 8) * ndflayers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))self.features = nn.SequentialCell(layers)def construct(self, x, y):x_y = ops.concat((x, y), axis=1)output = self.features(x_y)return output
初始化
实例化对应的生成器和判别器
g_in_planes = 3
g_out_planes = 3
g_ngf = 64
g_layers = 8
d_in_planes = 6
d_ndf = 64
d_layers = 3
alpha = 0.2
init_gain = 0.02
init_type = 'normal'net_generator = UNetGenerator(in_planes=g_in_planes, out_planes=g_out_planes,ngf=g_ngf, n_layers=g_layers)
for _, cell in net_generator.cells_and_names():if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):if init_type == 'normal':cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))elif init_type == 'xavier':cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))elif init_type == 'constant':cell.weight.set_data(init.initializer(0.001, cell.weight.shape))else:raise NotImplementedError('initialization method [%s] is not implemented' % init_type)elif isinstance(cell, nn.BatchNorm2d):cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))cell.beta.set_data(init.initializer('zeros', cell.beta.shape))net_discriminator = Discriminator(in_planes=d_in_planes, ndf=d_ndf,alpha=alpha, n_layers=d_layers)
for _, cell in net_discriminator.cells_and_names():if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):if init_type == 'normal':cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))elif init_type == 'xavier':cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))elif init_type == 'constant':cell.weight.set_data(init.initializer(0.001, cell.weight.shape))else:raise NotImplementedError('initialization method [%s] is not implemented' % init_type)elif isinstance(cell, nn.BatchNorm2d):cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))cell.beta.set_data(init.initializer('zeros', cell.beta.shape))class Pix2Pix(nn.Cell):"""Pix2Pix模型网络"""def __init__(self, discriminator, generator):super(Pix2Pix, self).__init__(auto_prefix=True)self.net_discriminator = discriminatorself.net_generator = generatordef construct(self, reala):fakeb = self.net_generator(reala)return fakeb
训练
分别训练判别器和生成器,以同时提高判别图像真伪的概率和更好的虚假图像生成效果。
epoch_num = 3
ckpt_dir = "results/ckpt"
dataset_size = 400
val_pic_size = 256
lr = 0.0002
n_epochs = 100
n_epochs_decay = 100def get_lr():lrs = [lr] * dataset_size * n_epochslr_epoch = 0for epoch in range(n_epochs_decay):lr_epoch = lr * (n_epochs_decay - epoch) / n_epochs_decaylrs += [lr_epoch] * dataset_sizelrs += [lr_epoch] * dataset_size * (epoch_num - n_epochs_decay - n_epochs)return Tensor(np.array(lrs).astype(np.float32))dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True, num_parallel_workers=1)
steps_per_epoch = dataset.get_dataset_size()
loss_f = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()def forword_dis(reala, realb):lambda_dis = 0.5fakeb = net_generator(reala)pred0 = net_discriminator(reala, fakeb)pred1 = net_discriminator(reala, realb)loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))loss_dis = loss_d * lambda_disreturn loss_disdef forword_gan(reala, realb):lambda_gan = 0.5lambda_l1 = 100fakeb = net_generator(reala)pred0 = net_discriminator(reala, fakeb)loss_1 = loss_f(pred0, ops.ones_like(pred0))loss_2 = l1_loss(fakeb, realb)loss_gan = loss_1 * lambda_gan + loss_2 * lambda_l1return loss_gand_opt = nn.Adam(net_discriminator.trainable_params(), learning_rate=get_lr(),beta1=0.5, beta2=0.999, loss_scale=1)
g_opt = nn.Adam(net_generator.trainable_params(), learning_rate=get_lr(),beta1=0.5, beta2=0.999, loss_scale=1)grad_d = value_and_grad(forword_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forword_gan, None, net_generator.trainable_params())def train_step(reala, realb):loss_dis, d_grads = grad_d(reala, realb)loss_gan, g_grads = grad_g(reala, realb)d_opt(d_grads)g_opt(g_grads)return loss_dis, loss_ganif not os.path.isdir(ckpt_dir):os.makedirs(ckpt_dir)g_losses = []
d_losses = []
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs=epoch_num)for epoch in range(epoch_num):for i, data in enumerate(data_loader):start_time = datetime.datetime.now()input_image = Tensor(data["input_images"])target_image = Tensor(data["target_images"])dis_loss, gen_loss = train_step(input_image, target_image)end_time = datetime.datetime.now()delta = (end_time - start_time).microsecondsif i % 2 == 0:print("ms per step:{:.2f} epoch:{}/{} step:{}/{} Dloss:{:.4f} Gloss:{:.4f} ".format((delta / 1000), (epoch + 1), (epoch_num), i, steps_per_epoch, float(dis_loss), float(gen_loss)))d_losses.append(dis_loss.asnumpy())g_losses.append(gen_loss.asnumpy())if (epoch + 1) == epoch_num:mindspore.save_checkpoint(net_generator, ckpt_dir + "Generator.ckpt")
推理
from mindspore import load_checkpoint, load_param_into_netparam_g = load_checkpoint(ckpt_dir + "Generator.ckpt")
load_param_into_net(net_generator, param_g)
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator())
predict_show = net_generator(data_iter["input_images"])
plt.figure(figsize=(10, 3), dpi=140)
for i in range(10):plt.subplot(2, 10, i + 1)plt.imshow((data_iter["input_images"][i].asnumpy().transpose(1, 2, 0) + 1) / 2)plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0.02)plt.subplot(2, 10, i + 11)plt.imshow((predict_show[i].asnumpy().transpose(1, 2, 0) + 1) / 2)plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()
最终的推理效果如下
总结
本章使用Pix2Pix完成了图像翻译任务。它基于cGAN,与传统GAN的不同是增加了条件信息。而在Pix2Pix中,指导信息为图片。而生成器和判别器分别采用U-Net和patchGan。