从零开始 - 在Python中构建和训练生成对抗网络(GAN)模型

生成对抗网络(GANs)是一种强大的生成模型,可以合成新的逼真图像。通过完整的实现过程,读者将对GANs在幕后的工作原理有深刻的理解。本教程首先导入必要的库并加载将用于训练GAN的Fashion-MNIST数据集。然后,提供了构建GAN核心组件(生成器和判别器模型)的代码示例。接下来的部分解释了如何构建一个组合模型,该模型训练生成器以欺骗判别器,以及如何设计一个训练函数来优化对抗过程。

目录:

1. 导入库和下载数据集

2. 构建生成器模型

3. 构建判别器模型

4. 构建组合模型

5. 构建训练函数

6. 训练和观察结果

  1. 导入库和下载数据集

让我们首先导入本文中将使用的重要库:

from __future__ import print_function, division
from keras.datasets import fashion_mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt

在本文中,您将在Fashion-MNIST数据集上训练DCGAN。Fashion-MNIST包含60,000个用于训练的灰度图像和一个包含10,000个图像的测试集。每个28×28的灰度图像与10个类别中的一个标签相关联。Fashion-MNIST旨在作为原始MNIST数据集的直接替代品,用于对比机器学习算法的性能。与三通道的彩色图像相比,灰度图像在一通道上训练卷积网络时需要更少的计算能力,这使您更容易在没有GPU的个人计算机上进行训练。

a43e74d2137f4a31ce4d40fe66ab7a52.jpeg

数据集分为10个时尚类别。类别标签如下:

760b0174d7592e71606bec49bf3407a5.jpeg

您可以使用以下代码加载数据集:

(training_data, _), (_, _) = fashion_mnist.load_data()
X_train = training_data / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)

要可视化数据集中的图像,可以使用以下代码:

def visualize_input(img, ax):ax.imshow(img, cmap='gray')width, height = img.shapethresh = img.max()/2.5for x in range(width):for y in range(height):ax.annotate(str(round(img[x][y],2)), xy=(y,x),horizontalalignment='center',verticalalignment='center',color='white' if img[x][y]<thresh else="" 'black')=""  =""  
fig = plt.figure(figsize = (12,12))
ax = fig.add_subplot(111)
visualize_input(training_data[3343], ax)We also use batch normalization and a ReLU activation.
For each of these layers, the general scheme is convolution ⇒ batch normalization
⇒ ReLU. We keep stacking up layers like this until we get the final transposed
convolution layer with shape 28 × 28 × 1:

b001bcb6986483ef65aa3f19ef9b657e.jpeg

2. 构建生成器模型

正如我们在前面的文章中所探讨的,GANs由两个主要组件组成,即生成器和判别器。在这一部分中,我们将构建生成器模型,其输入将是一个噪声向量(z)。生成器的架构如下图所示。

第一层是一个全连接层,然后被重新塑造成深而窄的层,在原始的DCGAN论文中,作者将输入重新塑造为4×4×1024。在这里,我们将使用7×7×128。然后,我们使用上采样层将特征映射的维度从7×7加倍到14×14,然后再次加倍到28×28。在这个网络中,我们使用了三个卷积层。我们还将使用批归一化和ReLU激活。

对于每个层,通用方案是卷积 ⇒ 批归一化 ⇒ ReLU。我们不断地堆叠这样的层,直到得到最终的转置卷积层,形状为28×28×1。

4fabaa16f62175b0c474ff334293c279.jpeg

以下是构建上述生成器模型的Keras代码:

def build_generator():generator = Sequential()generator.add(Dense(6272, activation="relu", input_dim=100)) # Add dense layergenerator.add(Reshape((7, 7, 128)))  # reshape the imagegenerator.add(UpSampling2D()) # Upsampling layer to double the size of the imagegenerator.add(Conv2D(128, kernel_size=3, padding="same", activation="relu"))generator.add(BatchNormalization(momentum=0.8))generator.add(UpSampling2D())# convolutional + batch normalization layersgenerator.add(Conv2D(64, kernel_size=3, padding="same", activation="relu"))generator.add(BatchNormalization(momentum=0.8))# convolutional layer with filters = 1generator.add(Conv2D(1, kernel_size=3, padding="same", activation="relu"))generator.summary() # prints the model summary"""We don't add upsampling here because the image size of 28 × 28 is equal to the image size in the MNIST dataset. You can adjust this for your own problem."""noise = Input(shape=(100,))fake_image = generator(noise)# Returns a model that takes the noise vector as an input and outputs the fake imagereturn Model(inputs=noise, outputs=fake_image)

3. 构建判别器模型

GANs的第二个主要组件是判别器。判别器只是一个传统的卷积分类器。判别器的输入是28×28×1的图像。我们希望有一些卷积层,然后是输出的全连接层。

与之前一样,我们希望得到一个Sigmoid输出,并且我们需要返回logits。对于卷积层的深度,我们可以从第一层开始使用32或64个过滤器,然后在添加层时将深度加倍。在这个实现中,我们将从64层开始,然后是128,然后是256。对于降采样,我们不使用池化层。相反,我们只使用步幅卷积层进行降采样,类似于Radford等人的实现。

我们还使用批归一化和dropout来优化训练。对于四个卷积层的每一层,通用方案是卷积 ⇒ 批归一化 ⇒ 泄漏的ReLU。

c99ea77aec1203923646688e02c6e1d6.jpeg

现在,让我们构建build_discriminator函数:

def build_discriminator():discriminator = Sequential()discriminator.add(Conv2D(32, kernel_size=3, strides=2, input_shape=(28,28,1), padding="same"))discriminator.add(LeakyReLU(alpha=0.2))discriminator.add(Dropout(0.25))discriminator.add(Conv2D(64, kernel_size=3, strides=2,padding="same"))discriminator.add(ZeroPadding2D(padding=((0,1),(0,1))))discriminator.add(BatchNormalization(momentum=0.8))discriminator.add(LeakyReLU(alpha=0.2))discriminator.add(Dropout(0.25))discriminator.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))discriminator.add(BatchNormalization(momentum=0.8))discriminator.add(LeakyReLU(alpha=0.2))discriminator.add(Dropout(0.25))discriminator.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))discriminator.add(BatchNormalization(momentum=0.8))discriminator.add(LeakyReLU(alpha=0.2))discriminator.add(Dropout(0.25))discriminator.add(Flatten())discriminator.add(Dense(1, activation='sigmoid'))img = Input(shape=(28,28,1))probability = discriminator(img)return Model(inputs=img, outputs=probability)

4. 构建组合模型

正如本系列的第二篇文章中所解释的,为了训练生成器,我们需要构建一个包含生成器和判别器的组合网络。组合模型以噪声信号(z)作为输入,并将判别器的预测输出作为虚假或真实输出。

e90e9c2335ae20998fab73b192b20485.jpeg

重要的是要记住,我们希望在组合模型中禁用判别器的训练,正如本系列的第二篇文章中所解释的那样。在训练生成器时,我们不希望判别器更新权重,但我们仍然希望将判别器模型包含在生成器训练中。因此,我们创建一个包含两个模型的组合网络,但在组合网络中冻结判别器模型的权重:

optimizer = Adam(learning_rate=0.0002, beta_1=0.5)
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
discriminator.trainable = False# Build the generator
generator = build_generator()
z = Input(shape=(100,))
img = generator(z)
valid = discriminator(img)
combined = Model(inputs=z, outputs=valid)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

5. 构建训练函数

为了训练GAN模型,我们训练两个网络:判别器和我们在前面部分创建的组合网络。让我们构建train函数,该函数接受以下参数:

  • epoch

  • batch size 大小

  • save_interval,以指定多久保存一次结果

def train(epochs, batch_size=128, save_interval=50):valid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):  # Train Discriminator networkidx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]noise = np.random.normal(0, 1, (batch_size, 100))gen_imgs = generator.predict(noise)d_loss_real = discriminator.train_on_batch(imgs, valid)d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)g_loss = combined.train_on_batch(noise, valid)# printing progressprint("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" %(epoch, d_loss[0], 100*d_loss[1], g_loss))if epoch % save_interval == 0:plot_generated_images(epoch, generator)

我们还将创建另一个函数`plot_generated_images()` 来绘制生成的图像。

def plot_generated_images(epoch, generator, examples=100, dim=(10, 10),figsize=(10, 10)):noise = np.random.normal(0, 1, size=[examples, latent_dim])generated_images = generator.predict(noise)generated_images = generated_images.reshape(examples, 28, 28)plt.figure(figsize=figsize)for i in range(generated_images.shape[0]):plt.subplot(dim[0], dim[1], i+1)plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')plt.axis('off')plt.tight_layout()plt.savefig('gan_generated_image_epoch_%d.png' % epoch

最后,让我们为训练GAN模型定义重要的变量和参数:

# Input shape
img_shape = (28,28,1)
channels = 1
latent_dim = 100
optimizer = Adam(0.0002, 0.5)# Build and compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
# Build the generator
generator = build_generator()
# The generator takes noise as input and generates imgs
z = Input(shape=(latent_dim,))
img = generator(z)
# For the combined model we will only train the generator
discriminator.trainable = False
# The discriminator takes generated images as input and determines validity
valid = discriminator(img)
# The combined model  (stacked generator and discriminator)
# Trains the generator to fool the discriminator
combined = Model(z, valid)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

6. 训练和观察结果

此时,代码实现已经完成,我们准备开始DCGAN的训练。要训练模型,请运行以下代码行:

train(epochs=1000, batch_size=32, save_interval=50)

这将在1,000个epochs上运行训练,并每50个epochs保存一次图像。当运行`train()` 函数时,训练进度将如下所示:

86d990d67af3b9ee259b9424b3e1e521.jpeg

如下图所示,在epoch = 0时,图像只是随机噪声,没有明确的模式或有意义的数据。到了第50个epoch,图案已经开始形成。

80fb00ada0dc22c60488b9d4fda559aa.jpeg

在训练过程的后期,到了第1,000个epoch,您可以看到清晰的形状,可能能够猜测输入到GAN模型的训练数据的类型。

49de38a46bd9065cb03bb8125b1a990e.jpeg

再快进到第10,000个epoch,您会发现生成器已经非常擅长重新创建训练数据集中不存在的新图像。

de6db2898ea32036dd85c216a275c842.jpeg

·  END  ·

HAPPY LIFE

aeccadbe0b4d2dc12a3db6eea9e70b49.png

本文仅供学习交流使用,如有侵权请联系作者删除

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/592039.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

反转链表、链表的中间结点、合并两个有序链表【LeetCode刷题日志】

一、反转链表 给你单链表的头节点 head &#xff0c;请你反转链表&#xff0c;并返回反转后的链表。 力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 思路一&#xff1a;翻转单链表指针方向 这里解释一下三个指针的作用&#xff1a; n1&#xff1…

linux支持的零拷贝类型以及java对应的支持

在之前整理的零拷贝文章基础上 https://blog.csdn.net/zlpzlpzyd/article/details/135321197 https://blog.csdn.net/zlpzlpzyd/article/details/135317834 得出如下 因为开发的程序很多运行在 linux 操作系统上&#xff0c;所以用 linux 进行讲解 linux 调用方式 dma复制次数…

Jupyter Notebook的10个常用扩展介绍

Jupyter Notebook&#xff08;前身为IPython Notebook&#xff09;是一种开源的交互式计算和数据可视化的工具&#xff0c;广泛用于数据科学、机器学习、科学研究和教育等领域。它提供了一个基于Web的界面&#xff0c;允许用户创建和共享文档&#xff0c;这些文档包含实时代码、…

Stable Diffusion 本地部署详细教程

目录 一、前言二、系统和硬件要求三、安装前说明四、安装步骤5、升级pip(这是管理python环境软件工具),并把资源库换成国内地址为清华镜像。一、前言 虽然MJ和SD都可以生成图像,但是为什么我们要考虑使用本地SD部署呢?原因其实很简单:首先,本地部署的使用成本更低,且更加…

(15)Linux 进程创建与终止函数forkslab 分派器

前言&#xff1a;本章我们主要讲解进程的创建与终止&#xff0c;最后简单介绍一下 slab 分派器。 一、进程创建&#xff08;Process creation&#xff09; 1、分叉函数 fork 在 中&#xff0c; fork 函数是非常重要的函数&#xff0c;它从已存在进程中创建一个新的进程。 …

Python异常处理TypeError: translation() got an unexpected keyword argument ‘codeset‘

背景 学习graphql-python安装好依赖后执行命令 python manage.py migrate python manage.py runserver仅接着出现下列错误&#xff0c;主要提示是 「TypeError: translation() got an unexpected keyword argument ‘codeset’」 ntribute_to_classself.remote_field.throug…

1885页学习资料。一本在手,python不愁!

python3.11即将于下半年发布&#xff0c;新的版本速度提升2倍&#xff0c;以弥补与其他编程语言在速度上的缺陷。可以预见Python语言在未来的应用范围会越来越广。 python学习方向建议&#xff1a; 如果你是本科及以下学历&#xff0c;建议你学习以下两个方向 1、爬虫。简单…

Matplotlib基础

目录&#xff1a; 一、绘制yx^2图像&#xff1a; 一、绘制yx^2图像&#xff1a; from matplotlib import pyplot as plt import numpy as np #生成&#xff08;-50,50&#xff09;的数组 x np.arange(-50,50) #计算因变量y的值 y x ** 2 #根据x、y数组绘制图形yx^2 plt.plot…

一文带你玩转Superset!大数据可视化框架学习网站大盘点!

介绍&#xff1a;Superset是一款由Airbnb开源的现代化企业级BI工具&#xff0c;它主要用于数据分析和可视化工作。作为Apache孵化器项目的一部分&#xff0c;它在处理复杂的数据分析需求上表现出色&#xff0c;并支持多种数据源和丰富的图表类型。 这款工具的主要特点包括自助分…

PE解释器之PE文件结构

PE文件是由许许多多的结构体组成的&#xff0c;程序在运行时就会通过这些结构快速定位到PE文件的各种资源&#xff0c;其结构大致如图所示&#xff0c;从上到下依次是Dos头、Nt头、节表、节区和调试信息(可选)。其中Dos头、Nt头和节表在本文中统称为PE文件头(因为SizeOfHeaders…

大数据毕业设计:基于python淘宝数据采集分析可视化系统 商品销量数据分析 计算机毕业设计(附源码+文档)✅

毕业设计&#xff1a;2023-2024年计算机专业毕业设计选题汇总&#xff08;建议收藏&#xff09; 毕业设计&#xff1a;2023-2024年最新最全计算机专业毕设选题推荐汇总 &#x1f345;感兴趣的可以先收藏起来&#xff0c;点赞、关注不迷路&#xff0c;大家在毕设选题&#xff…

【第31例】IPD产品开发计划阶段详解

目录 简介 详细内容 作者简介 简介 今天继续更新 IPD 进阶专栏。 这节内容主要来谈谈 IPD 产品开发计划阶段。 计划阶段的主要目标是回答“怎么做”的问题。 具体就是要: 清晰定义产品及竞争优势; 理解业务计划; 制定项目计划,以及资源计划; 确保风险可以被合理管理…

微软真是活菩萨,面向初学者的机器学习、数据科学、AI、LLM课程统统免费

微软真是活菩萨&#xff0c;面向初学者的机器学习、数据科学、AI、LLM课程统统免费 大家好&#xff0c;我是老章 推荐几个质量上乘且完全免费的微软开源课程 面向初学者的机器学习课程 **地址&#xff1a;**https://microsoft.github.io/ML-For-Beginners/#/ 学习经典机器学…

Mysql 下载与安装教程(详细介绍与总结)

一&#xff1a;版本介绍 首先&#xff0c;我们需要先进入官网进行下载&#xff0c;在官网中有好几个版本&#xff0c;那么这里我分别简述一下MySQL各个版本区别&#xff1a; 1&#xff1a;企业版&#xff0c;MySQL Enterprise Edition 需要付费的&#xff0c;可以免费试用30天…

超声波传感器(附:c语言测距代码)

一、引言 超声波传感器是一种利用超声波进行检测的装置&#xff0c;具有非接触、高精度、抗干扰能力强等优点。在工业自动化、医疗诊断、环境监测等领域&#xff0c;超声波传感器发挥着重要的作用。本文将深入探讨超声波传感器的原理、应用&#xff0c;并通过C语言代码示例来展…

Windows使用IIS服务搭建WebDAV站点结合内网穿透公网访问

文章目录 1. 安装IIS必要WebDav组件2. 客户端测试3. cpolar内网穿透3.1 打开Web-UI管理界面3.2 创建隧道3.3 查看在线隧道列表3.4 浏览器访问测试 4. 安装Raidrive客户端4.1 连接WebDav服务器4.2 连接成功4.2 连接成功总结&#xff1a; 自己用Windows Server搭建了家用NAS主机&…

ffmpeg与SDL结合使用

FFmpeg 使用了 SDL 库来处理音频和视频数据的显示。SDL 提供了一套跨平台的图形显示库&#xff0c;它可以在多个操作系统上提供硬件加速的视频输出功能&#xff0c;并且支持多种常用的视频编解码格式&#xff0c;这些特性使得它成为 FFmpeg 中的一个重要组件。 在 FFmpeg 中&a…

OpenGL FXAA抗锯齿算法(Qt,Quality版本)

文章目录 一、简介二、实现代码三、实现效果参考资料一、简介 将FXAA添加到现有渲染器中很简单:它作为最终渲染通道[1]应用,仅将渲染图像作为输入,并输出抗锯齿版本。其主要思想是检测渲染图像中的边缘并使其平滑。这种方法快速有效,但会模糊纹理上的细节。该算法有两个版本…

汉诺塔问题

问题&#xff1a; Hanoi(汉诺)塔问题。这时一个古典的数学问题&#xff0c;是一个递归方法解题的典型例子。问题是这样的&#xff1a;古代有一个梵塔&#xff0c;塔内有3个座 A,B,C&#xff08;如下图&#xff09;。开始时A座上有64个盘子&#xff0c;盘子大小不等&#xff0c…

返利机器人的实现原理:从技术到收益的全面解析

返利机器人的实现原理&#xff1a;从技术到收益的全面解析 大家好&#xff0c;我是免费搭建查券返利机器人赚佣金就用微赚淘客系统3.0的小编&#xff0c;也是冬天不穿秋裤&#xff0c;天冷也要风度的程序猿&#xff01;在电商时代&#xff0c;许多消费者对返利机器人并不陌生。…