昇思25天学习打卡营第25天|GAN图像生成

学AI还能赢奖品?每天30分钟,25天打通AI任督二脉 (qq.com)

GAN图像生成

模型简介

生成式对抗网络(Generative Adversarial Networks,GAN)是一种生成式机器学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。

GAN论文逐段精读【论文精读】_哔哩哔哩_bilibili

最初,GAN由Ian J. Goodfellow于2014年发明,并在论文Generative Adversarial Nets中首次进行了描述,其主要由两个不同的模型共同组成——生成器(Generative Model)和判别器(Discriminative Model):

  • 生成器的任务是生成看起来像训练图像的“假”图像;
  • 判别器需要判断从生成器输出的图像是真实的训练图像还是虚假的图像。

GAN通过设计生成模型和判别模型这两个模块,使其互相博弈学习产生了相当好的输出。

GAN模型的核心在于提出了通过对抗过程来估计生成模型这一全新框架。在这个框架中,将会同时训练两个模型——捕捉数据分布的生成模型 𝐺 和估计样本是否来自训练数据的判别模型 𝐷 。

在训练过程中,生成器会不断尝试通过生成更好的假图像来骗过判别器,而判别器在这过程中也会逐步提升判别能力。这种博弈的平衡点是,当生成器生成的假图像和训练数据图像的分布完全一致时,判别器拥有50%的真假判断置信度。

用 𝑥 代表图像数据,用 𝐷(𝑥)表示判别器网络给出图像判定为真实图像的概率。在判别过程中,𝐷(𝑥)需要处理作为二进制文件的大小为 1×28×28的图像数据。当 𝑥 来自训练数据时,𝐷(𝑥) 数值应该趋近于 1 ;而当 𝑥 来自生成器时,𝐷(𝑥)数值应该趋近于 0 。因此 𝐷(𝑥)也可以被认为是传统的二分类器。

用 𝑧 代表标准正态分布中提取出的隐码(隐向量),用 𝐺(𝑧):表示将隐码(隐向量) 𝑧 映射到数据空间的生成器函数。函数 𝐺(𝑧)的目标是将服从高斯分布的随机噪声 𝑧 通过生成网络变换为近似于真实分布 的数据分布,我们希望找到 θ 使得尽可能的接近,其中 𝜃代表网络参数。

符号p_G(x; \theta) 表示的是生成模型 G 生成数据x的概率分布,参数为\theta。具体来说:

x:表示数据样本,例如图像或其他形式的数据。
\theta:表示生成器G的参数,这些参数是通过训练过程来优化的。
p_G(x; \theta):表示通过生成器G在参数\theta下生成数据x的概率分布。

表示生成器 𝐺 生成的假图像被判定为真实图像的概率,如Generative Adversarial Nets中所述,𝐷 和 𝐺 在进行一场博弈,𝐷 想要最大程度的正确分类真图像与假图像,也就是参数;而 𝐺 试图欺骗 𝐷 来最小化假图像被识别到的概率,也就是参数。因此GAN的损失函数为:

\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)} [\log D(x)] + \mathbb{E}_{z \sim p_z(z)} [\log (1 - D(G(z)))]

- 判别器的目标函数:

\max_D \mathbb{E}_{x \sim p_{\text{data}}(x)} [\log D(x)] + \mathbb{E}_{z \sim p_z(z)} [\log (1 - D(G(z)))]

- 生成器的目标函数:
\min_G \mathbb{E}_{z \sim p_z(z)} [\log (1 - D(G(z)))]

在GAN(生成式对抗网络)中,损失函数反映了生成器和判别器之间的对抗关系。有两个模型:生成器G和判别器D。判别器D的目标是最大化其正确分类真实图像和生成图像的概率,而生成器 G的目标是最小化被判别器正确识别为生成图像的概率。

判别器D的目标是最大化以下期望值:

\mathbb{E}_{x \sim p_{\text{data}}(x)} [\log D(x)] + \mathbb{E}_{z \sim p_z(z)} [\log (1 - D(G(z)))]

其中,x表示真实图像,z是从标准正态分布中采样的隐变量,G(z)表示生成的假的图像。第一项期望值表示真实图像被判别为真实的概率,第二项期望值表示生成的图像被判别为假的概率。

另一方面,生成器G的目标是最小化判别器正确识别为生成图像的概率,这相当于最小化以下期望值:

\mathbb{E}_{z \sim p_z(z)} [\log (1 - D(G(z)))]

综合两个目标, 可以给出生成式对抗网络的损失函数(目标函数)为:

\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)} [\log D(x)] + \mathbb{E}_{z \sim p_z(z)} [\log (1 - D(G(z)))]

这是一个双重优化问题,两个模型在同一时间被训练:判别器D尽可能区分真实和生成图像,生成器G尽可能生成能够欺骗判别器的图像。

以下是解释关键项的详细含义:

1.\mathbb{E}_{x \sim p_{\text{data}}(x)} [\log D(x)]:期望真实数据x被判别为真实的概率的对数,这部分使得判别器更好地区分真实数据。

2. \mathbb{E}_{z \sim p_z(z)} [\log (1 - D(G(z)))]:期望生成的假数据G(z)被判别为假的概率的对数,这部分反映了生成器通过尝试生成能够欺骗判别器的假数据来提高生成效果。

通过这样的对抗训练过程,生成器会逐渐提高其生成质量,生成的图像会愈发接近真实图像,使得判别器愈加难以区分真假图像。

从理论上讲,此博弈游戏的平衡点是,此时判别器会随机猜测输入是真图像还是假图像。下面我们简要说明生成器和判别器的博弈过程:·

  1. 在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布。
  2. 判别器通过求取梯度和损失函数对网络进行优化,将靠近真实数据分布的数据判定为1,将靠近生成器生成出来数据分布的数据判定为0。
  3. 生成器通过优化,生成出更加贴近真实数据分布的数据。
  4. 生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2。

在理想的均衡点,生成器生成的数据分布 p_g和真实的数据分布 p_{data} 相同,即 p_g = p_{data}。此时,对于任意一个输入数据,判别器无法区分它是生成器生成的还是实际数据,这就表示判别器的输出应该在0和1之间摇摆不定,最可能的值是0.5。

gan

在上图中,蓝色虚线表示判别器,黑色虚线表示真实数据分布,绿色实线表示生成器生成的虚假数据分布,𝑧 表示隐码,𝑥 表示生成的虚假图像 𝐺(𝑧)。该图片来源于Generative Adversarial Nets。详细的训练方法介绍见原论文。

在这四张图片中,黑色虚线代表了真实数据分布的概率密度。我们可以看到左侧通常表示真实数据分布较高的区域。因此,判别器在这些区域应该输出更高的概率值。从图中蓝色虚线的形状可以看出,随着训练的进行,判别器在左侧区域的输出值是大于0.5的,因为这个区域内的样本被判别器认为更可能是真实数据。第四张图片在理想情况下,生成器生成的数据分布(绿色实线)与真实数据分布(黑色虚线)几乎完全重合。判别器无法区分生成数据和真实数据,只能随机猜测,所以在这个阶段的判别器输出会接近1/2。

数据集

数据集简介

MNIST手写数字数据集是NIST数据集的子集,共有70000张手写数字图片,包含60000张训练样本和10000张测试样本,数字图片为二进制文件,图片大小为28*28,单通道。图片已经预先进行了尺寸归一化和中心化处理。

本案例将使用MNIST手写数字数据集来训练一个生成式对抗网络,使用该网络模拟生成手写数字图片。

数据集下载

使用download接口下载数据集,并将下载后的数据集自动解压到当前目录下。数据下载之前需要使用pip install download安装download包。

下载解压后的数据集目录结构如下:

./MNIST_Data/
├─ train
│ ├─ train-images-idx3-ubyte
│ └─ train-labels-idx1-ubyte
└─ test├─ t10k-images-idx3-ubyte└─ t10k-labels-idx1-ubyte

数据下载的代码如下:

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore
Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by: 
# 数据下载
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
download(url, ".", kind="zip", replace=True)
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip (10.3 MB)file_sizes: 100%|███████████████████████████| 10.8M/10.8M [00:00<00:00, 116MB/s]
Extracting zip file...
Successfully downloaded / unzipped to .

[3]:

'.'

数据加载

使用MindSpore自己的MnistDatase接口,读取和解析MNIST数据集的源文件构建数据集。然后对数据进行一些前处理。

import numpy as np
import mindspore.dataset as dsbatch_size = 64
latent_size = 100  # 隐码的长度# 载入数据集
train_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/train')
test_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/test')# 定义数据加载函数
def data_load(dataset):dataset1 = ds.GeneratorDataset(dataset, ["image", "label"], shuffle=True, python_multiprocessing=False,num_samples=10000)# 数据增强mnist_ds = dataset1.map(operations=lambda x: (x.astype("float32"), np.random.normal(size=latent_size).astype("float32")),output_columns=["image", "latent_code"])mnist_ds = mnist_ds.project(["image", "latent_code"])# 批量操作mnist_ds = mnist_ds.batch(batch_size, True)return mnist_dsmnist_ds = data_load(train_dataset)iter_size = mnist_ds.get_dataset_size()
print('Iter size: %d' % iter_size)
Iter size: 156

dataset1 是一个 GeneratorDataset,它从原始数据集中生成一个新的数据集并对其进行打乱(shuffle),限制样本数量为10000。operations 部分是一个匿名函数(lambda 函数),它接收输入的数据 x,然后返回两个新的值:将图像数据转换为 float32 类型、生成一个服从标准正态分布的随机向量,大小为 latent_size,并将其转化为 float32 类型。output_columns=["image", "latent_code"] 定义了输出数据集的列名。project 方法用于选择数据集中指定的列。["image", "latent_code"] 表示只保留图像和生成的潜在编码(latent code)列,丢弃其他不需要的列如 label

数据集可视化

通过create_dict_iterator函数将数据转换成字典迭代器,然后使用matplotlib模块可视化部分训练数据。

import matplotlib.pyplot as pltdata_iter = next(mnist_ds.create_dict_iterator(output_numpy=True))
figure = plt.figure(figsize=(3, 3))
cols, rows = 5, 5
for idx in range(1, cols * rows + 1):image = data_iter['image'][idx]figure.add_subplot(rows, cols, idx)plt.axis("off")plt.imshow(image.squeeze(), cmap="gray")
plt.show()

隐码构造

为了跟踪生成器的学习进度,我们在训练的过程中的每轮迭代结束后,将一组固定的遵循高斯分布的隐码test_noise输入到生成器中,通过固定隐码所生成的图像效果来评估生成器的好坏。

import random
import numpy as np
from mindspore import Tensor
from mindspore.common import dtype# 利用随机种子创建一批隐码
np.random.seed(2323)
test_noise = Tensor(np.random.normal(size=(25, 100)), dtype.float32)
random.shuffle(test_noise)

生成隐码test_noise,用来生成图像。

模型构建

本案例实现中所搭建的 GAN 模型结构与原论文中提出的 GAN 结构大致相同,但由于所用数据集 MNIST 为单通道小尺寸图片,可识别参数少,便于训练,我们在判别器和生成器中采用全连接网络架构和 ReLU 激活函数即可达到令人满意的效果,且省略了原论文中用于减少参数的 Dropout 策略和可学习激活函数 Maxout

生成器

生成器 Generator 的功能是将隐码映射到数据空间。由于数据是图像,这一过程也会创建与真实图像大小相同的灰度图像(或 RGB 彩色图像)。在本案例演示中,该功能通过五层 Dense 全连接层来完成的,每层都与 BatchNorm1d 批归一化层和 ReLU 激活层配对,输出数据会经过 Tanh 函数,使其返回 [-1,1] 的数据范围内。注意实例化生成器之后需要修改参数的名称,不然静态图模式下会报错。

from mindspore import nn
import mindspore.ops as opsimg_size = 28  # 训练图像长(宽)class Generator(nn.Cell):def __init__(self, latent_size, auto_prefix=True):super(Generator, self).__init__(auto_prefix=auto_prefix)self.model = nn.SequentialCell()# [N, 100] -> [N, 128]# 输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维self.model.append(nn.Dense(latent_size, 128))self.model.append(nn.ReLU())# [N, 128] -> [N, 256]self.model.append(nn.Dense(128, 256))self.model.append(nn.BatchNorm1d(256))self.model.append(nn.ReLU())# [N, 256] -> [N, 512]self.model.append(nn.Dense(256, 512))self.model.append(nn.BatchNorm1d(512))self.model.append(nn.ReLU())# [N, 512] -> [N, 1024]self.model.append(nn.Dense(512, 1024))self.model.append(nn.BatchNorm1d(1024))self.model.append(nn.ReLU())# [N, 1024] -> [N, 784]# 经过线性变换将其变成784维self.model.append(nn.Dense(1024, img_size * img_size))# 经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间self.model.append(nn.Tanh())def construct(self, x):img = self.model(x)return ops.reshape(img, (-1, 1, 28, 28))net_g = Generator(latent_size)
net_g.update_parameters_name('generator')

定义了一个生成器 (Generator) 模型,用于生成与真实图像大小相同的灰度图像 (或RGB彩色图像),该生成器使用了五层 Dense 全连接层,每层都与 BatchNorm1d 批归一化层和 ReLU 激活层配对,最后一层使用 Tanh 激活函数。
使用五层 Dense 全连接层(全连接神经网络层)主要是为了逐步提升输入的维度以生成高质量的图像。到具体每层参数的选择,比如128、256、512、1024,是通过经验和实验选择的,这些数值通常能够在保持计算高效的同时提升生成图像的质量。
- 输入100维到128维:从一个100维的噪声输入开始,初步扩大到128维。
- 128维到256维:进一步扩大,增加特征复杂度。
- 256维到512维:继续递增维度,捕捉更多特征信息。
- 512维到1024维:最大层,提供足够的容量以生成复杂图像。
- 1024维到784维(28×28):最终输出尺寸为28x28(784个像素),适合常见的小图像数据集如MNIST。
BatchNorm1d 批归一化层
批归一化层(Batch Normalization)在神经网络训练中有如下作用:
1. 稳定训练过程:它通过标准化每个小批次的输入,使其均值为0方差为1,来稳定和加速神经网络的训练。
2. 减轻内部协变量偏移:每一层输入的分布保持相对稳定,使网络各层能够更容易学习。
3. 减轻过拟合:有轻微的正则化效果,从而减少过拟合的现象。
BatchNorm1d 对高维向量进行归一化。其参数(如256,512)与前后层输出的维度相匹配。
代码中使用了两个激活函数:
1. ReLU (Rectified Linear Unit):
用于隐藏层。ReLU激活函数具有稀疏激活和梯度消失问题的解决能力。
2. Tanh (Hyperbolic Tangent):
用于生成器的最后一层。
Tanh激活函数将输出限制在[-1, 1]之间。

在实例化生成器之后需要修改参数的名称,这是因为在静态图模式(如有些深度学习框架中的图计算模式)下,参数名称需要唯一且一致,不然可能会发生名称冲突或不一致的问题,导致内部图计算无法正确构建和运行。new_parameters_prefix 设为 generator,调用 net_g.update_parameters_name('generator') 后,net_g 内部所有参数的名称都会被自动地加上generator前缀,确保参数名称的唯一性。

判别器

如前所述,判别器 Discriminator 是一个二分类网络模型,输出判定该图像为真实图的概率。主要通过一系列的 Dense 层和 LeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数,使其返回 [0, 1] 的数据范围内,得到最终概率。注意实例化判别器之后需要修改参数的名称,不然静态图模式下会报错。

 # 判别器
class Discriminator(nn.Cell):def __init__(self, auto_prefix=True):super().__init__(auto_prefix=auto_prefix)self.model = nn.SequentialCell()# [N, 784] -> [N, 512]self.model.append(nn.Dense(img_size * img_size, 512))  # 输入特征数为784,输出为512self.model.append(nn.LeakyReLU())  # 默认斜率为0.2的非线性映射激活函数# [N, 512] -> [N, 256]self.model.append(nn.Dense(512, 256))  # 进行一个线性映射self.model.append(nn.LeakyReLU())# [N, 256] -> [N, 1]self.model.append(nn.Dense(256, 1))self.model.append(nn.Sigmoid())  # 二分类激活函数,将实数映射到[0,1]def construct(self, x):x_flat = ops.reshape(x, (-1, img_size * img_size))return self.model(x_flat)net_d = Discriminator()
net_d.update_parameters_name('discriminator')

Dense层的参数(神经元数)主要是从输入图像大小逐层缩减到单一输出(表示图像为真实图的概率)。
输入层(784 -> 512):输入图像的每个像素点展平成一个一维向量,长度为28*28=784。使用512个神经元来处理这个输入,在信息传递过程中做到信息浓缩而不过多丢失细节。
隐藏层1(512 -> 256将512维的输入映射到256维。继续收缩特征空间,减少特征数量,同时保持重要的表示。
输出层(256 -> 1):最后一层只输出一个值,通过 Sigmoid 激活函数,将其映射到[0,1]区间,表示判别结果的概率分数。
两个激活函数:
LeakyReLU:在隐藏层(512和256)后面使用,它是一种改进的ReLU函数,通过线性保持负数部分而不是直接切断它们(默认斜率为0.2),这有助于缓解神经元的"死亡"问题并允许在负值区域有一定的梯度传播。
Sigmoid:在最后的输出层后面使用,将线性输出值压缩到[0,1]区间,使其成为一个概率值,适合二分类任务。

损失函数和优化器

定义了 Generator 和 Discriminator 后,损失函数使用MindSpore中二进制交叉熵损失函数BCELoss ;这里生成器和判别器都是使用Adam优化器,但是需要构建两个不同名称的优化器,分别用于更新两个模型的参数,详情见下文代码。注意优化器的参数名称也需要修改。

lr = 0.0002  # 学习率# 损失函数
adversarial_loss = nn.BCELoss(reduction='mean')# 优化器
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g.update_parameters_name('optim_g')
optimizer_d.update_parameters_name('optim_d')

BCELoss(Binary Cross Entropy Loss)是二分类问题常用的一种损失函数。

Adam(Adaptive Moment Estimation)优化器是一种能很好地适应训练过程中参数调整的优化算法。

模型训练

训练分为两个主要部分。

第一部分是训练判别器。训练判别器的目的是最大程度地提高判别图像真伪的概率。按照原论文的方法,通过提高其随机梯度来更新判别器,最大化 的值。

第二部分是训练生成器。如论文所述,最小化来训练生成器,以产生更好的虚假图像。

在这两个部分中,分别获取训练过程中的损失,并在每轮迭代结束时进行测试,将隐码批量推送到生成器中,以直观地跟踪生成器 Generator 的训练效果。

import os
import time
import matplotlib.pyplot as plt
import mindspore as ms
from mindspore import Tensor, save_checkpointtotal_epoch = 12  # 训练周期数
batch_size = 64  # 用于训练的训练集批量大小# 加载预训练模型的参数
pred_trained = False
pred_trained_g = './result/checkpoints/Generator99.ckpt'
pred_trained_d = './result/checkpoints/Discriminator99.ckpt'checkpoints_path = "./result/checkpoints"  # 结果保存路径
image_path = "./result/images"  # 测试结果保存路径
# 生成器计算损失过程
def generator_forward(test_noises):fake_data = net_g(test_noises)fake_out = net_d(fake_data)loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out))return loss_g# 判别器计算损失过程
def discriminator_forward(real_data, test_noises):fake_data = net_g(test_noises)fake_out = net_d(fake_data)real_out = net_d(real_data)real_loss = adversarial_loss(real_out, ops.ones_like(real_out))fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out))loss_d = real_loss + fake_lossreturn loss_d# 梯度方法
grad_g = ms.value_and_grad(generator_forward, None, net_g.trainable_params())
grad_d = ms.value_and_grad(discriminator_forward, None, net_d.trainable_params())def train_step(real_data, latent_code):# 计算判别器损失和梯度loss_d, grads_d = grad_d(real_data, latent_code)optimizer_d(grads_d)loss_g, grads_g = grad_g(latent_code)optimizer_g(grads_g)return loss_d, loss_g# 保存生成的test图像
def save_imgs(gen_imgs1, idx):for i3 in range(gen_imgs1.shape[0]):plt.subplot(5, 5, i3 + 1)plt.imshow(gen_imgs1[i3, 0, :, :] / 2 + 0.5, cmap="gray")plt.axis("off")plt.savefig(image_path + "/test_{}.png".format(idx))# 设置参数保存路径
os.makedirs(checkpoints_path, exist_ok=True)
# 设置中间过程生成图片保存路径
os.makedirs(image_path, exist_ok=True)net_g.set_train()
net_d.set_train()# 储存生成器和判别器loss
losses_g, losses_d = [], []for epoch in range(total_epoch):start = time.time()for (iter, data) in enumerate(mnist_ds):start1 = time.time()image, latent_code = dataimage = (image - 127.5) / 127.5  # [0, 255] -> [-1, 1]image = image.reshape(image.shape[0], 1, image.shape[1], image.shape[2])d_loss, g_loss = train_step(image, latent_code)end1 = time.time()if iter % 10 == 10:print(f"Epoch:[{int(epoch):>3d}/{int(total_epoch):>3d}], "f"step:[{int(iter):>4d}/{int(iter_size):>4d}], "f"loss_d:{d_loss.asnumpy():>4f} , "f"loss_g:{g_loss.asnumpy():>4f} , "f"time:{(end1 - start1):>3f}s, "f"lr:{lr:>6f}")end = time.time()print("time of epoch {} is {:.2f}s".format(epoch + 1, end - start))losses_d.append(d_loss.asnumpy())losses_g.append(g_loss.asnumpy())# 每个epoch结束后,使用生成器生成一组图片gen_imgs = net_g(test_noise)save_imgs(gen_imgs.asnumpy(), epoch)# 根据epoch保存模型权重文件if epoch % 1 == 0:save_checkpoint(net_g, checkpoints_path + "/Generator%d.ckpt" % (epoch))save_checkpoint(net_d, checkpoints_path + "/Discriminator%d.ckpt" % (epoch))
time of epoch 1 is 84.03s
time of epoch 2 is 7.17s
time of epoch 3 is 7.07s
time of epoch 4 is 6.86s
time of epoch 5 is 7.00s
time of epoch 6 is 6.96s
time of epoch 7 is 7.02s
time of epoch 8 is 7.02s
time of epoch 9 is 7.00s
time of epoch 10 is 6.88s
time of epoch 11 is 6.91s
time of epoch 12 is 6.93s

效果展示

运行下面代码,描绘DG损失与训练迭代的关系图:

plt.figure(figsize=(6, 4))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(losses_g, label="G", color='blue')
plt.plot(losses_d, label="D", color='orange')
plt.xlim(-5,15)
plt.ylim(0, 3.5)
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

可视化训练过程中通过隐向量生成的图像。

import cv2
import matplotlib.animation as animation# 将训练过程中生成的测试图转为动态图
image_list = []
for i in range(total_epoch):image_list.append(cv2.imread(image_path + "/test_{}.png".format(i), cv2.IMREAD_GRAYSCALE))
show_list = []
fig = plt.figure(dpi=70)
for epoch in range(0, len(image_list), 5):plt.axis("off")show_list.append([plt.imshow(image_list[epoch], cmap='gray')])ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)
ani.save('train_test.gif', writer='pillow', fps=1)

epoch为100时:

epoch为200时:

训练过程测试动态图

从上面的图像可以看出,随着训练次数的增多,图像质量也越来越好。如果增大训练周期数,当 epoch 达到100以上时,生成的手写数字图片与数据集中的较为相似。下面我们通过加载生成器网络模型参数文件来生成图像,代码如下:

模型推理

下面我们通过加载生成器网络模型参数文件来生成图像,代码如下:

import mindspore as ms# test_ckpt = './result/checkpoints/Generator199.ckpt'# parameter = ms.load_checkpoint(test_ckpt)
# ms.load_param_into_net(net_g, parameter)
# 模型生成结果
test_data = Tensor(np.random.normal(0, 1, (25, 100)).astype(np.float32))
images = net_g(test_data).transpose(0, 2, 3, 1).asnumpy()
# 结果展示
fig = plt.figure(figsize=(3, 3), dpi=120)
for i in range(25):fig.add_subplot(5, 5, i + 1)plt.axis("off")plt.imshow(images[i].squeeze(), cmap="gray")
plt.show()

epoch为12时效果:

epoch为200时效果:

GAN的基本原理:GAN由生成器(Generator)和判别器(Discriminator)组成。生成器的目标是生成看起来像真实数据的假数据,而判别器的目标是区分真实数据和生成器生成的假数据。通过这种对抗训练过程,生成器和判别器相互竞争,最终生成器能够生成高质量的假数据。
数据集处理:使用了MNIST手写数字数据集进行训练。数据集通过MindSpore的MnistDataset接口进行加载和预处理。
模型构建:生成器和判别器都使用了全连接层(Dense)和激活函数(ReLU、LeakyReLU、Sigmoid)。生成器通过多层全连接层将随机噪声映射到图像空间,而判别器则通过多层全连接层将图像映射到概率空间。
损失函数和优化器:使用了二进制交叉熵损失函数(BCELoss)来计算生成器和判别器的损失。
使用了Adam优化器来更新生成器和判别器的参数。
训练过程:训练过程分为两个主要部分:训练判别器和训练生成器。在每个epoch结束时,生成一组固定的随机噪声输入到生成器中,以评估生成器的性能。训练过程中,生成器和判别器的损失被记录下来,并在训练结束后进行可视化。
结果展示:通过加载预训练的生成器模型参数文件,生成新的手写数字图像。生成的图像通过Matplotlib进行展示。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/46218.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

1.32、 基于区域卷积神经网络(R-CNN)的停车标志检测(matlab)

1、基于区域卷积神经网络(R-CNN)的停车标志检测原理及流程 基于区域卷积神经网络&#xff08;R-CNN&#xff09;的停车标志检测原理及流程如下&#xff1a; 原理&#xff1a; R-CNN 是一种用于目标检测的深度学习模型&#xff0c;其核心思想是首先在输入图像中提取出候选区域&…

springboot3 web

springboot web配置 springboot web的配置有&#xff1a; SpringMvc配置的前缀为&#xff1a;spring.mvcweb场景的通用配置为&#xff1a;spring.web文件上传的配置为&#xff1a;spring.servlet.multipart服务器相关配置为&#xff1a;server 接管SpringMVC 的三种方式 方…

对LinkedList和链表的理解

一.ArrayList的缺陷 二.链表 三.链表部分相关oj面试题 四.LinkedList的模拟实现 五.LinkedList的使用 六.ArrayList和LinkedList的区别 一.ArrayList的缺陷: 1. ArrayList底层使用 数组 来存储元素&#xff0c;如果不熟悉可以来再看看&#xff1a; ArrayList与顺序表-CSDN…

2024年7月13日全国青少年信息素养大赛Python复赛小学高年级组真题

第一题 题目描述 握情况。他决定让每个人输入一个正整数 N (0≤N≤1000)&#xff0c;然后计算并输出(5*N)的值。请用 在一个神秘的王国里&#xff0c;国王希望通过一个简单的测试来评估他的子民对基 础数学运算的掌 Python 编写程序&#xff0c;程序执行后要求用户输入一个正…

Hash表(C++)

本篇将会开始介绍有关于 unordered_map 和 unordered_set 的底层原理&#xff0c;其中底层实现其实就是我们的 Hash 表&#xff0c;本篇将会讲解两种 Hash 表&#xff0c;其中一种为开放定址法&#xff0c;另一种为 hash 桶&#xff0c;在unordered_map 和 unordered_set 的底层…

智驭未来:人工智能与目标检测的深度交融

在科技日新月异的今天&#xff0c;人工智能&#xff08;AI&#xff09;如同一股不可阻挡的浪潮&#xff0c;正以前所未有的速度重塑着我们的世界。在众多AI应用领域中&#xff0c;目标检测以其独特的魅力和广泛的应用前景&#xff0c;成为了连接现实与智能世界的桥梁。本文旨在…

20240715 每日AI必读资讯

&#x1f310; 代号“ 草莓 ”&#xff0c;OpenAI 被曝研发新项目&#xff1a;将 AI 推理能力提至新高度 - OpenAI 公司被曝正在研发代号为“ 草莓 ”的全新项目&#xff0c;进一步延伸去年 11 月宣布的 Q* 项目&#xff0c;不断提高 AI 推理能力&#xff0c;让其更接近人类的…

基于Java的休闲娱乐代理售票系统

你好&#xff0c;我是专注于Java开发的码农小野&#xff01;如果你对系统开发感兴趣&#xff0c;欢迎私信交流。 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;Java技术、SpringBoot框架、B/S架构 工具&#xff1a;Eclipse IDE、MySQL数据库管理工具…

牛客小白月赛98 (个人题解)(补全)

前言&#xff1a; 昨天晚上自己一个人打的小白月赛&#xff08;因为准备数学期末已经写烦了&#xff09;&#xff0c;题目难度感觉越来越简单了&#xff08;不在像以前一样根本写不了一点&#xff0c;现在看题解已经能看懂一点了&#xff09;&#xff0c;能感受到自己在不断进步…

2024年是不是闰年?

闰年的由来 闰年的概念最早可以追溯到古罗马时期的朱利叶斯凯撒。当时的罗马历法是根据太阳年来制定的&#xff0c;每年大约有365.25天。为了使日历与季节保持同步&#xff0c;人们需要定期插入一个额外的日子。朱利叶斯凯撒在公元前46年颁布了一项法令&#xff0c;规定每四年增…

SAP PP学习笔记26 - User Status(用户状态)的实例,订单分割中的重要概念 成本收集器,Confirmation(报工)的概述

上面两章讲了生产订单的创建以及生产订单的相关内容。 SAP PP学习笔记24 - 生产订单&#xff08;制造指图&#xff09;的创建_sap 工程外注-CSDN博客 SAP PP学习笔记25 - 生产订单的状态管理(System Status(系统状态)/User Status(用户状态)),物料的可用性检查&#xff0c;生…

最长下降序列

如何理解这个题目呢,我们可以每个人的分数放到排名上&#xff0c;然后求解最长下降序列即可 #include<bits/stdc.h> using namespace std;int n; const int N (int)1e5 5; int a[N]; int b[N]; int d[N]; int dp[N]; int t;int main() {cin >> t;while (t--) {…

Apache Hadoop之历史服务器日志聚集配置

上篇介绍绍了Apache Hadoop的分布式集群环境搭建&#xff0c;并测试了MapReduce分布式计算案例。但集群历史做了哪些任务&#xff0c;任务执行日志等信息还需要配置历史服务器和日志聚集才能更好的查看。 配置历史服务器 在Yarn中运行的任务产生的日志数据不能查看&#xff0…

【计算机毕业设计】013新闻资讯微信小程序

&#x1f64a;作者简介&#xff1a;拥有多年开发工作经验&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。&#x1f339;赠送计算机毕业设计600个选题excel文件&#xff0c;帮助大学选题。赠送开题报告模板&#xff…

Python数据分析案例51——基于K均值的客户聚类分析可视化

案例背景 本次案例带来的是最经典的K均值聚类&#xff0c;对客户进行划分类别的分析&#xff0c;其特点是丰富的可视化过程。这个经典的小案例用来学习或者课程作业在合适不过了。 数据介绍 数据集如下: 客户的编码&#xff0c;性别&#xff0c;年龄&#xff0c;年收入&#…

Vue2-集成路由Vue Router介绍与使用

文章目录 路由&#xff08;Vue2&#xff09;1. SPA 与前端路由2. vue-router基本使用创建路由组件声明路由链接和占位标签创建路由模块挂载路由模块 3. vue-router进阶路由重定向嵌套路由动态路由编程式导航导航守卫 本篇小结 更多相关内容可查看 路由&#xff08;Vue2&#xf…

安全防御----防火墙综合实验2

安全防御----防火墙综合实验2 一、题目 二、实验要求&#xff1a; 1&#xff0c;DMZ区内的服务器&#xff0c;办公区仅能在办公时间内&#xff08;9&#xff1a;00 - 18&#xff1a;00&#xff09;可以访问&#xff0c;生产区的设备全天可以访问. 2&#xff0c;生产区不允许访…

雷赛运动控制卡编程(1)

一、运动控制卡选择 电气常用知识-CSDN博客 如下旋转控制卡 DMC3800八轴高性能点位卡 - 东莞市雅恰达机电有限公司 轴少的时候选择脉冲系列卡 轴多的话就选总线型系列控制卡 样品 架构&#xff1a; 二、 添加文件 dll 添加接口文件 【最全&#xff0c;带注释版】雷赛运动…

OpenCV中使用Canny算法在图像中查找边缘

操作系统&#xff1a;ubuntu22.04OpenCV版本&#xff1a;OpenCV4.9IDE:Visual Studio Code编程语言&#xff1a;C11 算法描述 Canny算法是一种广泛应用于计算机视觉和图像处理领域中的边缘检测算法。它由John F. Canny在1986年提出&#xff0c;旨在寻找给定噪声条件下的最佳边…

Python+wxauto=微信自动化?

Pythonwxauto微信自动化&#xff1f; 一、wxauto库简介 1.什么是wxauto库 wxauto是一个基于UIAutomation的开源Python微信自动化库。它旨在帮助用户通过编写Python脚本&#xff0c;轻松实现对微信客户端的自动化操作&#xff0c;从而提升效率并满足个性化需求。这一工具的出现&…