使用pytorch构建GAN模型的评估

本文为此系列的第六篇对GAN的评估,上一篇为Controllable GAN。文中使用训练好的分类模型的部分网络提取特征将真实分布与生成分布进行对比来评估模型的好坏,若有不懂的无监督知识点可以看本系列第一篇。

原理

1.评估模型的指标
一般来说,我们评估模型的好坏可以通过对测试集的错误率来体现:比如图像分类我们可以统计几张分错几张分对来量化错误率、目标检测我们可以通过比对每个框得到mAP从而量化错误率…但是我们怎么通过生成的图像来评估GAN的好坏呢?
在这里插入图片描述
我们总不能说,生成的某一个像素要更绿色一点比较好,或者某个像素要更黄色一点比较好吧?
先进行概括一下,全文主要围绕着生成质量(保真度fidelity)、多样性(diversity)进行讲解。
在这里插入图片描述
2. 图像对比有两种方法,pixel distance、feature distance。
第一种像素对比,直接做相减运算。这样做的缺点是尽管两张图片可能非常相似,但是每个像素的像素值会有一些细微的差异,即使我们肉眼看不出来,最终的差值也会非常大,太过于关注细节。
在这里插入图片描述
第二种则是特征对比,通俗的说是成片的像素区域进行对比是否相似,这样的对比更符合我们人眼观察标准。
在这里插入图片描述
那么,接下来的问题就是如何进行特征提取。
3. 特征提取的方法
我们训练好的分类器是一个很好的特征提取器,比如我们训练了一个识别猫狗的分类器,那它必然是学习到了猫狗的特征才会对他们进行分类。
在这里插入图片描述
直接将分类部分的最后一层分类层去掉,其余的都是对我们有价值的。我们一般选择的是连接最后一个全连接层的池化层作为输出特征的层,我们成为特征层,输出的特征我们称为embedding。
选择这个位置并不固定,只是选择的位置越后面,每个单元的感受野越大,所包含的信息就越多,更符合我们的要求。很前面的层获取到的特征可能只是一横或者一竖或者一个弧度等。

  • 我们使用Inception v3作为我们的特征提取器,Inception使用超1400万张图片、2万多类别的ImageNet数据库作为训练集。提取详细流程如图:
    在这里插入图片描述

对总的概括可以概括为一下流程:
在这里插入图片描述
最终我们就是对真实数据提取的特征于生成数据提取的特征进行对比。
4. Frechet Inception Distance(FID)
我们使用FID来量化真假特征的差异。
通俗来说Frechet Distance是用来衡量两条曲线之间的的最小距离,比如人狗同时走所需的最短牵引绳的长度。
在这里插入图片描述
严格来说,Frechet Distance是衡量两个分布之间的差异。
在这里插入图片描述
①我们可以使用以下公式来表示两个单维正态分布的Frechet Distance:
在这里插入图片描述
分别从真实数据和生成数据里面提取大量的特征,分别作为真实特征分布于生成特征分布,计算出各自的均值和标准差即可计算出真假之间的差值。
②两个多变量正态分布的Frechet Distance
我们可以为每个维度提供一个单变量的正态分布,假设是两个变量的(便于举例),如图:
在这里插入图片描述

协方差矩阵:
比如(x1,x2)代表第一变量的正态分布的随机变量与第二正态分布的随机变量之间的协方差。非对角线元素代表不同变量之间的协方差,即不同变量之间的相关性。若两个变量变化趋势一致则协方差为正值,反之负值,若没有线性关系则为0。上图就代表两个变量之间相互不影响相互独立,下图代表两变量之间负相关;
比如(x1,x1)代表第一变量的正态分布的方差。对角线元素代表每个变量分布的方差,即每个变量本身的变化程度。
在这里插入图片描述
由此可以计算我们的多变量正态分布之间的Frechet Distance,可以将单维正态分布之间的Frechet Distance公式展开进行对比发现他们之间其实是相似的:
在这里插入图片描述
Tr运算为矩阵的对角线元素之和,例如上面那个负相关的协方差矩阵的Tr运算结果为2+2=4。
将多变量正态分布之间的Frechet Distance应用于真假特征的分布就是FID了:
在这里插入图片描述
FID越小,就代表着真假分布就越接近,那么GAN就越好。

代码

import torch
import numpy as np
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import CelebA
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0) # Set for our testing purposes, please do not change!class Generator(nn.Module):def __init__(self, z_dim=10, im_chan=3, hidden_dim=64):super(Generator, self).__init__()self.z_dim = z_dim# Build the neural networkself.gen = nn.Sequential(self.make_gen_block(z_dim, hidden_dim * 8),self.make_gen_block(hidden_dim * 8, hidden_dim * 4),self.make_gen_block(hidden_dim * 4, hidden_dim * 2),self.make_gen_block(hidden_dim * 2, hidden_dim),self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),)def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):if not final_layer:return nn.Sequential(nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),nn.BatchNorm2d(output_channels),nn.ReLU(inplace=True),)else:return nn.Sequential(nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),nn.Tanh(),)def forward(self, noise):x = noise.view(len(noise), self.z_dim, 1, 1)return self.gen(x)def get_noise(n_samples, z_dim, device='cpu'):return torch.randn(n_samples, z_dim, device=device)z_dim = 64
image_size = 299
device = 'cuda'transform = transforms.Compose([transforms.Resize(image_size),transforms.CenterCrop(image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])dataset = CelebA(".", download=True, transform=transform)gen = Generator(z_dim).to(device)
gen.load_state_dict(torch.load(f"pretrained_celeba.pth", map_location=torch.device(device))["gen"])
gen = gen.eval()from torchvision.models import inception_v3
inception_model = inception_v3(pretrained=False)
inception_model.load_state_dict(torch.load("inception_v3_google-1a9a5a14.pth"))
inception_model.to(device)
inception_model = inception_model.eval() # Evaluation modeinception_model.fc = torch.nn.Identity()from torch.distributions import MultivariateNormal
import seaborn as sns # This is for visualization
mean = torch.Tensor([0, 0]) # Center the mean at the origin
covariance = torch.Tensor( # This matrix shows independence - there are only non-zero values on the diagonal[[1, 0],[0, 1]]
)
independent_dist = MultivariateNormal(mean, covariance)
samples = independent_dist.sample((10000,))
res = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind="kde")
plt.show()mean = torch.Tensor([0, 0])
covariance = torch.Tensor([[2, -1],[-1, 2]]
)
covariant_dist = MultivariateNormal(mean, covariance)
samples = covariant_dist.sample((10000,))
res = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind="kde")
plt.show()import scipy
def matrix_sqrt(x):y = x.cpu().detach().numpy()y = scipy.linalg.sqrtm(y)return torch.Tensor(y.real, device=x.device)def frechet_distance(mu_x, mu_y, sigma_x, sigma_y):return (mu_x - mu_y).dot(mu_x - mu_y) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2*torch.trace(matrix_sqrt(sigma_x @ sigma_y))def preprocess(img):img = torch.nn.functional.interpolate(img, size=(299, 299), mode='bilinear', align_corners=False)return imgimport numpy as np
def get_covariance(features):return torch.Tensor(np.cov(features.detach().numpy(), rowvar=False))fake_features_list = []
real_features_list = []n_samples = 512 # The total number of samples
batch_size = 4 # Samples per iterationdataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True)cur_samples = 0
with torch.no_grad(): # You don't need to calculate gradients here, so you do this to save memorytry:for real_example, _ in tqdm(dataloader, total=n_samples // batch_size): # Go by batchreal_samples = real_examplereal_features = inception_model(real_samples.to(device)).detach().to('cpu') # Move features to CPUreal_features_list.append(real_features)fake_samples = get_noise(len(real_example), z_dim).to(device)fake_samples = preprocess(gen(fake_samples))fake_features = inception_model(fake_samples.to(device)).detach().to('cpu')fake_features_list.append(fake_features)cur_samples += len(real_samples)if cur_samples >= n_samples:breakexcept:print("Error in loop")fake_features_all = torch.cat(fake_features_list)
real_features_all = torch.cat(real_features_list)mu_fake = fake_features_all.mean(0)
mu_real = real_features_all.mean(0)
sigma_fake = get_covariance(fake_features_all)
sigma_real = get_covariance(real_features_all)indices = [2, 4, 5]
fake_dist = MultivariateNormal(mu_fake[indices], sigma_fake[indices][:, indices])
fake_samples = fake_dist.sample((5000,))
real_dist = MultivariateNormal(mu_real[indices], sigma_real[indices][:, indices])
real_samples = real_dist.sample((5000,))import pandas as pd
df_fake = pd.DataFrame(fake_samples.numpy(), columns=indices)
df_real = pd.DataFrame(real_samples.numpy(), columns=indices)
df_fake["is_real"] = "no"
df_real["is_real"] = "yes"
df = pd.concat([df_fake, df_real])
sns.pairplot(df, plot_kws={'alpha': 0.1}, hue='is_real')
plt.show()with torch.no_grad():print(frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake).item())

代码中使用的生成器模型可以从上一篇当中下载,inception_v3_google-1a9a5a14.pth模型可以从这里下载。

代码解析

  • 去掉分类层
inception_model.fc = torch.nn.Identity()

将最后一层的全连接层替换为恒等函数,它将输入的数据不做任何操作、原封不动地输出。
通常Inception模型的全连接层用于图像分类任务,它将提取的特征映射到类别预测上。然而我们不需要进行图像分类,而是想要利用Inception模型的前面部分来提取图像的特征。
这样就将Inception模型从原始的分类任务模型转变为一个特征提取器,从而不再执行图像分类任务,而是将图像转换为特征向量。

  • 可视化多变量正态分布
from torch.distributions import MultivariateNormal
import seaborn as sns # This is for visualization
mean = torch.Tensor([0, 0]) # Center the mean at the origin
covariance = torch.Tensor( # This matrix shows independence - there are only non-zero values on the diagonal[[1, 0],[0, 1]]
)
independent_dist = MultivariateNormal(mean, covariance)
samples = independent_dist.sample((10000,))
res = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind="kde")
plt.show()mean = torch.Tensor([0, 0])
covariance = torch.Tensor([[2, -1],[-1, 2]]
)
covariant_dist = MultivariateNormal(mean, covariance)
samples = covariant_dist.sample((10000,))
res = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind="kde")
plt.show()

首先定义均值和协方差矩阵(原理中举的两个例子),然后使用MultivariateNormal构建一个多变量正态分布对象covariant_dist。然后从这个分布中抽取了10000个样本,每个样本是一个shape为(samples, 2)的二维向量。最后将生成的样本可视化为二维核密度估计图(Kernel Density Estimate,KDE)。
在这里插入图片描述
在这里插入图片描述

  • 计算矩阵的平方根
def matrix_sqrt(x):y = x.cpu().detach().numpy()y = scipy.linalg.sqrtm(y)return torch.Tensor(y.real, device=x.device)

首先将输入矩阵转移到CPU上并将其转换为NumPy数组。这是因为scipy.linalg.sqrtm函数只能接受NumPy数组作为输入,不能接受PyTorch张量,且在CPU上计算更高效。
然后使用scipy.linalg.sqrtm函数计算平方根且返回一个复数矩阵,所以需要取其实部(real)部分,然后再转换为PyTorch张量。同时,函数还会确保新的张量与输入矩阵在相同的设备(device)上。

  • 计算FID
def frechet_distance(mu_x, mu_y, sigma_x, sigma_y):return (mu_x - mu_y).dot(mu_x - mu_y) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2*torch.trace(matrix_sqrt(sigma_x @ sigma_y))

给定两个分布的均值和协方差矩阵,利用原理中的公式进行计算。

  • 对生成图像进行处理
def preprocess(img):img = torch.nn.functional.interpolate(img, size=(299, 299), mode='bilinear', align_corners=False)return img

将输入的图像进行插值操作,插值方法使用双线性插值,参数align_corners=False指示在进行插值操作时不对齐图像的角点,这在图像处理中常用于避免不必要的插值偏差。
在这里插入图片描述

  • 计算协方差矩阵
def get_covariance(features):return torch.Tensor(np.cov(features.detach().numpy(), rowvar=False))

使用NumPy的np.cov()函数计算特征向量集合的协方差矩阵,rowvar=False参数表示传递的数据中每一列代表一个特征向量的观测值,而不是每一行代表一个观测样本。

  • 提取特征
for real_example, _ in tqdm(dataloader, total=n_samples // batch_size): # Go by batchreal_samples = real_examplereal_features = inception_model(real_samples.to(device)).detach().to('cpu') # Move features to CPUreal_features_list.append(real_features)fake_samples = get_noise(len(real_example), z_dim).to(device)fake_samples = preprocess(gen(fake_samples))fake_features = inception_model(fake_samples.to(device)).detach().to('cpu')fake_features_list.append(fake_features)cur_samples += len(real_samples)if cur_samples >= n_samples:break

使用预训练的Inception模型提取真实图像和生成图像的特征,并将这些特征存储在列表中,以备后续计算Fréchet Distance。
在这里需要对生成的图像进行preprocess()处理为299的宽高是因为真实数据的宽高为299,而生成数据的宽高为64。
我们可以将生成数据和preprocess处理后的数据显示出来看效果:

import matplotlib.pyplot as plt# 选择其中一个样本进行显示
sample_index = 0# 显示生成图像
fake_image = fake[sample_index].permute(1, 2, 0)  # 将张量形状转换为图像的形状(C, H, W)->(H, W, C)
plt.imshow(fake_image)
plt.axis('off')
plt.show()# 显示经过处理的图像
fake_image = fake_samples[sample_index].permute(1, 2, 0)  # 将张量形状转换为图像的形状(C, H, W)->(H, W, C)
plt.imshow(fake_image)
plt.axis('off')
plt.show()

在这里插入图片描述
在这里插入图片描述
可以看到插值操作后平滑很多。

  • 可视化真实数据分布与生成数据分布,并计算FID
indices = [2, 4, 5]
import pandas as pd
df_fake = pd.DataFrame(fake_samples.numpy(), columns=indices)
df_real = pd.DataFrame(real_samples.numpy(), columns=indices)
df_fake["is_real"] = "no"
df_real["is_real"] = "yes"
df = pd.concat([df_fake, df_real])
sns.pairplot(df, plot_kws={'alpha': 0.1}, hue='is_real')
plt.show()with torch.no_grad():print(frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake).item())

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/2730.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

出海抢先看!必应bing广告投放的运营策略,助力企业获客爆单!

当今出海已成为众多企业拓展业务、寻求新增长点的重要战略,面对广阔的海外市场,选择合适的推广渠道对于提升品牌影响力、实现高效获客至关重要。必应Bing作为全球第二大搜索引擎,覆盖了大量高质量用户群体,尤其是在北美和欧洲市场…

HarmonyOS-静态库(SDK)的创建和使用

文章目录 一、静态库(SDK)二、创建静态库1.新建静态库模块2. 开发静态库内容3. 编译静态库 三、使用静态库1. 配置项目依赖2. 在应用中使用静态库3. 注意事项 四、打包错误1. library引用本地har包错误 一、静态库(SDK) 在Harmon…

IPEmotion轻松解决急停设备的控制与数据存储问题

一 背景 众所周知,急停操作在各种工业领域中都扮演着非常重要的角色。在一个个紧急情况下,及时采取急停操作可节省宝贵时间,避免人身伤害或设备损坏,降低安全风险,尤其是在新能源测试中,出于对高压电性能方…

Linux使用Libevent库实现一个网页服务器---C语言程序

Web服务器 这一个库的实现 其他的知识都是这一个专栏里面的文章 实际使用 编译的时候需要有一个libevent库 gcc httpserv.c -o httpserv -levent实际使用的时候需要指定端口以及共享的目录 ./httpserv 80 .这一个函数会吧这一个文件夹下面的所有文件共享出去 实际的效果, 这…

AI大模型重塑新媒体变现格局:智能写作技术助力腾飞!

文章目录 一、AI大模型:新媒体变革的引擎二、智能写作:内容生产的新范式三、精准推送:增强用户粘性的关键四、新媒体变现:插上AI翅膀的飞跃五、挑战与机遇并存:AI与新媒体的未来展望AI智能写作: 巧用AI大模型让新媒体变…

Python-GEE遥感云大数据分析、管理与可视化

原文链接:Python-GEE遥感云大数据分析、管理与可视化https://mp.weixin.qq.com/s?__bizMzUzNTczMDMxMg&mid2247601238&idx2&sn6b0557cf61451eaff65f025d648da869&chksmfa820db1cdf584a76de953b96519704177e6206d4ecd47a2f2fabbcac2f7ea619b0bce184…

idm序列号永久激活码2023免费可用 IDM软件破解版下载 最新版Internet Download Manager 网络下载加速必备神器 IDM设置中文

IDM是一款多线程下载工具,全称Internet Download Manager。IDM的多线程加速功能,能够充分利用宽带,所以下载速度会比较快,而且它支持断点续传。它的网站音视频捕获、站点抓取、静默下载等功能,也特别实用。 idm使用技…

数组:最值,反转数组,打乱顺序

文章目录 最值反转数组打乱顺序 位置 最值 package com.zhang.demo; /*这个是求最大值 * * */ public class Test1 {public static void main(String[] args) {int[] arr {13,77,89,333,2,99};int max arr[0];for(int i 1;i < arr.length-1;i){if(max < arr[i]){maxa…

JavaEE 初阶篇-深入了解网络通信相关的基本概念(三次握手建立连接、四次挥手断开连接)

&#x1f525;博客主页&#xff1a; 【小扳_-CSDN博客】 ❤感谢大家点赞&#x1f44d;收藏⭐评论✍ 文章目录 1.0 网络通信概述 1.1 基本的通信架构 2.0 网络通信三要素 3.0 网络通信三要素 - IP 地址 3.1 查询 IP 地址 3.2 IP 地址由谁供应&#xff1f; 3.3 IP 域名 3.4 IP 分…

智慧城市标准化白皮书(2022版)发布

2022年7月25日&#xff0c;国家智慧城市标准化总体组2022年度全体会议召开期间&#xff0c;《智慧城市标准化白皮书&#xff08;2022版&#xff09;》正式发布。 城市作为一个复杂巨系统&#xff0c;是多元主体融合及多元活动集聚的复杂综合体。城市的运行发展关联 到发展、治…

Maven基础篇6

Idea环境中资源上传与下载 具体问题本地仓库如何与私服打交道&#xff1b; 本地仓库向私服上传文件&#xff0c;上传的文件位置在哪里&#xff1f; 访问私服配置相关信息&#xff1a;用户名密码&#xff1b; 下载东西&#xff0c;需要的各种信息&#xff0c;需要的仓库组的…

MES(生产管理系统)开发岗人才定向培养来啦

定向就业培养&#xff0c;职等你来 《中国制造2025》&#xff0c;是我国实施制造强国战略第一个十年的行动纲领&#xff0c;按照“四个全面”战略布局要求&#xff0c;实施制造强国战略&#xff0c;加强 统筹规划和前瞻部署。围绕重点行业转型升级和新一代信息技术、智能制造、…

stripe.js踩坑日记

stripe.js踩坑日记 先附上代码【选择支付方式并唤起对应支付后重定向到支付结果页面】 先安装依赖包 npm install stripe/stripe-js代码【vue3语法】 <template><div class"stripe-pay-ment-box"><div id"payment-element"></div…

【数据库】三、数据库SQL语言命令(基础从入门到入土)

【全文两万多字&#xff0c;涵盖大部分常见情况&#xff0c;建议点赞收藏】 目录 文章目录 目录安装SQL语言1.使用2.DATABASE查看所有库新建数据库修改数据库删除数据库连接数据库 3.TABLE创建表查看库所有表删除表查看表信息重命名表修改表字段&#xff08;列&#xff09;表中…

GUI测试首推!TestComplete 帮助有效缩短 40-50% 测试时长!

TestComplete 是一款自动化UI测试工具&#xff0c;这款工具目前在全球范围内被广泛应用于进行桌面、移动和Web应用的自动化测试。 TestComplete 集成了一种精心设计的自动化引擎&#xff0c;可以自动记录和回放用户的操作&#xff0c;方便用户进行UI&#xff08;用户界面&…

七分钟“手撕”三大特性<多态>

目录 一、学习多态之前需要的知识储备 二、重写 1.什么是重写 2.重写可以干嘛 3.怎么书写重写 4.重载与重写的区别 三、向上转型 1.什么是向上转型&#xff1f; 2.向上转型的语法 3.向上转型的使用场景 四、多态是什么 六、多态实现 七、多态的好处 八、多态的缺…

zookeeper安装原生开发 C API接口时报错

报出的错误&#xff1a;error: %d directive writing between 1 and 5 bytes into a region of size be 问题原因 %d 格式说明符用于格式化有符号十进制整数。它需要一个与要格式化的整数大小相匹配的缓冲区。如果缓冲区太小&#xff0c;则会导致缓冲区溢出&#xff0c;从而可…

码头船只出行及配套货柜码放管理系统-毕设

毕业设计说明书 码头船只出行及配套货柜码放 管理系统 码头船只出行及配套货柜码放管理系统 摘要 伴随着全球化的发展&#xff0c;码头的物流和客运增多&#xff0c;码头业务迎来新的高峰。然而码头业务的增加&#xff0c;导致了人员成本和工作量的增多。为了解决这一基本问题&…

Redis篇:缓存更新策略最佳实践

前景&#xff1a; 缓存更新是redis为了节约内存而设计出来的一个东西&#xff0c;主要是因为内存数据宝贵&#xff0c;当我们向redis插入太多数据&#xff0c;此时就可能会导致缓存中的数据过多&#xff0c;所以redis会对部分数据进行更新&#xff0c;或者把他叫为淘汰更合适&a…

开放式耳机怎样选性价比高?五大性能出色爆款推荐!

在今年的耳机市场&#xff0c;开放式耳机如雨后春笋般涌现&#xff0c;为消费者提供了更多的选择。在这样一个产品繁多的市场中&#xff0c;如何挑选出一款音质上乘、性能卓越的开放式耳机&#xff0c;确实是一个值得探讨的问题。相较于长时间佩戴传统入耳式耳机可能带来的耳朵…