打卡
目录
打卡
GAN
博弈函数
博弈过程
GAN 案例
数据集
数据加载与可视化
隐码构造
模型构建
生成器
判别器
损失函数和优化器
模型训练
输出展示-1w张训练样本
输出展示-6w张训练样本
输出展示-6w张-100 epoch
效果展示
部分展示如图-12epoch-6w张
部分展示如图-100epoch-6w张编辑
模型推理
GAN
生成式对抗网络(Generative Adversarial Networks,GAN)是一种生成式机器学习模型,由Ian J. Goodfellow 于2014年发明(Generative Adversarial Nets),其主要由生成器(Generative Model) 和判别器(Discriminative Model) 两个模型互相博弈对抗,实现平衡、更好的输出。
- 生成器模型G:捕捉数据分布,生成看起来像训练图像的“假”图像;
- 判别器模型D:估计样本是否来自训练数据,即判断从生成器输出的图像是真实的训练图像还是虚假的图像。
博弈函数
>> 博弈的平衡点:当生成器生成的假图像和训练数据图像的分布完全一致时,判别器拥有50%的真假判断置信度。
>> GAN的损失函数:
其中,(1)x 表示图像数据,D(x) 表示判别器判别图像为真实图像的概率,当x来自训练数据时,D(x)的数值接近1;当 x 来自生成器时,D(x)的数值接近0。因此,D(x) 也可以被认为是传统的而分类器。(2)z代表标准正态分布中提取出的隐码(隐向量),G(z) 表示将隐码(隐向量)z映射到数据空间的生成器函数。函数 G(z) 的目标是将服从高斯分布的随机噪声 z 通过生成网络变换为近似于真实分布 的数据分布,我们希望找到 θ 使得 和 尽可能的接近,其中 𝜃 代表网络参数。(3) 𝐷(𝐺(𝑧)) 表示生成器 𝐺 生成的假图像被判定为真实图像的概率。(4)如Generative Adversarial Nets中所述,𝐷 和 𝐺 在进行一场博弈,𝐷 想要最大程度的正确分类真图像与假图像,也就是参数 log𝐷(𝑥) ;而 𝐺 试图欺骗 𝐷 来最小化假图像被识别到的概率,也就是参数 log(1−𝐷(𝐺(𝑧))) 。
因此,理论上,该博弈游戏的平衡点是,此时判别器D会随机猜测输入是真图像还是假图像。
博弈过程
如下图,蓝色虚线表示判别器D,黑色虚线表示真实数据分布,绿色实线表示生成器G生成的虚假数据分布,𝑧 表示隐码,𝑥 表示生成的虚假图像 𝐺(𝑧) 。
a)在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布。
b)判别器通过求取梯度和损失函数对网络进行优化,将靠近真实数据分布的数据判定为1,将靠近生成器生成出来数据分布的数据判定为0。
c)生成器通过优化,生成出更加贴近真实数据分布的数据。
d)生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2。
GAN 案例
案例说明:使用MNIST手写数字数据集来训练一个生成式对抗网络,使用该网络模拟生成手写数字图片。
数据集
MNIST手写数字数据集 共有70000张手写数字图片,包含60000张训练样本和10000张测试样本,数字图片为二进制文件,图片大小为28*28,单通道。图片已经预先进行了尺寸归一化和中心化处理。
数据下载代码:
# 数据下载
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
download(url, ".", kind="zip", replace=True)
数据加载与可视化
用MindSpore.MnistDatase
接口读取、解析MNIST数据集的源文件构建数据集,并做一些前处理。如下代码,只获取了训练数据集的10000张训练样本。
import numpy as np
import matplotlib.pyplot as plt
import mindspore.dataset as dsbatch_size = 64
latent_size = 100 # 隐码的长度train_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/train')
test_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/test')def data_load(dataset):## GeneratorDataset:自定义Python数据源,通过迭代该数据源构造数据集。生成的数据集的列名和列类型取决于用户定义的Python数据源。dataset1 = ds.GeneratorDataset(dataset, ["image", "label"], ## 指定数据集生成的列名。默认 None 不指定。 shuffle=True, ## 是否混洗数据集。python_multiprocessing=False, # 启用Python多进程模式加速运算。默认True。num_samples=10000 # 指定从数据集中读取的样本数。默认None = 读取全部样本。)# 数据增强## 给定一组数据增强列表,按顺序将数据增强作用在数据集对象上。mnist_ds = dataset1.map(operations=lambda x: (x.astype("float32"),np.random.normal(size=latent_size).astype("float32")),output_columns=["image", "latent_code"])## 从数据集对象中选择需要的列,并按给定的列名的顺序进行排序。mnist_ds = mnist_ds.project(["image", "latent_code"])# 批量操作mnist_ds = mnist_ds.batch(batch_size, True)return mnist_dsmnist_ds = data_load(train_dataset)iter_size = mnist_ds.get_dataset_size()
print('Iter size: %d' % iter_size)## create_dict_iterator 将数据转换成字典迭代器
data_iter = next(mnist_ds.create_dict_iterator(output_numpy=True))
figure = plt.figure(figsize=(3, 3))
cols, rows = 5, 5
for idx in range(1, cols * rows + 1):image = data_iter['image'][idx]figure.add_subplot(rows, cols, idx)plt.axis("off")plt.imshow(image.squeeze(), cmap="gray")
plt.show()
数据输出展示图
隐码构造
为了跟踪生成器的学习进度,我们在训练的过程中的每轮迭代结束后,将一组固定的遵循高斯分布的隐码 test_noise
输入到生成器中,通过固定隐码所生成的图像效果来评估生成器的好坏。
固定隐码的定义如下所示。
import random
import numpy as np
from mindspore import Tensor
from mindspore.common import dtype# 利用随机种子创建一批隐码
np.random.seed(2323)
test_noise = Tensor(np.random.normal(size=(25, 100)), dtype.float32)
random.shuffle(test_noise)
模型构建
本案例实现中所搭建的 GAN 模型结构与原论文中提出的 GAN 结构大致相同,但由于所用数据集 MNIST 为单通道小尺寸图片,可识别参数少,便于训练,本案例在判别器和生成器中采用全连接网络架构和 ReLU
激活函数,且省略了原论文中用于减少参数的 Dropout
策略和可学习激活函数 Maxout
。
生成器
- 生成器
Generator
:将隐码映射到数据空间(此处是图像,可以是灰度或RGB彩色图像)。 - 生成器代码定义如下。Generator 类继承于 nn.Cell 基类,包括了5层Dnese全连接层,每层都与
BatchNorm1d
批归一化层和ReLU
激活层配对,输出数据会经过Tanh
函数,使其返回 [-1,1] 的数据范围内。
注意实例化生成器之后需要修改参数的名称,不然静态图模式下会报错。
from mindspore import nn
import mindspore.ops as opsimg_size = 28 # 训练图像长(宽)class Generator(nn.Cell):def __init__(self, latent_size, auto_prefix=True):super(Generator, self).__init__(auto_prefix=auto_prefix)self.model = nn.SequentialCell()# [N, 100] -> [N, 128]# 输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维self.model.append(nn.Dense(latent_size, 128))self.model.append(nn.ReLU())# [N, 128] -> [N, 256]self.model.append(nn.Dense(128, 256))self.model.append(nn.BatchNorm1d(256))self.model.append(nn.ReLU())# [N, 256] -> [N, 512]self.model.append(nn.Dense(256, 512))self.model.append(nn.BatchNorm1d(512))self.model.append(nn.ReLU())# [N, 512] -> [N, 1024]self.model.append(nn.Dense(512, 1024))self.model.append(nn.BatchNorm1d(1024))self.model.append(nn.ReLU())# [N, 1024] -> [N, 784]# 经过线性变换将其变成784维self.model.append(nn.Dense(1024, img_size * img_size))# 经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间self.model.append(nn.Tanh())def construct(self, x):img = self.model(x)return ops.reshape(img, (-1, 1, 28, 28))net_g = Generator(latent_size)
net_g.update_parameters_name('generator')
判别器
- 判别器
Discriminator
:一个二分类网络模型,输出判定该图像为真实图的概率。 - 判别器代码定义如下。
Discriminator
类继承于 nn.Cell 基类,包括了3个 Dnese全连接层和LeakyReLU
层,最后通过Sigmoid
激活函数,使其返回 [0, 1] 的数据范围内,得到最终概率。
注意实例化判别器之后需要修改参数的名称,不然静态图模式下会报错。
# 判别器
class Discriminator(nn.Cell):def __init__(self, auto_prefix=True):super().__init__(auto_prefix=auto_prefix)self.model = nn.SequentialCell()# [N, 784] -> [N, 512]self.model.append(nn.Dense(img_size * img_size, 512)) # 输入特征数为784,输出为512self.model.append(nn.LeakyReLU()) # 默认斜率为0.2的非线性映射激活函数# [N, 512] -> [N, 256]self.model.append(nn.Dense(512, 256)) # 进行一个线性映射self.model.append(nn.LeakyReLU())# [N, 256] -> [N, 1]self.model.append(nn.Dense(256, 1))self.model.append(nn.Sigmoid()) # 二分类激活函数,将实数映射到[0,1]def construct(self, x):x_flat = ops.reshape(x, (-1, img_size * img_size))return self.model(x_flat)net_d = Discriminator()
net_d.update_parameters_name('discriminator')
损失函数和优化器
- 损失函数使用MindSpore中二进制交叉熵损失函数
BCELoss
。 - 生成器和判别器都是使用 Adam 优化器,但是需要构建两个不同名称的优化器,分别用于更新两个模型的参数。
lr = 0.0002 # 学习率# 损失函数
adversarial_loss = nn.BCELoss(reduction='mean')# 优化器
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g.update_parameters_name('optim_g')
optimizer_d.update_parameters_name('optim_d')
模型训练
训练分为两个主要部分。
- 第一部分是训练判别器D。训练判别器的目的是最大程度地提高判别图像真伪的概率。通过提高其随机梯度来更新判别器,最大化 𝑙𝑜𝑔𝐷(𝑥) + 𝑙𝑜𝑔(1−𝐷(𝐺(𝑧)) 的值。
- 第二部分是训练生成器G。最小化 𝑙𝑜𝑔(1−𝐷(𝐺(𝑧))) 来训练生成器,以产生更好的虚假图像。
在这两个部分中,分别获取训练过程中的损失,并在每轮迭代结束时进行测试,将隐码批量推送到生成器中,以直观地跟踪生成器 Generator
的训练效果。
import os
import time
import matplotlib.pyplot as plt
import mindspore as ms
from mindspore import Tensor, save_checkpointtotal_epoch = 12 # 训练周期数
batch_size = 64 # 用于训练的训练集批量大小# 加载预训练模型的参数
pred_trained = False
pred_trained_g = './result/checkpoints/Generator99.ckpt'
pred_trained_d = './result/checkpoints/Discriminator99.ckpt'checkpoints_path = "./result/checkpoints" # 结果保存路径
image_path = "./result/images" # 测试结果保存路径# 生成器计算损失过程
def generator_forward(test_noises):fake_data = net_g(test_noises) # 生成器生成假的图fake_out = net_d(fake_data) # 判别器判别假的图是训练数据的概率## 计算交叉损失函数loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out))return loss_g# 判别器计算损失过程
def discriminator_forward(real_data, test_noises):fake_data = net_g(test_noises) # 生成器生成假的图fake_out = net_d(fake_data) # 判别器判别假的图是训练数据的概率real_out = net_d(real_data) # 判别器判别真的图是训练数据的概率## 计算两个交叉损失函数real_loss = adversarial_loss(real_out, ops.ones_like(real_out))fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out))## 计算总的损失函数loss_d = real_loss + fake_lossreturn loss_d# 梯度方法:判别器 和 生成器
grad_g = ms.value_and_grad(generator_forward, None, net_g.trainable_params())
grad_d = ms.value_and_grad(discriminator_forward, None, net_d.trainable_params())def train_step(real_data, latent_code):# 计算判别器损失和梯度loss_d, grads_d = grad_d(real_data, latent_code)optimizer_d(grads_d)loss_g, grads_g = grad_g(latent_code)optimizer_g(grads_g)return loss_d, loss_g# 保存生成的test图像
def save_imgs(gen_imgs1, idx):for i3 in range(gen_imgs1.shape[0]):plt.subplot(5, 5, i3 + 1)plt.imshow(gen_imgs1[i3, 0, :, :] / 2 + 0.5, cmap="gray")plt.axis("off")plt.savefig(image_path + "/test_{}.png".format(idx))# 设置参数保存路径
os.makedirs(checkpoints_path, exist_ok=True)
# 设置中间过程生成图片保存路径
os.makedirs(image_path, exist_ok=True)net_g.set_train()
net_d.set_train()# 储存生成器和判别器loss
losses_g, losses_d = [], []## 训练过程
for epoch in range(total_epoch):start = time.time()for (iter, data) in enumerate(mnist_ds):start1 = time.time()image, latent_code = data## 图像归一化image = (image - 127.5) / 127.5 # [0, 255] -> [-1, 1]image = image.reshape(image.shape[0], 1, image.shape[1], image.shape[2])## 训练步骤d_loss, g_loss = train_step(image, latent_code)end1 = time.time()## 每10个迭代图像 打印一次if iter % 10 == 10:print(f"Epoch:[{int(epoch):>3d}/{int(total_epoch):>3d}], "f"step:[{int(iter):>4d}/{int(iter_size):>4d}], "f"loss_d:{d_loss.asnumpy():>4f} , "f"loss_g:{g_loss.asnumpy():>4f} , "f"time:{(end1 - start1):>3f}s, "f"lr:{lr:>6f}")end = time.time()print("time of epoch {} is {:.2f}s".format(epoch + 1, end - start))losses_d.append(d_loss.asnumpy())losses_g.append(g_loss.asnumpy())# 每个epoch结束后,使用生成器生成一组图片gen_imgs = net_g(test_noise)save_imgs(gen_imgs.asnumpy(), epoch)# 根据epoch保存模型权重文件if epoch % 1 == 0:save_checkpoint(net_g, checkpoints_path + "/Generator%d.ckpt" % (epoch))save_checkpoint(net_d, checkpoints_path + "/Discriminator%d.ckpt" % (epoch))### 描绘D和G损失与训练迭代的关系图
plt.figure(figsize=(6, 4))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(losses_g, label="G", color='blue')
plt.plot(losses_d, label="D", color='orange')
plt.xlim(-5,15)
plt.ylim(0, 3.5)
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
输出展示-1w张训练样本
如下图的训练过程,是使用1w张训练样本训练的。
输出展示-6w张训练样本
如下图的训练过程,是使用6w张训练样本训练的。
输出展示-6w张-100 epoch
效果展示
可视化训练过程中通过隐向量生成的图像。
import cv2
import matplotlib.animation as animation# 将训练过程中生成的测试图转为动态图
image_list = []
for i in range(total_epoch):image_list.append(cv2.imread(image_path + "/test_{}.png".format(i), cv2.IMREAD_GRAYSCALE))
show_list = []
fig = plt.figure(dpi=70)
for epoch in range(0, len(image_list), 5):plt.axis("off")show_list.append([plt.imshow(image_list[epoch], cmap='gray')])ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)
ani.save('train_test.gif', writer='pillow', fps=1)
部分展示如图-12epoch-6w张
部分展示如图-100epoch-6w张
可以看到,随着训练次数的增多,图像质量也越来越好。如果增大训练周期数,当 epoch
达到100以上时,生成的手写数字图片与数据集中的较为相似。
模型推理
通过加载生成器网络模型参数文件来生成图像
import mindspore as mstest_ckpt = './result/checkpoints/Generator199.ckpt'parameter = ms.load_checkpoint(test_ckpt)
ms.load_param_into_net(net_g, parameter)
# 模型生成结果
test_data = Tensor(np.random.normal(0, 1, (25, 100)).astype(np.float32))
images = net_g(test_data).transpose(0, 2, 3, 1).asnumpy()
# 结果展示
fig = plt.figure(figsize=(3, 3), dpi=120)
for i in range(25):fig.add_subplot(5, 5, i + 1)plt.axis("off")plt.imshow(images[i].squeeze(), cmap="gray")
plt.show()