什么是生成对抗网络 (GAN)?
钦吉兹·赛义德贝利
一、说明
GAN(Generative Adversarial Network)网络是一种深度学习模型,由两个神经网络——生成器和判别器组成。生成器负责生成虚假的数据,而判别器负责判断数据的真实性。它们之间通过对抗学习的方式相互影响和学习,最终生成器能够生成更加真实的数据,而判别器能够更准确地判断数据的真伪。GAN网络被认为是生成式模型中最具有潜力的一种方法之一。
二、GAN概论
GAN或生成对抗网络是一种神经网络架构,由两个主要组件组成:生成器网络和鉴别器网络。GAN 的目的是生成模拟输入数据分布的真实数据。
生成器网络采用随机噪声向量作为输入,并生成一个旨在类似于输入数据分布的新数据点。鉴别器网络从输入分布中获取生成的数据点和真实数据点,并预测每个输入是真实的还是生成的。
在训练期间,生成器网络生成一个数据点,鉴别器网络预测它是真实的还是生成的。然后,生成器网络根据鉴别器的输出接收有关其生成的数据的真实程度的反馈。重复此过程,直到生成器网络能够产生判别器网络无法与真实数据区分开来的真实数据。
GAN的训练过程可以被描述为一个双人游戏,其中生成器和鉴别器网络不断尝试相互智取。生成器网络旨在生成足够逼真的数据以欺骗鉴别器网络,而鉴别器网络试图正确识别给定的数据点是真实的还是生成的。
训练后,生成器网络可用于生成类似于输入数据分布的新数据。GAN 已成功用于各种应用,包括图像和视频生成、文本生成和音乐生成。然而,GAN 的训练也可能具有挑战性,并且容易出现模式崩溃等问题,其中发电机网络产生的输出范围有限。
GAN应用程序的一个例子是图像生成。在此方案中,生成器网络接收随机噪声向量并生成类似于输入图像分布的新图像。鉴别器网络从输入分布中获取生成的图像和真实图像,并预测每个图像是真实的还是生成的。
在训练期间,生成器网络生成图像,鉴别器网络预测它是真实的还是生成的。然后,生成器网络根据鉴别器的输出接收有关其生成的图像逼真的反馈。重复此过程,直到生成器网络能够生成判别器网络无法与真实图像区分的真实图像。
训练后,生成器网络可用于生成类似于输入图像分布的新图像。例如,可以在名人面孔数据集上训练 GAN,然后用于生成新的、逼真的名人面孔。GAN还用于其他与图像相关的任务,例如图像到图像的转换,其中GAN用于将图像从一个域(例如,白天)转换为另一个域(例如,夜间),同时保持图像的内容。
让我们为 GAN 网络编写一个伪代码
Initialize the generator network G with random weights
Initialize the discriminator network D with random weights
Set the learning rate for both networks
Set the number of training epochs
Set the batch sizefor epoch in range(num_epochs):for batch in data:# Train the discriminator networkSample a batch of real images from the training dataGenerate a batch of fake images from the generator networkTrain the discriminator network on the real and fake imagesCompute the discriminator loss# Train the generator networkGenerate a new batch of fake images from the generator networkCompute the generator loss based on the discriminator's outputBackpropagate the loss and update the generator's weights# Update the discriminator's weightsBackpropagate the loss and update the discriminator's weights# Generate a sample of fake images from the generatorSave the generator's weights
三、GAN 编码与 Python
要为GAN编写完整的Python代码,需要大量的时间和资源。但是,我可以简要概述使用 PyTorch 库训练 GAN 所涉及的步骤:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
使用 PyTorch 定义生成器和鉴别器网络:nn.Module
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()# Define the layers of the generator networkdef forward(self, z):# Define the forward pass of the generator networkclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()# Define the layers of the discriminator networkdef forward(self, x):# Define the forward pass of the discriminator network
定义超参数:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 128
num_epochs = 100
learning_rate = 2e-4
latent_size = 100
image_size = 28*28
加载 MNIST 数据集并创建数据加载器:
train_dataset = datasets.MNIST(root='data/', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
定义损失函数和优化器:
criterion = nn.BCELoss()
d_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
g_optimizer = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
训练 GAN:
for epoch in range(num_epochs):for batch_idx, (real_images, _) in enumerate(train_loader):# Train discriminator with real imagesreal_images = real_images.view(-1, image_size).to(device)real_labels = torch.ones(batch_size, 1).to(device)fake_labels = torch.zeros(batch_size, 1).to(device)# Train discriminator with fake imagesz = torch.randn(batch_size, latent_size).to(device)fake_images = generator(z)d_real_loss = criterion(discriminator(real_images), real_labels)d_fake_loss = criterion(discriminator(fake_images), fake_labels)d_loss = d_real_loss + d_fake_lossd_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# Train generatorz = torch.randn(batch_size, latent_size).to(device)fake_images = generator(z)g_loss = criterion(discriminator(fake_images), real_labels)g_optimizer.zero_grad()g_loss.backward()g_optimizer.step()
使用经过训练的生成器生成新图像:
z = torch.randn(64, latent_size).to(device)
generated_images = generator(z)
请注意,上面的代码只是一个简短的概述,对于 GAN 的特定用例,可能需要额外的步骤和修改。
让我们在代码中填写空白:)
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader# Define the generator network
class Generator(nn.Module):def __init__(self, input_size=100, output_size=784):super(Generator, self).__init__()self.input_size = input_sizeself.output_size = output_sizeself.fc1 = nn.Linear(input_size, 256)self.bn1 = nn.BatchNorm1d(256)self.fc2 = nn.Linear(256, 512)self.bn2 = nn.BatchNorm1d(512)self.fc3 = nn.Linear(512, 1024)self.bn3 = nn.BatchNorm1d(1024)self.fc4 = nn.Linear(1024, output_size)self.activation = nn.Tanh()def forward(self, x):x = self.fc1(x)x = self.bn1(x)x = self.activation(x)x = self.fc2(x)x = self.bn2(x)x = self.activation(x)x = self.fc3(x)x = self.bn3(x)x = self.activation(x)x = self.fc4(x)x = self.activation(x)return x# Define the discriminator network
class Discriminator(nn.Module):def __init__(self, input_size=784, output_size=1):super(Discriminator, self).__init__()self.input_size = input_sizeself.output_size = output_sizeself.fc1 = nn.Linear(input_size, 1024)self.activation = nn.LeakyReLU(0.2)self.fc2 = nn.Linear(1024, 512)self.fc3 = nn.Linear(512, 256)self.fc4 = nn.Linear(256, output_size)self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.fc1(x)x = self.activation(x)x = self.fc2(x)x = self.activation(x)x = self.fc3(x)x = self.activation(x)x = self.fc4(x)x = self.sigmoid(x)return x# Define the hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
num_epochs = 50
learning_rate = 0.0002
input_size = 100
image_size = 28 * 28# Load the MNIST dataset
train_dataset = datasets.MNIST(root="./data", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)# Initialize the generator and discriminator networks
generator = Generator(input_size).to(device)
discriminator = Discriminator().to(device)# Define the loss functions and optimizers
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=learning_rate)
d_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate)# Train the GAN
for epoch in range(num_epochs):for batch_idx, (real_images, _) in enumerate(train_loader):real_images = real_images.view(-1, image_size).to(device)batch_size = real_images.shape[0]# Train the discriminator networkd_optimizer.zero_grad()# Train on real imagesreal_labels = torch.ones(batch