《昇思 25 天学习打卡营第 19 天 | 生成式对抗网络(GAN)实践指南 》
活动地址:https://xihe.mindspore.cn/events/mindspore-training-camp
签名:Sam9029
GAN 模型概述
生成式对抗网络(GAN)是一种前沿的无监督学习模型,由 Goodfellow 等人于 2014 年提出。GAN 由两部分组成:生成器(Generator)和判别器(Discriminator)。生成器负责生成逼真的假图像,而判别器则负责区分图像是真实的还是生成器生成的。
模型组成
- 生成器(G):从标准正态分布中采样隐码(latent code),并将其映射到数据空间,生成假图像。
- 判别器(D):接收输入图像,并预测图像为真实或假的概率。
训练过程
GAN 的训练是一个动态的博弈过程,生成器和判别器相互竞争,共同进步。
- 初始阶段:生成器生成质量较差的图像,判别器容易区分真假。
- 训练过程中:生成器不断学习生成更逼真的图像,判别器则提高其识别能力。
- 平衡点:理想情况下,生成器生成的图像与真实图像分布一致,判别器无法区分。
数据集
本案例使用 MNIST 手写数字数据集,包含 60000 张训练样本和 10000 张测试样本,图像大小为 28x28。
数据加载与预处理
使用 MindSpore 的MnistDataset
接口加载数据集,并进行必要的预处理,如归一化和数据增强。
train_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/train')
test_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/test')
模型构建
生成器
生成器采用五层全连接层,每层后接 BatchNorm 和 ReLU 激活层,输出通过 Tanh 函数以产生[-1, 1]范围内的图像数据。
class Generator(nn.Cell):# 定义生成器结构# ...def construct(self, x):img = self.model(x)return ops.reshape(img, (-1, 1, 28, 28))
判别器
判别器使用一系列全连接层和 LeakyReLU 激活层,最后通过 Sigmoid 激活函数输出概率。
class Discriminator(nn.Cell):# 定义判别器结构# ...def construct(self, x):x_flat = ops.reshape(x, (-1, img_size * img_size))return self.model(x_flat)
损失函数与优化器
使用二元交叉熵损失函数(BCELoss)和 Adam 优化器。
adversarial_loss = nn.BCELoss(reduction='mean')
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr)
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr)
模型训练
训练过程包括训练判别器和生成器,记录损失并在每个 epoch 结束时生成图像以跟踪进度。
for epoch in range(total_epoch):# 训练循环# ...gen_imgs = net_g(test_noise)save_imgs(gen_imgs.asnumpy(), epoch)
效果展示
训练过程中,生成器生成的图像质量逐渐提高,最终可生成与真实图像相似的假图像。GAN 模型的强大之处在于其生成高质量图像的能力,但其训练过程可能不稳定,需要仔细调整参数。此外,GAN 的应用不仅限于图像生成,还可以扩展到其他领域,如风格迁移、数据增强等、 通过本实践指南,学习了如何使用 MindSpore 框架构建和训练 GAN 模型,并使用 MNIST 数据集进行训练。见证了从模型初始化到训练,再到生成高质量图像的整个过程。随着技术的不断发展,GAN 有望在更多领域展现其潜力。