生成对抗网络(GAN)手写数字生成

文章目录

  • 一、前言
  • 二、前期工作
    • 1. 设置GPU(如果使用的是CPU可以忽略这步)
  • 二、什么是生成对抗网络
    • 1. 简单介绍
    • 2. 应用领域
  • 三、网络结构
  • 四、构建生成器
  • 五、构建鉴别器
  • 六、训练模型
    • 1. 保存样例图片
    • 2. 训练模型
  • 七、生成动图

一、前言

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

往期精彩内容:

  • 卷积神经网络(CNN)实现mnist手写数字识别
  • 卷积神经网络(CNN)多种图片分类的实现
  • 卷积神经网络(CNN)衣服图像分类的实现
  • 卷积神经网络(CNN)鲜花识别
  • 卷积神经网络(CNN)天气识别
  • 卷积神经网络(VGG-16)识别海贼王草帽一伙
  • 卷积神经网络(ResNet-50)鸟类识别
  • 卷积神经网络(AlexNet)鸟类识别
  • 卷积神经网络(CNN)识别验证码
  • 卷积神经网络(Inception-ResNet-v2)交通标志识别

来自专栏:机器学习与深度学习算法推荐

二、前期工作

1. 设置GPU(如果使用的是CPU可以忽略这步)

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")# 打印显卡信息,确认GPU可用
print(gpus)
from tensorflow.keras import layers, datasets, Sequential, Model, optimizers
from tensorflow.keras.layers import LeakyReLU, UpSampling2D, Conv2Dimport matplotlib.pyplot as plt
import numpy             as np
import sys,os,pathlib
img_shape  = (28, 28, 1)
latent_dim = 200

二、什么是生成对抗网络

1. 简单介绍

生成对抗网络(GAN) 包含生成器和判别器,两个模型通过对抗训练不断学习、进化。

  • 生成器(Generator):生成数据(大部分情况下是图像),目的是“骗过”判别器。
  • 鉴别器(Discriminator):判断这张图像是真实的还是机器生成的,目的是找出生成器生成的“假数据”。

2. 应用领域

GAN 的应用十分广泛,它的应用包括图像合成、风格迁移、照片修复以及照片编辑,数据增强等等。

1)风格迁移

图像风格迁移是将图像A的风格转换到图像B中去,得到新的图像。

2)图像生成

GAN 不但能生成人脸,还能生成其他类型的图片,比如漫画人物。

三、网络结构

简单来讲,就是用生成器生成手写数字图像,用鉴别器鉴别图像的真假。二者相互对抗学习(卷),在对抗学习(卷)的过程中不断完善自己,直至生成器可以生成以假乱真的图片(鉴别器无法判断其真假)。结构图如下:

在这里插入图片描述

GAN步骤:

  • 1.生成器(Generator)接收随机数并返回生成图像。
  • 2.将生成的数字图像与实际数据集中的数字图像一起送到鉴别器(Discriminator)。
  • 3.鉴别器(Discriminator)接收真实和假图像并返回概率,0到1之间的数字,1表示真,0表示假。

四、构建生成器

def build_generator():# ======================================= ##     生成器,输入一串随机数字生成图片# ======================================= #model = Sequential([layers.Dense(256, input_dim=latent_dim),layers.LeakyReLU(alpha=0.2),               # 高级一点的激活函数layers.BatchNormalization(momentum=0.8),   # BN 归一化layers.Dense(512),layers.LeakyReLU(alpha=0.2),layers.BatchNormalization(momentum=0.8),layers.Dense(1024),layers.LeakyReLU(alpha=0.2),layers.BatchNormalization(momentum=0.8),layers.Dense(np.prod(img_shape), activation='tanh'),layers.Reshape(img_shape)])noise = layers.Input(shape=(latent_dim,))img = model(noise)return Model(noise, img)

五、构建鉴别器

def build_discriminator():# ===================================== ##   鉴别器,对输入的图片进行判别真假# ===================================== #model = Sequential([layers.Flatten(input_shape=img_shape),layers.Dense(512),layers.LeakyReLU(alpha=0.2),layers.Dense(256),layers.LeakyReLU(alpha=0.2),layers.Dense(1, activation='sigmoid')])img = layers.Input(shape=img_shape)validity = model(img)return Model(img, validity)
# 创建判别器
discriminator = build_discriminator()
# 定义优化器
optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])# 创建生成器 
generator = build_generator()
gan_input = layers.Input(shape=(latent_dim,))
img = generator(gan_input)# 对生成的假图片进行预测
validity = discriminator(img)
combined = Model(gan_input, validity)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

六、训练模型

1. 保存样例图片

def sample_images(epoch):"""保存样例图片"""row, col = 4, 4noise = np.random.normal(0, 1, (row*col, latent_dim))gen_imgs = generator.predict(noise)fig, axs = plt.subplots(row, col)cnt = 0for i in range(row):for j in range(col):axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')axs[i,j].axis('off')cnt += 1fig.savefig("images/%05d.png" % epoch)plt.close()

2. 训练模型

train_on_batch:函数接受单批数据,执行反向传播,然后更新模型参数,该批数据的大小可以是任意的,即,它不需要提供明确的批量大小,属于精细化控制训练模型。

def train(epochs, batch_size=128, sample_interval=50):# 加载数据(train_images,_), (_,_) = tf.keras.datasets.mnist.load_data()# 将图片标准化到 [-1, 1] 区间内   train_images = (train_images - 127.5) / 127.5# 数据train_images = np.expand_dims(train_images, axis=3)# 创建标签true = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))# 进行循环训练for epoch in range(epochs): # 随机选择 batch_size 张图片idx = np.random.randint(0, train_images.shape[0], batch_size)imgs = train_images[idx]      # 生成噪音noise = np.random.normal(0, 1, (batch_size, latent_dim))# 生成器通过噪音生成图片,gen_imgs的shape为:(128, 28, 28, 1)gen_imgs = generator.predict(noise)# 训练鉴别器 d_loss_true = discriminator.train_on_batch(imgs, true)d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)# 返回loss值d_loss = 0.5 * np.add(d_loss_true, d_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, latent_dim))g_loss = combined.train_on_batch(noise, true)print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))# 保存样例图片if epoch % sample_interval == 0:sample_images(epoch)
train(epochs=30000, batch_size=256, sample_interval=200)

七、生成动图

如果报错:ModuleNotFoundError: No module named 'imageio' 可以使用:pip install imageio 安装 imageio 库。

import imageiodef compose_gif():# 图片地址data_dir = "images_old"data_dir = pathlib.Path(data_dir)paths    = list(data_dir.glob('*'))gif_images = []for path in paths:print(path)gif_images.append(imageio.imread(path))imageio.mimsave("test.gif",gif_images,fps=2)compose_gif()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/188339.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于SSM实现的图书管理系统

一、系统架构 前端:jsp | js | css | jquery | layui 后端:spring | springmvc | mybatis 环境:jdk1.7 | mysql | maven | tomcat 二、代码及数据库 三、功能介绍 01. 登录页 02. 首页 03. 借阅管理 04. 图书管理 05. 读者管理 06. 类型管理…

【EI会议征稿】第三届网络安全、人工智能与数字经济国际学术会议(CSAIDE 2024)

第三届网络安全、人工智能与数字经济国际学术会议(CSAIDE 2024) 2024 3rd International Conference on Cyber Security, Artificial Intelligence and Digital Economy 第三届网络安全、人工智能与数字经济国际学术会议(CSAIDE 2024&#…

c题目13:验证100以内的数是否满足哥德巴赫猜想。(任一大于2的偶数都可以写成两个质数之和)

每日小语 活下去的诀窍是:保持愚蠢,又不能知道自己有多蠢。——王小波 自己思考 即要让第一个质数与这个数减去第一个质数的值都为质数,所以要满足几个条件 1.abc 2.a为质数 3.b为质数 这里是否可以用到我之前刚学的自己设置的那个判断…

带头结点的双向循环链表

目录 带头结点的双向循环链表 1.存储定义 2.结点的创建 3.结点的初始化 4.尾插结点 5.尾删结点 6.头插结点 7.头删结点 8.查找并返回结点 9.在pos结点前插入结点 10.删除pos结点 11.打印链表 12.销毁链表 13.头插结点2.0版 14.尾插结点2.0版 前言: 当…

深入探究Python中的JSON、Pickle和Shelve模块:特性与区别

更多资料获取 📚 个人网站:ipengtao.com 在Python中,处理数据序列化和持久化是极其重要的。JSON、Pickle和Shelve是三种常用的模块,它们提供了不同的方法来处理数据的序列化和持久化。本文将深入研究这三个模块,探讨它…

VBA_MF系列技术资料1-232

MF系列VBA技术资料 为了让广大学员在VBA编程中有切实可行的思路及有效的提高自己的编程技巧,我参考大量的资料,并结合自己的经验总结了这份MF系列VBA技术综合资料,而且开放源码(MF04除外),其中MF01-04属于定…

【高效开发工具系列】Hutool DateUtil工具类

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

Python+Requests模块获取响应内容

Requests模块获取响应内容 响应包括响应行、响应头、响应正文内容,这些返回的响应信息都可以通过Requests模块获取。这些 获取到的响应内容也是接口测试执行得到的实际结果。 获取响应行 获取响应头 获取其它响应信息 代码示例: # 导入requests模块…

算法通关村第十四关-青铜挑战认识堆

大家好我是苏麟 , 今天带大家认识认识堆 . 堆 堆是将一组数据按照完全二叉树的存储顺序,将数据存储在一个一维数组中的结构。 堆有两种结构,一种称为大顶堆,一种称为小顶堆 : 大顶堆 大顶堆的任何一个父节点的值,都大于或等于…

前端下拉框select标签的插件——select2.js

本文采用的是select2 版本:Select2 4.0.6-rc.1。 可以兼容IE8及以上。亲测过。 官网:Getting Started | Select2 - The jQuery replacement for select boxes 一、认识select2.js 1、使用插件,首先要引入别人的插件了,你可以选择离线(无网络)或者在线引用的(如果有网…

ios 逆向分分析,某业帮逆向算法(一)

用到工具: 爱思助手CrackerXL(砸壳软件)越狱手机ida反汇编软件分析login 的sign 签名算法中自己写算法 已知我们32位,我们不妨猜测是md5 ,那我们试图使用CC_MD5 ,这个是ios 中的标准库, 我们使用frida-trace 注入hook一下,看看有没有 经过 是经过了这个函数,密码也是…

新建的springboot项目中application.xml没有绿色小叶子(不可用)

经常有朋友会遇到新建了一个springboot项目,发现为啥我创建的application.xml配置文件不是绿色的??? 下面教大家如何解决,这也是博主在做测试的时候遇到的: 将当前位置application.xml删掉,重新…

在Spring Boot中使用@Async实现一个异步调用

在使用异步注解之前,我们需要先了解,什么是异步调用? 异步调用对应的事同步调用,同步调用是值程序按照我们定义的顺序依次执行,每一行程序都必须等待上一行的程序执行完成之后才执行,而异步是指在顺序执行…

YOLOv8 第Y7周 水果识别

1.创建文件夹: YOLOv8开源地址 -- ultralytics-main文件下载链接:GitHub - ultralytics/ultralytics: NEW - YOLOv8 🚀 in PyTorch > ONNX > OpenVINO > CoreML > TFLite 其余文件由代码生成。 数据集下载地址:Frui…

使用NVM管理多个版本的node.js

1、nvm介绍: nvm全英文也叫node.js version management,是一个nodejs的版本管理工具。nvm是node.js版本管理工具,为了解决node.js各种版本存在不兼容现象可以通过它可以安装和切换不同版本的node.js 2、下载nvm地址: https://d…

Mybatis如何执行批量操作

文章目录 Mybatis如何执行批量操作使用foreach标签 使用ExecutorType.BATCH如何获取生成的主键 Mybatis如何执行批量操作 使用foreach标签 foreach的主要用在构建in条件中,它可以在SQL语句中进行迭代一个集合。foreach标签的属性主要有item,index&…

iPhone苹果手机如何将词令网页添加到苹果iPhone手机桌面快捷打开?

iPhone苹果手机如何将词令网页添加到苹果iPhone手机桌面快捷打开? 1、在iPhone苹果手机上找到「Safari浏览器」,并点击打开; 2、打开Safari浏览器后,输入词令官方网站地址:ciling.cn ; 3、打开词令官网后,点击Safari…

Maven的配置亲测有效

文章目录 前言一、maven网址二、操作步骤三.配置环境变量四.配置本地仓库五.找到mirror 和配置JDK六.胜利七.提醒⏰;总结 前言 (我讲一下什么是maven,不想看跳到下一步就行了,也没必要看) Maven(Apache Maven&#x…

使用策略模式彻底消除if-else

文章目录 使用策略模式彻底消除if-else1. 场景描述2. if-else方式3. 策略模式 使用策略模式彻底消除if-else 如果一个对象有很多的行为,如果不用恰当的模式,这些行为就只好使用多重的条件选择语句来实现,这样会显得代码逻辑很臃肿&#xff0c…

【广州华锐视点】机械零件拆装VR仿真教学系统

随着科技的不断发展,虚拟现实(VR)技术已经逐渐走进我们的生活。在教育领域,VR技术的应用也日益广泛,为学生提供了更加生动、直观的学习体验。广州华锐视点开发的机械零件拆装VR仿真教学系统作为一种新兴的教学方式&…