使用pytorch构建带梯度惩罚的Wasserstein GAN(WGAN-GP)网络模型

本文为此系列的第三篇WGAN-GP,上一篇为DCGAN。文中仍然不会过多详细的讲解之前写过的,只会写WGAN-GP相对于之前版本的改进点,若有不懂的可以重点看第一篇比较详细。

原理

具有梯度惩罚的 Wasserstein GAN (WGAN-GP)可以解决 GAN 的一些稳定性问题。 具体来说,使用W-loss 作为损失函数替代传统的 BCE 等 loss,并使用梯度惩罚来防止 mode collapse。

  • WGAN-GP 使用了 Wasserstein distance(也成为Earth Mover’s distance, EMD)作为训练 GAN 模型的目标函数,Wasserstein distance is a function of amount and distance,体现的是生成的数据的分布移动到真实数据的分布之间所需的距离与量。
    在这里插入图片描述
    随着判别器训练的越来越好,使用 BCE loss 的话会让鉴别器给出接近于 0 或者接近于 1 的极端值,如下为 sigmoid 曲线,极端值的梯度无限接近于 0,这样判别器就没有太多有用的信息反馈给生成器让它学习,导致梯度消失或 model collapse。使用距离的方式可以有效解决,分布距离再远都不再限制。
    在这里插入图片描述
    在这里插入图片描述
  • BCE loss 本质是一个 minimax game, d 即 discriminator 希望尽可能的 minimize,g 即 generator 希望尽可能的 maximize(意味着造出来的假东西对于鉴别器来说看起来很真实),可以进行如下的简化:
    在这里插入图片描述
    基于 Wasserstein distance 的 W-loss 的的式子与其简化版进行对比:
    在这里插入图片描述
    在 Wasserstein GAN 中不再是 discriminator 了,因为输出不再是 0-1 之间来进行分类,既然不分类了就不是 discriminator 了,而是 critic,所以这里使用 c 代表 critic。critic 希望其尽可能的 maximize,因为希望让 real 和 feak 的距离尽可能的大,起到划清界限的目的;generator 希望其尽可能的minimize,减小两者之间的距离,达到以假乱真的目的。
  • mode collapse 即模式崩溃,当生成器学会从单个类生成特征来欺骗鉴别器时,就会发生 mode collapse(陷入一种模式出不来),跟 cnn 的局部最优是一个概念。这会导致输出出现重复,缺乏多样性和细节。

但在使用 W-loss 训练 GAN 时需要对 critic 有一定的条件 —— critic 需要 1-L(1-Lipschitz)连续:
∣ f ( x 1 ) − f ( x 2 ) ∣ ≤ k ∣ x 1 − x 2 ∣ |f(x_1)-f(x_2)|\le k|x_1-x_2\ | f(x1)f(x2)kx1x2 
这里的 k = 1,也就是 critic 的 nn 函数曲线的梯度(斜率)始终在 -1 到 1 之间,即梯度的 L2 范数不超过1:

在这里插入图片描述
如图:
在这里插入图片描述
曲线的每个点的斜率都是在绿色区域内,很显然这个曲线并不符合。像如下这个曲线就是符合的:
在这里插入图片描述
达到 1-L 连续有两种方法:weigh clipping、gradient penalty。

  • weigh clipping 将权重裁剪到固定范围内,从而限制 critic 的学习能力。但是这样有缺点,可能让所有参数走极端,要么取最大值要么取最小值,critic 会非常倾向于学习一个简单的映射函数。
  • gradient penalty 则是添加一个正则项在 loss function 中,相比 weigh clipping 更加柔和对critic参数的限制更加灵活,通常不会导致梯度消失或梯度爆炸问题。
    在这里插入图片描述
    这里的 λ \lambda λ 为超参值,reg 等于 critic 神经网络梯度范数 -1 的平方,即:
    在这里插入图片描述
    当 critic 神经网络梯度范数 >1 时正则化项发挥作用。平方的作用是为了让其偏离越大,惩罚越大。
    这里的 x ^ \hat{x} x^ 为真实数据与生成数据之间随机取样得到的中间数据,随机值 ϵ \epsilon ϵ 作为权重值,假设 ϵ \epsilon ϵ 为0.3,那么 1- ϵ \epsilon ϵ 为0.7。
    在这里插入图片描述

代码

model.py

from torch import nnclass Generator(nn.Module):def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):super(Generator, self).__init__()self.z_dim = z_dim# Build the neural networkself.gen = nn.Sequential(self.make_gen_block(z_dim, hidden_dim * 4),self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),self.make_gen_block(hidden_dim * 2, hidden_dim),self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),)def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):if not final_layer:return nn.Sequential(nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),nn.BatchNorm2d(output_channels),nn.ReLU(inplace=True),)else:return nn.Sequential(nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),nn.Tanh(),)def forward(self, noise):x = noise.view(len(noise), self.z_dim, 1, 1)return self.gen(x)class Critic(nn.Module):def __init__(self, im_chan=1, hidden_dim=64):super(Critic, self).__init__()self.crit = nn.Sequential(self.make_crit_block(im_chan, hidden_dim),self.make_crit_block(hidden_dim, hidden_dim * 2),self.make_crit_block(hidden_dim * 2, 1, final_layer=True),)def make_crit_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):if not final_layer:return nn.Sequential(nn.Conv2d(input_channels, output_channels, kernel_size, stride),nn.BatchNorm2d(output_channels),nn.LeakyReLU(0.2, inplace=True),)else:return nn.Sequential(nn.Conv2d(input_channels, output_channels, kernel_size, stride),)def forward(self, image):crit_pred = self.crit(image)return crit_pred.view(len(crit_pred), -1)

train.py

import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from model import *
torch.manual_seed(0) # Set for testing purposes, please do not change!def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):image_tensor = (image_tensor + 1) / 2image_unflat = image_tensor.detach().cpu()image_grid = make_grid(image_unflat[:num_images], nrow=5)plt.imshow(image_grid.permute(1, 2, 0).squeeze())plt.show()def get_noise(n_samples, z_dim, device='cpu'):return torch.randn(n_samples, z_dim, device=device)n_epochs = 100
z_dim = 64
display_step = 50
batch_size = 128
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
c_lambda = 10
crit_repeats = 5
device = 'cuda'transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),
])dataloader = DataLoader(MNIST('.', download=False, transform=transform),batch_size=batch_size,shuffle=True)gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
crit = Critic().to(device)
crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(beta_1, beta_2))def weights_init(m):if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):torch.nn.init.normal_(m.weight, 0.0, 0.02)if isinstance(m, nn.BatchNorm2d):torch.nn.init.normal_(m.weight, 0.0, 0.02)torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
crit = crit.apply(weights_init)def get_gradient(crit, real, fake, epsilon):# Mix the images togethermixed_images = real * epsilon + fake * (1 - epsilon)# Calculate the critic's scores on the mixed imagesmixed_scores = crit(mixed_images)# Take the gradient of the scores with respect to the imagesgradient = torch.autograd.grad(inputs=mixed_images,outputs=mixed_scores,# These other parameters have to do with the pytorch autograd engine worksgrad_outputs=torch.ones_like(mixed_scores),create_graph=True,retain_graph=True,)[0]return gradientdef gradient_penalty(gradient):# Flatten the gradients so that each row captures one imagegradient = gradient.view(len(gradient), -1)# Calculate the magnitude of every rowgradient_norm = gradient.norm(2, dim=1)# Penalize the mean squared distance of the gradient norms from 1penalty = torch.mean((gradient_norm - 1) ** 2)return penaltydef get_gen_loss(crit_fake_pred):gen_loss = -1. * torch.mean(crit_fake_pred)return gen_lossdef get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda):crit_loss = torch.mean(crit_fake_pred) - torch.mean(crit_real_pred) + c_lambda * gpreturn crit_losscur_step = 0
generator_losses = []
critic_losses = []
for epoch in range(n_epochs):# Dataloader returns the batchesfor real, _ in tqdm(dataloader):cur_batch_size = len(real)real = real.to(device)mean_iteration_critic_loss = 0for _ in range(crit_repeats):### Update critic ###crit_opt.zero_grad()fake_noise = get_noise(cur_batch_size, z_dim, device=device)fake = gen(fake_noise)crit_fake_pred = crit(fake.detach())crit_real_pred = crit(real)epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)gradient = get_gradient(crit, real, fake.detach(), epsilon)gp = gradient_penalty(gradient)crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)# Keep track of the average critic loss in this batchmean_iteration_critic_loss += crit_loss.item() / crit_repeats# Update gradientscrit_loss.backward(retain_graph=True)# Update optimizercrit_opt.step()critic_losses += [mean_iteration_critic_loss]### Update generator ###gen_opt.zero_grad()fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)fake_2 = gen(fake_noise_2)crit_fake_pred = crit(fake_2)gen_loss = get_gen_loss(crit_fake_pred)gen_loss.backward()# Update the weightsgen_opt.step()# Keep track of the average generator lossgenerator_losses += [gen_loss.item()]### Visualization code ###if cur_step % display_step == 0 and cur_step > 0:gen_mean = sum(generator_losses[-display_step:]) / display_stepcrit_mean = sum(critic_losses[-display_step:]) / display_stepprint(f"Step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")show_tensor_images(fake)show_tensor_images(real)step_bins = 20num_examples = (len(generator_losses) // step_bins) * step_binsplt.plot(range(num_examples // step_bins),torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),label="Generator Loss")plt.plot(range(num_examples // step_bins),torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),label="Critic Loss")plt.legend()plt.show()cur_step += 1

在这里插入图片描述

代码讲解

网络模型与上一篇的DCGAN没有变动。
在这里插入图片描述
这个模块进行梯度计算,即上文原理中正则项公式里面的梯度L2范数里的梯度。首先计算真实数据与生成数据之间随机取样的混合数据,然后输入 critic,最后计算出其梯度。
在这里插入图片描述
梯度惩罚模块,即上文原理中的整个正则项公式,梯度范数 -1 的平方。
在这里插入图片描述
critic 的 loss function 公式如下,generator 因为和真实数据无关,且与正则项也无关,所以只有中间一项。
在这里插入图片描述————————————————————————————————————————————

总之,WGAN-GP 不一定要提高 GAN 的整体性能,但会很好的提高稳定性并避免模式崩溃。

下一篇条件生成GAN。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/781053.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【unity】认识unity Hub的主要功能

这里我们主要讲解unity Hub中的【项目】和【安装】功能,其他对应的功能栏相信大家根据文字就可以知道相应的作用。 首先是介绍【项目】功能,在这里我们可以创建本地项目和云端项目,作为初学者我们创建本地项目皆可,当然如果你是多…

UE4_碰撞_使用蓝图控制物体移动时如何让被阻挡

当我们这样设置蓝图时: 运行效果: 利用蓝图更改一个物体的位置,发现本来两个应该相互阻挡的物体被穿过去了。为了不让相互阻挡的物体被穿过去,我们需要设置好蓝图节点的参数Sweep。 勾选之后 墙的蓝图我们这样设置: 运…

【软件工程】需求分析

1. 导言 1.1. 需求文档的目的 该文档是关于用户对于“学生成绩管理系统”的功能和性能的要求,重点描述了“学生成绩管理系统”的设计需求,将作为对该工具在概要设计阶段的设计输入。编写本文档的目的在于说明软件工程管理系统的业务需求内容&#xff0…

30-3 越权漏洞 - 水平越权(横向越权)

环境准备:构建完善的安全渗透测试环境:推荐工具、资源和下载链接_渗透测试靶机下载-CSDN博客 一、定义 攻击者可以访问和操作与其拥有同级权限的用户资源。 示例: 学生A在教务系统上正常只能修改自己的作业内容,但由于不合理的权限校验规则等原因,学生A可以修改学生B的内…

记录C++中,vector的迭代器在push_back以后扩容导致迭代器失效的问题

前言 vector是我们用到最多的数据结构,其底层数据结构是单端动态数组,由于数组的特点,vector也具有以下特性: ①O(1)时间的快速访问; ②顺序存储,所以插入到非尾结点位置所需时间复杂度为O(n),删…

uniapp开发微信小程序设置分包,简单易学

文章目录 前言一、在 manifest.json文件中的源码试图中配置二、配置pages.json 前言 我们使用uniapp开发微信小程序的时候,当我们的包体积过大的时候,无法真机模拟。 因为小程序单个包只支持2MB(现已支持预览4MB),所以…

AI:155-基于深度学习的股票价格预测模型

本文收录于专栏:精通AI实战千例专栏合集 从基础到实践,深入学习。无论你是初学者还是经验丰富的老手,对于本专栏案例和项目实践都有参考学习意义。 每一个案例都附带关键代码,详细讲解供大家学习,希望可以帮到大家。正在不断更新中~ 一.基于深度学习的股票价格预测模型 …

基于k8s的web服务器构建

文章目录 k8s综合项目1、项目规划图2、项目描述3、项目环境4、前期准备4.1、环境准备4.2、ip划分4.3、静态配置ip地址4.4、修改主机名4.5、部署k8s集群4.5.1、关闭防火墙和selinux4.5.2、升级系统4.5.3、每台主机都配置hosts文件,相互之间通过主机名互相访问4.5.4、…

总结IP协议各类知识点

前言 本篇博客博主将详解IP协议中的各类知识点,坐好板凳发车啦~ 一.IP协议格式 1.1 4位版本号(version) 指定IP协议的版本,对于IPv4来说,就是4。 1.2 4位头部长度(header length) IP头部的…

HarmonyOS像素转换-如何使用像素单位设置组件的尺寸。

1 卡片介绍 基于像素单位,展示了像素单位的基本知识与像素转换API的使用。 2 标题 像素转换(ArkTS) 3 介绍 本篇Codelab介绍像素单位的基本知识与像素单位转换API的使用。通过像素转换案例,向开发者讲解了如何使用像素单位设…

大数据-Hadoop---基础配置案例

VMware17创建新虚拟机: 1.静态设置与关闭防火墙 在终端命令行依次输入: 1)cd /etc 2) ls 3) cd sysconfig/ 4) cd network-scripts/ 5) ls 6) vi ifcfg-nes33 在cmd命令栏输入:ncpa.cpl,是找网络适配器的命令 IPADDR&qu…

elementui el-input输入框类型为textarea时,将输入的数据保存换行和空格,并展示换行和空格

el-input输入框类型为textarea时,如果不做数据处理,是不会保存换行和空格的说输入了换行,但是保存数据后不会进行换行,需要保存输入的换行。 1、效果图 输入状态: 显示时: 2、实现代码 2.1、html部分&am…

在jupyter notebook中使用conda环境

在jupyter notebook中使用conda环境 1. 环境配置 conda activate my-conda-env # this is the environment for your project and code conda install ipykernel conda deactivateconda activate base # could be also some other environment conda install nb_cond…

在新能源充电桩、智能充电枪、储能等产品领域得到广泛应用的两款微功耗轨至轨运算放大器芯片——D8541和D8542

D8541和D8542是我们推荐的两款微功耗轨至轨运算放大器芯片,其中D8541为单运放, D8542为双运放,它特别适用于NTC温度采集电路、ADC基准电压电路、有源滤波器、电压跟随器、信号放大器等电路应用,在新能源充电桩、智能充电枪、…

网络编程--高并发服务器(二)

这里写目录标题 线程池高并发服务器UDP服务器TCP与UDP机制的对比TCP与UDP优缺点比较UDP的C/S模型实现思路模型分析实现思路(对照TCP的C/S模型) 二级目录 一级目录二级目录二级目录二级目录 一级目录二级目录二级目录二级目录 一级目录二级目录二级目录二…

【跟着CHATGPT学习硬件外设 | 05】I2C

本文根据博主设计的Prompt由CHATGPT生成,形成极简外设概念。 🚀 1. 概念揭秘 I2C(Inter-Integrated Circuit),也被称为IIC或双线接口,是一种用于微控制器(Microcontrollers)和外设…

用navicat进行mysql表结构同步

用navicat进行mysql表结构同步 前言新增一个列然后进行表结构同步删除一个列然后进行表结构同步把Int列转成TinyInt列,看数字溢出的情况下能不能表结构同步总结 前言 从同事那边了解到还能用navicat进行表结构同步,他会在发布更新的时候,直接…

Python基础之Class类的定义、继承、多态

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 一、class类1.类属性操作(增删改)2.类方法操作 二、类的继承1、语法2、方法重写 二、类的多态 一、class类 、三部分组成 1、类名&#xff…

MySQl on和where条件的区别?

MySQ L on和where条件的区别? on会生成临时表,不满足条件会置空 where 过滤数据,不满足的数据不会显示

木地板 VS 瓷砖,不同风格应该怎么选?福州中宅装饰,福州装修

不同装修风格应该怎么选择地板铺贴材质?是选择木地板还是瓷砖?以下分点阐述: ①现代简约风格 推荐使用瓷砖。因为瓷砖的表面光滑,能反射出灯光的倒影,营造出简洁明亮的视觉效果。同时,瓷砖耐磨、易清洁&am…