使用Pytorch实现频谱归一化生成对抗网络(SN-GAN)

自从扩散模型发布以来,GAN的关注度和论文是越来越少了,但是它们里面的一些思路还是值得我们了解和学习。所以本文我们来使用Pytorch 来实现SN-GAN

谱归一化生成对抗网络是一种生成对抗网络,它使用谱归一化技术来稳定鉴别器的训练。谱归一化是一种权值归一化技术,它约束了鉴别器中每一层的谱范数。这有助于防止鉴别器变得过于强大,从而导致不稳定和糟糕的结果。

SN-GAN由Miyato等人(2018)在论文“生成对抗网络的谱归一化”中提出,作者证明了sn - gan在各种图像生成任务上比其他gan具有更好的性能。

SN-GAN的训练方式与其他gan相同。生成器网络学习生成与真实图像无法区分的图像,而鉴别器网络学习区分真实图像和生成图像。这两个网络以竞争的方式进行训练,它们最终达到一个点,即生成器能够产生逼真的图像,从而欺骗鉴别器。

以下是SN-GAN相对于其他gan的优势总结:

  • 更稳定,更容易训练
  • 可以生成更高质量的图像
  • 更通用,可以用来生成更广泛的内容。

模式崩溃

模式崩溃是生成对抗网络(GANs)训练中常见的问题。当GAN的生成器网络无法产生多样化的输出,而是陷入特定的模式时,就会发生模式崩溃。这会导致生成的输出出现重复,缺乏多样性和细节,有时甚至与训练数据完全无关。

GAN中发生模式崩溃有几个原因。一个原因是生成器网络可能对训练数据过拟合。如果训练数据不够多样化,或者生成器网络太复杂,就会发生这种情况。另一个原因是生成器网络可能陷入损失函数的局部最小值。如果学习率太高,或者损失函数定义不明确,就会发生这种情况。

以前有许多技术可以用来防止模式崩溃。比如使用更多样化的训练数据集。或者使用正则化技术,例如dropout或批处理归一化,使用合适的学习率和损失函数也很重要。

Wassersteian损失

Wasserstein损失,也称为Earth Mover’s Distance(EMD)或Wasserstein GAN (WGAN)损失,是一种用于生成对抗网络(GAN)的损失函数。引入它是为了解决与传统GAN损失函数相关的一些问题,例如Jensen-Shannon散度和Kullback-Leibler散度。

Wasserstein损失测量真实数据和生成数据的概率分布之间的差异,同时确保它具有一定的数学性质。他的思想是最小化这两个分布之间的Wassersteian距离(也称为地球移动者距离)。Wasserstein距离可以被认为是将一个分布转换为另一个分布所需的最小“成本”,其中“成本”被定义为将概率质量从一个位置移动到另一个位置所需的“工作量”。

Wasserstein损失的数学定义如下:

对于生成器G和鉴别器D, Wasserstein损失(Wasserstein距离)可以表示为:

Jensen-Shannon散度(JSD): Jensen-Shannon散度是一种对称度量,用于量化两个概率分布之间的差异

对于概率分布P和Q, JSD定义如下:

 JSD(P∥Q)=1/2(KL(P∥M)+KL(Q∥M))

M为平均分布,KL为Kullback-Leibler散度,P∥Q为分布P与分布Q之间的JSD。

JSD总是非负的,在0和1之间有界,并且对称(JSD(P|Q) = JSD(Q|P))。它可以被解释为KL散度的“平滑”版本。

Kullback-Leibler散度(KL散度):Kullback-Leibler散度,通常被称为KL散度或相对熵,通过量化“额外信息”来测量两个概率分布之间的差异,这些“额外信息”需要使用另一个分布作为参考来编码一个分布。

对于两个概率分布P和Q,从Q到P的KL散度定义为:KL(P∥Q)=∑x P(x)log(Q(x)/P(x))。KL散度是非负非对称的,即KL(P∥Q)≠KL(Q∥P)。当且仅当P和Q相等时它为零。KL散度是无界的,可以用来衡量分布之间的不相似性。

1-Lipschitz Contiunity

1- lipschitz函数是斜率的绝对值以1为界的函数。这意味着对于任意两个输入x和y,函数输出之间的差不超过输入之间的差。

数学上函数f是1-Lipschitz,如果对于f定义域内的所有x和y,以下不等式成立:

 |f(x) — f(y)| <= |x — y|

在生成对抗网络(GANs)中强制Lipschitz连续性是一种用于稳定训练和防止与传统GANs相关的一些问题的技术,例如模式崩溃和训练不稳定。在GAN中实现Lipschitz连续性的主要方法是通过使用Lipschitz约束或正则化,一种常用的方法是Wasserstein GAN (WGAN)。

在标准gan中,鉴别器(也称为WGAN中的批评家)被训练来区分真实和虚假数据。为了加强Lipschitz连续性,WGAN增加了一个约束,即鉴别器函数应该是Lipschitz连续的,这意味着函数的梯度不应该增长得太大。在数学上,它被限制为:

 ∥∣D(x)−D(y)∣≤K⋅∥x−y∥

其中D(x)是评论家对数据点x的输出,D(y)是y的输出,K是Lipschitz 常数。

WGAN的权重裁剪:在原始的WGAN中,通过在每个训练步骤后将鉴别器网络的权重裁剪到一个小范围(例如,[-0.01,0.01])来强制执行该约束。权重裁剪确保了鉴别器的梯度保持在一定范围内,并加强了利普希茨连续性。

WGAN的梯度惩罚: WGAN的一种变体,称为WGAN-GP,它使用梯度惩罚而不是权值裁剪来强制Lipschitz约束。WGAN-GP基于鉴别器的输出相对于真实和虚假数据之间的随机点的梯度,在损失函数中添加了一个惩罚项。这种惩罚鼓励了Lipschitz约束,而不需要权重裁剪。

谱范数

从符号上看矩阵𝑊的谱范数通常表示为:对于神经网络𝑊矩阵表示网络层中的一个权重矩阵。矩阵的谱范数是矩阵的最大奇异值,可以通过奇异值分解(SVD)得到。

奇异值分解是特征分解的推广,用于将矩阵分解为

其中𝑈,q为正交矩阵,Σ为其对角线上的奇异值矩阵。注意Σ不一定是正方形的。

其中𝜎1和𝑛分别为最大奇异值和最小奇异值。更大的值对应于一个矩阵可以应用于另一个向量的更大的拉伸量。依此表示,𝜎(𝑊)=𝜎1.

SVD在谱归一化中的应用

为了对权矩阵进行频谱归一化,将矩阵中的每个值除以它的频谱范数。谱归一化矩阵可以表示为

计算𝑊is的SVD非常昂贵,所以SN-GAN论文的作者做了一些简化。它们通过幂次迭代来近似左、右奇异向量𝑢和𝑣,分别为:𝑊)≈𝑢

代码实现

现在我们开始使用Pytorch实现

 import torchfrom torch import nnfrom tqdm.auto import tqdmfrom torchvision import transformsfrom torchvision.datasets import MNISTfrom torchvision.utils import make_gridfrom torch.utils.data import DataLoaderimport matplotlib.pyplot as plttorch.manual_seed(0)def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):image_tensor = (image_tensor + 1) / 2image_unflat = image_tensor.detach().cpu()image_grid = make_grid(image_unflat[:num_images], nrow=5)plt.imshow(image_grid.permute(1, 2, 0).squeeze())plt.show()

生成器:

 class Generator(nn.Module):def __init__(self,z_dim=10,im_chan = 1,hidden_dim = 64):super(Generatoe,self).__init__()self.gen = nn.Sequential(self.make_gen_block(z_dim,hidden_dim * 4),self.make_gen_block(hidden_dim*4,hidden_dim * 2,kernel_size = 4,stride =1),self.make_gen_block(hidden_dim * 2,hidden_dim),self.make_gen_block(hidden_dim,im_chan,kernel_size=4,final_layer = True),)def make_gen_block(self,input_channels,output_channels,kernel_size=3,stride=2,final_layer = False):if not final_layer :return nn.Sequential(nn.ConvTranspose2D(input_layer,output_layer,kernel_size,stride),nn.BatchNorm2d(output_channels),nn.ReLU(inplace = True),)else:return nn.Sequential(nn.ConvTranspose2D(input_layer,output_layer,kernel_size,stride),nn.Tanh(),)def unsqueeze_noise():return noise.view(len(noise), self.z_dim, 1, 1)def forward(self,noise):x = self.unsqueeze_noise(noise)return self.gen(x)def get_noise(n_samples, z_dim, device='cpu'):return torch.randn(n_samples, z_dim, device=device)

鉴频器

对于鉴别器,我们可以使用spectral_norm对每个Conv2D 进行处理。除了𝑊之外,还引入了𝑢、𝑣、和其他的参数,这样在运行时就可以计算出𝑊𝑆的二进制二进制运算符:𝑢、y、y、y、y

因为Pytorch还提供 nn.utils. spectral_norm,nn.utils. remove_spectral_norm函数,所以我们操作起来很方便。

我们只在推理期间将nn.utils. remove_spectral_norm应用于卷积层,以提高运行速度。

值得注意的是,谱范数并不能消除对批范数的需要。谱范数影响每一层的权重,批范数影响每一层的激活度。

 class Discriminator(nn.Module):def __init__(self, im_chan=1, hidden_dim=16):super(Discriminator, self).__init__()self.disc = nn.Sequential(self.make_disc_block(im_chan, hidden_dim),self.make_disc_block(hidden_dim, hidden_dim * 2),self.make_disc_block(hidden_dim * 2, 1, final_layer=True),)def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):if not final_layer:return nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(input_channels, output_channels, kernel_size, stride)),nn.BatchNorm2d(output_channels),nn.LeakyReLU(0.2, inplace=True),)else: return nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(input_channels, output_channels, kernel_size, stride)),)def forward(self, image):disc_pred = self.disc(image)return disc_pred.view(len(disc_pred), -1)

训练

我们这里使用MNIST数据集,bcewithlogitsloss()函数计算logit和目标标签之间的二进制交叉熵损失。二值交叉熵损失是对两个分布差异程度的度量。在二元分类中,这两种分布分别是逻辑的分布和目标标签的分布。

 criterion = nn.BCEWithLogitsLoss()n_epochs = 50z_dim = 64display_step = 500batch_size = 128# A learning rate of 0.0002 works well on DCGANlr = 0.0002beta_1 = 0.5 beta_2 = 0.999device = 'cuda'transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),])dataloader = DataLoader(MNIST(".", download=True, transform=transform),batch_size=batch_size,shuffle=True)

创建生成器和鉴别器

 gen = Generator(z_dim).to(device)gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))disc = Discriminator().to(device) disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))# initialize the weights to the normal distribution# with mean 0 and standard deviation 0.02def weights_init(m):if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):torch.nn.init.normal_(m.weight, 0.0, 0.02)if isinstance(m, nn.BatchNorm2d):torch.nn.init.normal_(m.weight, 0.0, 0.02)torch.nn.init.constant_(m.bias, 0)gen = gen.apply(weights_init)disc = disc.apply(weights_init)

下面是训练步骤

 cur_step = 0mean_generator_loss = 0mean_discriminator_loss = 0for epoch in range(n_epochs):# Dataloader returns the batchesfor real, _ in tqdm(dataloader):cur_batch_size = len(real)real = real.to(device)## Update Discriminator ##disc_opt.zero_grad()fake_noise = get_noise(cur_batch_size, z_dim, device=device)fake = gen(fake_noise)disc_fake_pred = disc(fake.detach())disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))disc_real_pred = disc(real)disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))disc_loss = (disc_fake_loss + disc_real_loss) / 2# Keep track of the average discriminator lossmean_discriminator_loss += disc_loss.item() / display_step# Update gradientsdisc_loss.backward(retain_graph=True)# Update optimizerdisc_opt.step()## Update Generator ##gen_opt.zero_grad()fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)fake_2 = gen(fake_noise_2)disc_fake_pred = disc(fake_2)gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))gen_loss.backward()gen_opt.step()# Keep track of the average generator lossmean_generator_loss += gen_loss.item() / display_step## Visualization code ##if cur_step % display_step == 0 and cur_step > 0:print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")show_tensor_images(fake)show_tensor_images(real)mean_generator_loss = 0mean_discriminator_loss = 0cur_step += 1

训练结果如下:

总结

本文我们介绍了SN-GAN的原理和简单的代码实现,SN-GAN已经被广泛应用于图像生成任务,包括图像合成、风格迁移和超分辨率等领域。它在改善生成模型的性能和稳定性方面取得了显著的成果,所以学习他的代码对我们理解会更有帮助。

https://avoid.overfit.cn/post/0c52f9dc7d124cb3998c95360b745463

作者:DhanushKumar

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/108042.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

YOLO目标检测——跌倒摔倒数据集【含对应voc、coco和yolo三种格式标签】

实际项目应用&#xff1a;公共安全监控、智能家居、工业安全等活动区域无监管情况下的人员摔倒事故数据集说明&#xff1a;YOLO目标检测数据集&#xff0c;真实场景的高质量图片数据&#xff0c;数据场景丰富。使用lableimg标注软件标注&#xff0c;标注框质量高&#xff0c;含…

【Eclipse】取消按空格自动补全,以及出现没有src的解决办法

【Eclipse】设置自动提示 教程 根据上方链接&#xff0c;我们已经知道如何设置Eclipse的自动补全功能了&#xff0c;但是有时候敲变量名的时候按空格&#xff0c;本意是操作习惯&#xff0c;不需要自动补全&#xff0c;但是它却给我们自动补全了&#xff0c;这就造成了困扰&…

诚迈科技董事长王继平出席中国(太原)人工智能大会并发表演讲

10月14日—15日&#xff0c;2023中国&#xff08;太原&#xff09;人工智能大会在山西省太原市举办。诚迈科技在大会上全面展示了其在人工智能领域的一系列创新技术与解决方案&#xff0c;诚迈科技董事长、统信软件董事长王继平受邀出席产业数字化转型论坛并发表主题演讲&#…

面试58同城!面试官问我redis 雪崩、穿透、击穿怎么处理?

一、Redis 缓存雪崩 1.1 缓存雪崩的概念 缓存雪崩指的是在某个时间点&#xff0c;缓存中的大量数据同时失效&#xff0c;导致大量请求直接落到数据库上&#xff0c;造成数据库压力过大&#xff0c;甚至引发系统崩溃。 1.2 缓存雪崩发生的原因 缓存雪崩通常是由以下原因引起…

[已解决]llegal target for variable annotation

llegal target for variable annotation 问题 变量注释的非法目标 思路 复制时编码错误&#xff0c;自己敲一遍后正常运行 #** 将垂直知识加入prompt&#xff0c;以使其准确回答 **# prompt_templates { # "recommand":"用户说&#xff1a;__INPUT__ …

紫光展锐荣评“5G技术创新力企业”,5G赋能千行百业

近日&#xff0c;2023年第十七届中国通信产业榜隆重发布&#xff0c;紫光展锐凭借多年以来在通信和芯片技术上的积累&#xff0c;从众多参选者中脱颖而出&#xff0c;荣评“5G技术创新力企业”&#xff0c;并蝉联2023年通信产业榜“中国通信设备技术服务供应商100强”。 作为一…

5.1 加载矢量图层(ogr,gpx)

文章目录 前言加载矢量(vector)图层ogrShapefileQGis导入.shp文件代码导入 gpxQGis导入GPX文件代码导入 gpkgQGis导入GPKG文件代码导入 geojsonQGis导入GeoJson文件代码导入 gmlQGis导入GML代码导入 kml/kmzQGis导入Kml代码导入 dxf/dwgQGis导入dxf代码导入 CoverageQGis导入Co…

JMeter安装及环境配置

1. JMeter 介绍 Apache组织开发的基于Java的压力测试工具 100%纯Java开发、完全的可移植性 可以用于测试静态和动态资源 多协议—HTTP/FTP/socket/Java/数据库(JDBC) 完全多线程 高可扩展性 2. 安装jdk并配置jdk环境 因为jmeter运行依赖jdk环境&#xff0c;所以在安装j…

lvgl 页面管理器

lv_scr_mgr lvgl 界面管理器 适配 lvgl 8.3 降低界面之间的耦合使用较小的内存&#xff0c;界面切换后会自动释放内存内存泄漏检测 使用方法 在lv_scr_mgr_port.h 中创建一个枚举&#xff0c;用于界面ID为每个界面创建一个页面管理器句柄将界面句柄添加到 lv_scr_mgr_por…

发电机组负载测试的必要性

发电机组负载测试是确保发电机组能够在实际运行中稳定工作的重要步骤&#xff0c;负载测试可以模拟发电机组在不同负载条件下的工作情况&#xff0c;评估其性能和稳定性。负载测试可以验证发电机组在不同负载条件下的性能表现&#xff0c;通过模拟实际使用情况评估发电机组的输…

css3自动吸附scroll-snap

我们希望可以一块一块的滚动&#xff0c;比如当前一个块滚出去了一部分并且后一个块滚进来一部分的时候&#xff0c;实现后一个块自动滚入或者前一个块回弹到初始位置这种效果&#xff0c;以前的时候用js需要写比较复杂的判断逻辑&#xff0c;后来有了一个css scroll snap这个方…

Java打印二进制

&#x1f495;"把握未定&#xff0c;宜绝迹尘嚣&#xff0c;使此心不见可欲而不乱&#xff0c;以澄悟吾静体。"&#x1f495; 作者&#xff1a;Mylvzi 文章主要内容&#xff1a;Java打印二进制 Java中打印二进制的方法有很多&#xff0c;这里介绍三种方式 1.利用In…

CSI2与CDPHY学习

注意&#xff1a;本文是基于CSI2-V3.0 spec。 其中CPHY为 V2.0 DPHY为V2.5 本文主要在packet级别介绍CSI2与对应的CDPHY&#xff0c;需要注意的是&#xff1a; CDPHY的HS burst数据和LPDT都是以packet为单位传输数据。 其中LPDT包括Escape和ALP的LPDT 1.CSI-CPHY 1.1CPH…

传输机房的基本结构

文章目录 传输机房主要结构 传输机房主要结构 ODF &#xff08;Optical Distribution Frame&#xff09;&#xff0c;光纤配线架&#xff0c;是专为光纤通信机房设计的光纤配线设备&#xff0c;具有光缆固定和保护功能、光缆终接功能、调线功能&#xff0c;完成从设备间纤缆连…

Linux学习——进程状态

目录 一&#xff0c;进程状态 1&#xff0c;进程状态的分类 2.状态的本质 3.进程状态详解 1.运行状态 2.阻塞状态 3.挂起状态 4.Linux内核中的状态分类 一&#xff0c;进程状态 1&#xff0c;进程状态的分类 如下图&#xff1a; 在计算机中我们的状态的分类便如下图所示…

【Java实战】Mysql读写分离主从复制搭建保姆级教程

MySQL 的数据同步通常采用主从复制&#xff08;Master-Slave&#xff09;的方式。 主从复制基于二进制日志&#xff08;binlog&#xff09;。主服务器&#xff08;Master&#xff09;在 binlog 中记录数据更改&#xff0c;从服务器&#xff08;Slave&#xff09;将这些日志读取…

Git版本控制管理

Git基础_环境配置 当安装Git后首先要做的事情是设置用户名称和email地址。这是非常重要的&#xff0c;因为每次Git提交都会使用该用户信息。 设置用户信息 git config --global user.name "Bandits" git config --global user.email "gb010704163.com"查…

Java反射调用jar包实现多态

上一篇实现了反射调用jar包&#xff0c;但是没有实现多态&#xff0c;这次先给自己的jar包类抽象一个接口&#xff0c;然后实现类实现接口。最后调用放反射得到的对像转换成接口类型调用执行。 定义接口&#xff0c;指定包为ZLZJar package ZLZJar;public interface ITest {p…

FPGA中的LUT查找表工作原理。

在RAM中填入1110,后续的不同AB组合Y输出对应的值&#xff0c;实现上面逻辑表达式的功能。

计算机视觉开源代码汇总

1.【基础网络架构】Regularization of polynomial networks for image recognition 论文地址&#xff1a;https://arxiv.org/pdf/2303.13896.pdf 开源代码:https://github.com/grigorisg9gr/regularized_polynomials 2.【目标检测&#xff1a;域自适应】2PCNet: Two-Phase Cons…