训练生成手写体数字 对抗神经网络

下面是一个使用TensorFlow和Keras的生成对抗网络(GAN)的基本示例,用于生成手写体数字。这个示例基于MNIST数据集。

 

我没有包括所有可能的最佳实践,如模型保存、加载、超参数调整、日志记录等。

首先,确保你安装了所需的库,特别是TensorFlow:

pip install tensorflow

接下来是GAN的代码:

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, Flatten, Reshape, LeakyReLU, BatchNormalization
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt# 加载MNIST数据集
(X_train, _), (_, _) = mnist.load_data()
X_train = X_train / 255.0 * 2 - 1  # 将像素值缩放到[-1, 1]# GAN参数
img_rows, img_cols, channels = 28, 28, 1
img_shape = (img_rows, img_cols, channels)
latent_dim = 100# 生成器
def build_generator():model = Sequential()model.add(Dense(256, input_dim=latent_dim))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod(img_shape), activation='tanh'))model.add(Reshape(img_shape))return model# 判别器
def build_discriminator():model = Sequential()model.add(Flatten(input_shape=img_shape))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dense(256))model.add(LeakyReLU(alpha=0.2))model.add(Dense(1, activation='sigmoid'))return model# 编译判别器
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])# 编译生成器
generator = build_generator()# 生成器输入噪声并生成图像
z = tf.keras.Input(shape=(latent_dim,))
img = generator(z)# 对于组合模型,我们只训练生成器
discriminator.trainable = False# 判别器尝试区分真实和生成的图像
valid = discriminator(img)# 组合模型(叠加生成器和判别器)
combined = tf.keras.Model(z, valid)
combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))# 训练GAN
def train(epochs, batch_size=128, save_interval=50):half_batch = batch_size // 2for epoch in range(epochs):# 随机选择一半的图像idx = np.random.randint(0, X_train.shape[0], half_batch)imgs = X_train[idx]# 生成一半的新图像noise = np.random.normal(0, 1, (half_batch, latent_dim))gen_imgs = generator.predict(noise)# 训练判别器d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, latent_dim))valid_y = np.array([1] * batch_size)g_loss = combined.train_on_batch(noise, valid_y)# 打印进度print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}] [G loss: {g_loss}]")# 如果到了保存间隔,则保存生成的图像样本if epoch % save_interval == 0:save_imgs(epoch)# 保存图像
def save_imgs(epoch):r, c = 5, 5noise = np.random.normal(0, 1, (r * c, latent_dim))gen_imgs = generator.predict(noise)# 缩放图片从[-1, 1]到[0, 1]gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')axs[i, j].axis('off')cnt += 1fig.savefig(f"mnist_{epoch}.png")plt.close()# 训练GAN
train(epochs=10000, batch_size=32, save_interval=1000)

在这个代码中,首先定义了生成器和判别器的架构,然后将它们结合起来形成一个GAN网络。

train函数负责训练过程,它交替地训练判别器和生成器。save_imgs函数用于保存生成的图像,以便我们可以查看GAN在训练过程中的进步。

这是一个非常基础的GAN实现,对于实际应用,你可能需要进行很多调整和优化,包括更复杂的模型架构、更细致的训练过程控制、超参数调整等。

逐行注释的代码,解释了每一步的作用:

# 导入所需的库
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization, Reshape, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf# 加载MNIST数据集
(X_train, _), (_, _) = mnist.load_data()# 数据预处理,将像素值归一化到[-1, 1]的范围
X_train = X_train / 127.5 - 1.0
# 增加一个维度,使图像有单通道,即形状为(batch_size, 28, 28, 1)
X_train = np.expand_dims(X_train, axis=-1)# 定义生成器模型
def build_generator():model = Sequential()model.add(Dense(256, input_shape=(100,)))  # 输入层,输入维度为100(噪声向量)model.add(LeakyReLU(alpha=0.2))  # 使用LeakyReLU激活函数model.add(BatchNormalization(momentum=0.8))  # 批量归一化model.add(Dense(512))  # 第二层,512个单元model.add(LeakyReLU(alpha=0.2))  # LeakyReLU激活函数model.add(BatchNormalization(momentum=0.8))  # 批量归一化model.add(Dense(1024))  # 第三层,1024个单元model.add(LeakyReLU(alpha=0.2))  # LeakyReLU激活函数model.add(BatchNormalization(momentum=0.8))  # 批量归一化model.add(Dense(np.prod((28, 28, 1)), activation='tanh'))  # 输出层,输出与图像像素数相同的单元数model.add(Reshape((28, 28, 1)))  # 将输出重塑为28x28图像return model# 定义判别器模型
def build_discriminator():model = Sequential()model.add(Flatten(input_shape=(28, 28, 1)))  # 输入层,将28x28图像展平model.add(Dense(512))  # 第二层,512个单元model.add(LeakyReLU(alpha=0.2))  # LeakyReLU激活函数model.add(Dense(256))  # 第三层,256个单元model.add(LeakyReLU(alpha=0.2))  # LeakyReLU激活函数model.add(Dense(1, activation='sigmoid'))  # 输出层,一个单元输出0到1之间的值return model# 编译判别器和生成器
discriminator = build_discriminator()
# 使用二元交叉熵作为损失函数,Adam优化器,以及准确度评估
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])
generator = build_generator()# GAN模型组合
z = tf.keras.Input(shape=(100,))  # 输入层,100维噪声向量
img = generator(z)  # 生成器生成图像
discriminator.trainable = False  # 在训练生成器时冻结判别器的权重
valid = discriminator(img)  # 判别器对生成的图像进行评估
combined = tf.keras.Model(z, valid)  # 组合模型,输入是噪声,输出是判别器的评估结果
combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))# 训练GAN
epochs = 10000  # 训练轮数
batch_size = 32  # 批量大小
save_interval = 1000  # 保存图片的间隔
noise_dim = 100  # 噪声向量的维度
half_batch = batch_size // 2  # 半批量大小
valid = np.ones((half_batch, 1))  # 真实图片标签
fake = np.zeros((half_batch, 1))  # 伪造图片标签for epoch in range(epochs):# 随机选择真实图片idx = np.random.randint(0, X_train.shape[0], half_batch)imgs = X_train[idx]# 生成噪声noise = np.random.normal(0, 1, (half_batch, noise_dim))# 使用噪声生成伪造图片gen_imgs = generator(noise, training=False)# 训练判别器d_loss_real = discriminator.train_on_batch(imgs, valid)d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# 生成更多噪声noise = np.random.normal(0, 1, (batch_size, noise_dim))# 训练生成器g_loss = combined.train_on_batch(noise, np.ones((batch_size, 1)))# 如果达到保存间隔,打印损失并保存生成的图片if epoch % save_interval == 0:print("Epoch {}/{} [D loss: {:.4f}, acc.: {:.2f}%] [G loss: {:.4f}]".format(epoch, epochs, d_loss[0], 100 * d_loss[1], g_loss))save_imgs(generator, epoch, noise_dim)# 定义函数以保存生成的手写数字图像
def save_imgs(generator, epoch, noise_dim):r, c = 5, 5  # 生成5x5网格的图片noise = np.random.normal(0, 1, (r * c, noise_dim))  # 生成噪声gen_imgs = generator(noise, training=False)  # 使用噪声生成图片gen_imgs = 0.5 * gen_imgs + 0.5  # 将图片的像素值从[-1, 1]缩放到[0, 1]fig, axs = plt.subplots(r, c)  # 创建子图cnt = 0for i in range(r):for j in range(c):axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')  # 显示生成的图片axs[i, j].axis('off')  # 关闭坐标轴cnt += 1fig.savefig("mnist_%d.png" % epoch)  # 保存生成的图片plt.close()  # 关闭图形显示窗口# 选择性地保存生成器模型
generator.save('mnist_generator.h5')

这样的注释有助于理解代码的每一步,特别是对于初学者来说,可以更好地理解GAN的工作原理和实现细节。

版权所有 © 2023 王一帆。除非另有说明,本作品采用[知识共享 署名-非衍生作品 4.0 国际许可协议](https://creativecommons.org/licenses/by-nd/4.0/)进行许可。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/593270.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用Redis进行搜索

文章目录 构建反向索引 构建反向索引 在Begin-End区域编写 tokenize(content) 函数,实现文本标记化的功能,具体参数与要求如下: 方法参数 content 为待标记化的文本; 文本标记的实现:使用正则表达式提取全小写化后的…

初识Java并发,一问读懂Java并发知识文集(1)

🏆作者简介,普修罗双战士,一直追求不断学习和成长,在技术的道路上持续探索和实践。 🏆多年互联网行业从业经验,历任核心研发工程师,项目技术负责人。 🎉欢迎 👍点赞✍评论…

数据库索引、三范式、事务

索引 索引(Index)是帮助 MySQL 高效获取数据的数据结构。常见的查询算法,顺序查找,二分查找,二叉排序树查找,哈希散列法,分块查找,平衡多路搜索树 B 树(B-tree)。 常见索引原则有 选择唯一性索引:唯一性索引的值是唯…

树与二叉树笔记整理

摘自小红书 ## 树与二叉树 ## 排序总结

eclipse中更改jdk版本

文章目录 步骤1:installed JREs步骤2:选择已安装的jdk步骤3:项目配置 步骤1:installed JREs 在eclipse上方工具栏找到Window -->Preferences,如下图所示: 选择Installed JREs 点击 Add 按钮, 选择Stand…

优化Adams许可管理流程,提高仿真分析效率与合规性

在工程仿真领域,Adams软件是一款广泛使用的动力学分析工具。然而,随着项目的不断扩大和复杂化,如何优化Adams许可管理流程已成为了一个重要的问题。为了帮助用户更好地管理和维护Adams软件的许可,本文将介绍优化许可管理流程的方法…

【字典树Trie】LeetCode-139. 单词拆分

139. 单词拆分。 给你一个字符串 s 和一个字符串列表 wordDict 作为字典。请你判断是否可以利用字典中出现的单词拼接出 s 。 注意:不要求字典中出现的单词全部都使用,并且字典中的单词可以重复使用。 示例 1: 输入: s "leetcode&q…

MySQL是如何保证数据一致性的?

文章目录 前言MySQL保证的一致性MySQL发生不一致环节并发冲突redolog不完整binlog&redolog不一致 MySQL解决不一致方案加锁解决并发冲突undolog解决redolog不完整XA两阶段提交解决binlog和redolog的不一致 总结 前言 通过上文《MySQL是如何保证数据不丢失的?》…

Ubuntu安装CUDA出在三个cuda相关文件夹?

按照网上的教程,在/usr/local中操作cuda文件夹,但是发现这里会出现不止一个cuda文件夹: 可以看大这里有cuda、cuda-11、cuda-11.8三个文件夹,实际上我安装的是11.8的cuda,那么第三个文件是好理解的,就是我…

Django Web框架

1、创建PyCharm项目 2、安装框架 pip install django4.2.0 3、查看安装的包列表 4、使用命令创建django项目 django-admin startproject web 5、目录结构 6、运行 cd web python manage.py runserver7、初始化后台登录的用户名密码 执行数据库迁移生成数据表 python man…

Mybatis-plus动态表名配置

一、pom文件依赖 <dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-boot-starter</artifactId><version>3.5.1</version></dependency> 二 、mybatis配置类 2.1 表名设置工具类 TableNameHelper pack…

React-hook-form-mui(二):表单数据处理

前言 在上一篇文章中&#xff0c;我们介绍了react-hook-form-mui的基础用法。本文将着表单数据处理。 react-hook-form-mui提供了丰富的表单数据处理功能&#xff0c;可以通过watch属性来获取表单数据。 Demo 下面是一个使用watch属性的例子&#xff1a; import React from…

【Redis交响乐】Redis中的数据类型/内部编码/单线程模型

文章目录 一. Redis中的数据类型和内部编码二. Redis的单线程模型面试题: redis是单线程模型,为什么效率之高,速度之快呢? 在上一篇博客中我们讲述了Redis中的通用命令,本篇博客中我们将围绕每个数据结构来介绍相关命令. 一. Redis中的数据类型和内部编码 type命令实际返回的…

【MATLAB】EMD_LSTM神经网络时序预测算法

有意向获取代码&#xff0c;请转文末观看代码获取方式~也可转原文链接获取~ 1 基本定义 EMD-LSTM神经网络时序预测算法是一种结合了经验模态分解&#xff08;EMD&#xff09;和长短期记忆神经网络&#xff08;LSTM&#xff09;的时间序列预测方法。 EMD是一种处理非平稳信号的…

Linux的引导过程与服务控制

一.开机启动的完整过程 引导过程&#xff1a; 1.bios加电自检 检测硬件是否正常&#xff0c;然后根据bios中的启动项设置&#xff0c;去找内核文件 服务器主机开机以后&#xff0c;将根据主板BIOS中的设置对CPU、内存、显卡、键盘灯设备进行初步检测&#xff0c;检测成功后根…

stable diffusion 基础教程-图生图

界面 图生图大概有以下几个功能: 图生图涂鸦绘制局部绘制局部绘制(涂鸦蒙版)其常用的也就上面四个,接下来逐步讲解。 以图反推提示词 图生图可以根据反推提示词来获取相应图片的提示词,目前3种主流方式,如下: CLIP反推提示词:推导出的文本倾向于自然语言的描述方式,…

openmediavault(OMV) (26)网络(1)ddns-go

简介 "ddns-go" 是一个动态域名解析(Dynamic DNS)工具,用于更新域名的IP地址。它可以自动检测你的公共IP地址,并将其更新到指定的域名解析服务商,以确保你的域名始终与最新的IP地址相匹配。 安装 hub.docker.com上下载ddns-go镜像 配置compose文件 --- versio…

C++系列十一:C++指针

C指针 1. 指针的声明和初始化2. 指针的运算3. 指针与数组4. 指针与函数参数传递5. 指针与动态内存分配6. 指针与多维数组7. 指针与函数返回值8. 指针与内存管理9. 指针的高级应用 指针是C中一个非常重要的概念&#xff0c;它是指向变量、数组或对象的内存地址的引用。通过指针&…

LeetCode 466. 统计重复个数,循环字符串匹配优化

一、题目 1、题目描述 定义 str [s, n] 表示 str 由 n 个字符串 s 连接构成。 例如&#xff0c;str ["abc", 3] "abcabcabc" 。 如果可以从 s2 中删除某些字符使其变为 s1&#xff0c;则称字符串 s1 可以从字符串 s2 获得。 例如&#xff0c;根据定义&a…

在 sealos 上使用 redisinsight 完美管理 redis

先起一个 redis 集群&#xff0c;在 sealos 上可以点点鼠标就搞定&#xff1a; 简单两步&#xff0c;redis 集群搞定。 再启动 RedisInsight, 是一个 redis 的可视化管理工具。 就可以看到部署后的地址了。进去之后填写 redis 的链接信息即可&#xff1a; 链接信息在数据库的…