使用pytorch构建带梯度惩罚的Wasserstein GAN(WGAN-GP)网络模型

本文为此系列的第三篇WGAN-GP,上一篇为DCGAN。文中仍然不会过多详细的讲解之前写过的,只会写WGAN-GP相对于之前版本的改进点,若有不懂的可以重点看第一篇比较详细。

原理

具有梯度惩罚的 Wasserstein GAN (WGAN-GP)可以解决 GAN 的一些稳定性问题。 具体来说,使用W-loss 作为损失函数替代传统的 BCE 等 loss,并使用梯度惩罚来防止 mode collapse。

  • WGAN-GP 使用了 Wasserstein distance(也成为Earth Mover’s distance, EMD)作为训练 GAN 模型的目标函数,Wasserstein distance is a function of amount and distance,体现的是生成的数据的分布移动到真实数据的分布之间所需的距离与量。
    在这里插入图片描述
    随着判别器训练的越来越好,使用 BCE loss 的话会让鉴别器给出接近于 0 或者接近于 1 的极端值,如下为 sigmoid 曲线,极端值的梯度无限接近于 0,这样判别器就没有太多有用的信息反馈给生成器让它学习,导致梯度消失或 model collapse。使用距离的方式可以有效解决,分布距离再远都不再限制。
    在这里插入图片描述
    在这里插入图片描述
  • BCE loss 本质是一个 minimax game, d 即 discriminator 希望尽可能的 minimize,g 即 generator 希望尽可能的 maximize(意味着造出来的假东西对于鉴别器来说看起来很真实),可以进行如下的简化:
    在这里插入图片描述
    基于 Wasserstein distance 的 W-loss 的的式子与其简化版进行对比:
    在这里插入图片描述
    在 Wasserstein GAN 中不再是 discriminator 了,因为输出不再是 0-1 之间来进行分类,既然不分类了就不是 discriminator 了,而是 critic,所以这里使用 c 代表 critic。critic 希望其尽可能的 maximize,因为希望让 real 和 feak 的距离尽可能的大,起到划清界限的目的;generator 希望其尽可能的minimize,减小两者之间的距离,达到以假乱真的目的。
  • mode collapse 即模式崩溃,当生成器学会从单个类生成特征来欺骗鉴别器时,就会发生 mode collapse(陷入一种模式出不来),跟 cnn 的局部最优是一个概念。这会导致输出出现重复,缺乏多样性和细节。

但在使用 W-loss 训练 GAN 时需要对 critic 有一定的条件 —— critic 需要 1-L(1-Lipschitz)连续:
∣ f ( x 1 ) − f ( x 2 ) ∣ ≤ k ∣ x 1 − x 2 ∣ |f(x_1)-f(x_2)|\le k|x_1-x_2\ | f(x1)f(x2)kx1x2 
这里的 k = 1,也就是 critic 的 nn 函数曲线的梯度(斜率)始终在 -1 到 1 之间,即梯度的 L2 范数不超过1:

在这里插入图片描述
如图:
在这里插入图片描述
曲线的每个点的斜率都是在绿色区域内,很显然这个曲线并不符合。像如下这个曲线就是符合的:
在这里插入图片描述
达到 1-L 连续有两种方法:weigh clipping、gradient penalty。

  • weigh clipping 将权重裁剪到固定范围内,从而限制 critic 的学习能力。但是这样有缺点,可能让所有参数走极端,要么取最大值要么取最小值,critic 会非常倾向于学习一个简单的映射函数。
  • gradient penalty 则是添加一个正则项在 loss function 中,相比 weigh clipping 更加柔和对critic参数的限制更加灵活,通常不会导致梯度消失或梯度爆炸问题。
    在这里插入图片描述
    这里的 λ \lambda λ 为超参值,reg 等于 critic 神经网络梯度范数 -1 的平方,即:
    在这里插入图片描述
    当 critic 神经网络梯度范数 >1 时正则化项发挥作用。平方的作用是为了让其偏离越大,惩罚越大。
    这里的 x ^ \hat{x} x^ 为真实数据与生成数据之间随机取样得到的中间数据,随机值 ϵ \epsilon ϵ 作为权重值,假设 ϵ \epsilon ϵ 为0.3,那么 1- ϵ \epsilon ϵ 为0.7。
    在这里插入图片描述

代码

model.py

from torch import nnclass Generator(nn.Module):def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):super(Generator, self).__init__()self.z_dim = z_dim# Build the neural networkself.gen = nn.Sequential(self.make_gen_block(z_dim, hidden_dim * 4),self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),self.make_gen_block(hidden_dim * 2, hidden_dim),self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),)def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):if not final_layer:return nn.Sequential(nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),nn.BatchNorm2d(output_channels),nn.ReLU(inplace=True),)else:return nn.Sequential(nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),nn.Tanh(),)def forward(self, noise):x = noise.view(len(noise), self.z_dim, 1, 1)return self.gen(x)class Critic(nn.Module):def __init__(self, im_chan=1, hidden_dim=64):super(Critic, self).__init__()self.crit = nn.Sequential(self.make_crit_block(im_chan, hidden_dim),self.make_crit_block(hidden_dim, hidden_dim * 2),self.make_crit_block(hidden_dim * 2, 1, final_layer=True),)def make_crit_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):if not final_layer:return nn.Sequential(nn.Conv2d(input_channels, output_channels, kernel_size, stride),nn.BatchNorm2d(output_channels),nn.LeakyReLU(0.2, inplace=True),)else:return nn.Sequential(nn.Conv2d(input_channels, output_channels, kernel_size, stride),)def forward(self, image):crit_pred = self.crit(image)return crit_pred.view(len(crit_pred), -1)

train.py

import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from model import *
torch.manual_seed(0) # Set for testing purposes, please do not change!def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):image_tensor = (image_tensor + 1) / 2image_unflat = image_tensor.detach().cpu()image_grid = make_grid(image_unflat[:num_images], nrow=5)plt.imshow(image_grid.permute(1, 2, 0).squeeze())plt.show()def get_noise(n_samples, z_dim, device='cpu'):return torch.randn(n_samples, z_dim, device=device)n_epochs = 100
z_dim = 64
display_step = 50
batch_size = 128
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
c_lambda = 10
crit_repeats = 5
device = 'cuda'transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),
])dataloader = DataLoader(MNIST('.', download=False, transform=transform),batch_size=batch_size,shuffle=True)gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
crit = Critic().to(device)
crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(beta_1, beta_2))def weights_init(m):if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):torch.nn.init.normal_(m.weight, 0.0, 0.02)if isinstance(m, nn.BatchNorm2d):torch.nn.init.normal_(m.weight, 0.0, 0.02)torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
crit = crit.apply(weights_init)def get_gradient(crit, real, fake, epsilon):# Mix the images togethermixed_images = real * epsilon + fake * (1 - epsilon)# Calculate the critic's scores on the mixed imagesmixed_scores = crit(mixed_images)# Take the gradient of the scores with respect to the imagesgradient = torch.autograd.grad(inputs=mixed_images,outputs=mixed_scores,# These other parameters have to do with the pytorch autograd engine worksgrad_outputs=torch.ones_like(mixed_scores),create_graph=True,retain_graph=True,)[0]return gradientdef gradient_penalty(gradient):# Flatten the gradients so that each row captures one imagegradient = gradient.view(len(gradient), -1)# Calculate the magnitude of every rowgradient_norm = gradient.norm(2, dim=1)# Penalize the mean squared distance of the gradient norms from 1penalty = torch.mean((gradient_norm - 1) ** 2)return penaltydef get_gen_loss(crit_fake_pred):gen_loss = -1. * torch.mean(crit_fake_pred)return gen_lossdef get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda):crit_loss = torch.mean(crit_fake_pred) - torch.mean(crit_real_pred) + c_lambda * gpreturn crit_losscur_step = 0
generator_losses = []
critic_losses = []
for epoch in range(n_epochs):# Dataloader returns the batchesfor real, _ in tqdm(dataloader):cur_batch_size = len(real)real = real.to(device)mean_iteration_critic_loss = 0for _ in range(crit_repeats):### Update critic ###crit_opt.zero_grad()fake_noise = get_noise(cur_batch_size, z_dim, device=device)fake = gen(fake_noise)crit_fake_pred = crit(fake.detach())crit_real_pred = crit(real)epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)gradient = get_gradient(crit, real, fake.detach(), epsilon)gp = gradient_penalty(gradient)crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)# Keep track of the average critic loss in this batchmean_iteration_critic_loss += crit_loss.item() / crit_repeats# Update gradientscrit_loss.backward(retain_graph=True)# Update optimizercrit_opt.step()critic_losses += [mean_iteration_critic_loss]### Update generator ###gen_opt.zero_grad()fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)fake_2 = gen(fake_noise_2)crit_fake_pred = crit(fake_2)gen_loss = get_gen_loss(crit_fake_pred)gen_loss.backward()# Update the weightsgen_opt.step()# Keep track of the average generator lossgenerator_losses += [gen_loss.item()]### Visualization code ###if cur_step % display_step == 0 and cur_step > 0:gen_mean = sum(generator_losses[-display_step:]) / display_stepcrit_mean = sum(critic_losses[-display_step:]) / display_stepprint(f"Step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")show_tensor_images(fake)show_tensor_images(real)step_bins = 20num_examples = (len(generator_losses) // step_bins) * step_binsplt.plot(range(num_examples // step_bins),torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),label="Generator Loss")plt.plot(range(num_examples // step_bins),torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),label="Critic Loss")plt.legend()plt.show()cur_step += 1

在这里插入图片描述

代码讲解

网络模型与上一篇的DCGAN没有变动。
在这里插入图片描述
这个模块进行梯度计算,即上文原理中正则项公式里面的梯度L2范数里的梯度。首先计算真实数据与生成数据之间随机取样的混合数据,然后输入 critic,最后计算出其梯度。
在这里插入图片描述
梯度惩罚模块,即上文原理中的整个正则项公式,梯度范数 -1 的平方。
在这里插入图片描述
critic 的 loss function 公式如下,generator 因为和真实数据无关,且与正则项也无关,所以只有中间一项。
在这里插入图片描述————————————————————————————————————————————

总之,WGAN-GP 不一定要提高 GAN 的整体性能,但会很好的提高稳定性并避免模式崩溃。

下一篇条件生成GAN。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/781053.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

十四、Spring源码学习之createBean方法

AbstractAutowireCapableBeanFactory#createBean()方法 protected Object createBean(String beanName, RootBeanDefinition mbd, Nullable Object[] args)throws BeanCreationException {if (logger.isTraceEnabled()) {logger.trace("Creating instance of bean "…

【unity】认识unity Hub的主要功能

这里我们主要讲解unity Hub中的【项目】和【安装】功能,其他对应的功能栏相信大家根据文字就可以知道相应的作用。 首先是介绍【项目】功能,在这里我们可以创建本地项目和云端项目,作为初学者我们创建本地项目皆可,当然如果你是多…

UE4_碰撞_使用蓝图控制物体移动时如何让被阻挡

当我们这样设置蓝图时: 运行效果: 利用蓝图更改一个物体的位置,发现本来两个应该相互阻挡的物体被穿过去了。为了不让相互阻挡的物体被穿过去,我们需要设置好蓝图节点的参数Sweep。 勾选之后 墙的蓝图我们这样设置: 运…

C#-ConcurrentDictionary用于多线程并发字典

ConcurrentDictionary 是 .NET Framework 中用于多线程并发操作的一种线程安全的字典集合类。它提供了一种在多个线程同时访问和修改字典时保持数据一致性的机制。 以下是 ConcurrentDictionary 类的一些重要特性和用法: 线程安全性:ConcurrentDictiona…

【软件工程】需求分析

1. 导言 1.1. 需求文档的目的 该文档是关于用户对于“学生成绩管理系统”的功能和性能的要求,重点描述了“学生成绩管理系统”的设计需求,将作为对该工具在概要设计阶段的设计输入。编写本文档的目的在于说明软件工程管理系统的业务需求内容&#xff0…

30-3 越权漏洞 - 水平越权(横向越权)

环境准备:构建完善的安全渗透测试环境:推荐工具、资源和下载链接_渗透测试靶机下载-CSDN博客 一、定义 攻击者可以访问和操作与其拥有同级权限的用户资源。 示例: 学生A在教务系统上正常只能修改自己的作业内容,但由于不合理的权限校验规则等原因,学生A可以修改学生B的内…

点云从入门到精通技术详解100篇-基于3D点云的盘类元件识别与定位

目录 前言 2 3D视觉机器人抓取系统方案设计 2.1 系统硬件方案选型 2.1.1 相机选取方案

记录C++中,vector的迭代器在push_back以后扩容导致迭代器失效的问题

前言 vector是我们用到最多的数据结构,其底层数据结构是单端动态数组,由于数组的特点,vector也具有以下特性: ①O(1)时间的快速访问; ②顺序存储,所以插入到非尾结点位置所需时间复杂度为O(n),删…

动态规划--(递推2(最长上升子序列,格子染色,斐波那切数列,奇数塔问题,最长子段和))

1281&#xff1a;最长上升子序列 【题目描述】 一个数的序列bi &#xff0c;当b1<b2<…<bS 的时候&#xff0c;我们称这个序列是上升的。对于给定的一个序列(a1,a2,…,aN) &#xff0c;我们可以得到一些上升的子序列(ai1,ai2,…,aiK) &#xff0c;这里1≤i1<i2<…

uniapp开发微信小程序设置分包,简单易学

文章目录 前言一、在 manifest.json文件中的源码试图中配置二、配置pages.json 前言 我们使用uniapp开发微信小程序的时候&#xff0c;当我们的包体积过大的时候&#xff0c;无法真机模拟。 因为小程序单个包只支持2MB&#xff08;现已支持预览4MB&#xff09;&#xff0c;所以…

Docker 学习总结(81)—— 冷门而又实用的 Docker 使用技巧总结

1、docker top 这个命令是用来查看一个容器里面的进程信息的,比如你想查看一个 nginx 容器里面有几个 nginx 进程的时候,就可以这么做。 ➜ ~ docker top 3b307a09d20d UID PID PPID C STIME …

JAVA面试大全之开发基础篇

目录 1、常用类库 1.1、平时常用的开发工具库有哪些? 1.2、Java常用的JSON库有哪些?有啥注意点? 1.3、Lombok工具库用来解决什么问题?

AI:155-基于深度学习的股票价格预测模型

本文收录于专栏:精通AI实战千例专栏合集 从基础到实践,深入学习。无论你是初学者还是经验丰富的老手,对于本专栏案例和项目实践都有参考学习意义。 每一个案例都附带关键代码,详细讲解供大家学习,希望可以帮到大家。正在不断更新中~ 一.基于深度学习的股票价格预测模型 …

基于k8s的web服务器构建

文章目录 k8s综合项目1、项目规划图2、项目描述3、项目环境4、前期准备4.1、环境准备4.2、ip划分4.3、静态配置ip地址4.4、修改主机名4.5、部署k8s集群4.5.1、关闭防火墙和selinux4.5.2、升级系统4.5.3、每台主机都配置hosts文件&#xff0c;相互之间通过主机名互相访问4.5.4、…

深入解析大数据Scala面试题及参考答案(持续更新)

Scala,作为一种多范式编程语言,因其强大的功能性和与Java的互操作性,在大数据和并发编程领域备受青睐。本文将深入探讨10个常见的Scala面试题,并提供详尽的参考答案,以期帮助读者在面试中展现其Scala编程的深厚功底。 目录 1. Scala的基本特性是什么? 2. 什么是函数式…

总结IP协议各类知识点

前言 本篇博客博主将详解IP协议中的各类知识点&#xff0c;坐好板凳发车啦~ 一.IP协议格式 1.1 4位版本号&#xff08;version&#xff09; 指定IP协议的版本&#xff0c;对于IPv4来说&#xff0c;就是4。 1.2 4位头部长度&#xff08;header length&#xff09; IP头部的…

HarmonyOS像素转换-如何使用像素单位设置组件的尺寸。

1 卡片介绍 基于像素单位&#xff0c;展示了像素单位的基本知识与像素转换API的使用。 2 标题 像素转换&#xff08;ArkTS&#xff09; 3 介绍 本篇Codelab介绍像素单位的基本知识与像素单位转换API的使用。通过像素转换案例&#xff0c;向开发者讲解了如何使用像素单位设…

【LeetCode热题100】【多维动态规划】最长公共子序列

我昨天面了天美L1的游戏客户端开发&#xff0c;面了我100分钟&#xff0c;问完实习、项目、计算机图形学和C后给了我两道算法题做&#xff0c;一道是最长公共子序列&#xff0c;一道是LRU缓存&#xff0c;我知道是经典的题目&#xff0c;但是我都没敲过&#xff0c;最长公共子序…

大数据-Hadoop---基础配置案例

VMware17创建新虚拟机&#xff1a; 1.静态设置与关闭防火墙 在终端命令行依次输入&#xff1a; 1&#xff09;cd /etc 2) ls 3) cd sysconfig/ 4) cd network-scripts/ 5) ls 6) vi ifcfg-nes33 在cmd命令栏输入&#xff1a;ncpa.cpl,是找网络适配器的命令 IPADDR&qu…

elementui el-input输入框类型为textarea时,将输入的数据保存换行和空格,并展示换行和空格

el-input输入框类型为textarea时&#xff0c;如果不做数据处理&#xff0c;是不会保存换行和空格的说输入了换行&#xff0c;但是保存数据后不会进行换行&#xff0c;需要保存输入的换行。 1、效果图 输入状态&#xff1a; 显示时&#xff1a; 2、实现代码 2.1、html部分&am…