生成对抗网络(DCGAN)手写数字生成

文章目录

  • 一、前言
  • 二、前期工作
    • 1. 设置GPU(如果使用的是CPU可以忽略这步)
  • 二、什么是生成对抗网络
    • 1. 简单介绍
    • 2. 应用领域
  • 三、创建模型
      • 1. 生成器
      • 2. 判别器
  • 四、定义损失函数和优化器
      • 1. 判别器损失
      • 2. 生成器损失
  • 五、定义训练循环
  • 六、训练模型
  • 七、创建 GIF

一、前言

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

往期精彩内容:

  • 卷积神经网络(CNN)实现mnist手写数字识别
  • 卷积神经网络(CNN)多种图片分类的实现
  • 卷积神经网络(CNN)衣服图像分类的实现
  • 卷积神经网络(CNN)鲜花识别
  • 卷积神经网络(CNN)天气识别
  • 卷积神经网络(VGG-16)识别海贼王草帽一伙
  • 卷积神经网络(ResNet-50)鸟类识别
  • 卷积神经网络(AlexNet)鸟类识别
  • 卷积神经网络(CNN)识别验证码
  • 卷积神经网络(Inception-ResNet-v2)交通标志识别

来自专栏:机器学习与深度学习算法推荐

二、前期工作

1. 设置GPU(如果使用的是CPU可以忽略这步)

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")# 打印显卡信息,确认GPU可用
print(gpus)
from tensorflow.keras  import layers
from IPython           import display
import matplotlib.pyplot as plt
import numpy             as np
import glob,imageio,os,PIL,time
(train_images, _), (_, _) = tf.keras.datasets.mnist.load_data()train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')# 将图片标准化到 [-1, 1] 区间内
train_images = train_images / 127.5 - 1  
BUFFER_SIZE = 60000
BATCH_SIZE  = 256# 批量化和打乱数据
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

二、什么是生成对抗网络

1. 简单介绍

生成对抗网络(GAN) 包含生成器和判别器,两个模型通过对抗训练不断学习、进化。

  • 生成器(Generator):生成数据(大部分情况下是图像),目的是“骗过”判别器。
  • 鉴别器(Discriminator):判断这张图像是真实的还是机器生成的,目的是找出生成器生成的“假数据”。

2. 应用领域

GAN 的应用十分广泛,它的应用包括图像合成、风格迁移、照片修复以及照片编辑,数据增强等等。

1)风格迁移

图像风格迁移是将图像A的风格转换到图像B中去,得到新的图像。

2)图像生成

GAN 不但能生成人脸,还能生成其他类型的图片,比如漫画人物。

三、创建模型

1. 生成器

生成器使用 tf.keras.layers.Conv2DTranspose (上采样)层来从种子(随机噪声)中产生图片。以一个使用该种子作为输入的 Dense 层开始,然后多次上采样直到达到所期望的 28x28x1 的图片尺寸。注意除了输出层使用 tanh 之外,其他每层均使用 tf.keras.layers.LeakyReLU 作为激活函数。

def make_generator_model():model = tf.keras.Sequential([layers.Dense(7*7*256, use_bias=False, input_shape=(100,)),layers.BatchNormalization(),layers.LeakyReLU(),layers.Reshape((7, 7, 256)),layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False),layers.BatchNormalization(),layers.LeakyReLU(),layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False),layers.BatchNormalization(),layers.LeakyReLU(),layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')])return modelgenerator = make_generator_model()
generator.summary()

2. 判别器

判别器是一个基于 CNN 的图片分类器。

def make_discriminator_model():model = tf.keras.Sequential([layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]),layers.LeakyReLU(),layers.Dropout(0.3),layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),layers.LeakyReLU(),layers.Dropout(0.3),layers.Flatten(),layers.Dense(1)])return modeldiscriminator = make_discriminator_model()
discriminator.summary()

四、定义损失函数和优化器

为两个模型定义损失函数和优化器。

# 该方法返回计算交叉熵损失的辅助函数
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

1. 判别器损失

该方法量化判断真伪图片的能力。它将判别器对真实图片的预测值与值全为 1 的数组进行对比,将判别器对伪造(生成的)图片的预测值与值全为 0 的数组进行对比。

def discriminator_loss(real_output, fake_output):real_loss = cross_entropy(tf.ones_like(real_output), real_output)fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)total_loss = real_loss + fake_lossreturn total_loss

2. 生成器损失

生成器损失量化其欺骗判别器的能力。直观来讲,如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或 1)。这里我们将把判别器在生成图片上的判断结果与一个值全为 1 的数组进行对比。

def generator_loss(fake_output):return cross_entropy(tf.ones_like(fake_output), fake_output)

由于我们需要分别训练两个网络,判别器和生成器的优化器是不同的。

generator_optimizer     = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

五、定义训练循环

EPOCHS = 60
noise_dim = 100
num_examples_to_generate = 16# 我们将重复使用该种子(在 GIF 中更容易可视化进度)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

训练循环在生成器接收到一个随机种子作为输入时开始。该种子用于生产一张图片。判别器随后被用于区分真实图片(选自训练集)和伪造图片(由生成器生成)。针对这里的每一个模型都计算损失函数,并且计算梯度用于更新生成器与判别器。

# 注意 `tf.function` 的使用
# 该注解使函数被“编译”
@tf.function
def train_step(images):# 生成噪音noise = tf.random.normal([BATCH_SIZE, noise_dim])with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:generated_images = generator(noise, training=True)real_output = discriminator(images, training=True)fake_output = discriminator(generated_images, training=True)# 计算lossgen_loss = generator_loss(fake_output)disc_loss = discriminator_loss(real_output, fake_output)#计算梯度gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)#更新模型generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):for epoch in range(epochs):start = time.time()for image_batch in dataset:train_step(image_batch)# 实时更新生成的图片display.clear_output(wait=True)generate_and_save_images(generator, epoch + 1, seed)print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))# 最后一个 epoch 结束后生成图片display.clear_output(wait=True)generate_and_save_images(generator, epochs, seed)
def generate_and_save_images(model, epoch, test_input):# 注意 training` 设定为 False# 因此,所有层都在推理模式下运行(batchnorm)。predictions = model(test_input, training=False)fig = plt.figure(figsize=(4,4))for i in range(predictions.shape[0]):plt.subplot(4, 4, i+1)plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')plt.axis('off')plt.savefig('./images/19/image_at_epoch_{:04d}.png'.format(epoch))plt.show()

六、训练模型

调用上面定义的 train() 方法来同时训练生成器和判别器。在训练之初,生成的图片看起来像是随机噪声。随着训练过程的进行,生成的数字将越来越真实。在大概 50 个 epoch 之后,这些图片看起来像是 MNIST 数字。

%%time:将会给出cell的代码运行一次所花费的时间。

%%time
train(train_dataset, EPOCHS)

在这里插入图片描述

七、创建 GIF

import imageio,pathlibdef compose_gif():# 图片地址data_dir = "./images/19"data_dir = pathlib.Path(data_dir)paths    = list(data_dir.glob('*'))gif_images = []for path in paths:gif_images.append(imageio.imread(path))imageio.mimsave("./pic_gif/MINST_DCGAN_19.gif",gif_images,fps=8)compose_gif()
print("GIF动图生成完成!")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/187881.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

“前端八股文背诵版“,终于整理完了,堪称最强!

随着互联网的快速发展,前端开发领域成为了IT行业中的热门领域之一。很多求职者都希望能够进入这个领域,但是面对着如此激烈的竞争,很多人都感到无从下手。为了帮助大家更好地掌握前端开发的相关知识,小编整理了一份前端面试题合集…

单片机怎么实现真正的多线程?

单片机怎么实现真正的多线程? 不考虑多核情况时,CPU在一个时间点只能做一件事,因为切换的速度快所以看起来好像是同时执行多个线程而已。 实际上就是用定时器来做时基,以时间片的方式分别执行来实现的,只不过实现起来细节比较复…

网络安全应急响应-Server2228(环境+解析)

网络安全应急响应 任务环境说明: 服务器场景:Server2228(开放链接)用户名:root,密码:p@ssw0rd123

【代码】考虑差异性充电模式的电动汽车充放电优化调度matlab-yalmip-cplex/gurobi

程序名称:考虑差异性充电模式的电动汽车充放电优化调度 实现平台:matlab-yalmip-cplex/gurobi 代码简介:提出了一种微电网中电动汽车的协调充电调度方法,以将负荷需求从高峰期转移到低谷期。在所提出的方法中,基于充…

RHEL8.9 静默安装Oracle19C

RHEL8.9 静默安装Oracle19C 甘肃圆角网络科技开发有限公司 说明(GUI):  1.实际业务场景中,Linux环境一般情况下是没有GUI的。没有GUI并不意味着没有安装图形界面。可能在部署Linux操作系统环境的时候安装了桌面环境,只是启动的时候设置了启动…

英国人工智能初创公司Stability AI面临卖身压力;深度学习中的检索增强生成简介

🦉 AI新闻 🚀 英国人工智能初创公司Stability AI面临卖身压力 摘要:多位知情人士透露,英国人工智能初创公司Stability AI正寻求出售公司,因为投资者对其财务状况的压力越来越大。管理层最近几周一直将自己标榜为收购…

MapInfo Pro错误提示:This operation requires elevated privileges……

尝试删除MapInfo Pro时出现错误“此操作需要提升权限。将产品DVD/CD插入媒体播放器并双击setup.exe文件重新启动。此安装程序现在将中止。”。 原因: 这可能是由于权限问题。 解决方式: 1.如果MapInfo Pro setup.exe可用,请执行以下步骤&…

笔记64:Bahdanau 注意力

本地笔记地址:D:\work_file\(4)DeepLearning_Learning\03_个人笔记\3.循环神经网络\第10章:动手学深度学习~注意力机制 a a a a a a a a a a a

FIORI /N/UI2/FLP 始终在IE浏览器中打开 无法在缺省浏览器中打开

在使用/N/UI2/FLP 打开fiori 启动面板的时候,总是会在IE浏览器中打开,无法在缺省浏览器打开 并且URL中包含myssocntl 无法正常打开 启动面板 这种情况可以取消激活ICF节点/sap/public/myssocntl

spring boot 3.2.0 idea从零开始

spring boot 3.2.0 idea从零开始 最新的spring initilizer 不再支持低版本java,只能选择17、21 。 我也被迫尝试下最新版本的java。 jdk下载地址 自定义好artifact和group之后点击下一步。 在这里选择需要的组件,我准备做web项目所以只选择spring web …

word模板导出word文件

前期准备工作word模板 右键字段如果无编辑域 ctrlF9 一下&#xff0c;然后再右键 wps 直接 ctrlF9 会变成编辑域 pom.xml所需依赖 <dependencies> <!--word 依赖--> <dependency><groupId>fr.opensagres.xdocreport</groupId><artifactId…

vim工具以及如何给用户加上sudo的权限

Linux开发工具之vim以及如何给用户配置sudo的权限文件的操作 1.vim概念的介绍 2.vim的多模式的介绍 3.vim的命令模式与低行模式的相关指令操作 4.vim如何配置 5.如何给普通用户配置sudo的权限 本文开始~~~~ 1. vim概念的介绍 vim是一款多模式的文本编辑器&#xff0c;简单…

C语言-指针讲解(4)

在上一篇博客中&#xff1a; C语言-指针讲解(3) 我们给大家介绍了指针进阶的用法 让下面我们来回顾一下讲了什么吧&#xff1a; 1.字符指针变量类型以及用法 2.数组指针本质上是一个指针&#xff0c;里面存放数组的地址。而指针数组本质上是个数组&#xff0c;里面存放的是指针…

iOS系统上待办事项提醒软件哪个好

在这个快节奏的生活中&#xff0c;各种待办事项充斥了我们的日常工作和生活&#xff0c;尤其对于像我这样的iPhone用户而言&#xff0c;一款能够在iOS系统上快速和准确记录和提醒待办事项的软件&#xff0c;显得至关重要。 正如前几天&#xff0c;我正沉浸在工作中的时突然被领…

【算法】Rabin-Karp 算法

目录 1.概述2.代码实现3.应用 更多数据结构与算法的相关知识可以查看数据结构与算法这一专栏。 有关字符串模式匹配的其它算法&#xff1a; 【算法】Brute-Force 算法 【算法】KMP 算法 1.概述 &#xff08;1&#xff09;Rabin-Karp 算法是由 Richard M. Karp 和 Michael O. R…

时间序列预测 — LSTM实现多变量多步负荷预测(Keras)

目录 1 数据处理 1.1 数据集简介 1.2 数据集处理 2 模型训练与预测 2.1 模型训练 2.2 模型多步预测 2.3 结果可视化 1 数据处理 1.1 数据集简介 实验数据集采用数据集6&#xff1a;澳大利亚电力负荷与价格预测数据&#xff08;下载链接&#xff09;&#xff0c;包括数…

2023年中国金融租赁行业研究报告

第一章 行业概况 1.1 定义 金融租赁是一种融资方式&#xff0c;其中租赁公司&#xff08;出租人&#xff09;为企业&#xff08;承租人&#xff09;购买所需设备&#xff0c;并在租赁期内由承租人使用。承租人负责支付租金&#xff0c;租赁期满后有权选择退租、续租或购买设备…

(二)Tiki-taka算法(TTA)求解无人机三维路径规划研究(MATLAB)

一、无人机模型简介&#xff1a; 单个无人机三维路径规划问题及其建模_IT猿手的博客-CSDN博客 参考文献&#xff1a; [1]胡观凯,钟建华,李永正,黎万洪.基于IPSO-GA算法的无人机三维路径规划[J].现代电子技术,2023,46(07):115-120 二、Tiki-taka算法&#xff08;TTA&#xf…

数据爬取+可视化实战_告白气球_词云展示----酷狗音乐

一、前言 歌词上做文本分析&#xff0c;数据存储在网页上&#xff0c;需要爬取数据下来&#xff0c;词云展示在工作中也变得日益重要&#xff0c;接下来将数据爬虫与可视化结合起来&#xff0c;做个词云展示案例。 二、代码 # -*- coding:utf-8 -*- # 酷狗音乐 通过获取每首歌…

使用com组件编辑word

一个普通的窗体应用&#xff0c;6个button using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.Linq; using System.Text; using System.Threading.Tasks; using System.Windows.Forms; u…