本文为此系列的第六篇对GAN的评估,上一篇为Controllable GAN。文中使用训练好的分类模型的部分网络提取特征将真实分布与生成分布进行对比来评估模型的好坏,若有不懂的无监督知识点可以看本系列第一篇。
原理
1.评估模型的指标
一般来说,我们评估模型的好坏可以通过对测试集的错误率来体现:比如图像分类我们可以统计几张分错几张分对来量化错误率、目标检测我们可以通过比对每个框得到mAP从而量化错误率…但是我们怎么通过生成的图像来评估GAN的好坏呢?
我们总不能说,生成的某一个像素要更绿色一点比较好,或者某个像素要更黄色一点比较好吧?
先进行概括一下,全文主要围绕着生成质量(保真度fidelity)、多样性(diversity)进行讲解。
2. 图像对比有两种方法,pixel distance、feature distance。
第一种像素对比,直接做相减运算。这样做的缺点是尽管两张图片可能非常相似,但是每个像素的像素值会有一些细微的差异,即使我们肉眼看不出来,最终的差值也会非常大,太过于关注细节。
第二种则是特征对比,通俗的说是成片的像素区域进行对比是否相似,这样的对比更符合我们人眼观察标准。
那么,接下来的问题就是如何进行特征提取。
3. 特征提取的方法
我们训练好的分类器是一个很好的特征提取器,比如我们训练了一个识别猫狗的分类器,那它必然是学习到了猫狗的特征才会对他们进行分类。
直接将分类部分的最后一层分类层去掉,其余的都是对我们有价值的。我们一般选择的是连接最后一个全连接层的池化层作为输出特征的层,我们成为特征层,输出的特征我们称为embedding。
选择这个位置并不固定,只是选择的位置越后面,每个单元的感受野越大,所包含的信息就越多,更符合我们的要求。很前面的层获取到的特征可能只是一横或者一竖或者一个弧度等。
- 我们使用Inception v3作为我们的特征提取器,Inception使用超1400万张图片、2万多类别的ImageNet数据库作为训练集。提取详细流程如图:
对总的概括可以概括为一下流程:
最终我们就是对真实数据提取的特征于生成数据提取的特征进行对比。
4. Frechet Inception Distance(FID)
我们使用FID来量化真假特征的差异。
通俗来说Frechet Distance是用来衡量两条曲线之间的的最小距离,比如人狗同时走所需的最短牵引绳的长度。
严格来说,Frechet Distance是衡量两个分布之间的差异。
①我们可以使用以下公式来表示两个单维正态分布的Frechet Distance:
分别从真实数据和生成数据里面提取大量的特征,分别作为真实特征分布于生成特征分布,计算出各自的均值和标准差即可计算出真假之间的差值。
②两个多变量正态分布的Frechet Distance
我们可以为每个维度提供一个单变量的正态分布,假设是两个变量的(便于举例),如图:
协方差矩阵:
比如(x1,x2)代表第一变量的正态分布的随机变量与第二正态分布的随机变量之间的协方差。非对角线元素代表不同变量之间的协方差,即不同变量之间的相关性。若两个变量变化趋势一致则协方差为正值,反之负值,若没有线性关系则为0。上图就代表两个变量之间相互不影响相互独立,下图代表两变量之间负相关;
比如(x1,x1)代表第一变量的正态分布的方差。对角线元素代表每个变量分布的方差,即每个变量本身的变化程度。
由此可以计算我们的多变量正态分布之间的Frechet Distance,可以将单维正态分布之间的Frechet Distance公式展开进行对比发现他们之间其实是相似的:
Tr运算为矩阵的对角线元素之和,例如上面那个负相关的协方差矩阵的Tr运算结果为2+2=4。
将多变量正态分布之间的Frechet Distance应用于真假特征的分布就是FID了:
FID越小,就代表着真假分布就越接近,那么GAN就越好。
代码
import torch
import numpy as np
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import CelebA
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0) # Set for our testing purposes, please do not change!class Generator(nn.Module):def __init__(self, z_dim=10, im_chan=3, hidden_dim=64):super(Generator, self).__init__()self.z_dim = z_dim# Build the neural networkself.gen = nn.Sequential(self.make_gen_block(z_dim, hidden_dim * 8),self.make_gen_block(hidden_dim * 8, hidden_dim * 4),self.make_gen_block(hidden_dim * 4, hidden_dim * 2),self.make_gen_block(hidden_dim * 2, hidden_dim),self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),)def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):if not final_layer:return nn.Sequential(nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),nn.BatchNorm2d(output_channels),nn.ReLU(inplace=True),)else:return nn.Sequential(nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),nn.Tanh(),)def forward(self, noise):x = noise.view(len(noise), self.z_dim, 1, 1)return self.gen(x)def get_noise(n_samples, z_dim, device='cpu'):return torch.randn(n_samples, z_dim, device=device)z_dim = 64
image_size = 299
device = 'cuda'transform = transforms.Compose([transforms.Resize(image_size),transforms.CenterCrop(image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])dataset = CelebA(".", download=True, transform=transform)gen = Generator(z_dim).to(device)
gen.load_state_dict(torch.load(f"pretrained_celeba.pth", map_location=torch.device(device))["gen"])
gen = gen.eval()from torchvision.models import inception_v3
inception_model = inception_v3(pretrained=False)
inception_model.load_state_dict(torch.load("inception_v3_google-1a9a5a14.pth"))
inception_model.to(device)
inception_model = inception_model.eval() # Evaluation modeinception_model.fc = torch.nn.Identity()from torch.distributions import MultivariateNormal
import seaborn as sns # This is for visualization
mean = torch.Tensor([0, 0]) # Center the mean at the origin
covariance = torch.Tensor( # This matrix shows independence - there are only non-zero values on the diagonal[[1, 0],[0, 1]]
)
independent_dist = MultivariateNormal(mean, covariance)
samples = independent_dist.sample((10000,))
res = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind="kde")
plt.show()mean = torch.Tensor([0, 0])
covariance = torch.Tensor([[2, -1],[-1, 2]]
)
covariant_dist = MultivariateNormal(mean, covariance)
samples = covariant_dist.sample((10000,))
res = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind="kde")
plt.show()import scipy
def matrix_sqrt(x):y = x.cpu().detach().numpy()y = scipy.linalg.sqrtm(y)return torch.Tensor(y.real, device=x.device)def frechet_distance(mu_x, mu_y, sigma_x, sigma_y):return (mu_x - mu_y).dot(mu_x - mu_y) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2*torch.trace(matrix_sqrt(sigma_x @ sigma_y))def preprocess(img):img = torch.nn.functional.interpolate(img, size=(299, 299), mode='bilinear', align_corners=False)return imgimport numpy as np
def get_covariance(features):return torch.Tensor(np.cov(features.detach().numpy(), rowvar=False))fake_features_list = []
real_features_list = []n_samples = 512 # The total number of samples
batch_size = 4 # Samples per iterationdataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True)cur_samples = 0
with torch.no_grad(): # You don't need to calculate gradients here, so you do this to save memorytry:for real_example, _ in tqdm(dataloader, total=n_samples // batch_size): # Go by batchreal_samples = real_examplereal_features = inception_model(real_samples.to(device)).detach().to('cpu') # Move features to CPUreal_features_list.append(real_features)fake_samples = get_noise(len(real_example), z_dim).to(device)fake_samples = preprocess(gen(fake_samples))fake_features = inception_model(fake_samples.to(device)).detach().to('cpu')fake_features_list.append(fake_features)cur_samples += len(real_samples)if cur_samples >= n_samples:breakexcept:print("Error in loop")fake_features_all = torch.cat(fake_features_list)
real_features_all = torch.cat(real_features_list)mu_fake = fake_features_all.mean(0)
mu_real = real_features_all.mean(0)
sigma_fake = get_covariance(fake_features_all)
sigma_real = get_covariance(real_features_all)indices = [2, 4, 5]
fake_dist = MultivariateNormal(mu_fake[indices], sigma_fake[indices][:, indices])
fake_samples = fake_dist.sample((5000,))
real_dist = MultivariateNormal(mu_real[indices], sigma_real[indices][:, indices])
real_samples = real_dist.sample((5000,))import pandas as pd
df_fake = pd.DataFrame(fake_samples.numpy(), columns=indices)
df_real = pd.DataFrame(real_samples.numpy(), columns=indices)
df_fake["is_real"] = "no"
df_real["is_real"] = "yes"
df = pd.concat([df_fake, df_real])
sns.pairplot(df, plot_kws={'alpha': 0.1}, hue='is_real')
plt.show()with torch.no_grad():print(frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake).item())
代码中使用的生成器模型可以从上一篇当中下载,inception_v3_google-1a9a5a14.pth模型可以从这里下载。
代码解析
- 去掉分类层
inception_model.fc = torch.nn.Identity()
将最后一层的全连接层替换为恒等函数,它将输入的数据不做任何操作、原封不动地输出。
通常Inception模型的全连接层用于图像分类任务,它将提取的特征映射到类别预测上。然而我们不需要进行图像分类,而是想要利用Inception模型的前面部分来提取图像的特征。
这样就将Inception模型从原始的分类任务模型转变为一个特征提取器,从而不再执行图像分类任务,而是将图像转换为特征向量。
- 可视化多变量正态分布
from torch.distributions import MultivariateNormal
import seaborn as sns # This is for visualization
mean = torch.Tensor([0, 0]) # Center the mean at the origin
covariance = torch.Tensor( # This matrix shows independence - there are only non-zero values on the diagonal[[1, 0],[0, 1]]
)
independent_dist = MultivariateNormal(mean, covariance)
samples = independent_dist.sample((10000,))
res = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind="kde")
plt.show()mean = torch.Tensor([0, 0])
covariance = torch.Tensor([[2, -1],[-1, 2]]
)
covariant_dist = MultivariateNormal(mean, covariance)
samples = covariant_dist.sample((10000,))
res = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind="kde")
plt.show()
首先定义均值和协方差矩阵(原理中举的两个例子),然后使用MultivariateNormal
构建一个多变量正态分布对象covariant_dist
。然后从这个分布中抽取了10000个样本,每个样本是一个shape为(samples, 2)的二维向量。最后将生成的样本可视化为二维核密度估计图(Kernel Density Estimate,KDE)。
- 计算矩阵的平方根
def matrix_sqrt(x):y = x.cpu().detach().numpy()y = scipy.linalg.sqrtm(y)return torch.Tensor(y.real, device=x.device)
首先将输入矩阵转移到CPU上并将其转换为NumPy数组。这是因为scipy.linalg.sqrtm函数只能接受NumPy数组作为输入,不能接受PyTorch张量,且在CPU上计算更高效。
然后使用scipy.linalg.sqrtm
函数计算平方根且返回一个复数矩阵,所以需要取其实部(real)部分,然后再转换为PyTorch张量。同时,函数还会确保新的张量与输入矩阵在相同的设备(device)上。
- 计算FID
def frechet_distance(mu_x, mu_y, sigma_x, sigma_y):return (mu_x - mu_y).dot(mu_x - mu_y) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2*torch.trace(matrix_sqrt(sigma_x @ sigma_y))
给定两个分布的均值和协方差矩阵,利用原理中的公式进行计算。
- 对生成图像进行处理
def preprocess(img):img = torch.nn.functional.interpolate(img, size=(299, 299), mode='bilinear', align_corners=False)return img
将输入的图像进行插值操作,插值方法使用双线性插值,参数align_corners=False
指示在进行插值操作时不对齐图像的角点,这在图像处理中常用于避免不必要的插值偏差。
- 计算协方差矩阵
def get_covariance(features):return torch.Tensor(np.cov(features.detach().numpy(), rowvar=False))
使用NumPy的np.cov()函数计算特征向量集合的协方差矩阵,rowvar=False
参数表示传递的数据中每一列代表一个特征向量的观测值,而不是每一行代表一个观测样本。
- 提取特征
for real_example, _ in tqdm(dataloader, total=n_samples // batch_size): # Go by batchreal_samples = real_examplereal_features = inception_model(real_samples.to(device)).detach().to('cpu') # Move features to CPUreal_features_list.append(real_features)fake_samples = get_noise(len(real_example), z_dim).to(device)fake_samples = preprocess(gen(fake_samples))fake_features = inception_model(fake_samples.to(device)).detach().to('cpu')fake_features_list.append(fake_features)cur_samples += len(real_samples)if cur_samples >= n_samples:break
使用预训练的Inception模型提取真实图像和生成图像的特征,并将这些特征存储在列表中,以备后续计算Fréchet Distance。
在这里需要对生成的图像进行preprocess()处理为299的宽高是因为真实数据的宽高为299,而生成数据的宽高为64。
我们可以将生成数据和preprocess处理后的数据显示出来看效果:
import matplotlib.pyplot as plt# 选择其中一个样本进行显示
sample_index = 0# 显示生成图像
fake_image = fake[sample_index].permute(1, 2, 0) # 将张量形状转换为图像的形状(C, H, W)->(H, W, C)
plt.imshow(fake_image)
plt.axis('off')
plt.show()# 显示经过处理的图像
fake_image = fake_samples[sample_index].permute(1, 2, 0) # 将张量形状转换为图像的形状(C, H, W)->(H, W, C)
plt.imshow(fake_image)
plt.axis('off')
plt.show()
可以看到插值操作后平滑很多。
- 可视化真实数据分布与生成数据分布,并计算FID
indices = [2, 4, 5]
import pandas as pd
df_fake = pd.DataFrame(fake_samples.numpy(), columns=indices)
df_real = pd.DataFrame(real_samples.numpy(), columns=indices)
df_fake["is_real"] = "no"
df_real["is_real"] = "yes"
df = pd.concat([df_fake, df_real])
sns.pairplot(df, plot_kws={'alpha': 0.1}, hue='is_real')
plt.show()with torch.no_grad():print(frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake).item())