生成对抗网络(GAN)手写数字生成

文章目录

  • 一、前言
  • 二、前期工作
    • 1. 设置GPU(如果使用的是CPU可以忽略这步)
  • 二、什么是生成对抗网络
    • 1. 简单介绍
    • 2. 应用领域
  • 三、网络结构
  • 四、构建生成器
  • 五、构建鉴别器
  • 六、训练模型
    • 1. 保存样例图片
    • 2. 训练模型
  • 七、生成动图

一、前言

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

往期精彩内容:

  • 卷积神经网络(CNN)实现mnist手写数字识别
  • 卷积神经网络(CNN)多种图片分类的实现
  • 卷积神经网络(CNN)衣服图像分类的实现
  • 卷积神经网络(CNN)鲜花识别
  • 卷积神经网络(CNN)天气识别
  • 卷积神经网络(VGG-16)识别海贼王草帽一伙
  • 卷积神经网络(ResNet-50)鸟类识别
  • 卷积神经网络(AlexNet)鸟类识别
  • 卷积神经网络(CNN)识别验证码
  • 卷积神经网络(Inception-ResNet-v2)交通标志识别

来自专栏:机器学习与深度学习算法推荐

二、前期工作

1. 设置GPU(如果使用的是CPU可以忽略这步)

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")# 打印显卡信息,确认GPU可用
print(gpus)
from tensorflow.keras import layers, datasets, Sequential, Model, optimizers
from tensorflow.keras.layers import LeakyReLU, UpSampling2D, Conv2Dimport matplotlib.pyplot as plt
import numpy             as np
import sys,os,pathlib
img_shape  = (28, 28, 1)
latent_dim = 200

二、什么是生成对抗网络

1. 简单介绍

生成对抗网络(GAN) 包含生成器和判别器,两个模型通过对抗训练不断学习、进化。

  • 生成器(Generator):生成数据(大部分情况下是图像),目的是“骗过”判别器。
  • 鉴别器(Discriminator):判断这张图像是真实的还是机器生成的,目的是找出生成器生成的“假数据”。

2. 应用领域

GAN 的应用十分广泛,它的应用包括图像合成、风格迁移、照片修复以及照片编辑,数据增强等等。

1)风格迁移

图像风格迁移是将图像A的风格转换到图像B中去,得到新的图像。

2)图像生成

GAN 不但能生成人脸,还能生成其他类型的图片,比如漫画人物。

三、网络结构

简单来讲,就是用生成器生成手写数字图像,用鉴别器鉴别图像的真假。二者相互对抗学习(卷),在对抗学习(卷)的过程中不断完善自己,直至生成器可以生成以假乱真的图片(鉴别器无法判断其真假)。结构图如下:

在这里插入图片描述

GAN步骤:

  • 1.生成器(Generator)接收随机数并返回生成图像。
  • 2.将生成的数字图像与实际数据集中的数字图像一起送到鉴别器(Discriminator)。
  • 3.鉴别器(Discriminator)接收真实和假图像并返回概率,0到1之间的数字,1表示真,0表示假。

四、构建生成器

def build_generator():# ======================================= ##     生成器,输入一串随机数字生成图片# ======================================= #model = Sequential([layers.Dense(256, input_dim=latent_dim),layers.LeakyReLU(alpha=0.2),               # 高级一点的激活函数layers.BatchNormalization(momentum=0.8),   # BN 归一化layers.Dense(512),layers.LeakyReLU(alpha=0.2),layers.BatchNormalization(momentum=0.8),layers.Dense(1024),layers.LeakyReLU(alpha=0.2),layers.BatchNormalization(momentum=0.8),layers.Dense(np.prod(img_shape), activation='tanh'),layers.Reshape(img_shape)])noise = layers.Input(shape=(latent_dim,))img = model(noise)return Model(noise, img)

五、构建鉴别器

def build_discriminator():# ===================================== ##   鉴别器,对输入的图片进行判别真假# ===================================== #model = Sequential([layers.Flatten(input_shape=img_shape),layers.Dense(512),layers.LeakyReLU(alpha=0.2),layers.Dense(256),layers.LeakyReLU(alpha=0.2),layers.Dense(1, activation='sigmoid')])img = layers.Input(shape=img_shape)validity = model(img)return Model(img, validity)
# 创建判别器
discriminator = build_discriminator()
# 定义优化器
optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])# 创建生成器 
generator = build_generator()
gan_input = layers.Input(shape=(latent_dim,))
img = generator(gan_input)# 对生成的假图片进行预测
validity = discriminator(img)
combined = Model(gan_input, validity)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

六、训练模型

1. 保存样例图片

def sample_images(epoch):"""保存样例图片"""row, col = 4, 4noise = np.random.normal(0, 1, (row*col, latent_dim))gen_imgs = generator.predict(noise)fig, axs = plt.subplots(row, col)cnt = 0for i in range(row):for j in range(col):axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')axs[i,j].axis('off')cnt += 1fig.savefig("images/%05d.png" % epoch)plt.close()

2. 训练模型

train_on_batch:函数接受单批数据,执行反向传播,然后更新模型参数,该批数据的大小可以是任意的,即,它不需要提供明确的批量大小,属于精细化控制训练模型。

def train(epochs, batch_size=128, sample_interval=50):# 加载数据(train_images,_), (_,_) = tf.keras.datasets.mnist.load_data()# 将图片标准化到 [-1, 1] 区间内   train_images = (train_images - 127.5) / 127.5# 数据train_images = np.expand_dims(train_images, axis=3)# 创建标签true = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))# 进行循环训练for epoch in range(epochs): # 随机选择 batch_size 张图片idx = np.random.randint(0, train_images.shape[0], batch_size)imgs = train_images[idx]      # 生成噪音noise = np.random.normal(0, 1, (batch_size, latent_dim))# 生成器通过噪音生成图片,gen_imgs的shape为:(128, 28, 28, 1)gen_imgs = generator.predict(noise)# 训练鉴别器 d_loss_true = discriminator.train_on_batch(imgs, true)d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)# 返回loss值d_loss = 0.5 * np.add(d_loss_true, d_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, latent_dim))g_loss = combined.train_on_batch(noise, true)print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))# 保存样例图片if epoch % sample_interval == 0:sample_images(epoch)
train(epochs=30000, batch_size=256, sample_interval=200)

七、生成动图

如果报错:ModuleNotFoundError: No module named 'imageio' 可以使用:pip install imageio 安装 imageio 库。

import imageiodef compose_gif():# 图片地址data_dir = "images_old"data_dir = pathlib.Path(data_dir)paths    = list(data_dir.glob('*'))gif_images = []for path in paths:print(path)gif_images.append(imageio.imread(path))imageio.mimsave("test.gif",gif_images,fps=2)compose_gif()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/188339.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于SSM实现的图书管理系统

一、系统架构 前端:jsp | js | css | jquery | layui 后端:spring | springmvc | mybatis 环境:jdk1.7 | mysql | maven | tomcat 二、代码及数据库 三、功能介绍 01. 登录页 02. 首页 03. 借阅管理 04. 图书管理 05. 读者管理 06. 类型管理…

浏览器安全攻击与防御

前言 浏览器是我们访问互联网的主要工具,也是我们接触信息的主要渠道。但是,浏览器也可能成为攻击者利用的突破口,通过各种手段,窃取或篡改我们的数据,甚至控制我们的设备.本文将向大家介绍一些常见的浏览器安全的攻击…

UniTask异步解决方案

是一个高性能,0GC的async/await异步方案 协程缺点: 依赖monobehaviour 不能进行异常处理 方法返回值获取困难 c#原生Task: 优点: 不依赖monobehaviour 可以处理异常 缺点: Task消耗大,设计跨线程操作 uniTa…

【EI会议征稿】第三届网络安全、人工智能与数字经济国际学术会议(CSAIDE 2024)

第三届网络安全、人工智能与数字经济国际学术会议(CSAIDE 2024) 2024 3rd International Conference on Cyber Security, Artificial Intelligence and Digital Economy 第三届网络安全、人工智能与数字经济国际学术会议(CSAIDE 2024&#…

c题目13:验证100以内的数是否满足哥德巴赫猜想。(任一大于2的偶数都可以写成两个质数之和)

每日小语 活下去的诀窍是:保持愚蠢,又不能知道自己有多蠢。——王小波 自己思考 即要让第一个质数与这个数减去第一个质数的值都为质数,所以要满足几个条件 1.abc 2.a为质数 3.b为质数 这里是否可以用到我之前刚学的自己设置的那个判断…

带头结点的双向循环链表

目录 带头结点的双向循环链表 1.存储定义 2.结点的创建 3.结点的初始化 4.尾插结点 5.尾删结点 6.头插结点 7.头删结点 8.查找并返回结点 9.在pos结点前插入结点 10.删除pos结点 11.打印链表 12.销毁链表 13.头插结点2.0版 14.尾插结点2.0版 前言: 当…

深入探究Python中的JSON、Pickle和Shelve模块:特性与区别

更多资料获取 📚 个人网站:ipengtao.com 在Python中,处理数据序列化和持久化是极其重要的。JSON、Pickle和Shelve是三种常用的模块,它们提供了不同的方法来处理数据的序列化和持久化。本文将深入研究这三个模块,探讨它…

conda和pip常用命令整理

文章目录 一、conda常用指令1. 更新2 .环境管理3. 包管理 二、pip常用命令1. 常用命令2. 国内镜像 一、conda常用指令 1. 更新 conda --version 或 conda -V #查看conda版本 conda update conda # 基本升级 conda update anaconda # 大的升级 conda upd…

VBA_MF系列技术资料1-232

MF系列VBA技术资料 为了让广大学员在VBA编程中有切实可行的思路及有效的提高自己的编程技巧,我参考大量的资料,并结合自己的经验总结了这份MF系列VBA技术综合资料,而且开放源码(MF04除外),其中MF01-04属于定…

动态规划:多重背包问题-一维滚动数组解法

题目描述 你是一名宇航员,即将前往一个遥远的行星。在这个行星上,有许多不同类型的矿石资源,每种矿石都有不同的重要性和价值。你需要选择哪些矿石带回地球,但你的宇航舱有一定的容量限制。 给定一个宇航舱,最大容量…

02_学习使用javax_ws_rs_下载文件

文章目录 1 前言2 Maven 依赖3 下载接口4 如何返回文件?5 感谢 1 前言 专栏上一篇,写了如何使用 javax.ws.rs 上传文件,那么必然的,我们得再学习学习如何下载文件😀 2 Maven 依赖 这个就不赘述了,和上一篇…

【高效开发工具系列】Hutool DateUtil工具类

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

网络爬虫与指纹浏览器:解析指纹浏览器对网络爬虫的作用

网络爬虫在信息搜集、数据挖掘等领域起着重要作用。然而,传统爬虫往往面临被目标网站封禁的风险。本文将介绍指纹浏览器对网络爬虫的作用,以及指纹浏览器如何帮助爬虫降低封禁风险。 网络爬虫面临的挑战 网络爬虫是一种自动化程序,用于从互联…

Python+Requests模块获取响应内容

Requests模块获取响应内容 响应包括响应行、响应头、响应正文内容,这些返回的响应信息都可以通过Requests模块获取。这些 获取到的响应内容也是接口测试执行得到的实际结果。 获取响应行 获取响应头 获取其它响应信息 代码示例: # 导入requests模块…

华为OD机试真题-分割均衡字符串-2023年OD统一考试(C卷)

题目描述: 均衡串定义:字符串只包含两种字符,且两种字符的个数相同。 给定一个均衡字符串,请给出可分割成新的均衡子串的最大个数。 约定字符串中只包含大写的X和Y两种字符。 输入描述: 均衡串:XXYYXY 字符…

算法通关村第十四关-青铜挑战认识堆

大家好我是苏麟 , 今天带大家认识认识堆 . 堆 堆是将一组数据按照完全二叉树的存储顺序,将数据存储在一个一维数组中的结构。 堆有两种结构,一种称为大顶堆,一种称为小顶堆 : 大顶堆 大顶堆的任何一个父节点的值,都大于或等于…

前端下拉框select标签的插件——select2.js

本文采用的是select2 版本:Select2 4.0.6-rc.1。 可以兼容IE8及以上。亲测过。 官网:Getting Started | Select2 - The jQuery replacement for select boxes 一、认识select2.js 1、使用插件,首先要引入别人的插件了,你可以选择离线(无网络)或者在线引用的(如果有网…

ios 逆向分分析,某业帮逆向算法(一)

用到工具: 爱思助手CrackerXL(砸壳软件)越狱手机ida反汇编软件分析login 的sign 签名算法中自己写算法 已知我们32位,我们不妨猜测是md5 ,那我们试图使用CC_MD5 ,这个是ios 中的标准库, 我们使用frida-trace 注入hook一下,看看有没有 经过 是经过了这个函数,密码也是…

新建的springboot项目中application.xml没有绿色小叶子(不可用)

经常有朋友会遇到新建了一个springboot项目,发现为啥我创建的application.xml配置文件不是绿色的??? 下面教大家如何解决,这也是博主在做测试的时候遇到的: 将当前位置application.xml删掉,重新…

Java,Stream API的使用

Stream是数据渠道,用于操作数据源(集合、数组等)所生成的元素序列。 Stream和Collection集合的区别:Collection是一种静态的内存数据结构,讲的是数据,而Stream是有关计算的,讲的是计算。集合主…