VAE学习笔记
-
普通的编码器可以将图像这类信息编码成为特征向量.
-
但通常这些特征向量不具有空间上的连续性.
-
VAE(变分自编码器)可以将图像信息编码成为具有空间连续性的特征向量.
-
方法是向编码器和解码器中加入统计信息,即特征向量代表的的是一个高斯分布,强迫特征向量服从高斯分布.
-
编码器是将图片信息编码成为一个高斯分布.
-
解码器则是从特征空间中进行采样,再经过全连接层,反卷积层,卷积层等恢复成一张与输入图片大小相等的图片.
-
损失函数有两个目标:即(1)拟合原始图片以及(2)使得特征空间具有良好的结构及降低过拟合.因此,我们的损失函数由两部分构成.其中第二部分需要使得编码出的正态分布围绕在标准正态分布周围.
实现代码
import keras
from keras import layers
from keras import backend as K
from keras.models import Model
import numpy as npimg_shape = (28,28,1)
batch_size = 16
latent_dim = 2##### 图片编码器部分
input_img = keras.Input(shape = img_shape)x = layers.Conv2D(32,3,padding = 'same', activation = 'relu')(input_img)
x = layers.Conv2D(64,3,padding = 'same', activation = 'relu',strides = (2,2))(x)
x = layers.Conv2D(64,3,padding = 'same', activation = 'relu')(x)
x = layers.Conv2D(64,3,padding = 'same', activation = 'relu')(x)shape_before_flatting = K.int_shape(x)x = layers.Flatten()(x)
x = layers.Dense(32,activation='relu')(x)#输入图像最终被编码为如下两个参数
z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x)
# 需要注意的是 z_log_var = 2log(sigma)
##### 编码器部分结束##### 采样函数,用于在给定的正态分布中进行采样,这也就是编码器加入统计信息的地方.
def sampling(args):z_mean, z_log_var = argsepsilon = K.random_normal(shape = (K.shape(z_mean)[0],latent_dim),mean = 0.,stddev = 1.)return z_mean + K.exp(0.5*z_log_var) * epsilon##### VAE解码器部分
decoder_input = layers.Input(K.int_shape(z)[1:])
x = layers.Dense(np.prod(shape_before_flatting[1:]),activation='relu')(decoder_input)
x = layers.Reshape(shape_before_flatting[1:])(x)
x = layers.Conv2DTranspose(32,3,padding='same',activation = 'relu',strides = (2,2))(x)
x = layers.Conv2D(1,3,padding='same',activation='sigmoid')(x)
##### 到这里x被恢复成为一张图片##### 下面两句话将编码器和解码器通过上采样函数连接到了一起
decoder = Model(decoder_input,x)
z = layers.Lambda(sampling)([z_mean,z_log_var])
z_decoder = decoder(z)###### 自定义损失函数层
def vae_loss(y_true,y_pred,e = 0.1):x = K.flatten(y_true)z_decoded = K.flatten(y_pred)xent_loss = keras.metrics.binary_crossentropy(x,z_decoded)kl_loss = -5e-4 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var),axis = -1)return K.mean(xent_loss + kl_loss)from keras.datasets import mnistvae = Model(input_img,z_decoder)
vae.compile(optimizer = 'rmsprop',loss = vae_loss)
vae.summary()##### 训练模型
from keras.datasets import mnist
(x_train,_),(x_test,y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_train = x_train.reshape(x_train.shape + (1,))
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape(x_test.shape + (1,))vae.fit(x = x_train,y = x_train,shuffle = True,epochs = 10,batch_size = batch_size,validation_data = (x_test,x_test))##### 在特征空间进行连续采样,观察输出图片
import matplotlib.pyplot as plt
from scipy.stats import norm
n = 24
digit_size = 28
figure = np.zeros((digit_size*n,digit_size*n))
grid_x = norm.ppf(np.linspace(0.02,0.98,n))
grid_y = norm.ppf(np.linspace(0.02,0.98,n))print(batch_size)for i,yi in enumerate(grid_x):for j,xi in enumerate(grid_y):z_sample = np.array([[xi,yi]])z_sample = np.tile(z_sample,1).reshape(1, 2)x_decoded = decoder.predict(z_sample, batch_size = 1)digit = x_decoded[0].reshape(digit_size,digit_size)figure[i*digit_size:(i+1)*digit_size,j*digit_size:(j+1)*digit_size] = digitplt.figure(figsize = (15,15))
plt.imshow(figure,cmap = 'Greys_r')
plt.show()