文章目录
- 模型介绍
- 网络结构
- 数据集可视化
- 网络的其他细节
- 模型推理
模型介绍
CycleGAN(Cycle Generative Adversarial Network) 即循环对抗生成网络,实现了一种在没有配对示例的情况下学习将图像从源域 X 转换到目标域 Y 的方法。
该模型一个重要应用领域是域迁移(Domain Adaptation),即图像风格迁移。在 CycleGAN 之前,就已经有了域迁移模型,比如 Pix2Pix昇思学习打卡-19-生成式/Pix2Pix实现图像转换 ,但是 Pix2Pix 要求训练数据必须是成对的,而现实生活中,要找到两个域(画风)中成对出现的图片是相当困难的,因此 CycleGAN 诞生了,它只需要两种域的数据,而不需要他们有严格对应关系,是一种新的无监督的图像迁移网络。
网络结构
CycleGAN 网络本质上是由两个镜像对称的 GAN 网络组成,下面这个例子以苹果和橘子为例介绍,讲解的很形象:
下图中𝑋可以理解为苹果,𝑌为橘子;𝐺为将苹果生成橘子风格的生成器,𝐹为将橘子生成的苹果风格的生成器,𝐷𝑋和𝐷𝑌为其相应判别器。模型最终能够输出两个模型的权重,分别将两种图像的风格进行彼此迁移,生成新的图像。
该网络需要多个损失函数,在所有损失里面循环一致损失(Cycle Consistency Loss)是最重要的,可以这样理解:
下图中苹果图片𝑥经过生成器𝐺得到伪橘子𝑌̂,然后将伪橘子𝑌̂结果送进生成器 𝐹又产生苹果风格的结果 𝑥̂ ,最后将生成的苹果风格结果 𝑥̂ 与原苹果图片 𝑥 一起计算出循环一致损失,反之亦然。循环损失捕捉了这样的直觉,即如果我们从一个域转换到另一个域,然后再转换回来,我们应该到达我们开始的地方。
循环一致损失能够保证重建图像与输入图像紧密匹配。
数据集可视化
import numpy as np
import matplotlib.pyplot as pltmean = 0.5 * 255
std = 0.5 * 255plt.figure(figsize=(12, 5), dpi=60)
for i, data in enumerate(dataset.create_dict_iterator()):if i < 5:show_images_a = data["image_A"].asnumpy()show_images_b = data["image_B"].asnumpy()plt.subplot(2, 5, i+1)show_images_a = (show_images_a[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))plt.imshow(show_images_a)plt.axis("off")plt.subplot(2, 5, i+6)show_images_b = (show_images_b[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))plt.imshow(show_images_b)plt.axis("off")else:break
plt.show()
网络的其他细节
-
构建生成器时,此模型使用ResNet 模型的结构
-
构建判别器,判别器其实是一个二分类网络模型,输出判定该图像为真实图的概率。
-
定义优化器和损失函数,优化器使用Adam,关于损失函数,主要关注循环一致损失函数
-
前向计算使用生成器生成图像的历史数据而不是生成器生成的最新图像数据来更新鉴别器。
-
计算梯度和反向传播,其中梯度计算也是分开不同的模型来进行的
-
最后是模型训练,模型训练训练分为两个主要部分:训练判别器和训练生成器,在前文的判别器损失函数中,论文采用了最小二乘损失代替负对数似然目标。
-
- 训练判别器:训练判别器的目的是最大程度地提高判别图像真伪的概率。按照论文的方法需要训练判别器来最小化 𝐸𝑦−𝑝𝑑𝑎𝑡𝑎(𝑦)[(𝐷(𝑦)−1)2];
-
- 训练生成器:如 CycleGAN 论文所述,我们希望通过最小化 𝐸𝑥−𝑝𝑑𝑎𝑡𝑎(𝑥)[(𝐷(𝐺(𝑥)−1)2]来训练生成器,以产生更好的虚假图像。
模型推理
%%time
import os
from PIL import Image
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
from mindspore import load_checkpoint, load_param_into_net# 加载权重文件
def load_ckpt(net, ckpt_dir):param_GA = load_checkpoint(ckpt_dir)load_param_into_net(net, param_GA)g_a_ckpt = './CycleGAN_apple2orange/ckpt/g_a.ckpt'
g_b_ckpt = './CycleGAN_apple2orange/ckpt/g_b.ckpt'load_ckpt(net_rg_a, g_a_ckpt)
load_ckpt(net_rg_b, g_b_ckpt)# 图片推理
fig = plt.figure(figsize=(11, 2.5), dpi=100)
def eval_data(dir_path, net, a):def read_img():for dir in os.listdir(dir_path):path = os.path.join(dir_path, dir)img = Image.open(path).convert('RGB')yield img, dirdataset = ds.GeneratorDataset(read_img, column_names=["image", "image_name"])trans = [vision.Resize((256, 256)), vision.Normalize(mean=[0.5 * 255] * 3, std=[0.5 * 255] * 3), vision.HWC2CHW()]dataset = dataset.map(operations=trans, input_columns=["image"])dataset = dataset.batch(1)for i, data in enumerate(dataset.create_dict_iterator()):img = data["image"]fake = net(img)fake = (fake[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))img = (img[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))fig.add_subplot(2, 8, i+1+a)plt.axis("off")plt.imshow(img.asnumpy())fig.add_subplot(2, 8, i+9+a)plt.axis("off")plt.imshow(fake.asnumpy())eval_data('./CycleGAN_apple2orange/predict/apple', net_rg_a, 0)
eval_data('./CycleGAN_apple2orange/predict/orange', net_rg_b, 4)
plt.show()
推理结果如下:
此章节学习到此结束,感谢昇思平台。