VAE理论上一篇已经详细讲完了,虽然VAE已经是过去的东西了,但是它对后面强大的生成模型是很有指导意义的。接下来,我们简单实现一下其代码吧。
1 VAE在minist数据集上的实现
完整的代码如下,没有什么特别好讲的。
import cv2
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary""" 就用线性层构造最简单的vae吧"""class VAE(nn.Module):def __init__(self, image_size=28*28, hidden1=400, hidden2=100, latent_dims=40):super().__init__()# encoderself.encoder = nn.Sequential(nn.Linear(image_size, hidden1),nn.ReLU(),nn.Linear(hidden1, hidden2),nn.ReLU(),)self.mu = nn.Sequential(nn.Linear(hidden2, latent_dims),)self.logvar = nn.Sequential(nn.Linear(hidden2, latent_dims),) # 由于方差是非负的,因此预测方差对数# decoderself.decoder = nn.Sequential(nn.Linear(latent_dims, hidden2),nn.ReLU(),nn.Linear(hidden2, hidden1),nn.ReLU(),nn.Linear(hidden1, image_size),nn.Tanh())# 重参数,为了可以反向传播def reparametrization(self, mu, logvar):# sigma = 0.5*exp(log(sigma^2))= 0.5*exp(log(var))std = 0.5 * torch.exp(logvar)# N(mu, std^2) = N(0, 1) * std + muz = torch.randn(std.size(), device=mu.device) * std + mureturn zdef forward(self, x):en = self.encoder(x)mu = self.mu(en)logvar = self.logvar(en)z = self.reparametrization(mu, logvar)return self.decoder(z), mu, logvardef loss_function(fake_imgs, real_imgs, mu, logvar):kl = -0.5 * torch.sum(1 + logvar - torch.exp(logvar) - mu ** 2)reconstruction = ((real_imgs - fake_imgs)**2).sum()return kl, reconstructiondef train(num_epoch):write_fake = SummaryWriter(f'logs/fake')device = torch.device("cuda:0")trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True)vae = VAE().to(device)optimizer = torch.optim.Adam(vae.parameters(), lr=0.0003)vae.train()step = 0for epoch in range(num_epoch):for batch_indx, (inputs, _) in enumerate(trainloader):inputs = inputs.to(device)real_imgs = torch.flatten(inputs, start_dim=1)fake_imgs, mu, logvar = vae(real_imgs)loss_kl, loss_re = loss_function(fake_imgs, real_imgs, mu, logvar)loss_all = loss_kl + loss_reoptimizer.zero_grad()loss_all.backward()optimizer.step()print(f"epoch:{epoch}, loss kl:{loss_kl.item()}, loss re:{loss_re.item()}, loss all:{loss_all.item()}")if batch_indx == 0:with torch.no_grad():x = torch.randn((32, 40)).to(device)fake = vae.decoder(x).reshape(-1, 1, 28, 28)img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)write_fake.add_image("Mnist Fake Image", img_grid_fake, global_step=step)step += 1if __name__ == "__main__":summary(VAE(), input_size=(1, 784))train(1000)
模型结构打印如下:
VAE [1, 784] –
├─Sequential: 1-1 [1, 100] –
│ └─Linear: 2-1 [1, 400] 314,000
│ └─ReLU: 2-2 [1, 400] –
│ └─Linear: 2-3 [1, 100] 40,100
│ └─ReLU: 2-4 [1, 100] –
├─Sequential: 1-2 [1, 40] –
│ └─Linear: 2-5 [1, 40] 4,040
├─Sequential: 1-3 [1, 40] –
│ └─Linear: 2-6 [1, 40] 4,040
├─Sequential: 1-4 [1, 784] –
│ └─Linear: 2-7 [1, 100] 4,100
│ └─ReLU: 2-8 [1, 100] –
│ └─Linear: 2-9 [1, 400] 40,400
│ └─ReLU: 2-10 [1, 400] –
│ └─Linear: 2-11 [1, 784] 314,384
│ └─Tanh: 2-12 [1, 784] –
训练结果,从结果上来看,是不如GAN的,主要原因在于其在KL散度和重建损失之间很难做到平衡,所以很难训练得好,当然原因是多方面的。
2 VAE的缺陷
变分自编码器(VAE, Variational Autoencoder)作为一种强大的深度学习模型,在生成建模领域有着广泛的应用,但它也存在一些缺陷,主要包括:
-
生成样本质量:与生成对抗网络(GANs)相比,VAE生成的样本可能显得较为模糊或缺乏清晰度。尽管VAE能够生成连续且有结构的潜在空间,其生成的样本在某些情况下可能不够真实或细节不够丰富。
-
潜在空间的连续性问题:虽然VAE设计用于学习连续的潜在空间,以允许插值和生成流畅的变化序列,但在实践中,这种连续性可能不如理论中那样完美。潜在空间中可能会出现空洞或不连贯区域,影响样本生成的质量和连续性变换的效果。
-
KL散度的平衡问题:VAE通过在其损失函数中加入KL散度项来约束潜在变量的分布,以确保它接近先验分布(通常是标准正态分布)。然而,KL散度的权重难以选择,如果设置不当,可能导致模型过分关注重构损失而忽视了潜在空间的平滑性和多样性,或者相反。
-
训练难度与稳定性:VAE的训练过程比一些其他模型更为复杂,涉及到优化 Evidence Lower Bound (ELBO),这可能导致训练过程较为不稳定,需要更多的计算资源和更长的训练时间。特别是优化过程中对似然的近似以及对数似然的下界处理增加了训练的复杂度。
-
表达能力与模型容量:由于VAE的编码器和解码器结构相对简单(通常为全连接层或简单的卷积层),在处理高度复杂的高维数据时,其表达能力可能受限,影响生成样本的质量和多样性。
这些缺陷提示研究者和实践者在使用VAE时需要仔细调整模型架构、损失函数的平衡以及训练策略,以最大化其生成能力和实用性。