经典神经网络(8)GAN、CGAN、DCGAN、LSGAN及其在MNIST数据集上的应用

经典神经网络(8)GAN、CGAN、DCGAN、LSGAN及其在MNIST数据集上的应用

1 GAN的简述及其在MNIST数据集上的应用

  • GAN模型主导了生成式建模的前一个时代,但由于训练过程中的不稳定性,对GAN进行扩展需要仔细调整网络结构和训练考虑,因此GANs虽然在为单个或多个对象类别建模方面表现出色,但扩展到复杂的数据集上,非常具有挑战性。
  • 最近几年发布的一系列大型模型,如DALL-E系列、Imagen、Parti和Stable Diffusion,开创了图像生成的新时代,在图像质量和模型灵活性方面达到了前所未有的水平。
  • 目前占主导地位的范式扩散模型自回归模型,都依赖于迭代推理这把双刃剑,因为迭代方法能够以简单的目标进行稳定的训练,但在推理过程中会产生更高的计算成本。与此形成对比的是生成对抗网络(GAN),只需要一次forward pass即可生成图像,因此本质上是更高效的。
  • 虽然现在超大型的模型、数据和计算资源都主要集中在扩散模型和自回归模型上。但是,也有研究人员证明GAN仍然是文本生成图像的可行选择之一,例如:2023年提出的GigaGAN(https://arxiv.org/abs/2303.05511)。
  • 今天,我们来了解下生成式对抗网络GAN及其几个改进网络。

1.1 GAN的简述

  • GAN 是 Generative Adversarial Network 生成式对抗网络英文的缩写,由蒙特利尔大学的Ian Goodfellow在2014年提出。
  • GAN由两个部分组成:
    • 一个是生成器Generator,尽量去学习真实的数据分布,随机接收一个随机噪声来生成无限接近真实数据的图像。
    • 一个是鉴别器Discriminator,判断一张图像是不是“真实的”,输入是一张图像,输出是该图像为真实图像的概率,介于0-1之间,概率值越小认为生成图像不真实的可能性越大。
  • 生成器的目标是通过生成接近真实的图像来欺骗判别器,而判别器的目标是尽量辨别出生成器生成的假图像和真实图像的区别。生成器希望假图像更逼真判别概率高,而判别器希望假图像再逼真也可以判别概率低,通过这样的动态博弈过程,最终达到纳什均衡点,通过深度神经网络训练完成之后,生成器可以从一段随机数中生成逼真的图像。
  • 不过,GAN存在着训练困难、生成器和判别器的loss无法指示训练进程、生成样本缺乏多样性等问题,因此出现了一系列改进模型,如:CGAN、LSGAN、DCGAN、WGAN、WGAN-GP、BEGAN、CycleGAN等
  • 论文链接:https://arxiv.org/pdf/1406.2661.pdf

1.1.1 GAN的架构

在这里插入图片描述

  • 生成器G:尽量去学习真实的数据分布,生成无限接近真实数据的样本
  • 判别器D:尽量去判别输入数据是真实数据还是来自于生成器生成的数据
  • 主要过程为:
    1. 输入噪声(隐藏变量)z
    2. 通过生成部分G,得到 G ( z ) = x f a k e G(z)=x_{fake} G(z)=xfake
    3. 从真实数据集中取一部分真实数据 x r e a l x_{real} xreal
    4. 将两者混合 x = x f a k e + x r e a l x=x_{fake}+x_{real} x=xfake+xreal
    5. 将数据喂入判别部分D,给定标签 l a b e l f a k e = 0 , l a b e l r e a l = 1 label_{fake}=0,label_{real}=1 labelfake=0,labelreal=1(简单的二类分类器)
    6. 按照分类结果,回传loss
  • GAN的对抗生成思想主要由其目标函数实现,通过给定一个生成器G和一个判别器D,GAN的目标函数 V ( G , D ) V(G, D) V(G,D)具体公式如下所示:

在这里插入图片描述

我们可以分两部分开看这个公式,即判别器最大化生成器最小化

在判别器角度,我们希望最大化这个目标函数

  • 因为在公式的第一部分,其表示GT样本 ( x ~ p d a t a ) (x~p_{data}) (xpdata)输入判别器后输出的置信度,当然是越接近1越好。
  • 而公式的第二部分表示生成器输出的生成样本 G ( z ) G(z) G(z)再输入判别器中进行进行二分类判别,因为 l o g ( 1 − D ( G ( z ) ) ) < = 0 log(1-D(G(z)))<=0 log(1D(G(z)))<=0,那么输出的置信度当然是越接近0越好,所以 1 − D ( G ( z ) ) 1-D(G(z)) 1D(G(z))越接近1越好。

在生成器角度,我们希望最小化【判别器目标函数的最大值】

  • 判别器目标函数的最大值代表的是真实数据分布与生成数据分布的JS散度
  • JS散度可以度量分布的相似性,两个分布越接近,JS散度越小(JS散度是在初始GAN论文中被提出,实际应用中会发现有不足的地方,后来的论文陆续提出了很多的新损失函数来进行优化)。

生成器与判别器之间存在着对抗

  • 一方面,从生成器而言,希望 D ( G ( z ) ) D(G(z)) D(G(z))为1,提高自己的生成能力;
  • 另一方面,从判别器而言,希望 D ( G ( z ) ) D(G(z)) D(G(z))为0,提高自己的判别能力。
  • 作者经过理论证明,两者最终可以达到纳什均衡——处于此状态下,利益达到最大,双方都不愿意改变自己的状态

1.1.2 理论证明

作者在论文中,证明了生成器与判别器最终可以达到纳什均衡状态。证明的过程中,利用了KL散度的概念,KL散度可以参考:信息量、熵、KL散度、交叉熵概念理解。

  • 首先,我们在给定生成器的情况下,考虑最优化判别器D。和一般的基于Sigmoid的二分类模型训练一样,训练判别器D也是最小化交叉熵的过程,其损失函数为(二分类):
    O b j D ( θ D , θ G ) = − 1 2 E x ~ p d a t a ( x ) [ l o g D ( x ) ] − 1 2 E z ~ p z ( z ) [ l o g ( 1 − D ( g ( z ) ) ] Obj^D(\theta_D,\theta_G)=-\frac{1}{2}E_{x~p_{data}}(x)[logD(x)]-\frac{1}{2}E_{z~p_{z}(z)}[log(1-D(g(z))] ObjD(θD,θG)=21Expdata(x)[logD(x)]21Ezpz(z)[log(1D(g(z))]

  • 训练过程就是最小化损失函数的过程,在连续空间上我们进而写成

O b j D ( θ D , θ G ) = − 1 2 ∫ x p d a t a ( x ) l o g D ( x ) − 1 2 ∫ z p z ( z ) l o g ( 1 − D ( g ( z ) ) 我们考虑在优化 D 的时候 G 是不变的,并且假设,通过 G 生成的 g ( z ) 满足的分布为 p g ,因此上式改写为: = − 1 2 ∫ x [ p d a t a ( x ) l o g D ( x ) + p g ( x ) l o g ( 1 − D ( x ) ) ] Obj^D(\theta_D,\theta_G)=-\frac{1}{2}\int_xp_{data}(x)logD(x)-\frac{1}{2}\int_zp_{z}(z)log(1-D(g(z))\\ 我们考虑在优化D的时候G是不变的,并且假设,通过G生成的g(z)满足的分布为p_g,因此上式改写为: \\ =-\frac{1}{2}\int_x[p_{data}(x)logD(x)+p_{g}(x)log(1-D(x))] \\ ObjD(θD,θG)=21xpdata(x)logD(x)21zpz(z)log(1D(g(z))我们考虑在优化D的时候G是不变的,并且假设,通过G生成的g(z)满足的分布为pg,因此上式改写为:=21x[pdata(x)logD(x)+pg(x)log(1D(x))]

  • 去除常量-1/2,我们约定质量函数为 V ( G , D ) V(G,D) V(G,D)

V ( G , D ) = E x ~ p d a t a ( x ) [ l o g D ( x ) ] − E z ~ p z ( z ) [ l o g ( 1 − D ( g ( z ) ) ] = ∫ x [ p d a t a ( x ) l o g D ( x ) + p g ( x ) l o g ( 1 − D ( x ) ) ] 上式什么时候取最大呢? a l o g ( y ) + b l o g ( 1 − y ) 在 [ 0 , 1 ] 上当 y = a a + b 取最大值,因此上式取得最大值时: D G ∗ ( x ) = p d a t a p d a t a + p g ( x ) , 此即为判别器的最优解 V(G,D)=E_{x~p_{data}}(x)[logD(x)]-E_{z~p_{z}(z)}[log(1-D(g(z))]\\ =\int_x[p_{data}(x)logD(x)+p_{g}(x)log(1-D(x))] \\ 上式什么时候取最大呢?\\ alog(y)+blog(1-y)在[0,1]上当y=\frac{a}{a+b}取最大值,因此上式取得最大值时:\\ D^*_{G}(x)=\frac{p_{data}}{p_{data}+p_{g}(x)},此即为判别器的最优解 V(G,D)=Expdata(x)[logD(x)]Ezpz(z)[log(1D(g(z))]=x[pdata(x)logD(x)+pg(x)log(1D(x))]上式什么时候取最大呢?alog(y)+blog(1y)[0,1]上当y=a+ba取最大值,因此上式取得最大值时:DG(x)=pdata+pg(x)pdata,此即为判别器的最优解

  • 我们将判别器的最优解,代入到质量函数 V ( G , D ) V(G,D) V(G,D)

    在这里插入图片描述

  • KL散度是非负的,所以我们可以认为-log4是最小值

  • 为了证明 p d a t a = p g p_{data}=p_g pdata=pg是使上式取-log4的唯一点,这里可以使用JS散度的特性

    • 在这里插入图片描述

    • 因此,当且仅当 p d a t a = p g p_{data}=p_g pdata=pg,我们得到最优生成器,即生成器的概率密度函数等于真实数据的概率密度函数,也即生成的数据和真实数据是一样的;

    • 此时最优判别器 D ∗ = 1 2 D^*=\frac{1}{2} D=21,即判别器无法判断数据到底是来自真实样本,还是伪造的数据。

1.1.3 模型的训练过程

先训练判别器使判别器达到最优,再训练生成器使二者完成对抗优化,最终达到 p d a t a = p g p_{data}=p_g pdata=pg

在这里插入图片描述

如上图所示,生成对抗网络会训练并更新判别分布(即 D,蓝色的虚线),更新判别器后就能将数据真实分布(黑点组成的线)从生成分布(绿色实线)中判别出来。

下方的水平线代表采样域Z,其中等距线表示Z中的样本为均匀分布,上方的水平线代表真实数据X中的一部分。向上的箭头表示映射 x = G ( z ) x=G(z) x=G(z) 如何对噪声样本(均匀采样)施加一个不均匀的分布 p g p_g pg.

  • 在算法内部循环中训练 D 以从数据中判别出真实样本,该循环最终会收敛到

D G ∗ ( x ) = p d a t a p d a t a + p g ( x ) D^*_{G}(x)=\frac{p_{data}}{p_{data}+p_{g}(x)} DG(x)=pdata+pg(x)pdata

  • 随后固定判别器并训练生成器,在更新G之后,D的梯度会引导 G ( z ) G(z) G(z)流向更可能D分类为真实数据的方向。
  • 经过若干次训练后,如果G和D有足够的复杂度,那么它们就会到达一个均衡点,这个时候 p d a t a = p g p_{data}=p_g pdata=pg

1.1.4 GAN存在的问题

1、可解释性非常差

  • 所学到的数据分布,没有显示的表达式。
  • 它只是一个黑盒子一样的映射函数: 输入是一个随机变量,输出想要的一个数据分布。

2、训练不稳定

  • 难以保持生成器与判别器的平衡

3、生成器容易产生模式崩溃(Mode collapse)

  • 举个生成数字图像的例子:生成器要生成0-9之间的数字,而判别器只是要判断生成器生成的数据像不像真实数据。
  • 比如”1“是非常容易生成的一个数字,那么生成器可能就会拼命的去生成更多的真实的”1“,从而判别器就难以判别。对于其他的复杂一点的数字比如”8“,”9“,生成器可能就干脆不生成了,从而避免犯错,这就是生成器的一个大问题。

1.2 GAN在MNIST数据集上的应用

参考代码:PyTorch-GAN/implementations

1.2.1 生成器D和判别器G

  • 我们这里实现的生成对抗网络(GAN)十分简单,仅用了线性层搭建。
  • 生成器Generator将随机生成的噪声z通过多个线性层生成图片,注意生成器的最后一层是Tanh,所以我们生成的图片的取值范围为[-1,1],同理,我们会将真实图片归一化(normalize)到[-1,1]。
  • 判别器Discriminator是一个二分类器,通过多个线性层得到一个概率值来判别图片是"真实"或者是"生成"的,所以在Discriminator的最后是一个sigmoid,来得到图片是真实的概率。
  • 在所有的网络结构中我们都使用了LeakyReLU作为激活函数,除了G与D的最后一层。在层与层之间,我们还加入了BatchNormalization。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import os
from PIL import Imageclass Generator(nn.Module):def __init__(self, image_size=32, latent_dim=100, output_channel=1):"""image_size: image with and heightlatent dim: the dimension of random noise zoutput_channel: the channel of generated image, for example, 1 for gray image, 3 for RGB image"""super(Generator, self).__init__()self.latent_dim = latent_dimself.output_channel = output_channelself.image_size = image_size# Linear layer: latent_dim -> 128 -> 256 -> 512 -> 1024 -> output_channel * image_size * image_size -> Tanhself.model = nn.Sequential(nn.Linear(latent_dim, 128),nn.BatchNorm1d(128),nn.LeakyReLU(0.2, inplace=True),nn.Linear(128, 256),nn.BatchNorm1d(256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 512),nn.BatchNorm1d(512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 1024),nn.BatchNorm1d(1024),nn.LeakyReLU(0.2, inplace=True),nn.Linear(1024, output_channel * image_size * image_size),nn.Tanh())def forward(self, z):img = self.model(z)img = img.view(img.size(0), self.output_channel, self.image_size, self.image_size)return imgclass Discriminator(nn.Module):def __init__(self, image_size=32, input_channel=1):"""image_size: image with and heightinput_channel: the channel of input image, for example, 1 for gray image, 3 for RGB image"""super(Discriminator, self).__init__()self.image_size = image_sizeself.input_channel = input_channel# Linear layer: input_channel * image_size * image_size -> 1024 -> 512 -> 256 -> 1 -> Sigmoidself.model = nn.Sequential(nn.Linear(input_channel * image_size * image_size, 1024),nn.LeakyReLU(0.2, inplace=True),nn.Linear(1024, 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid(),)def forward(self, img):img_flat = img.view(img.size(0), -1)out = self.model(img_flat)return out

1.2.2 MNIST数据集的加载

  • MNIST是一个手写数字数据集,通常用于机器学习和计算机视觉领域的基准测试。每个样本都是一个28x28像素的灰度图像,表示从0到9的手写数字。
  • MNIST数据集共包含70000个图像,其中60000个用作训练集,10000个用作测试集。对于GAN而言,我们不需要测试集,仅使用训练集。
  • 我们将所有图片normalize到了[-1,1]之间。
def load_mnist_data():"""load mnist(0,1,2) dataset"""transform = torchvision.transforms.Compose([# transform to 1-channel gray image since we reading image in RGB modetransforms.Grayscale(1),# resize image from 28 * 28 to 32 * 32transforms.Resize(32),transforms.ToTensor(),# normalize with mean=0.5 std=0.5,transforms.Normalize(mean=(0.5,),std=(0.5,))])train_dataset = torchvision.datasets.MNIST(r"/root/autodl-fs/data/minist", download=False, train=True,transform=transform)return train_dataset
  • 通过下面代码,我们能够查看数据集中的20张随机真实图片
def denorm(x):# denormalizeout = (x + 1) / 2return out.clamp(0, 1)def show_train_dataset():train_dataset = load_mnist_data()trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=20, shuffle=True)grid = torchvision.utils.make_grid(denorm(next(iter(trainloader))[0]), nrow=5)os.makedirs("gan_minist", exist_ok=True)image_grid = Image.fromarray(grid.mul(255).permute(1, 2, 0).byte().numpy())image_grid.save(f"./gan_minist/init.jpg")

1.2.3 模型的训练

  • GAN的训练过程分为两步
    • 第一步将随机噪声z喂给生成器G生成图片,然后将真实图片和生成器G生成的图片喂给判别器D,然后使用对应的loss函数反向传播优化判别器D。
    • 第二步使用生成器G生成图片,并喂给判别器D,并使用对应的loss函数反向传播优化生成器G。
  • 对于判别器D,最大化其优化目标可以通过最小化一个BCEloss来实现,其真实图片的标签设置为1,而生成图片的标签设置为0。
  • 对于生成器G,也通过最小化一个BCEloss来实现,即将生成图片的标签设置为1即可。
  • 当模型训练时,我们需要查看G生成的图片效果,下面的visualize_results代码便实现了这块内容。需要注意的是,我们生成的图片都在[-1,1]。因此,我们需要将图片反向归一化(denorm)到[0,1]。
def visualize_results(epoch, G, device, z_dim, result_size=20):epoch = str(epoch).zfill(3)G.eval()z = torch.rand(result_size, z_dim).to(device)g_z = G(z)grid = torchvision.utils.make_grid(denorm(g_z.detach().cpu()), nrow=5)os.makedirs("gan_minist", exist_ok=True)image_grid = Image.fromarray(grid.mul(255).permute(1, 2, 0).byte().numpy())image_grid.save(f"./gan_minist/{epoch}.jpg")def run_gan(trainloader, G, D, G_optimizer, D_optimizer, loss_func, n_epochs, device, latent_dim):d_loss_hist = []g_loss_hist = []t_epochs = []for epoch in range(n_epochs):d_loss, g_loss = train_one_epoch(trainloader, G, D, G_optimizer, D_optimizer, loss_func, device,z_dim=latent_dim)print('Epoch {}: Train D loss: {:.4f}, G loss: {:.4f}'.format(epoch, d_loss, g_loss))d_loss_hist.append(d_loss)g_loss_hist.append(g_loss)t_epochs.append(epoch)if epoch == 0 or (epoch + 1) % 10 == 0:# 每10个epoch 就可视化一下图像visualize_results(epoch + 1, G, device, latent_dim)return d_loss_hist, g_loss_hist, t_epochs
def train_one_epoch(trainloader, G, D, G_optimizer, D_optimizer, loss_func, device, z_dim):"""train a GAN with model G and D in one epochArgs:trainloader: data loader to trainG: model GeneratorD: model DiscriminatorG_optimizer: optimizer of G(etc. Adam, SGD)D_optimizer: optimizer of D(etc. Adam, SGD)loss_func: loss function to train G and D. For example, Binary Cross Entropy(BCE) loss functiondevice: cpu or cuda devicez_dim: the dimension of random noise z"""# set train modeD.train()G.train()D_total_loss = 0G_total_loss = 0for i, (x, _) in enumerate(trainloader):# real label and fake labely_real = torch.ones(x.size(0), 1).to(device)y_fake = torch.zeros(x.size(0), 1).to(device)x = x.to(device)z = torch.rand(x.size(0), z_dim).to(device)# 1、训练判别器# D optimizer zero gradsD_optimizer.zero_grad()# D real loss from real imagesd_real = D(x)d_real_loss = loss_func(d_real, y_real)# D fake loss from fake images generated by Gg_z = G(z)d_fake = D(g_z)d_fake_loss = loss_func(d_fake, y_fake)# D backward and stepd_loss = d_real_loss + d_fake_lossd_loss.backward()D_optimizer.step()# 2、训练生成器# G optimizer zero gradsG_optimizer.zero_grad()# G lossg_z = G(z)d_fake = D(g_z)g_loss = loss_func(d_fake, y_real)# G backward and stepg_loss.backward()G_optimizer.step()D_total_loss += d_loss.item()G_total_loss += g_loss.item()return D_total_loss / len(trainloader), G_total_loss / len(trainloader)
  • 设置好超参数就可以开始训练,我们可以将训练过程中loss值记录下来方便画图
def save_loss2txt(x_values, y1_values, y2_values):# 打开文件进行写入with open('gan_minist/loss_data.txt', 'w') as file:for x, y1, y2 in zip(x_values, y1_values, y2_values):file.write(f'{x} {y1} {y2}\n')def plot_loss():# 然后使用matplotlib读取txt文件中的数据进行绘图x_values, y1_values, y2_values = [], [], []with open('gan_minist/loss_data.txt', 'r') as file:for line in file:parts = line.split()x_values.append(float(parts[0]))y1_values.append(float(parts[1]))y2_values.append(float(parts[2]))# 绘图plt.plot(x_values, y1_values, label='d_loss_hist')plt.plot(x_values, y2_values, label='g_loss_hist')plt.legend()plt.show()if __name__ == '__main__':# hyper params# z dimlatent_dim = 100# image size and channelimage_size = 32image_channel = 1# Adam lr and betaslearning_rate = 0.0002betas = (0.5, 0.999)# epochs and batch sizen_epochs = 200batch_size = 512# devicedevice = "cuda" if torch.cuda.is_available() else "cpu"# mnist dataset and dataloadertrain_dataset = load_mnist_data()trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=12)# use BCELoss as loss functionbceloss = nn.BCELoss().to(device)# G and D modelG = Generator(image_size=image_size, latent_dim=latent_dim, output_channel=image_channel).to(device)D = Discriminator(image_size=image_size, input_channel=image_channel).to(device)# G and D optimizer, use Adam or SGDG_optimizer = optim.Adam(G.parameters(), lr=learning_rate, betas=betas)D_optimizer = optim.Adam(D.parameters(), lr=learning_rate, betas=betas)d_loss_hist, g_loss_hist, t_epochs = run_gan(trainloader, G, D, G_optimizer, D_optimizer, bceloss,n_epochs, device, latent_dim)# 保存Loss信息save_loss2txt(t_epochs, d_loss_hist, g_loss_hist)
  • 下面是训练第1、100、200轮时,随机生成的图像。
  • 可以看到,即使是一个简单的GAN在MNIST这种简单数据集上的生成效果还是不错的。

在这里插入图片描述

  • 训练过程中的损失函数图像如下所示。
  • 我们知道在训练过程中,一般损失曲线倾向于下降并最终收敛。然而,在生成对抗网络(GAN)模型中,当判别器(D_loss)降低时,生成器损失(G_loss)升高,反之亦然。
  • 这是因为在GAN中,生成器和判别器相互对抗,生成器希望生成的图像能够欺骗判别器,而判别器希望能够找到生成器的伪装,因此两者的表现往往是相反的。

在这里插入图片描述

2 CGAN的简述及其在MNIST数据集上的应用

2.1 CGAN的简述

  • 原始GAN的生成过程采用随机噪声就可以开始训练,不再需要一个假设的数据分布,但是这样自由散漫的方式对于较大的图像就不太可控了
  • CGAN(Conditional GAN)方法提出了一种带有条件约束的GAN,通过额外的信息对模型增加条件,来指导数据生成过程。
  • 将额外信息y输送给判别模型和生成模型,作为输入层的一部分,从而实现条件GAN,是在Mnist数据集上以类别标签为条件变量,生成指定类别的图像,把纯无监督的GAN变成有监督的模型。

在这里插入图片描述

  • 条件 GAN 的目标函数是带有条件概率的二人极小极大值博弈

在这里插入图片描述

  • 论文链接:https://arxiv.org/pdf/1411.1784.pdf

2.2 CGAN在MNIST数据集上的应用

  • 我们在GAN的基础上,利用nn.Embedding(10, label_latent_dim)将labels进行映射
  • 再利用torch.cat([z, label_embedding], dim=-1)拼接起来就得到了CGAN。
import torch
from tqdm import trange
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import os
from PIL import Imageclass Generator(nn.Module):def __init__(self, image_size=32, latent_dim=100, output_channel=1, label_latent_dim=10):"""image_size: image with and heightlatent dim: the dimension of random noise zoutput_channel: the channel of generated image, for example, 1 for gray image, 3 for RGB image"""super(Generator, self).__init__()self.latent_dim = latent_dimself.output_channel = output_channelself.image_size = image_sizeself.embedding = nn.Embedding(10, label_latent_dim)# Linear layer: latent_dim -> 128 -> 256 -> 512 -> 1024 -> output_channel * image_size * image_size -> Tanhself.model = nn.Sequential(nn.Linear(latent_dim + label_latent_dim, 128),nn.BatchNorm1d(128),nn.LeakyReLU(0.2, inplace=True),nn.Linear(128, 256),nn.BatchNorm1d(256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 512),nn.BatchNorm1d(512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 1024),nn.BatchNorm1d(1024),nn.LeakyReLU(0.2, inplace=True),nn.Linear(1024, output_channel * image_size * image_size),nn.Tanh())def forward(self, z, labels):# concat 标签向量label_embedding = self.embedding(labels)z = torch.cat([z, label_embedding], dim=-1)img = self.model(z)img = img.view(img.size(0), self.output_channel, self.image_size, self.image_size)return imgclass Discriminator(nn.Module):def __init__(self, image_size=32, input_channel=1, label_latent_dim=10):"""image_size: image with and heightinput_channel: the channel of input image, for example, 1 for gray image, 3 for RGB image"""super(Discriminator, self).__init__()self.image_size = image_sizeself.input_channel = input_channelself.embedding = nn.Embedding(10, label_latent_dim)# Linear layer: input_channel * image_size * image_size -> 1024 -> 512 -> 256 -> 1 -> Sigmoidself.model = nn.Sequential(nn.Linear(input_channel * image_size * image_size + label_latent_dim, 1024),nn.LeakyReLU(0.2, inplace=True),nn.Linear(1024, 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid(),)def forward(self, img, labels):img_flat = img.view(img.size(0), -1)# concat 标签向量label_embedding = self.embedding(labels)img_flat = torch.cat([img_flat, label_embedding], dim=-1)out = self.model(img_flat)return out
  • 注意此时的训练函数中,需要传入lables信息了。
  • 其他函数,和GAN一致。
def train_one_epoch(trainloader, G, D, G_optimizer, D_optimizer, loss_func, device, z_dim):"""train a CGAN with model G and D in one epochArgs:trainloader: data loader to trainG: model GeneratorD: model DiscriminatorG_optimizer: optimizer of G(etc. Adam, SGD)D_optimizer: optimizer of D(etc. Adam, SGD)loss_func: loss function to train G and D. For example, Binary Cross Entropy(BCE) loss functiondevice: cpu or cuda devicez_dim: the dimension of random noise z"""# set train modeD.train()G.train()D_total_loss = 0G_total_loss = 0for i, (x, labels) in enumerate(trainloader):# real label and fake labely_real = torch.ones(x.size(0), 1).to(device)y_fake = torch.zeros(x.size(0), 1).to(device)x = x.to(device)labels = labels.to(device)z = torch.rand(x.size(0), z_dim).to(device)# 1、训练判别器# D optimizer zero gradsD_optimizer.zero_grad()# D real loss from real imagesd_real = D(x, labels)d_real_loss = loss_func(d_real, y_real)# D fake loss from fake images generated by Gg_z = G(z, labels)d_fake = D(g_z, labels)d_fake_loss = loss_func(d_fake, y_fake)# D backward and stepd_loss = d_real_loss + d_fake_lossd_loss.backward()D_optimizer.step()# 2、训练生成器# G optimizer zero gradsG_optimizer.zero_grad()# G lossg_z = G(z, labels)d_fake = D(g_z, labels)g_loss = loss_func(d_fake, y_real)# G backward and stepg_loss.backward()G_optimizer.step()D_total_loss += d_loss.item()G_total_loss += g_loss.item()return D_total_loss / len(trainloader), G_total_loss / len(trainloader)
  • 下面是训练第1、100、200轮时,随机生成的图像。

在这里插入图片描述

3 DCGAN的简述及其在MNIST数据集上的应用

3.1 DCGAN的简述

  • DCGAN使用卷积层代替了全连接层,采用带步长的卷积代替上采样,更好的提取图像特征,判别器和生成器对称存在,极大的提升了GAN训练的稳定性和生成结果的质量。

  • 判别器中采用leakyRELU而不是RELU来防止梯度稀疏,而生成器仍然采用RELU,但输出层采用tanh。采用Adam优化器训练GAN,设置学习率为0.0002。

  • DCGAN并没有从根本上解决GAN训练不稳定的问题,训练的时候仍需要小心的平衡生成器和判别器的训练,往往是训练一个多次,训练另一个一次。

  • 论文链接:https://arxiv.org/pdf/1511.06434.pdf

3.2 DCGAN在MNIST数据集上的应用

  • 在DCGAN(Deep Convolution GAN)中,最大的改变是使用了CNN代替全连接层。

    • 在生成器G中,使用stride为2的转置卷积来生成图片同时扩大图片尺寸;
    • 而在判别器D中,使用stride为2的卷积来将图片进行卷积并下采样。
  • 除此之外,DCGAN加入了在层与层之间BatchNormalization(虽然我们在普通的GAN中就已经添加),在G中使用ReLU作为激活函数,而在D中使用LeakyReLU作为激活函数

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import os
from PIL import Imageclass DCGenerator(nn.Module):def __init__(self, image_size=32, latent_dim=64, output_channel=1):super(DCGenerator, self).__init__()self.image_size = image_sizeself.latent_dim = latent_dimself.output_channel = output_channelself.init_size = image_size // 8# fc: Linear -> BN -> ReLUself.fc = nn.Sequential(nn.Linear(latent_dim, 512 * self.init_size ** 2),nn.BatchNorm1d(512 * self.init_size ** 2),nn.ReLU(inplace=True))# deconv: ConvTranspose2d(4, 2, 1) -> BN -> ReLU ->#         ConvTranspose2d(4, 2, 1) -> BN -> ReLU ->#         ConvTranspose2d(4, 2, 1) -> Tanhself.deconv = nn.Sequential(nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),nn.ConvTranspose2d(128, output_channel, 4, stride=2, padding=1),nn.Tanh(),)def forward(self, z):out = self.fc(z)out = out.view(out.shape[0], 512, self.init_size, self.init_size)img = self.deconv(out)return imgclass DCDiscriminator(nn.Module):def __init__(self, image_size=32, input_channel=1, sigmoid=True):super(DCDiscriminator, self).__init__()self.image_size = image_sizeself.input_channel = input_channelself.fc_size = image_size // 8# conv: Conv2d(3,2,1) -> LeakyReLU#       Conv2d(3,2,1) -> BN -> LeakyReLU#       Conv2d(3,2,1) -> BN -> LeakyReLUself.conv = nn.Sequential(nn.Conv2d(input_channel, 128, 3, 2, 1),nn.LeakyReLU(0.2),nn.Conv2d(128, 256, 3, 2, 1),nn.BatchNorm2d(256),nn.LeakyReLU(0.2),nn.Conv2d(256, 512, 3, 2, 1),nn.BatchNorm2d(512),nn.LeakyReLU(0.2),)# fc: Linear -> Sigmoidself.fc = nn.Sequential(nn.Linear(512 * self.fc_size * self.fc_size, 1),)if sigmoid:self.fc.add_module('sigmoid', nn.Sigmoid())def forward(self, img):out = self.conv(img)out = out.view(out.shape[0], -1)out = self.fc(out)return out
  • 同样使用mnist数据集对DCGAN进行训练,训练代码只需要修改G、D模型分别为DCGenerator、DCDiscriminator。
  • 其他代码和GAN一致。
if __name__ == '__main__':# hyper params# z dimlatent_dim = 100# image size and channelimage_size = 32image_channel = 1# Adam lr and betaslearning_rate = 0.0002betas = (0.5, 0.999)# epochs and batch sizen_epochs = 200batch_size = 512# devicedevice = "cuda" if torch.cuda.is_available() else "cpu"# mnist dataset and dataloadertrain_dataset = load_mnist_data()trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=12)# use BCELoss as loss functionbceloss = nn.BCELoss().to(device)# G and D modelG = DCGenerator(image_size=image_size, latent_dim=latent_dim, output_channel=image_channel).to(device)D = DCDiscriminator(image_size=image_size, input_channel=image_channel).to(device)# G and D optimizer, use Adam or SGDG_optimizer = optim.Adam(G.parameters(), lr=learning_rate, betas=betas)D_optimizer = optim.Adam(D.parameters(), lr=learning_rate, betas=betas)d_loss_hist, g_loss_hist, t_epochs = run_gan(trainloader, G, D, G_optimizer, D_optimizer, bceloss,n_epochs, device, latent_dim)# 保存Loss信息save_loss2txt(t_epochs, d_loss_hist, g_loss_hist)
  • 下面是训练第1、100、200轮时,随机生成的图像。

在这里插入图片描述

4 LSGAN的简述及其在MNIST数据集上的应用

4.1 LSGAN的简述

  • LSGAN(最小二乘GAN)采用最小二乘损失函数代替原始GAN的交叉熵损失函数
  • 主要针对原始GAN生成器生成的图像质量不高和训练过程不稳定两个问题
    • 作者认为以交叉熵作为损失,会使得生成器不会再优化那些被判别器识别为真实图片的生成图片,即使这些生成图片距离判别器的决策边界仍然很远,也就是距真实数据比较远。这意味着生成器的生成图片质量并不高。
    • 为什么生成器不再优化优化生成图片呢?这是因为生成器已经完成我们为它设定的目标——尽可能地混淆判别器,所以交叉熵损失已经很小了。
    • 而最小二乘就不一样了,要想最小二乘损失比较小,在混淆判别器的前提下还得让生成器把距离决策边界比较远的生成图片拉向决策边界。
  • 损失函数定义如下:

在这里插入图片描述

  • sigmoid交叉熵损失很容易就达到饱和状态(饱和是指梯度为0),而最小二乘损失只在一点达到饱和,因此LSGAN使得GAN的训练更加稳定。
    在这里插入图片描述

  • 论文链接:https://arxiv.org/pdf/1611.04076.pdf

4.2 LSGAN在MNIST数据集上的应用

  • 我们在CGAN基础上,修改为LSGAN,只修改一行代码即可。
# bceloss = nn.BCELoss().to(device)
mseloss = nn.MSELoss().to(device)

下面是训练第1、100、200轮时,随机生成的图像。
在这里插入图片描述

训练过程中的损失函数如下:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/837949.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

金融科技员工一年赚多少钱?富如恒生电子,穷如长亮科技

在本篇文章中&#xff0c;我们继续分析四家金融科技公司的数据&#xff0c;本次站在员工视角来看&#xff08;链接&#xff0c;这四家公司的更详细内容&#xff09;。 先说结论。 2023年&#xff0c;如果你是恒生电子的普通员工&#xff0c;那年薪在35万元&#xff1b;如果你在…

OSPF多区域通信基础实验

基本概念了解&#xff1a; 网络设备接口速率、 Ethernet 100MB GE 1000MB serial接口 1.544MB area 0 骨干区域&#xff08;backbone&#xff09; area 0area0.0.0.0 不是area 0 非骨干区域 ar area 256area 0.0.1.0 实验拓扑如下&#xff1a;要求PC1能够访问PC2 &#xff08;跨…

oracle多条重复数据,取最新的

1、原理讲解-可直接看2 筛选出最新的 SELECT * FROM ( SELECT t.*, ROW_NUMBER() OVER (PARTITION BY LOCALAUTHID ORDER BY LASTUPDATETIME DESC) AS rn FROM USER_LOCALAUTH_STATE t ) t WHERE t.rn 1; 解释&#xff1a; 这个序号是基于[LOCALAUTHID]字段进行分…

外网ip地址怎么获取?快解析

大家都清楚互联网是通过ip地址通信的&#xff0c;ip地址又分内网ip和外网ip。内网ip只能在内网使用&#xff1b;而外网ip作为电脑唯一标识&#xff0c;可在公网使用。那么外网ip地址怎么获取呢&#xff1f; 外网ip是网络运营商分配给用户的。目前最常见的两种上网方式一个是拉…

图文详解JUC:Wait与Sleep的区别与细节

目录 一.Wait() 二.Sleep() 三.总结Wait()与Sleep()的区别 一.Wait() 在Java中&#xff0c;wait() 方法是 Object类中的一个方法&#xff0c;用于线程间的协作。当一个线程调用wait() 方法时&#xff0c;它会释放对象的锁并进入等待状态&#xff0c;直到其他线程调用相同对…

JVM调优-调优原则和原理分析

1.写在前面 对于JVM调优这个话题&#xff0c;可能大部分程序员都听过这个名词。 但是绝大多数程序员&#xff0c;都没有真真实实去干过&#xff0c;都没有真实的实践过。也不懂得如何调优&#xff1f;不知道要调成怎么样&#xff1f; 那今天咋们就对这个话题来展开描述一下&…

洛谷 P3372:线段树 1 ← 分块算法模板(区间更新、区间查询)

【题目来源】https://www.luogu.com.cn/problem/P3372【题目描述】 如题&#xff0c;已知一个数列&#xff0c;你需要进行下面两种操作&#xff1a; &#xff08;1&#xff09;将某区间每一个数加上 k。 &#xff08;2&#xff09;求出某区间每一个数的和。【输入格式】 第一行…

二叉树——初解

二叉树 树树的概念树的性质 二叉树二叉树的概念二叉树的性质二叉树的实现方式数组构建左孩子右兄弟法构建指针构建 树 树的概念 在计算机科学中&#xff0c;树&#xff08;Tree&#xff09;是一种重要的非线性数据结构&#xff0c;它由若干节点&#xff08;Node&#xff09;组…

Chromium 调试指南2024 Windows11篇-调试变量监视(十)

1. 前言 设置断点和监视变量是调试过程中常用的两种技术手段。通过设置断点&#xff0c;我们可以暂停程序的执行并检查程序的内部状态&#xff0c;而监视变量则可以帮助我们实时查看程序中关键变量的值。本文将介绍如何在Chromium项目中进行断点设置和变量监视&#xff0c;帮助…

java内容快速回顾+SSM+SpringBoot简要概述

文章目录 java基础知识基本知识列表面对对象堆与栈的关系值修改与引用修改异常&#xff1a;错误异常 SSMspringMVCServletSpringMVC&#xff1a;基于 Servlet的 Spring Web 框架&#xff0c; spring控制反转 IoC(Inversion of Control)面向切面 Aop MybatisJDBCMybatis SpringB…

Git 基础使用(1) 入门指令

文章目录 Git 作用Git 安装Git 使用Git 仓库配置Git 工作原理Git 修改添加Git 查看日志Git 修改查询Git 版本回退 概念补充 Git 作用 Git 是一种分布式版本控制系统&#xff0c;它旨在追踪文件和文件夹的更改&#xff0c;并协助多人协作开发项目。 Git 安装 &#xff08;Lin…

17.多线程

多线程 程序、进程、线程的概念 程序&#xff1a;是指令和数据的有序集合&#xff0c;是一个静态的概念。比如&#xff0c;在电脑中&#xff0c;打开某个软件&#xff0c;就是启动程序。 进程&#xff1a;是执行程序的一次执行过程&#xff0c;是一个动态的概念&#xff0c;…

基于SSM的“口腔护理网站”的设计与实现(源码+数据库+文档)

基于SSM的“口腔护理网站”的设计与实现&#xff08;源码数据库文档) 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;SSM 工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 系统展示 首页 用户注册页面 医生信息查看模块 口腔护理预约模块 后台首页面…

分享如何通过定时任务调用lighthouse前端测试脚本+在持续集成测试中调用lighthouse前端测试脚本

最近写了个小工具来优化lighthouse在实际工作中的使用&#xff0c;具体实现了&#xff1a;通过定时任务调用前端测试脚本在持续集成测试中调用前端测试脚本。由于在公司中已经应用&#xff0c;所以就不能提供源码了&#xff0c;这里简单说一下实现思路&#xff0c;希望可以帮助…

Java 循环结构 - for, while 及 do...while

Java 循环结构 - for, while 及 do…while 顺序结构的程序语句只能被执行一次。 如果您想要同样的操作执行多次&#xff0c;就需要使用循环结构。 Java中有三种主要的循环结构&#xff1a; while 循环 do…while 循环 for 循环 在 Java5 中引入了一种主要用于数组的增强型 f…

OUC图书馆电脑开启无线网络,连接手机热点,解决联网但无法访问网络的问题

OUC图书馆电脑连手机热点 前言手动脚本&#xff08;暂未测试&#xff09;注意 前言 【中国海洋大学】OUC图书馆电脑默认只能有线连校园网&#xff0c;这让没有校园网的人很是头疼&#xff08;手机流量太多了&#xff0c;根本用不完&#xff0c;需要大流量卡的可以私信我&#…

在Android设备丢失数据后恢复数据的4个方法

了解 Android 媒体存储 媒体存储是下载、查看、播放和流式传输视频文件、音频文件、图像和其他媒体文件时所需的过程。此服务无法从手机桌面访问&#xff0c;因此您需要按照以下步骤通过安卓手机访问此系统服务。 步骤1&#xff1a;导航到手机设置&#xff0c;然后转到应用程…

初识鸿蒙之ArkTS基础

前言 学习一种应用程序开发&#xff0c;需要从这种程序的开发语言开始&#xff0c;比如说Android开发从入门到放弃&#xff0c;肯定是从Java基础或者是Kotlin语言基础开始学习的&#xff0c;IOS程序开发也肯定是从object-c开始学习的。鸿蒙软件开发也不例外&#xff0c;如果做…

Vue3+TS实现将html或富文本编辑器转为Word并下载

说明&#xff1a;我用的富文本编辑器是wangEditor&#xff1a; wangEditor官网 安装 yarn add wangeditor/editor # 或者 npm install wangeditor/editor --save yarn add wangeditor/editor-for-vuenext # 或者 npm install wangeditor/editor-for-vuenext --save yarn add …

金万维动态域名小助手怎么用?

金万维动态域名小助手是一个域名检测工具&#xff0c;使用此工具可以进行检测域名解析是否正确、清除DNS缓存、修改DNS服务器地址及寻找在线客服&#xff08;仅支持付费用户&#xff09;等操作。对不懂网络的用户是一个很好的检测域名的工具&#xff0c;下面我就讲解一下金万维…