文章目录
- 昇思MindSpore应用实践
- 基于MindSpore的DCGAN生成漫画头像
- 1、DCGAN 概述
- 零和博弈 vs 极大极小博弈
- GAN的生成对抗损失
- DCGAN原理
- 2、数据预处理
- 3、DCGAN模型构建
- 生成器部分
- 判别器部分
- 4、模型训练
- Reference
昇思MindSpore应用实践
本系列文章主要用于记录昇思25天学习打卡营的学习心得。
基于MindSpore的DCGAN生成漫画头像
1、DCGAN 概述
这部分原理介绍参考昇思官方文档GAN图像生成和昇思25天学习打卡营第5天_GAN图像生成
生成对抗网络简介:
零和博弈 vs 极大极小博弈
生成对抗网络Generative adversarial networks (GANs)主要包括生成器网络(Generator)和判别器网络(Discriminator)
这两个网络在GAN的训练过程中相互竞争,形成了一种博弈论中的极大极小博弈(MinMax game)
零和博弈(Zero-sum game)是博弈论中的一个重要概念,指的是参与者的利益完全相反,即一方的利益的增加意味着另一方的利益的减少,总利益为零。在零和博弈中,参与者之间的利益是完全对立的,因此一个参与者的利益的增加必然导致其他参与者的利益减少。在非合作博弈中,纳什均衡是一种重要的解,纳什均衡代表每个玩家选择的策略都是其在对方策略给定的情况下的最优策略。在零和博弈中,寻找纳什均衡通常涉及找到使每个玩家的预期收益最大化的策略组合。
极大极小博弈(MinMax game)是一种博弈论中的解决方法,用于确定参与者的最佳决策策略,此外为人所熟知用于决策的方法还有强化学习。在极大极小博弈中,每个参与者都试图最大化自己的最小收益。也就是说,每个参与者都采取行动,以确保在对手选择其最优策略时自己的收益最大化。
假设GAN网络训练达到了纳什平衡状态,那么判别器无法准确地判断出输入样本是真样本还是假样本,此时判别器失效,生成器达到了巅峰状态,我们就无需使用判别器并终止训练了,得到的生成器就是我们用来生成数据的预训练模型。
从理论上讲,此博弈游戏的平衡点是 p G ( x ; θ ) = p d a t a ( x ) p_{G}(x;\theta) = p_{data}(x) pG(x;θ)=pdata(x),此时判别器会随机猜测输入是真图像还是假图像。下面我们简要说明生成器和判别器的博弈过程:
- 在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布;
- 判别器通过求取梯度和损失函数对网络进行优化,将接近真实数据分布的数据判定为1( D ( x ) = 1 D(x)=1 D(x)=1),将接近生成器生成数据分布的数据判定为0(( G ( z ) = 0 G(z)=0 G(z)=0)),即希望 min G max D V ( G , D ) \underset{G}{\min} \underset{D}{\max}V(G, D) GminDmaxV(G,D);
- 生成器通过优化,生成出更加贴近真实数据分布的数据;
- 生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2,如上图中的(d)所示。
GAN的生成对抗损失
min G max D V ( G , D ) = E x ∼ p data ( x ) [ log D ( x ) ] + E z ∼ p z ( z ) [ log ( 1 − D ( G ( z ) ) ) ] \underset{G}{\min} \underset{D}{\max}V(G, D) = \mathbb{E}_{x \sim p{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] GminDmaxV(G,D)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
GAN网络本身就是在训练一个能达到平衡状态的损失函数,生成对抗损失是GANs中最基本的损失函数。
当生成对抗损失达到纳什均衡时,判别器对真假数据的判别概率都是0.5,即 D ( x ) = 1 − G ( z ) = 0.5 D(x)=1-G(z)=0.5 D(x)=1−G(z)=0.5,
即 l o g ( D ( x ) ) = l o g ( 1 − G ( z ) ) ≈ 0.693 log(D(x))=log(1-G(z))\approx0.693 log(D(x))=log(1−G(z))≈0.693
由于数据x和G(z)不仅是一张图片,再分别取两者的均值 E \mathbb{E} E,相加,就得到了生成对抗损失。
近十年来著名的GAN网络结构:
DCGAN原理
如上图所示,DCGAN(深度卷积对抗生成网络,Deep Convolutional Generative Adversarial Networks)是GAN的直接扩展。
不同之处在于,DCGAN会分别在判别器和生成器中使用卷积和转置卷积层。
它最早由Radford等人在论文Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks中进行描述。判别器由分层的卷积层、BatchNorm层和LeakyReLU激活层组成。输入是3x64x64的图像,输出是该图像为真图像的概率。生成器则是由转置卷积层、BatchNorm层和ReLU激活层组成。输入是标准正态分布中提取出的隐向量 z z z,输出是3x64x64的RGB图像。
本教程将使用动漫头像数据集来训练一个生成式对抗网络,接着使用该网络生成动漫头像图片。
2、数据预处理
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as visiondef create_dataset_imagenet(dataset_path):"""数据加载"""dataset = ds.ImageFolderDataset(dataset_path,num_parallel_workers=4,shuffle=True,decode=True)# 数据增强操作transforms = [vision.Resize(image_size),vision.CenterCrop(image_size),vision.HWC2CHW(),lambda x: ((x / 255).astype("float32"))]# 数据映射操作dataset = dataset.project('image')dataset = dataset.map(transforms, 'image')# 批量操作dataset = dataset.batch(batch_size)return datasetdataset = create_dataset_imagenet('./faces')# 通过create_dict_iterator函数将数据转换成字典迭代器,然后使用matplotlib模块可视化部分训练数据。import matplotlib.pyplot as pltdef plot_data(data):# 可视化部分训练数据plt.figure(figsize=(10, 3), dpi=140)for i, image in enumerate(data[0][:30], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow(image.transpose(1, 2, 0))plt.show()sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)
3、DCGAN模型构建
生成器部分
生成器G
的功能是将隐向量z
映射到数据空间。由于数据是图像,这一过程也会创建与真实图像大小相同的 RGB 图像。在实践场景中,该功能是通过一系列Conv2dTranspose
转置卷积层来完成的,每个层都与BatchNorm2d
层和ReLu
激活层配对,输出数据会经过tanh
函数,使其返回[-1,1]
的数据范围内。
DCGAN生成器生成图像的大致流程如下:
1、将一个1x100的高斯潜在噪声向量投影变换为一个4x4x1024的特征图;
2、在经过CONV1卷积输出为8x8x512的特征图;
3、逐步增大分辨率,缩小通道数,经过CONV2卷积输出为16x16x256的特征图;
4、经过CONV3卷积输出为32x32x128的特征图;
5、最后经过CONV4卷积输出为64x64x3的生成图像,与真实图像一起送入判别器进行鉴定;
6、在训练过程中尽可能地生成逼近真实图像分布的效果从而欺骗判别器,令其失效,这样生成对抗就达到了平衡状态,生成器的训练过程完毕,拿去用作模型推理。
import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class Generator(nn.Cell):"""DCGAN网络生成器"""def __init__(self):super(Generator, self).__init__()self.generator = nn.SequentialCell(nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),nn.Tanh())def construct(self, x):return self.generator(x)generator = Generator()
判别器部分
class Discriminator(nn.Cell):"""DCGAN网络判别器"""def __init__(self):super(Discriminator, self).__init__()self.discriminator = nn.SequentialCell(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),)self.adv_layer = nn.Sigmoid()def construct(self, x):out = self.discriminator(x)out = out.reshape(out.shape[0], -1)return self.adv_layer(out)discriminator = Discriminator()
4、模型训练
# 定义损失函数
adversarial_loss = nn.BCELoss(reduction='mean')# 为生成器和判别器设置优化器
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G.update_parameters_name('optim_g.')
optimizer_D.update_parameters_name('optim_d.')# 定义训练时要用到的功能函数
def generator_forward(real_imgs, valid):# 将噪声采样为发生器的输入z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))# 生成一批图像gen_imgs = generator(z)# 损失衡量发生器绕过判别器的能力g_loss = adversarial_loss(discriminator(gen_imgs), valid)return g_loss, gen_imgsdef discriminator_forward(real_imgs, gen_imgs, valid, fake):# 衡量鉴别器从生成的样本中对真实样本进行分类的能力real_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs), fake)d_loss = (real_loss + fake_loss) / 2return d_lossgrad_generator_fn = ms.value_and_grad(generator_forward, None,optimizer_G.parameters,has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,optimizer_D.parameters)@ms.jit
def train_step(imgs):valid = ops.ones((imgs.shape[0], 1), mindspore.float32)fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)(g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)optimizer_G(g_grads)d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)optimizer_D(d_grads)return g_loss, d_loss, gen_imgsimport mindsporeG_losses = []
D_losses = []
image_list = []total = dataset.get_dataset_size()
for epoch in range(num_epochs):generator.set_train()discriminator.set_train()# 为每轮训练读入数据for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):g_loss, d_loss, gen_imgs = train_step(imgs)if i % 100 == 0 or i == total - 1:# 输出训练记录print('[%2d/%d][%3d/%d] Loss_D:%7.4f Loss_G:%7.4f' % (epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))D_losses.append(d_loss.asnumpy())G_losses.append(g_loss.asnumpy())# 每个epoch结束后,使用生成器生成一组图片generator.set_train(False)fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))img = generator(fixed_noise)image_list.append(img.transpose(0, 2, 3, 1).asnumpy())# 保存网络模型参数为ckpt文件mindspore.save_checkpoint(generator, "./generator.ckpt")mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")
cpu训练5个epoch的训练效果:
可以明显看出Loss_D和Loss_G的分数并没有达到0.5:0.5的纳什平衡状态,生成图像自然是很可怕的抽象二次元漫画头像,这里忘了截图了就不放效果了。
申请了Ascend910 NPU的算力,训练50轮效果:
910太快了啊,吃顿饭回来就跑完了,不过结果还是蚌埠住了…
还是很糊,练崩了,今天先到这里了,先打次卡,有时间再调整一下网络结构试试,DCGAN可能对Anime数据集来说还是太简单了,不太好控制的样子。
两个网络训练的log:
Reference
昇思大模型平台
什么是GAN生成对抗网络,使用DCGAN生成动漫头像