使用pytorch构建GAN网络并实现FID评估

上一篇文章介绍了GAN的详细理论,只要掌握了GAN,对于后面各种GAN的变形都变得很简单,基础打好了,盖大楼自然就容易了。既然有了理论,实践也是必不可少的,这篇文章将使用mnist数据集来实现简单的GAN网络,并附带使用FID来评估生成质量。

1. FID评估方法

1.1 计算方法

Fréchet Inception Distance (FID),是一种用于评估生成模型生成图像质量的指标,通常用于比较生成图像与真实图像之间的相似度,FID的数值越低表示生成的图像质量越好。具体来源可自行百度一下,这里不在介绍。FID是通过计算两组图像的均值,方差的距离,从而计算两组图像分布的相似读。直接看公式:
F I D ( r e a l , g e n ) = ∣ ∣ μ r e a l − μ g e n ∣ ∣ 2 2 + T r ( C r e a l + C g e n − 2 ( C r e a l C g e n ) 1 / 2 ) FID(real,gen) = ||\mu_{real}-\mu_{gen}||_2^2 + Tr(C_{real} + C_{gen} - 2(C_{real}C_{gen})^{1/2}) FID(real,gen)=∣∣μrealμgen22+Tr(Creal+Cgen2(CrealCgen)1/2)
其中 μ r e a l , μ g e n \mu_{real},\mu_{gen} μreal,μgen是real数据和gen数据分布的均值, C r e a l , C g e n C_{real},C_{gen} Creal,Cgen表示real和gen各自特征向量的各自的协方差;Tr表示矩阵的迹 T r ( A ) = ∑ i = 1 n A i i Tr(A)=\sum_{i=1}^nA_{ii} Tr(A)=i=1nAii(方阵对角线元素之和)。
这里需要注意到是,一般情况real数据和gen数据是经过inception V3模型提取图像特征后的结果,并非真实输入图片。

1.2 代码实现

虽然有些库里面集成了FID函数,为了更好理解,我们手动来实现这个代码。
主要分为三个部分来计算:

  • inception V3 特征提取
  • 均值计算、协方差计算
  • FID计算

具体我们来看一下完整代码实现。

import torch
import torchvision.models as models
import numpy as np
from scipy import linalg"""
FID 测试一般3000~5000张图片,
FID小于50:生成质量较好,可以认为生成的图像与真实图像相似度较高。
FID在50到100之间:生成质量一般,生成的图像与真实图像相似度一般。
FID大于100:生成质量较差,生成的图像与真实图像相似度较低。
"""# 加载预训练inception v3模型, 并移除top层,第一次运行会下载模型到cache里面
def load_inception():model = models.inception_v3(weights='IMAGENET1K_V1')model.eval()# 将fc用Identity()代替,即去掉fc层model.fc = torch.nn.Identity()return model# inception v3 特征提取
def extract_features(images, model):# images = images / 255.0with torch.no_grad():feat = model(images)return feat.numpy()# FID计算
def cal_fid(images1, images2):"""images1, images2: nchw 归一化,且维度resize到[N,3,299,299]"""model = load_inception()#1. inception v3 特征feats1 = extract_features(images1, model)feats2 = extract_features(images2, model)#2. 均值协方差feat1_mean, feat1_cov = np.mean(feats1, axis=0), np.cov(feats1, rowvar=False)feat2_mean, feat2_cov = np.mean(feats2, axis=0), np.cov(feats2, rowvar=False)#3. Fréchet距离sqrt_trace_cov = linalg.sqrtm(feat1_cov @ feat2_cov)fid = np.sum((feat1_mean - feat2_mean) ** 2) + np.trace(feat1_cov + feat2_cov - 2 * sqrt_trace_cov)return fid.realif __name__ == '__main__':f = cal_fid(torch.rand(1000, 3, 299, 299), torch.rand(1000, 3, 299, 299))print(f)

2. 构建GAN网络

参考:
https://github.com/growvv/GAN-Pytorch/blob/main/README.md

2.1 使用全连接构建一个最简单的GAN网络

2.1.1 网络结构

import torch
import torch.nn as nn
from torchinfo import summaryclass Discriminator(nn.Module):def __init__(self, in_features):super().__init__()self.disc = nn.Sequential(nn.Linear(in_features, 256),  # 784 -> 256nn.LeakyReLU(0.2),  #nn.Linear(256, 256), # 256 -> 256nn.LeakyReLU(0.2),nn.Linear(256, 1),  # 255 -> 1nn.Sigmoid(),   # 将实数映射到[0,1]区间)def forward(self, x):return self.disc(x)class Generator(nn.Module):def __init__(self, z_dim, image_dim):super().__init__()self.gen = nn.Sequential(nn.Linear(z_dim, 256),   # 64 升至 256维nn.ReLU(True),nn.Linear(256, 256),   # 256 -> 256nn.ReLU(True),nn.Linear(256, image_dim), # 256 -> 784nn.Tanh(),  # Tanh使得生成数据范围在[-1, 1],因为真实数据经过transforms后也是在这个区间)def forward(self, x):return self.gen(x)if __name__ == "__main__":gnet = Generator(64, 784)dnet = Discriminator(784)summary(gnet, input_data=[torch.randn(10, 64)])summary(dnet, input_data=[torch.randn(10, 784)])

网络结构运行以上代码,可以查看模型结构:

在这里插入图片描述

2.1.2 训练代码

以下是训练代码,直接可以运行

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from simplegan import Generator, Discriminator# 超参数
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1
batch_size = 32
num_epochs = 100Disc = Discriminator(image_dim).to(device)
Gen = Generator(z_dim, image_dim).to(device)
opt_disc = optim.Adam(Disc.parameters(), lr=lr)
opt_gen = optim.Adam(Gen.parameters(), lr=lr)
criterion = nn.BCELoss()  # 单目标二分类交叉熵函数transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),]
)
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)fixed_noise = torch.randn((batch_size, z_dim)).to(device)
write_fake = SummaryWriter(f'logs/fake')
write_real = SummaryWriter(f'logs/real')
step = 0for epoch in range(num_epochs):for batch_idx, (real, _) in enumerate(loader):real = real.view(-1, 784).to(device)batch_size = real.shape[0]## D: 目标:真的判断为真,假的判断为假## 训练Discriminator: max log(D(x)) + log(1-D(G(z)))disc_real = Disc(real)#.view(-1)  # 将真实图片放入到判别器中lossD_real = criterion(disc_real, torch.ones_like(disc_real))  # 真的判断为真noise = torch.randn(batch_size, z_dim).to(device)fake = Gen(noise)  # 将随机噪声放入到生成器中disc_fake = Disc(fake).view(-1)  # 识别器判断真假lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))  # 假的应该判断为假lossD = (lossD_real + lossD_fake) / 2  # loss包括判真损失和判假损失Disc.zero_grad()   # 在反向传播前,先将梯度归0lossD.backward(retain_graph=True)  # 将误差反向传播opt_disc.step()   # 更新参数# G: 目标:生成的越真越好## 训练生成器: min log(1-D(G(z))) <-> max log(D(G(z)))output = Disc(fake).view(-1)   # 生成的放入识别器lossG = criterion(output, torch.ones_like(output))  # 与“真的”的距离,越小越好Gen.zero_grad()lossG.backward()opt_gen.step()# 输出一些信息,便于观察if batch_idx == 0:print(f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)}' \loss D: {lossD:.4f}, loss G: {lossG:.4f}")with torch.no_grad():fake = Gen(fixed_noise).reshape(-1, 1, 28, 28)data = real.reshape(-1, 1, 28, 28)img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)img_grid_real = torchvision.utils.make_grid(data, normalize=True)write_fake.add_image("Mnist Fake Image", img_grid_fake, global_step=step)write_real.add_image("Mnist Real Image", img_grid_real, global_step=step)step += 1

使用 tensorboard --logdir=./log/fake 查看生成的质量, 这个是41个epoch的结果,想要质量更好一点,可以继续训练。
在这里插入图片描述

2.2 DCGAN网络

DCGAN只是把全连接替换成全卷积的结构,思路完全一样,没什么变换

2.2.1 DCGAN网络结构

"""
Discriminator and Generator implementation from DCGAN paper
"""import torch
import torch.nn as nn
from torchinfo import summaryclass Discriminator(nn.Module):def __init__(self, channels_img, features_d):super().__init__()self.disc = nn.Sequential(self._block(channels_img, features_d, kernel_size=4, stride=2, padding=1),self._block(features_d, features_d * 2, 4, 2, 1),self._block(features_d * 2, features_d * 4, 4, 2, 1),self._block(features_d * 4, features_d * 8, 4, 2, 1),nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),nn.Sigmoid(),)def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),nn.LeakyReLU(0.2),)def forward(self, x):return self.disc(x)class Generator(nn.Module):def __init__(self, channels_noise, channels_img, features_g):super().__init__()self.gen = nn.Sequential(self._block(channels_noise, features_g * 16, 4, 1, 0),self._block(features_g * 16, features_g * 8, 4, 2, 1),self._block(features_g * 8, features_g * 4, 4, 2, 1),self._block(features_g * 4, features_g * 2, 4, 2, 1),nn.ConvTranspose2d(features_g * 2, channels_img, 4, 2, 1),nn.Tanh(),)def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride,padding,bias=False,),nn.ReLU(),)def forward(self, x):return self.gen(x)def initialize_weights(model):## initilialize weight according to paperfor m in model.modules():if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d,)):nn.init.normal_(m.weight.data, 0.0, 0.02)def test():N, in_channels, H, W = 8, 1, 64, 64noise_dim = 100x = torch.randn(N, in_channels, H, W)disc = Discriminator(in_channels, 8)initialize_weights(disc)assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"gen = Generator(noise_dim, in_channels, 8)initialize_weights(gen)z = torch.randn(N, noise_dim, 1, 1)assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"if __name__ == "__main__":gnet = Generator(100, 1, 64)dnet = Discriminator(1, 64)summary(gnet, input_data=[torch.randn(10, 100, 1, 1)])summary(dnet, input_data=[torch.randn(10, 1, 64, 64)])

2.2.2 训练代码

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from dcgan import Generator, Discriminator, initialize_weights
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
import torchvisionLEARNING_RATE = 2e-4
BATCH_SIZE = 128
IMAGE_SIZE = 64
NUM_EPOCHS = 5
CHANNELS_IMG = 1
NOISE_DIM = 100
FEATURES_DISC = 64
FEATURES_GEN = 64transforms = transforms.Compose([transforms.Resize(IMAGE_SIZE),transforms.ToTensor(),transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]),]
)write_fake = SummaryWriter(f'log/fake')
write_real = SummaryWriter(f'log/real')def train(NUM_EPOCHS, gpuid):device = torch.device(f"cuda:{gpuid}")# 数据load# dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)dataset = MNIST(root='./data', train=True, download=True, transform=transforms)dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)initialize_weights(gen)initialize_weights(disc)opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))criterion = nn.BCELoss()fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)writer_real = SummaryWriter(f"logs2/real")writer_fake = SummaryWriter(f"logs2/fake")step = 0gen.train()disc.train()for epoch in range(NUM_EPOCHS):# 不需要目标的标签,无监督for batch_id, (real, _) in enumerate(dataloader):real = real.to(device)noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)fake = gen(noise)# Train Discriminator: max log(D(x)) + log(1 - D(G(z)))disc_real = disc(real).reshape(-1)loss_real = criterion(disc_real, torch.ones_like(disc_real))disc_fake = disc(fake.detach()).reshape(-1)loss_fake = criterion(disc_fake, torch.zeros_like(disc_fake))loss_disc = (loss_real + loss_fake) / 2disc.zero_grad()loss_disc.backward()opt_disc.step()# Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z)), 先训练一个epoch 的Dif epoch >= 0:output = disc(fake).reshape(-1)loss_gen = criterion(output, torch.ones_like(output))gen.zero_grad()loss_gen.backward()opt_gen.step()if batch_id % 20 == 0:print(f'Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_id}/{len(dataloader)} Loss D: {loss_disc}, loss G: {loss_gen}')with torch.no_grad():fake = gen(fixed_noise)img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)writer_real.add_image("Real Image", img_grid_real, global_step=step)writer_fake.add_image("Fake Image", img_grid_fake, global_step=step)step += 1if __name__ == "__main__":train(100, 0)

同样使用tensorboard --logdir=./logs2/fake 查看生成的质量,大概10个epoch的结果

在这里插入图片描述

结论

FID指标可自行测试。GAN的基本训练思路是完全按照论文来做的,包括损失函数设计完全跟论文一致。具体理论可仔细看上一篇博客。如有不足,错误请指出。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/8027.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

为什么使用httpClient发送x-www-form-urlencoded类型的请求时,必须要使用MultiValueMap来传参

大家好&#xff0c;我是G探险者。 今天主要介绍一下MultiValueMap和HashMap的区别。 事情起因是这样的&#xff0c;在我们项目code review的时候,客户方提了一个问题&#xff0c;说&#xff0c;你们在用restTemplate进行远程调用的时候&#xff0c;为啥使用MultiValueMap来传…

数据结构-线性表-应用题-2.2-7

将两个有序顺序表合并为一个新的有序顺序表&#xff0c;并由函数返回结果顺序表 使用了归并排序的思想 按顺序将两个顺序表表头较小的节点存入新的顺序表中&#xff0c;若一个表用完了&#xff0c;就把另一个表的剩下的部分加到新表中去 bool Merge(SeqList A,SeqList B,Seq…

docker部署elasticsearch7.7.0级拼音(pinyin)插件和分词(ik)插件

拉取并启动es docker run -d --namees -p 9200:9200 -p 9300:9300 -e "discovery.typesingle-node" elasticsearch:7.7.0安装pinyin插件 下载pinyin插件 下载ik插件 上传插件到服务器 docker cp /path/to/elasticsearch-analysis-pinyin-7.7.0.zip elasticsearch…

免费分享一套微信小程序在线订餐(点餐)配送系统(SpringBoot+Vue),帅呆了~~

大家好&#xff0c;我是java1234_小锋老师&#xff0c;看到一个不错的微信小程序在线订餐(点餐)配送系统(SpringBootVue)&#xff0c;分享下哈。 项目视频演示 【免费】微信小程序在线订餐(点餐)配送系统(SpringBootVue) Java毕业设计_哔哩哔哩_bilibili【免费】微信小程序在…

Django中如何实现单元测试覆盖率报告?

在 Django 中可以使用 coverage 模块来实现单元测试覆盖率报告。下面是一个实现的步骤&#xff1a; 首先&#xff0c;在项目的根目录下&#xff0c;安装 coverage 模块&#xff1a; pip install coverage创建一个 .coveragerc 文件&#xff0c;用于配置 coverage 的一些参数。在…

91、动态规划-不同的路径

思路&#xff1a; 首先我们可以使用暴力递归解法&#xff0c;无非就是每次向下或者向右看看是否有解法&#xff0c;代码如下&#xff1a; public class Solution {public int uniquePaths(int m, int n) {return findPaths(0, 0, m, n);}private int findPaths(int i, int j,…

企业防泄露如何做到安全有效

随着信息时代的急速演进&#xff0c;企业的重要商业机密越来越多地以电子文档的形式存在。常见的CAD图纸、Office文档承载着公司的核心价值和竞争优势&#xff0c;同时也面临着前所未有的数据安全威胁。确保这些重要信息的文档安全已经成为每个企业必须直面的挑战。在这样的背景…

绝地求生:新型小队对决系统或将择日上线?

就在刚才&#xff0c;PUBG官博发布了一则短视频&#xff0c;视频内容为两只小队通过竞争积分排名产生不断地变化。 原文官博 视频内容 在这里我猜测为之前官方在2024工作计划视频中介绍过的新型小队对决系统&#xff1a; 据当时的介绍称&#xff1a;这个系统中&#xff0c;己方…

【牛客】【模板】差分

原题链接&#xff1a;登录—专业IT笔试面试备考平台_牛客网 目录 1. 题目描述 2. 思路分析 3. 代码实现 1. 题目描述 2. 思路分析 差分模板。 b[0]a[0]; b[1]a[1]-a[0]; b[2]a[2]-a[1]; ...... b[n-1]a[n-1]-a[n-2]; b[n]a[n]-a[n-1]; 差分标记&#xff1a;b[l]k,b…

【DevOps】深入剖析Elasticsearch的分片与副本对性能的影响

目录 一、分片 (Shards) 1、什么是分片&#xff1f; 2、分片的类型 3、分片对性能的影响 二、副本 (Replicas) 1、什么是副本&#xff1f; 2、副本对性能的影响 三、最佳实践 1、主分片数量的选择 2、副本分片的设置 3、监控和调整 4、考虑使用 Shrink 和 Split AP…

vue3中用 let a= b赋值会改变b的值

注意b是reactive const paramreactive({ pageSize:12, }) 在JavaScript中&#xff0c;基本类型&#xff08;比如String&#xff0c;Number&#xff0c;Boolean&#xff0c;undefined&#xff0c;null&#xff09;是按值传递的&#xff0c;这意味着当你将一个基本类型变量赋值给…

k8s部署Kubeflow v1.7.0

文章目录 环境介绍部署访问kubeflow ui问题记录 环境介绍 K8S版本&#xff1a;v1.23.17&#xff0c;需要配置默认的sc 参考&#xff1a;https://github.com/kubeflow/manifests/tree/v1.7.0 部署 #获取安装包 wget https://github.com/kubeflow/manifests/archive/refs/tag…

faiss 原理和使用总结

FAISS&#xff08;Facebook AI Similarity Search&#xff09;是由 Facebook AI Research 开发的一个高效的相似性搜索和密集向量索引库。它主要用于大规模向量搜索和高维数据的聚类。下面&#xff0c;我将为你概述 FAISS 的工作原理和使用方法。 ### 原理 1. **向量量化&…

致远M3 Session 敏感信息泄露漏洞复现

0x01 产品简介 M3移动办公是致远互联打造的一站式智能工作平台,提供全方位的企业移动业务管理,致力于构建以人为中心的智能化移动应用场景,促进人员工作积极性和创造力,提升企业效率和效能,是为企业量身定制的移动智慧协同平台。 0x02 漏洞概述 致远M3 server多个日志文…

函数练习.

1.打印乘法口诀表 口诀表的行数和列数自己指定如&#xff1a;输入9&#xff0c;输出99口诀表&#xff0c;输出12&#xff0c;输出1212的乘法口诀表。 multiplication(int index) { ​if (index 9) { ​int i 0; ​for (i 1; i < 10; i) { ​int j 0; ​for (j 1; j &…

《系统架构设计师教程(第2版)》第10章-软件架构的演化和维护-04-软件架构演化原则

文章目录 1. 演化成本控制原则2. 进度可控原则3. 风险可控原则4. 主体维持原则5. 系统总体结构优化原则6. 平滑演化原则7. 目标一致原则8. 模块独立演化原则9. 影响可控原则10. 复杂性可控原则11. 有利于重构原则12. 有利于重用原则13. 设计原则遵从性原则14. 适应新技术原则15…

Django 4.x 智能分页get_elided_page_range

Django智能分页 分页效果 第1页的效果 第10页的效果 带输入框的效果 主要函数 # 参数解释 # number: 当前页码&#xff0c;默认&#xff1a;1 # on_each_side&#xff1a;当前页码前后显示几页&#xff0c;默认&#xff1a;3 # on_ends&#xff1a;首尾固定显示几页&#…

html5基础知识——表单

表单由三个部分组成&#xff0c;分别是表单域、表单控件、提示信息&#xff08;也就是默认显示的内容&#xff09; 表单域 使用form标签定义 将所有的元素信息定义在一块区域中 用于将表单中的所有元素信息提交给服务器 其中&#xff1a;action表示该表单将要提交到哪个地址&am…

5. DNS 记录和报文

DNS 服务器中以资源记录的形式存储信息&#xff0c;每一个 DNS 响应报文一般包含多条资源记录。一条资源记录的具体的格式为 &#xff08;Name&#xff0c;Value&#xff0c;Type&#xff0c;TTL&#xff09; 其中 TTL 是资源记录的生存时间&#xff0c;它定义了资源记录能够…

LWIP+TCP客户端

一、TCP API函数 其中tcp_poll()函数的第三个参数表示隔几秒调用一次这个周期性函数 二、修改服务器的IP 三、TCP客户端编程思路 申请套接字绑定服务器IP和端口号等待客户端连接 进入连接回调函数在连接回调函数中 配置一些回调函数&#xff0c;如接收回调函数&#xff0c;周期…