GAN网络理论和实验(二)

  

文章目录

  • 一、说明
  • 二、什么是生成对抗网络?
  • 三、判别模型与生成模型
  • 四、生成对抗网络的架构
  • 五、你的第一个 GAN
  • 六、准备训练数据
  • 七、实现鉴别器
  • 八、实现生成器
  • 九、训练模型
  • 十、检查 GAN 生成的样本
  • 十一、使用 GAN 生成手写数字
  • 十二、准备训练数据
  • 十三、实现鉴别器和生成器
  • 十四、训练模型
  • 十五、检查 GAN 生成的样本
  • 十六、结论

一、说明

   生成对抗网络(GAN) 是一种神经网络,可以生成与人类产生的内容类似的材料,例如图像、音乐、语音或文本。近年来,GAN 一直是一个活跃的研究课题。Facebook 的 AI 研究总监 Yann LeCun 称对抗训练是机器学习领域 “过去 10 年最有趣的想法” 。下面,您将在实现自己的两个生成模型之前了解 GAN 的工作原理。

   在本教程中,您将学习:

   什么是生成模型以及它与判别模型有何不同
   GAN 的构造和训练方式
   如何使用PyTorch构建自己的 GAN
   如何使用 GPU 和 PyTorch训练 GAN以用于实际应用

二、什么是生成对抗网络?

   生成对抗网络是一种机器学习系统,可以学习模仿给定的数据分布。深度学习专家 Ian Goodfellow 及其同事在 2014 年的NeurIPS 论文中首次提出了生成对抗网络。

   GAN 由两个神经网络组成,一个用于生成数据,另一个用于区分假数据和真实数据(因此该模型具有“对抗性”性质)。虽然生成数据的结构的想法并不新鲜,但在图像和视频生成方面,GAN 已经提供了令人印象深刻的结果,例如:

  • 使用CycleGAN进行风格转换,它可以对图像执行许多令人信服的风格转换
  • 使用StyleGAN生成人脸,如网站This Person Does Not Exist上所演示的那样生成数据的结构(包括 GAN)被视为生成模型,与研究更广泛的判别模型相对。在深入研究 GAN 之前,您需要了解这两类模型之间的差异。

三、判别模型与生成模型

   如果你研究过神经网络,那么你遇到的大多数应用程序可能都是使用判别模型实现的。另一方面,生成对抗网络属于另一类称为生成模型的模型。

   判别模型用于大多数监督 分类或回归问题。作为分类问题的一个例子,假设您想要训练一个模型来对 0 到 9 的手写数字图像进行分类。为此,您可以使用一个标记数据集,其中包含手写数字图像及其相关标签,指示每个图像代表哪个数字。

   在训练过程中,您将使用算法来调整模型的参数。目标是最小化损失函数,以便模型学习给定输入的输出的概率分布。训练阶段结束后,您可以使用该模型通过估计输入对应的最可能数字来对新的手写数字图像进行分类,如下图所示:
在这里插入图片描述
   您可以将分类问题的判别模型描绘为使用训练数据来学习类别之间边界的块。然后,它们使用这些边界来区分输入并预测其类别。用数学术语来说,判别模型学习给定输入x时输出y的条件概率P ( y | x ) 。

   除了神经网络之外,其他结构也可以用作判别模型,例如逻辑回归模型和支持向量机(SVM)。

   然而,像 GAN 这样的生成模型经过训练后,可以用概率模型来描述数据集的生成方式。通过从生成模型中采样,您可以生成新数据。虽然判别模型用于监督学习,但生成模型通常与未标记的数据集一起使用,可以看作是一种无监督学习。

   使用手写数字数据集,您可以训练生成模型来生成新数字。在训练阶段,您将使用某种算法来调整模型的参数以最小化损失函数并学习训练集的概率分布。然后,在训练模型后,您可以生成新样本,如下图所示:
在这里插入图片描述
   为了输出新样本,生成模型通常会考虑影响模型生成样本的随机元素。用于驱动生成器的随机样本是从潜在空间中获得的,其中的向量表示生成样本的一种压缩形式。

   与判别模型不同,生成模型学习输入数据x的概率P ( x ) ,并且通过输入数据的分布,它们能够生成新的数据实例。

   注意:生成模型也可以与标记数据集一起使用。当它们被训练时,它们被训练来学习给定输出y时输入x的概率P ( x | y ) 。它们也可以用于分类任务,但一般来说,判别模型在分类方面表现更好。

   您可以在文章“判别分类器与生成分类器:逻辑回归和朴素贝叶斯的比较”中找到有关判别分类器和生成分类器的相对优势和劣势的更多信息。

   尽管 GAN 近年来备受关注,但它们并不是唯一可用作生成模型的架构。除了 GAN,还有各种其他生成模型架构,例如:

  • 玻尔兹曼机
  • 变分自动编码器
  • 隐马尔可夫模型
  • 预测序列中下一个单词的模型,例如GPT-2然而,由于在图像和视频生成方面取得了令人兴奋的成果,GAN 近年来吸引了公众的最大兴趣。

   现在您已经了解了生成模型的基础知识,您将了解 GAN 的工作原理以及如何训练它们。

四、生成对抗网络的架构

   生成对抗网络由两个神经网络组成的整体结构,一个称为生成器,另一个称为鉴别器。

   生成器的作用是估计真实样本的概率分布,以便提供与真实数据相似的生成样本。而鉴别器则经过训练,可以估计给定样本来自真实数据而非生成器的概率。

   这些结构被称为生成对抗网络,因为生成器和鉴别器经过训练可以相互竞争:生成器试图更好地欺骗鉴别器,而鉴别器试图更好地识别生成的样本。

   为了理解 GAN 训练的工作原理,请考虑一个玩具示例,其数据集由二维样本 ( x 1 , x 2 ) (x ₁,x ₂) x1x2组成,其中 x 1 x ₁ x1 在 0 到 2π 的区间内,且 x 2 = s i n ( x 1 ) x ₂ = sin( x ₁) x2=sin(x1),如下图所示:
在这里插入图片描述
   如您所见,此数据集由位于正弦曲线上的点 ( x 1 , x 2 ) ( x ₁,x ₂) x1x2组成,具有非常特殊的分布。用于生成与数据集样本相似的对 ( x ~ 1 , x ~ 2 ) ( x̃ ₁,x̃ ₂) x~1x~2的 GAN 的整体结构如下图所示:
在这里插入图片描述
   生成器G接收来自潜在空间的随机数据,其作用是生成与真实样本相似的数据。在此示例中,您有一个二维潜在空间,因此生成器接收随机 ( z ₁, z ₂) 对,并需要对其进行变换,使它们与真实样本相似。

   神经网络G的结构可以是任意的,允许您将神经网络用作多层感知器(MLP)、卷积神经网络(CNN) 或任何其他结构,只要输入和输出的维度与潜在空间和真实数据的维度相匹配。

   判别器D要么输入来自训练数据集的真实样本,要么输入由G提供的生成样本。其作用是估计输入属于真实数据集的概率。训练的结果是,当输入真实样本时, D输出 1,而当输入生成样本时,D 输出 0。

   与G一样,你可以为D选择任意的神经网络结构,只要它符合必要的输入和输出维度即可。在此示例中,输入是二维的。对于二元鉴别器,输出可能是从 0 到 1 的标量。

   GAN 的训练过程包含一个双人极小最大游戏,其中D用来最小化真实样本和生成样本之间的判别误差,而G用来最大化D犯错的概率。

   尽管包含真实数据的数据集没有标记,但D和G的训练过程是以监督方式进行的。在训练的每个步骤中,D和G的参数都会更新。事实上,在原始 GAN 提案中, D的参数更新了k次,而G的参数在每个训练步骤中仅更新一次。但是,为了使训练更简单,您可以考虑将k设置为 1。

   为了训练D,在每次迭代中,你将从训练数据中获取的一些真实样本标记为 1,将G提供的一些生成样本标记为 0。这样,你可以使用传统的监督训练框架来更新D的参数,以最小化损失函数,如下图所示:
在这里插入图片描述
   对于包含标记的真实样本和生成样本的每批训练数据,您可以更新D的参数以最小化损失函数。更新D的参数后,您可以训练G以生成更好的生成样本。G的输出连接到D,其参数保持冻结,如下所示:
在这里插入图片描述
   您可以将由G和D组成的系统想象为一个单一分类系统,它接收随机样本作为输入并输出分类,在这种情况下可以解释为概率。

   当G 的表现足够好,可以欺骗D时,输出概率应该接近 1。您也可以在这里使用传统的监督训练框架:用于训练由G和D组成的分类系统的数据集将由随机输入样本提供,并且与每个输入样本相关联的标签为 1。

   在训练过程中,随着D和G的参数更新,预计G给出的生成样本将更加接近真实数据,并且D将更难以区分真实数据和生成的数据。

   现在您已经了解了 GAN 的工作原理,您可以使用PyTorch实现自己的 GAN 。

五、你的第一个 GAN

   作为生成对抗网络的首次实验,您将实现上一节中描述的示例。

   要运行该示例,您将使用PyTorch库,您可以使用Anaconda Python 发行版和conda包和环境管理系统安装该库。要了解有关 Anaconda 和 conda 的更多信息,请查看在 Windows 上设置 Python 进行机器学习的教程。

   首先,创建一个 conda 环境并激活它:

$ conda create --name gan
$ conda activate gan

   激活 conda 环境后,你的提示符将显示其名称。gan然后你可以在环境中安装必要的软件包:

$ conda install -c pytorch pytorch=1.4.0
$ conda install matplotlib jupyter

   由于PyTorch是一个非常活跃的开发框架,API 可能会在新版本中发生变化。为了确保示例代码能够运行,您需要安装特定版本1.4.0。

   除了 PyTorch,您还将使用Matplotlib来处理图表,并使用Jupyter Notebook在交互式环境中运行代码。这样做不是强制性的,但它有助于开展机器学习项目。

   要重新了解如何使用 Matplotlib 和 Jupyter Notebook,请参阅使用 Matplotlib 进行 Python 绘图(指南)和Jupyter Notebook:简介。

   在打开 Jupyter Notebook 之前,您需要注册 condagan环境,以便可以使用它作为内核创建 Notebook。为此,在gan激活环境的情况下,运行以下命令:

$ python -m ipykernel install --user --name gan

   现在,您可以通过运行打开 Jupyter Notebook jupyter notebook。单击新建,然后选择gan来创建一个新的 Notebook 。

   在 Notebook 中,首先导入必要的库:

import torch
from torch import nnimport math
import matplotlib.pyplot as plt

   在这里,您使用 导入 PyTorch 库torch。您还导入了nn以便能够以不太冗长的方式设置神经网络。然后您导入math以获取 pi 常数的值,并像plt往常一样导入 Matplotlib 绘图工具。

   设置随机生成器种子是一种很好的做法,这样就可以在任何机器上以相同的方式复制实验。要在 PyTorch 中执行此操作,请运行以下代码:

torch.manual_seed(111)

   该数字111表示用于初始化随机数生成器的随机种子,该种子用于初始化神经网络的权重。尽管实验具有随机性,但只要使用相同的种子,它就必须提供相同的结果。

   现在环境已经设置好了,您可以准备训练数据。

六、准备训练数据

   训练数据由 ( x 1 , x 2 ) ( x ₁, x ₂) (x1,x2)对组成,其中 x 2 x ₂ x2 x 1 x ₁ x1的正弦值组成,其中 x 1 x ₁ x1 在 0 到 2π 的区间内。您可以按如下方式实现它:

train_data_length = 1024
train_data = torch.zeros((train_data_length, 2))
train_data[:, 0] = 2 * math.pi * torch.rand(train_data_length)
train_data[:, 1] = torch.sin(train_data[:, 0])
train_labels = torch.zeros(train_data_length)
train_set = [(train_data[i], train_labels[i]) for i in range(train_data_length)
]

   1024在这里,您用 ( x 1 , x 2 ) ( x ₁, x ₂) (x1,x2)对组成一个训练集。在第 2 行中,您初始化train_data一个张量,其1024行和2列的维度均为零。张量是一个类似于NumPy 数组的多维数组。

   在第 3 行中,使用 的第一列train_data存储从 到 区间内的随机值0。2π然后,在第 4 行中,计算张量的第二列作为第一列的正弦。

   接下来,你需要一个标签张量,这是 PyTorch 的数据加载器所必需的。由于 GAN 使用无监督学习技术,因此标签可以是任何东西。毕竟它们不会被使用。

   在第 5 行中,创建train_labels一个用零填充的张量。最后,在第 6 行至第 8 行中,创建train_set一个元组列表,每个元组中的每一行train_data和train_labels都按照 PyTorch 的数据加载器的预期表示。

   您可以通过绘制每个点 ( x 1 , x 2 ) ( x ₁,x ₂) x1x2来检查训练数据:

plt.plot(train_data[:, 0], train_data[:, 1], ".")

   输出应类似于下图:
在这里插入图片描述
   使用train_set,你可以创建一个 PyTorch 数据加载器:

batch_size = 32
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True
)

   在这里,您将创建一个名为的数据加载器train_loader,它将对数据进行混洗train_set并返回32用于训练神经网络的样本批次。

   设置训练数据后,您需要为组成 GAN 的鉴别器和生成器创建神经网络。在下一节中,您将实现鉴别器。

七、实现鉴别器

   在 PyTorch 中,神经网络模型由从 继承的类表示nn.Module,因此您必须定义一个类来创建鉴别器。有关定义类的更多信息,请查看Python 3 中的面向对象编程 (OOP)。

   鉴别器是一个具有二维输入和一维输出的模型。它将从真实数据或生成器中接收样本,并提供样本属于真实训练数据的概率。以下代码显示了如何创建鉴别器:

class Discriminator(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Linear(2, 256),nn.ReLU(),nn.Dropout(0.3),nn.Linear(256, 128),nn.ReLU(),nn.Dropout(0.3),nn.Linear(128, 64),nn.ReLU(),nn.Dropout(0.3),nn.Linear(64, 1),nn.Sigmoid(),)def forward(self, x):output = self.model(x)return output

   您使用.init()来构建模型。首先,您需要从 调用 来 super().init()运行。您使用的鉴别器是使用 以顺序方式定义的 MLP 神经网络。它具有以下特点:.init()nn.Modulenn.Sequential()

   第5、6行:输入是二维的,第一个隐藏层由具有ReLU256激活的神经元组成。

   第 8、9、11 和 12 行:128第二和第三隐藏层分别由和神经元组成64,并带有 ReLU 激活。

   第 14 和 15 行:输出由一个具有S 形激活的单个神经元组成,以表示概率。

   第 7、10 和 13 行:在第一、第二和第三个隐藏层之后,使用dropout来避免过度拟合。

   最后,你使用.forward()来描述如何计算模型的输出。这里,x表示模型的输入,它是一个二维张量。在这个实现中,输出是通过将输入馈送x到你定义的模型而获得的,而无需任何其他处理。

   声明鉴别器类后,您应该实例化一个Discriminator对象:

discriminator = Discriminator()

八、实现生成器

   在生成对抗网络中,生成器是一种模型,它从潜在空间中获取样本作为输入,并生成与训练集中的数据相似的数据。在这种情况下,它是一个具有二维输入的模型,它将接收随机点(z ₁,z ₂),以及一个二维输出,该输出必须提供(x̃ ₁,x̃ ₂)与训练数据中的点相似的点。

   该实现与你对鉴别器所做的类似。首先,你必须创建一个Generator继承自的类nn.Module,定义神经网络架构,然后你需要实例化一个Generator对象:

class Generator(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Linear(2, 16),nn.ReLU(),nn.Linear(16, 32),nn.ReLU(),nn.Linear(32, 2),)def forward(self, x):output = self.model(x)return outputgenerator = Generator()

   这里,表示生成器神经网络。它由两个带有和神经元generator的隐藏层组成,均具有 ReLU 激活,以及一个带有输出神经元的线性激活层。这样,输出将由一个具有两个元素的向量组成,该向量可以是从负无穷大到无穷大的任何值,它将表示(x̃ ₁,x̃ ₂)。16322

   现在您已经定义了鉴别器和生成器的模型,可以开始训练了!

九、训练模型

   在训练模型之前,您需要设置一些训练期间使用的参数:

lr = 0.001
num_epochs = 300
loss_function = nn.BCELoss()

   在此设置以下参数:

   第 1 行设置学习率(lr),您将使用它来调整网络权重。

   第 2 行设置了 epoch 的数量(num_epochs),这定义了使用整个训练集进行多少次重复训练。

   第 3行将变量分配loss_function给二元交叉熵函数BCELoss(),这是用于训练模型的损失函数。

   二元交叉熵函数是一种适合训练鉴别器的损失函数,因为它考虑的是二元分类任务。它也适合训练生成器,因为它将其输出馈送到鉴别器,从而提供二元可观察输出。

   PyTorch 实现了 中的模型训练的各种权重更新规则torch.optim。您将使用Adam 算法来训练鉴别器和生成器模型。要使用 创建优化器torch.optim,请运行以下几行:

optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)

   最后,需要实现一个训练循环,将训练样本输入到模型中,并更新其权重以最小化损失函数:

for epoch in range(num_epochs):for n, (real_samples, _) in enumerate(train_loader):# Data for training the discriminatorreal_samples_labels = torch.ones((batch_size, 1))latent_space_samples = torch.randn((batch_size, 2))generated_samples = generator(latent_space_samples)generated_samples_labels = torch.zeros((batch_size, 1))all_samples = torch.cat((real_samples, generated_samples))all_samples_labels = torch.cat((real_samples_labels, generated_samples_labels))# Training the discriminatordiscriminator.zero_grad()output_discriminator = discriminator(all_samples)loss_discriminator = loss_function(output_discriminator, all_samples_labels)loss_discriminator.backward()optimizer_discriminator.step()# Data for training the generatorlatent_space_samples = torch.randn((batch_size, 2))# Training the generatorgenerator.zero_grad()generated_samples = generator(latent_space_samples)output_discriminator_generated = discriminator(generated_samples)loss_generator = loss_function(output_discriminator_generated, real_samples_labels)loss_generator.backward()optimizer_generator.step()# Show lossif epoch % 10 == 0 and n == batch_size - 1:print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")print(f"Epoch: {epoch} Loss G.: {loss_generator}")

   对于 GAN,您可以在每次训练迭代时更新鉴别器和生成器的参数。与所有神经网络通常所做的一样,训练过程由两个循环组成,一个用于训练时期,另一个用于每个时期的批次。在内循环中,您开始准备数据来训练鉴别器:

   第 2 行:从数据加载器中获取当前批次的真实样本,并将其分配给real_samples。请注意,张量的第一维元素数量等于batch_size。这是 PyTorch 中组织数据的标准方式,张量的每一行代表批次中的一个样本。

   第 4 行:您使用为真实样本torch.ones()创建具有值的标签1,然后将标签分配给real_samples_labels。

   第 5 行和第 6 行:通过将随机数据存储在 中来创建生成的样本latent_space_samples,然后将其提供给生成器以获得generated_samples。

   第 7 行:您使用为生成的样本的标签torch.zeros()分配值0,然后将标签存储在中generated_samples_labels。

   第 8 至 11 行:将真实和生成的样本和标签连接起来,并将它们存储在all_samples和中all_samples_labels,您将使用它们来训练鉴别器。

   接下来,在第 14 行到第 19 行,训练鉴别器:

   第 14 行:在 PyTorch 中,需要在每个训练步骤中清除梯度以避免累积梯度。您可以使用 来执行此操作.zero_grad()。

   第 15 行:使用中的训练数据计算鉴别器的输出all_samples。

   第 16 和 17 行:使用中的模型输出output_discriminator和中的标签计算损失函数all_samples_labels。

   第 18 行:计算梯度来更新权重loss_discriminator.backward()。

   第 19 行:通过调用来更新鉴别器权重optimizer_discriminator.step()。

   接下来,在第 22 行,准备数据来训练生成器。将随机数据存储在 中latent_space_samples,行数等于batch_size。由于您要向生成器提供二维数据作为输入,因此使用两列。

   在第 25 至 32 行中训练生成器:

   第 25 行:使用 清除渐变.zero_grad()。

   第 26 行:您向生成器提供latent_space_samples并将其输出存储在中generated_samples。

   第 27 行:将生成器的输出输入到鉴别器并将其输出存储在中output_discriminator_generated,您将使用它作为整个模型的输出。

   第 28 至 30 行:使用存储在 中的分类系统的输出output_discriminator_generated和 中的标签来计算损失函数real_samples_labels,它们都等于1。

   第 31 行和第 32 行:计算梯度并更新生成器权重。请记住,在训练生成器时,由于您创建了鉴别器权重,因此将optimizer_generator其第一个参数设置为,因此保持鉴别器权重不变generator.parameters()。

   最后,在第 35 行到第 37 行,显示每十个时期结束时鉴别器和生成器损失函数的值。

   由于本例中使用的模型参数很少,训练将在几分钟内完成。在下一节中,您将使用训练好的 GAN 生成一些样本。

十、检查 GAN 生成的样本

   生成对抗网络旨在生成数据。因此,在训练过程完成后,您可以从潜在空间中获取一些随机样本,并将它们输入到生成器中以获得一些生成的样本:

latent_space_samples = torch.randn(100, 2)
generated_samples = generator(latent_space_samples)

   然后,您可以绘制生成的样本并检查它们是否与训练数据相似。在绘制数据之前generated_samples,您需要使用.detach()从 PyTorch 计算图返回张量,然后使用该张量来计算梯度:

generated_samples = generated_samples.detach()
plt.plot(generated_samples[:, 0], generated_samples[:, 1], ".")

   输出应类似于下图:
在这里插入图片描述
   您可以看到生成的数据的分布与真实数据的分布相似。通过在训练过程中使用固定的潜在空间样本张量并在每次迭代结束时将其提供给生成器,您可以直观地看到训练的演变过程:
注意,在训练过程开始时,生成的数据分布与真实数据有很大不同。然而,随着训练的进行,生成器会学习真实的数据分布。

   现在您已经完成了生成对抗网络的首次实现,您将使用图像进行更实际的应用。

十一、使用 GAN 生成手写数字

   生成对抗网络还可以生成高维样本,例如图像。在此示例中,您将使用 GAN 生成手写数字图像。为此,您将使用包中包含的手写数字MNIST 数据集torchvision来训练模型。

   首先,您需要torchvision在激活的ganconda 环境中安装:

$ conda install -c pytorch torchvision=0.5.0

   再次,您使用特定版本来torchvision确保示例代码可以运行,就像您对 所做的那样pytorch。设置环境后,您可以开始在 Jupyter Notebook 中实现模型。打开它并通过单击新建然后选择gan来创建一个新的 Notebook 。

   与前面的示例一样,首先导入必要的库:

import torch
from torch import nnimport math
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms

   除了之前导入的库之外,您还需要torchvision获取transforms训练数据并执行图像转换。

   再次设置随机生成器种子以便能够复制实验:

torch.manual_seed(111)

   由于此示例在训练集中使用了图像,因此模型需要更复杂,参数数量也更多。这使得训练过程变慢,在CPU上运行时,每个 epoch 大约需要两分钟。您需要大约 50 个 epoch 才能获得相关结果,因此使用 CPU 时的总训练时间约为一百分钟。

   为了减少训练时间,您可以使用GPU来训练模型(如果有)。但是,您需要手动将张量和模型移动到 GPU 才能在训练过程中使用它们。

   device您可以通过创建指向 CPU 或 GPU(如果有)的对象来确保您的代码可以在任一设置上运行:

device = ""
if torch.cuda.is_available():device = torch.device("cuda")
else:device = torch.device("cpu")

   稍后,您将使用它device来设置应该创建张量和模型的位置(如果可用,则使用 GPU)。

   现在已经设置好了基本环境,您可以准备训练数据。

十二、准备训练数据

   MNIST 数据集由 0 到 9 的手写数字的 28 × 28 像素灰度图像组成。要将它们与 PyTorch 一起使用,您需要执行一些转换。为此,您定义了transform一个在加载数据时使用的函数:

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

   该函数有两个部分:

transforms.ToTensor()将数据转换为 PyTorch 张量。
transforms.Normalize()转换张量系数的范围。

   给出的原始系数transforms.ToTensor()范围是从0到1,并且由于图像背景是黑色的,因此当使用该范围表示时,大多数系数等于0。

   transforms.Normalize()0.5通过从原始系数中减去并将结果除以,将系数的范围更改为 -1 到 1。0.5通过这种变换,输入样本中等于 0 的元素数量显著减少,这有助于训练模型。

   的参数transforms.Normalize()是两个元组,(M₁, …, Mₙ)和(S₁, …, Sₙ),其中表示图像的通道n数。灰度图像(例如 MNIST 数据集中的图像)只有一个通道,因此元组只有一个值。然后,对于图像的每个通道,从系数中减去并将结果除以。itransforms.Normalize()MᵢSᵢ

   现在,您可以使用加载训练数据torchvision.datasets.MNIST并使用执行转换transform:

train_set = torchvision.datasets.MNIST(root=".", train=True, download=True, transform=transform
)

   该参数download=True确保第一次运行上述代码时,MNIST 数据集将被下载并存储在当前目录中,如参数所示root。

   现在您已经创建了train_set,您可以像以前一样创建数据加载器:

batch_size = 32
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True
)

   您可以使用 Matplotlib 绘制一些训练数据样本。为了改善可视化效果,您可以使用cmap=gray_r反转颜色图并在白色背景上用黑色绘制数字:

real_samples, mnist_labels = next(iter(train_loader))
for i in range(16):ax = plt.subplot(4, 4, i + 1)plt.imshow(real_samples[i].reshape(28, 28), cmap="gray_r")plt.xticks([])plt.yticks([])

输出应类似于以下内容:
在这里插入图片描述
   如你所见,数字的笔迹各不相同。随着 GAN 学习数据的分布,它也会生成具有不同笔迹的数字。

   现在您已经准备好训练数据,您可以实现鉴别器和生成器模型。

十三、实现鉴别器和生成器

   在这种情况下,鉴别器是一个 MLP 神经网络,它接收 28×28 像素的图像并提供该图像属于真实训练数据的概率。

   您可以使用以下代码定义模型:

class Discriminator(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Linear(784, 1024),nn.ReLU(),nn.Dropout(0.3),nn.Linear(1024, 512),nn.ReLU(),nn.Dropout(0.3),nn.Linear(512, 256),nn.ReLU(),nn.Dropout(0.3),nn.Linear(256, 1),nn.Sigmoid(),)def forward(self, x):x = x.view(x.size(0), 784)output = self.model(x)return output

   要将图像系数输入到 MLP 神经网络中,需要对其进行矢量化,以便神经网络接收带有784系数的向量。

   矢量化发生在 的第一行.forward(),因为 的调用会x.view()转换输入张量的形状。在本例中,输入的原始形状x为 32 × 1 × 28 × 28,其中 32 是您设置的批处理大小。转换后, 的形状x变为 32 × 784,每行代表训练集图像的系数。

   要使用 GPU 运行鉴别器模型,您必须实例化它并使用 将其发送到 GPU .to()。要在有可用 GPU 时使用 GPU,您可以将模型发送到device之前创建的对象:

discriminator = Discriminator().to(device=device)

   由于生成器将生成更复杂的数据,因此有必要增加潜在空间输入的维度。在这种情况下,生成器将输入 100 维数据,并提供具有 784 个系数的输出,这些系数将组织成表示图像的 28 × 28 张量。

   以下是完整的生成器模型代码:

class Generator(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Linear(100, 256),nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, 1024),nn.ReLU(),nn.Linear(1024, 784),nn.Tanh(),)def forward(self, x):output = self.model(x)output = output.view(x.size(0), 1, 28, 28)return outputgenerator = Generator().to(device=device)

   在第 12 行中,使用双曲正切函数 Tanh()作为输出层的激活,因为输出系数应该在 -1 到 1 的区间内。在第 20 行中,实例化生成器并将其发送到deviceGPU 以使用 GPU(如果有)。

   现在您已经定义了模型,您将使用训练数据对它们进行训练。

十四、训练模型

   为了训练模型,您需要像前面的示例一样定义训练参数和优化器:

lr = 0.0001
num_epochs = 50
loss_function = nn.BCELoss()optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)

   为了获得更好的结果,你可以降低上例中的学习率。你还可以设置 epoch 数以50减少训练时间。

   训练循环与上一个示例中使用的循环非常相似。在突出显示的行中,您将训练数据发送到deviceGPU 以使用 GPU(如果可用):

for epoch in range(num_epochs):for n, (real_samples, mnist_labels) in enumerate(train_loader):# Data for training the discriminatorreal_samples = real_samples.to(device=device)real_samples_labels = torch.ones((batch_size, 1)).to(device=device)latent_space_samples = torch.randn((batch_size, 100)).to(device=device)generated_samples = generator(latent_space_samples)generated_samples_labels = torch.zeros((batch_size, 1)).to(device=device)all_samples = torch.cat((real_samples, generated_samples))all_samples_labels = torch.cat((real_samples_labels, generated_samples_labels))# Training the discriminatordiscriminator.zero_grad()output_discriminator = discriminator(all_samples)loss_discriminator = loss_function(output_discriminator, all_samples_labels)loss_discriminator.backward()optimizer_discriminator.step()# Data for training the generatorlatent_space_samples = torch.randn((batch_size, 100)).to(device=device)# Training the generatorgenerator.zero_grad()generated_samples = generator(latent_space_samples)output_discriminator_generated = discriminator(generated_samples)loss_generator = loss_function(output_discriminator_generated, real_samples_labels)loss_generator.backward()optimizer_generator.step()# Show lossif n == batch_size - 1:print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")print(f"Epoch: {epoch} Loss G.: {loss_generator}")

   有些张量不需要使用 明确发送到 GPU 。第 11 行中device的 就是这种情况,它将被发送到可用的 GPU,因为和之前已经发送到 GPU。generated_sampleslatent_space_samplesgenerator

   由于此示例的模型较为复杂,训练可能需要更多时间。训练完成后,您可以通过生成一些手写数字样本来检查结果。

十五、检查 GAN 生成的样本

   要生成手写数字,您必须从潜在空间中取出一些随机样本并将它们提供给生成器:

latent_space_samples = torch.randn(batch_size, 100).to(device=device)
generated_samples = generator(latent_space_samples)

   要绘制generated_samples,您需要将数据移回 CPU 以防它在 GPU 上运行。为此,您只需调用.cpu()。与之前一样,.detach()在使用 Matplotlib 绘制数据之前,您还需要调用:

generated_samples = generated_samples.cpu().detach()
for i in range(16):ax = plt.subplot(4, 4, i + 1)plt.imshow(generated_samples[i].reshape(28, 28), cmap="gray_r")plt.xticks([])plt.yticks([])

输出应该是类似于训练数据的数字,如下图所示:
在这里插入图片描述
   经过 50 个训练周期后,生成的数字与真实数字相似。您可以通过考虑更多训练周期来改善结果。与前面的示例一样,通过在训练过程中使用固定的潜在空间样本张量并在每个周期结束时将其馈送到生成器,您可以直观地看到训练的演变:
在这里插入图片描述
   您可以看到,在训练过程开始时,生成的图像是完全随机的。随着训练的进行,生成器会学习真实数据的分布,并且在大约 20 个 epoch 后,一些生成的数字已经与真实数据相似。

十六、结论

   恭喜!您已经学会了如何实现自己的生成对抗网络。在深入研究生成手写数字图像的实际应用之前,您首先通过一个示例了解了 GAN 结构。

   您会看到,尽管 GAN 非常复杂,但 PyTorch 等机器学习框架通过提供自动区分和简单的 GPU 设置使实现更加直接。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/24520.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

笔记-2024视频会议软件技术选型方案

一、背景 视频会议系统是一种现代化的办公系统,它可以使不同会场的实时现场场景和语音互连起来,同时向与会者提供分享听觉和视觉的空间,使各与会方有“面对面”交谈的感觉。随着社会的发展,视频会议的应用越来越广泛,…

BC6 小飞机

BC6 小飞机 废话不多说先上题目&#xff1a; 代码如下&#xff1a; #include<stdio.h> int main() {printf(" ## \n############\n############\n # # \n # # \n");return 0; }这是用一个printf打印我们还可以用多个printf发打印代码如下…

Django框架中级

Django框架中级 – 潘登同学的WEB框架 文章目录 Django框架中级 -- 潘登同学的WEB框架 中间件自定义中间件常用中间件process_view() 使用中间件进行URL过滤 Django生命周期生命周期分析 Django日志日志配置filter过滤器自定义filter 日志格式化formatter Django信号内置信号定…

类和对象(二)(C++)

初始化列表 class Date{public:Date(int year, int month, int day){_year year;_month month;_day day;}private:int _year;int _month;int _day;}; 虽然上述构造函数调用之后&#xff0c;对象中已经有了一个初始值&#xff0c;但是不能将其称为对对象中成员变量的初始化…

【纯血鸿蒙】——响应式布局如何实现?

前面介绍了自适应布局&#xff0c;但是将窗口尺寸变化较大时&#xff0c;仅仅依靠自适应布局可能出现图片异常放大或页面内容稀疏、留白过多等问题。此时就需要借助响应式布局能力调整页面结构。 响应式布局 响应式布局是指页面内的元素可以根据特定的特征&#xff08;如窗口…

docker部署使用本地文件的fastapi项目

项目背景&#xff1a;项目使用python开发&#xff0c;需要使用ubutun系统部署后端api接口&#xff0c;对外使用8901端口。 1:项目结构&#xff1a; 2&#xff1a;项目需要使用的pyhton版本为3.9&#xff0c;dockerfile内容如下&#xff1a; # FROM python:3.9# WORKDIR /co…

自制植物大战僵尸:HTML5与JavaScript实现的简单游戏

引言 在本文中&#xff0c;我们将一起探索如何使用HTML5和JavaScript来创建一个简单的植物大战僵尸游戏。这不仅是一项有趣的编程挑战&#xff0c;也是学习游戏开发基础的绝佳机会。 什么是植物大战僵尸&#xff1f; 植物大战僵尸是一款流行的策略塔防游戏&#xff0c;玩家需…

如何提高网站排名?

提高网站排名是一个复杂的过程&#xff0c;涉及到多个方面的优化&#xff0c;包括但不限于内容质量、网站结构、用户体验、外部链接建设等&#xff0c;GSR这个系统&#xff0c;它是一种快速提升关键词排名的方案&#xff0c;不过它有个前提&#xff0c;就是你的站点在目标关键词…

超详解——深入详解Python基础语法——小白篇

目录 1 .语句和变量 变量赋值示例&#xff1a; 打印变量的值&#xff1a; 2. 语句折行 反斜杠折行示例&#xff1a; 使用括号自动折行&#xff1a; 3. 缩进规范 缩进示例&#xff1a; 4. 多重赋值&#xff08;链式赋值&#xff09; 多重赋值的应用&#xff1a; 5 .多…

FonesGo Location Changer 用Mac修改iPhone定位的工具

搜索Mac软件之家下载FonesGo Location Changer 用Mac修改iPhone定位的工具 FonesGo Location Changer 7.0.0 可以自定义修改iPhone和Android手机的GPS定位。FonesGo Location Changer 是玩 Pokemon Go 时的最佳搭档。您可以以自定义速度模拟 GPS 运动&#xff0c;例如步行、骑…

【设计模式】JAVA Design Patterns——State(状态模式)

&#x1f50d;目的 允许对象在内部状态改变时改变它的行为。对象看起来好像修改了它的类。 &#x1f50d;解释 真实世界例子 当在长毛象的自然栖息地观察长毛象时&#xff0c;似乎它会根据情况来改变自己的行为。它开始可能很平静但是随着时间推移当它检测到威胁时它会对周围的…

element-plus 的icon 图标的使用

element-plus的icon 已经独立出来了&#xff0c;需要单独安装 1. npm安装 icon包 npm install element-plus/icons-vue2.注册到全局组件中 同时注册到全局组件中&#xff0c;或者按需单独引入&#xff0c;这里只介绍全局引入。 import { createApp } from vue import { cre…

Python易错点总结

目录 多分支选择结构 嵌套选择 用match模式识别 match与if的对比 案例&#xff1a;闰年判断 三角形的判断 用whlie循环 高斯求和 死循环 用for循环 ​编辑continue​编辑 whlie与else结合 pass 序列 列表&#xff08;有序&#xff09; 元组&#xff08;有序&…

LeetCode热题100—链表(二)

19.删除链表的倒数第N个节点 题目 给你一个链表&#xff0c;删除链表的倒数第 n 个结点&#xff0c;并且返回链表的头结点。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5], n 2 输出&#xff1a;[1,2,3,5] 示例 2&#xff1a; 输入&#xff1a;head [1], n 1 …

Docker中搭建likeadmin

一、使用Docker中的docker-compose搭建likeadmin 1.去网址&#xff1a;https://gitee.com/likeadmin/likeadmin_php中下载likeadmin 注册一个giee账号后 点那个克隆下载 按照序号在终端复制粘贴进去。 接着&#xff0c;输入ls 可以发现有一个这个&#xff1a; 里面有一个like…

摄影店展示服务预约小程序的作用是什么

摄影店包含婚照、毕业照、写真、儿童照、工作照等多个服务项目&#xff0c;虽然如今人们手机打开便可随时拍照摄影&#xff0c;但在专业程度和场景应用方面&#xff0c;却是需要前往专业门店服务获取。 除了进店&#xff0c;也有外部预约及活动、同行合作等场景&#xff0c;重…

Ezsql(buuctf加固题)

开启环境 SSH连接 第一个为页面地址WEB服务 or 11# 利用万能密码登录 密码可以随便输入或者不输入 这里就可以判断这个题目是让我们加固这个登录页面 防止sql注入 查看index.php 添加以下代码 $username addslashes($username); $password addslashes($password);…

2024年京东618红包领取跨店满300减50第二波活动时间什么时候开始到几号结束?

2024年京东618活动时间 整个618红包满减活动时间是从&#xff1a;2024年5月28日12:00开始一直持续到6月20日23:59 第一波红包领取活动时间是从&#xff1a;2024年5月28日12:00开始到6月6日23:59结束 第二波红包领取活动时间是从&#xff1a;2024年6月7日00:00开始到6月18日2…

【HarmonyOS】放大缩小手势实现

【HarmonyOS】放大缩小手势实现 一、鸿蒙中手势的类型&#xff1a; 对于放大缩小手势&#xff0c;在应用开发中使用较为常见&#xff0c;例如预览图片时&#xff0c;扫码时等。 在鸿蒙中对于常见的手势进行的封装&#xff0c;可以通过简单的API进行监听调用&#xff0c;以下是…

k8s测试题

k8s集群k8s集群node01192.168.246.11k8s集群node02192.168.246.12k8s集群master 192.168.246.10 k8s集群nginxkeepalive负载均衡nginxkeepalive01&#xff08;master&#xff09;192.168.246.13负载均衡nginxkeepalive02&#xff08;backup&#xff09;192.168.246.14VIP 192…