文章目录
- 使用GAN生成手写数字样本
- 附:系列文章
使用GAN生成手写数字样本
生成对抗网络
GAN(Generative Adversarial Networks)生成对抗网络是一种深度学习模型架构,由深度生成网络(Generator)和深度鉴别网络(Discriminator)组成,并且利用对抗学习的方式训练。GAN最初由Ian Goodfellow在2014年提出,自提出以来一直受到学术界和工业界广泛的关注和研究。
GAN的主要思想是让生成网络从噪声中生成样本,并通过鉴别网络来评估生成的样本与真实数据的相似度。生成网络利用噪声输入生成样本,鉴别网络则根据输入的样本给出一个判断,判断这个样本是不是真实数据。生成网络和鉴别网络通过对抗学习的方式来互相学习和提高。在训练过程中,生成网络希望生成的样本能够欺骗鉴别网络,鉴别网络则希望能够区分真实数据和生成数据,从而达到提高样本质量的目的。
GAN的应用非常广泛,主要包括图像生成、视频生成、自然语言处理等领域。在图像生成方面,GAN可以用于生成各种样式的图片,例如人物头像、动物、食品等。在视频生成方面,GAN可以生成逼真的视频序列,包括人物动作、自然风景等。在自然语言处理方面,GAN可以生成逼真的对话、文章等。
GAN的训练过程相对其他深度学习模型更加复杂。生成网络和鉴别网络需要保持平衡,让生成网络生成的样本能够欺骗鉴别网络,同时鉴别网络也需要保持自己的准确率,判断生成的样本是否真实。由于GAN的训练过程极易出现训练不稳定、模式崩溃等问题,因此需要在使用时进行一定的调整和优化。
GAN的发展史上涌现出一系列的变体模型,例如Conditional GAN(CGAN)、CycleGAN、Pix2Pix等。这些变体模型在应用场景上有所不同,但是核心思想都是在GAN的基础上进行调整和改进。
GAN在学术界和工业界都受到了广泛的关注和研究,许多实际应用都对GAN有较高的需求。同时,GAN的研究也面临着一系列的问题和挑战,例如GAN的稳定性、样本多样性等。可以预见,在未来的发展中,GAN会继续得到广泛的关注和应用。
程序设计
# 导入相关库
from __future__ import print_function, division from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adamimport matplotlib.pyplot as plt
import numpy as np
import sys
import os
class GAN():def __init__(self):# 行28,列28,也就是mnist的shape# 通道为1,灰度图self.img_rows = 28self.img_cols = 28self.channels = 1# 28*28*1self.img_shape = (self.img_rows, self.img_cols, self.channels)self.latent_dim = 100# adam优化器optimizer = Adam(0.0002, 0.5)# 构造一个判别器self.discriminator = self.build_discriminator()self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])# 构造一个生成器self.generator = self.build_generator()gan_input = Input(shape=(self.latent_dim,))img = self.generator(gan_input)# 在训练generator的时候不训练discriminatorself.discriminator.trainable = False# 对生成的假图片进行预测validity = self.discriminator(img)self.combined = Model(gan_input, validity)self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)# 定义生成器def build_generator(self):model = Sequential()model.add(Dense(256, input_dim=self.latent_dim))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))#全连接层,28*28*1个神经元model.add(Dense(np.prod(self.img_shape), activation='tanh'))#变成图片的形状model.add(Reshape(self.img_shape))noise = Input(shape=(self.latent_dim,))#建立了从输入100维随机向量到28,28,1大小的图片生成模型img = model(noise)return Model(noise, img)# 定义判别器def build_discriminator(self):model = Sequential()# 输入一张图片model.add(Flatten(input_shape=self.img_shape))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dense(256))model.add(LeakyReLU(alpha=0.2)) # 判别真伪model.add(Dense(1, activation='sigmoid'))img = Input(shape=self.img_shape)validity = model(img)return Model(img, validity)# 定义训练函数def train(self, epochs, batch_size=128, sample_interval=50):# 获取数据(X_train, _), (_,_) = mnist.load_data()# 进行标准化# 将图片像素值映射到-1到1X_train = X_train / 127.5 - 1X_train = np.expand_dims(X_train, axis=3)# 创建标签valid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))# 先训练判别器,再训练生成器for epoch in range(epochs):# 随机选取batch_size个图片# 对discriminator进行训练# 从train训练集里面随机找出batch—size大小(这么多个)的索引值idx = np.random.randint(0, X_train.shape[0], batch_size)# 取出一个batch大小的图片imgs = X_train[idx] # 正态分布生成batch_size个100维向量作为输入noise = np.random.normal(0, 1, (batch_size, self.latent_dim))# 用生成model的predict方法(model内部方法)将输入进行生成输出gen_imgs = self.generator.predict(noise)# 输入真实图片和标签全1》》到判别model,》》计算判别模型的lossd_loss_real = self.discriminator.train_on_batch(imgs, valid) # 输入假的图片和标签全0》》到判别model,》计算判别模型的loss d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) # 将两者损失结合作为总损失d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # 训练generatornoise = np.random.normal(0, 1, (batch_size, self.latent_dim))# 如果输入噪音的输出是1,则正确,输入噪音输出是0,则生成网络需要改进,所以loss累加g_loss = self.combined.train_on_batch(noise, valid)# D准确度越高,代表G生成的图片越离谱,准确率为0.5左右就可以以假乱真了print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss)) # 每sample_interval轮生成一个图片if epoch % sample_interval == 0 :self.sample_images(epoch)# 定义生成图片函数def sample_images(self, epoch):r, c = 5, 5noise = np.random.normal(0, 1, (r * c, self.latent_dim))gen_imgs = self.generator.predict(noise)gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')axs[i,j].axis('off')cnt += 1fig.savefig("images/%d.png" % epoch)plt.close()if __name__ == '__main__':if not os.path.exists("./images"):os.makedirs("./images")gan = GAN()gan.train(epochs=10000, batch_size=256, sample_interval=200)
Using TensorFlow backend.WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.Instructions for updating:Use tf.where in 2.0, which has the same broadcast rule as np.whereWARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead./home/nlp/anaconda3/lib/python3.7/site-packages/keras/engine/training.py:297: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?'Discrepancy between trainable weights and collected trainable'0 [D loss: 0.986130, acc.: 26.17%] [G loss: 0.834596]/home/nlp/anaconda3/lib/python3.7/site-packages/keras/engine/training.py:297: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?'Discrepancy between trainable weights and collected trainable'1 [D loss: 0.403944, acc.: 83.98%] [G loss: 0.796327]2 [D loss: 0.347891, acc.: 83.01%] [G loss: 0.777482]3 [D loss: 0.344294, acc.: 81.45%] [G loss: 0.784850]4 [D loss: 0.340509, acc.: 82.42%] [G loss: 0.815181]5 [D loss: 0.323516, acc.: 86.52%] [G loss: 0.901500]6 [D loss: 0.292972, acc.: 93.75%] [G loss: 0.991136]7 [D loss: 0.257421, acc.: 97.27%] [G loss: 1.111775]8 [D loss: 0.231006, acc.: 98.05%] [G loss: 1.239357]9 [D loss: 0.194000, acc.: 99.80%] [G loss: 1.371341]10 [D loss: 0.173448, acc.: 100.00%] [G loss: 1.501673]11 [D loss: 0.154554, acc.: 100.00%] [G loss: 1.620853]12 [D loss: 0.142011, acc.: 99.61%] [G loss: 1.732671]13 [D loss: 0.124580, acc.: 99.80%] [G loss: 1.827322]14 [D loss: 0.116470, acc.: 99.80%] [G loss: 1.972561]15 [D loss: 0.105582, acc.: 100.00%] [G loss: 2.067226]16 [D loss: 0.093254, acc.: 100.00%] [G loss: 2.198446]17 [D loss: 0.087950, acc.: 100.00%] [G loss: 2.304677]18 [D loss: 0.073583, acc.: 100.00%] [G loss: 2.355863]19 [D loss: 0.072164, acc.: 100.00%] [G loss: 2.464585]20 [D loss: 0.065558, acc.: 99.80%] [G loss: 2.534361]21 [D loss: 0.059140, acc.: 100.00%] [G loss: 2.626909]22 [D loss: 0.057848, acc.: 100.00%] [G loss: 2.673893]23 [D loss: 0.052325, acc.: 100.00%] [G loss: 2.714813]24 [D loss: 0.052922, acc.: 100.00%] [G loss: 2.763450]25 [D loss: 0.046035, acc.: 100.00%] [G loss: 2.853940]26 [D loss: 0.049457, acc.: 100.00%] [G loss: 2.869173]27 [D loss: 0.042687, acc.: 100.00%] [G loss: 2.941574]28 [D loss: 0.039089, acc.: 100.00%] [G loss: 2.948203]29 [D loss: 0.036347, acc.: 100.00%] [G loss: 2.968413]30 [D loss: 0.038200, acc.: 100.00%] [G loss: 3.048651]31 [D loss: 0.039299, acc.: 100.00%] [G loss: 3.102673]32 [D loss: 0.033043, acc.: 100.00%] [G loss: 3.050264]33 [D loss: 0.035250, acc.: 100.00%] [G loss: 3.078978]34 [D loss: 0.037255, acc.: 100.00%] [G loss: 3.131599]35 [D loss: 0.033308, acc.: 100.00%] [G loss: 3.127816]36 [D loss: 0.035622, acc.: 100.00%] [G loss: 3.157865]37 [D loss: 0.038046, acc.: 100.00%] [G loss: 3.272691]38 [D loss: 0.037665, acc.: 100.00%] [G loss: 3.304567]39 [D loss: 0.029662, acc.: 100.00%] [G loss: 3.323656]40 [D loss: 0.031073, acc.: 100.00%] [G loss: 3.342812]41 [D loss: 0.031860, acc.: 100.00%] [G loss: 3.330144]42 [D loss: 0.033744, acc.: 100.00%] [G loss: 3.365006]43 [D loss: 0.030133, acc.: 100.00%] [G loss: 3.361420]44 [D loss: 0.032508, acc.: 100.00%] [G loss: 3.456270]45 [D loss: 0.030021, acc.: 100.00%] [G loss: 3.498577]46 [D loss: 0.029159, acc.: 100.00%] [G loss: 3.499414]47 [D loss: 0.031974, acc.: 100.00%] [G loss: 3.484164]48 [D loss: 0.033442, acc.: 99.80%] [G loss: 3.459633]49 [D loss: 0.030912, acc.: 100.00%] [G loss: 3.481130]50 [D loss: 0.033645, acc.: 100.00%] [G loss: 3.492231]51 [D loss: 0.034441, acc.: 100.00%] [G loss: 3.489124]52 [D loss: 0.034330, acc.: 100.00%] [G loss: 3.506902]53 [D loss: 0.034518, acc.: 100.00%] [G loss: 3.520910]54 [D loss: 0.030822, acc.: 100.00%] [G loss: 3.618950]55 [D loss: 0.034566, acc.: 99.80%] [G loss: 3.538144]56 [D loss: 0.032794, acc.: 100.00%] [G loss: 3.566177]57 [D loss: 0.037374, acc.: 99.61%] [G loss: 3.600816]58 [D loss: 0.037127, acc.: 100.00%] [G loss: 3.521185]59 [D loss: 0.039322, acc.: 100.00%] [G loss: 3.531039]60 [D loss: 0.030453, acc.: 100.00%] [G loss: 3.616879]61 [D loss: 0.044332, acc.: 99.02%] [G loss: 3.628755]62 [D loss: 0.037772, acc.: 99.80%] [G loss: 3.723062]63 [D loss: 0.041130, acc.: 99.61%] [G loss: 3.533709]64 [D loss: 0.044611, acc.: 99.41%] [G loss: 3.657721]65 [D loss: 0.037362, acc.: 99.61%] [G loss: 3.582735]66 [D loss: 0.050663, acc.: 99.02%] [G loss: 3.555587]67 [D loss: 0.039863, acc.: 99.41%] [G loss: 3.611456]68 [D loss: 0.051172, acc.: 99.02%] [G loss: 3.540278]69 [D loss: 0.052263, acc.: 98.63%] [G loss: 3.612799]70 [D loss: 0.056154, acc.: 99.41%] [G loss: 3.557292]71 [D loss: 0.055386, acc.: 99.22%] [G loss: 3.744767]72 [D loss: 0.096904, acc.: 97.66%] [G loss: 3.443518]73 [D loss: 0.070626, acc.: 98.05%] [G loss: 3.833835]74 [D loss: 0.180408, acc.: 93.55%] [G loss: 3.301687]75 [D loss: 0.074523, acc.: 98.44%] [G loss: 3.776305]76 [D loss: 0.057483, acc.: 99.02%] [G loss: 3.714150]77 [D loss: 0.141995, acc.: 95.12%] [G loss: 3.380850]78 [D loss: 0.067733, acc.: 98.63%] [G loss: 3.779586]79 [D loss: 0.303615, acc.: 87.89%] [G loss: 2.848376]80 [D loss: 0.145237, acc.: 94.14%] [G loss: 3.108039]81 [D loss: 0.046822, acc.: 99.22%] [G loss: 3.635069]82 [D loss: 0.108516, acc.: 96.48%] [G loss: 3.235212]83 [D loss: 0.105234, acc.: 96.48%] [G loss: 3.336948]84 [D loss: 0.233112, acc.: 90.82%] [G loss: 2.740180]85 [D loss: 0.118313, acc.: 94.92%] [G loss: 3.181991]86 [D loss: 0.300344, acc.: 87.30%] [G loss: 2.879515]87 [D loss: 0.106900, acc.: 96.48%] [G loss: 3.189476]88 [D loss: 0.381278, acc.: 84.38%] [G loss: 2.337953]89 [D loss: 0.252046, acc.: 88.28%] [G loss: 2.707138]90 [D loss: 0.087314, acc.: 97.07%] [G loss: 3.401120]91 [D loss: 0.260525, acc.: 90.62%] [G loss: 2.520348]92 [D loss: 0.148098, acc.: 93.36%] [G loss: 2.991073]93 [D loss: 0.141315, acc.: 96.09%] [G loss: 2.805464]94 [D loss: 0.288812, acc.: 89.45%] [G loss: 2.549888]95 [D loss: 0.143633, acc.: 94.14%] [G loss: 2.978777]96 [D loss: 0.584615, acc.: 78.32%] [G loss: 2.050247]97 [D loss: 0.328917, acc.: 83.01%] [G loss: 2.579935]98 [D loss: 0.111224, acc.: 97.66%] [G loss: 3.526271]99 [D loss: 0.702403, acc.: 68.95%] [G loss: 1.994847]100 [D loss: 0.335197, acc.: 84.96%] [G loss: 2.110721]101 [D loss: 0.147330, acc.: 93.55%] [G loss: 2.962312]102 [D loss: 0.091300, acc.: 98.44%] [G loss: 3.025173]103 [D loss: 0.304929, acc.: 87.70%] [G loss: 2.458197]104 [D loss: 0.199925, acc.: 90.43%] [G loss: 2.897576]105 [D loss: 0.335472, acc.: 87.30%] [G loss: 2.198746]106 [D loss: 0.235486, acc.: 88.09%] [G loss: 2.742341]107 [D loss: 0.346595, acc.: 84.77%] [G loss: 2.340909]108 [D loss: 0.211129, acc.: 91.60%] [G loss: 2.801579]109 [D loss: 0.361250, acc.: 84.96%] [G loss: 2.304583]110 [D loss: 0.183040, acc.: 93.16%] [G loss: 2.763792]111 [D loss: 0.365892, acc.: 82.62%] [G loss: 2.418060]112 [D loss: 0.197837, acc.: 92.19%] [G loss: 2.826400]113 [D loss: 0.413041, acc.: 81.05%] [G loss: 2.408184]114 [D loss: 0.198854, acc.: 91.80%] [G loss: 2.784730]115 [D loss: 0.395174, acc.: 81.45%] [G loss: 2.115457]116 [D loss: 0.189158, acc.: 90.04%] [G loss: 2.603389]117 [D loss: 0.237316, acc.: 92.97%] [G loss: 2.648600]118 [D loss: 0.285941, acc.: 87.89%] [G loss: 2.370326]119 [D loss: 0.208490, acc.: 90.43%] [G loss: 2.849175]120 [D loss: 0.454702, acc.: 80.08%] [G loss: 1.897220]121 [D loss: 0.217595, acc.: 89.06%] [G loss: 2.498424]122 [D loss: 0.173055, acc.: 94.92%] [G loss: 2.664538]123 [D loss: 0.262918, acc.: 90.82%] [G loss: 2.133595]124 [D loss: 0.190525, acc.: 91.02%] [G loss: 2.840866]125 [D loss: 0.292295, acc.: 87.11%] [G loss: 2.199357]126 [D loss: 0.215348, acc.: 88.87%] [G loss: 2.739654]127 [D loss: 0.365445, acc.: 84.96%] [G loss: 2.162226]128 [D loss: 0.200284, acc.: 89.65%] [G loss: 2.871504]129 [D loss: 0.450811, acc.: 79.10%] [G loss: 1.971582]130 [D loss: 0.200712, acc.: 90.82%] [G loss: 2.715580]131 [D loss: 0.310609, acc.: 85.94%] [G loss: 2.443402]132 [D loss: 0.234690, acc.: 89.65%] [G loss: 2.654381]133 [D loss: 0.449007, acc.: 79.30%] [G loss: 1.873044]134 [D loss: 0.233484, acc.: 87.89%] [G loss: 2.710910]135 [D loss: 0.274398, acc.: 87.70%] [G loss: 2.632379]136 [D loss: 0.295981, acc.: 87.11%] [G loss: 2.511465]137 [D loss: 0.247948, acc.: 89.65%] [G loss: 2.698283]138 [D loss: 0.490601, acc.: 75.20%] [G loss: 2.161157]139 [D loss: 0.215320, acc.: 90.43%] [G loss: 2.841792]140 [D loss: 0.564996, acc.: 73.63%] [G loss: 1.618642]141 [D loss: 0.270847, acc.: 86.52%] [G loss: 2.598266]142 [D loss: 0.210049, acc.: 93.75%] [G loss: 3.058943]143 [D loss: 0.462835, acc.: 76.17%] [G loss: 2.219012]144 [D loss: 0.213740, acc.: 89.84%] [G loss: 2.845963]145 [D loss: 0.518464, acc.: 73.63%] [G loss: 1.735387]146 [D loss: 0.273846, acc.: 87.11%] [G loss: 2.634973]……
附:系列文章
序号 | 文章目录 | 直达链接 |
---|---|---|
1 | 波士顿房价预测 | https://want595.blog.csdn.net/article/details/132181950 |
2 | 鸢尾花数据集分析 | https://want595.blog.csdn.net/article/details/132182057 |
3 | 特征处理 | https://want595.blog.csdn.net/article/details/132182165 |
4 | 交叉验证 | https://want595.blog.csdn.net/article/details/132182238 |
5 | 构造神经网络示例 | https://want595.blog.csdn.net/article/details/132182341 |
6 | 使用TensorFlow完成线性回归 | https://want595.blog.csdn.net/article/details/132182417 |
7 | 使用TensorFlow完成逻辑回归 | https://want595.blog.csdn.net/article/details/132182496 |
8 | TensorBoard案例 | https://want595.blog.csdn.net/article/details/132182584 |
9 | 使用Keras完成线性回归 | https://want595.blog.csdn.net/article/details/132182723 |
10 | 使用Keras完成逻辑回归 | https://want595.blog.csdn.net/article/details/132182795 |
11 | 使用Keras预训练模型完成猫狗识别 | https://want595.blog.csdn.net/article/details/132243928 |
12 | 使用PyTorch训练模型 | https://want595.blog.csdn.net/article/details/132243989 |
13 | 使用Dropout抑制过拟合 | https://want595.blog.csdn.net/article/details/132244111 |
14 | 使用CNN完成MNIST手写体识别(TensorFlow) | https://want595.blog.csdn.net/article/details/132244499 |
15 | 使用CNN完成MNIST手写体识别(Keras) | https://want595.blog.csdn.net/article/details/132244552 |
16 | 使用CNN完成MNIST手写体识别(PyTorch) | https://want595.blog.csdn.net/article/details/132244641 |
17 | 使用GAN生成手写数字样本 | https://want595.blog.csdn.net/article/details/132244764 |
18 | 自然语言处理 | https://want595.blog.csdn.net/article/details/132276591 |