目录
1--CVAE模型
2--代码实例
1--CVAE模型
简单介绍:
与VAE类似,只不过模型的输入需要考虑图片和条件(condition)的融合,融合结果通过一个 encoder 映射到标准分布(均值和方差),从映射的标准分布中随机采样一个样本,样本也需要和条件进行融合,最后通过 decoder 重构图片;
由于模型的输入是图片和条件的融合,因此模型学习了基于条件的图片生成;
计算源图片和重构图片之间的损失,具体损失函数的推导可以参考:变分自编码器(VAE)
2--代码实例
下面的 CVAE 中,用了最简单的融合方式(concat)将条件 Y 与输入 X 融合形成X_given_Y,同理条件 Y 与 X_given_Y 融合形成 z_given_Y;
import torch
import torch.nn as nnclass VAE(nn.Module):def __init__(self, in_features, latent_size, y_size=0):super(VAE, self).__init__()self.latent_size = latent_sizeself.encoder_forward = nn.Sequential( # encodernn.Linear(in_features + y_size, in_features),nn.LeakyReLU(),nn.Linear(in_features, in_features),nn.LeakyReLU(),nn.Linear(in_features, self.latent_size * 2))self.decoder_forward = nn.Sequential( # decodernn.Linear(self.latent_size + y_size, in_features),nn.LeakyReLU(),nn.Linear(in_features, in_features),nn.LeakyReLU(),nn.Linear(in_features, in_features),nn.Sigmoid())def encoder(self, X): # encodeout = self.encoder_forward(X) # 这里通过一个encoder生成均值和标准差mu = out[:, :self.latent_size] # 输出的前半部分作为均值log_var = out[:, self.latent_size:] # 后半部分作为标准差return mu, log_vardef decoder(self, z): # decodemu_prime = self.decoder_forward(z)return mu_primedef reparameterization(self, mu, log_var): # reparameterizationepsilon = torch.randn_like(log_var)z = mu + epsilon * torch.sqrt(log_var.exp())return zdef loss(self, X, mu_prime, mu, log_var): # cal lossreconstruction_loss = torch.mean(torch.square(X - mu_prime).sum(dim=1))latent_loss = torch.mean(0.5 * (log_var.exp() + torch.square(mu) - log_var).sum(dim=1))return reconstruction_loss + latent_lossdef forward(self, X, *args, **kwargs):mu, log_var = self.encoder(X) # encodez = self.reparameterization(mu, log_var) # generate z by reparameterizationmu_prime = self.decoder(z) # decodereturn mu_prime, mu, log_varclass CVAE(VAE):def __init__(self, in_features, latent_size, y_size):super(CVAE, self).__init__(in_features, latent_size, y_size)def forward(self, X, y = None, *args, **kwargs):y = y.to(next(self.parameters()).device)X_given_Y = torch.cat((X, y.unsqueeze(1)), dim = 1)mu, log_var = self.encoder(X_given_Y)z = self.reparameterization(mu, log_var)z_given_Y = torch.cat((z, y.unsqueeze(1)), dim = 1)mu_prime_given_Y = self.decoder(z_given_Y)return mu_prime_given_Y, mu, log_var
简单的损失计算代码:
def loss(self, X, mu_prime, mu, log_var): # cal lossreconstruction_loss = torch.mean(torch.square(X - mu_prime).sum(dim=1))latent_loss = torch.mean(0.5 * (log_var.exp() + torch.square(mu) - log_var).sum(dim=1))return reconstruction_loss + latent_loss
完整代码参考:liujf69/VAE