AIGC实战——WGAN(Wasserstein GAN)

AIGC实战——WGAN

    • 0. 前言
    • 1. WGAN-GP
      • 1.1 Wasserstein 损失
      • 1.2 Lipschitz 约束
      • 1.3 强制 Lipschitz 约束
      • 1.4 梯度惩罚损失
      • 1.5 训练 WGAN-GP
    • 2. GAN 与 WGAN-GP 的关键区别
    • 3. WGAN-GP 模型分析
    • 小结
    • 系列链接

0. 前言

原始的生成对抗网络 (Generative Adversarial Network, GAN) 在训练过程中面临着模式坍塌和梯度消失等问题,为了解决这些问题,研究人员提出了大量的关键技术以提高GAN模型的整体稳定性,并降低了上述问题出现的可能性。例如 WGAN (Wasserstein GAN) 和 WGAN-GP (Wasserstein GAN-Gradient Penalty) 等,通过对原始生成对抗网络 (Generative Adversarial Network, GAN) 框架进行了细微调整,就能够训练复杂GAN。在本节中,我们将学习 WGANWGAN-GP,两者都对原始 GAN 框架进行了细微调整,以改善图像生成过程的稳定性和质量。

1. WGAN-GP

WGAN (Wasserstein GAN) 是提高 GAN 训练稳定性方面的一次巨大进步,在经过一些简单改动后 GAN 就能够实现以下两个特点:

  • 与生成器的收敛度和生成样本质量相关的损失度量
  • 优化过程的稳定性得到提高

具体来说,WGAN 针对判别器和生成器提出了一种新的损失函数 (Wasserstein Loss),用这种损失函数代替二元交叉熵就可以让 GAN 的收敛更加稳定。
在本节中,我们将构建一个 WGAN-GP (Wasserstein GAN-Gradient Penalty),利用 CelebA 数据集训练模型以生成人脸图像。

1.1 Wasserstein 损失

首先我们来回顾一下二元交叉嫡, 在训练 DCGAN 判别器和生成器时采用了这种损失函数:
− 1 n ∑ i = 1 n ( y i l o g ( p i ) + ( 1 − y i ) l o g ( 1 − p i ) ) -\frac 1 n \sum_{i=1}^n(y_ilog(p_i)+(1-y_i)log(1-p_i)) n1i=1n(yilog(pi)+(1yi)log(1pi))
为了训练 GAN 的判别器 D,我们根据以下两者计算损失:真实图像的预测 p i = D ( x i ) p_i=D(x_i) pi=D(xi) 与标签 y i = 1 y_i=1 yi=1 之间的误差,以及生成图像的预测 p i = D ( G ( z i ) ) p_i=D(G(z_i)) pi=D(G(zi))与标签 y i = 0 y_i=0 yi=0 之间的误差。因此,对于 GAN 的判别器来说,损失函数最小化的过程可以表示为:
min ⁡ D − ( E x ∼ p X [ log ⁡ D ( x ) ] + E z ∼ p Z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] ) \mathop {\min} \limits_{D}-(\mathbb E_{x\sim p_X}[\log D(x)]+\mathbb E_{z\sim p_Z}[\log (1-D(G(z)))]) Dmin(ExpX[logD(x)]+EzpZ[log(1D(G(z)))])
为了训练 GAN 的生成器 G,我们根据生成图像的预测 p i = D ( G ( z i ) ) p_i=D(G(z_i)) pi=D(G(zi)) 与标签 y i = 1 y_i=1 yi=1 的误差计算损失。因此,对于 GAN 的生成器来说,将损失函数最小化的过程可以表示为:
min ⁡ G − ( E z ∼ p Z [ log ⁡ D ( G ( z ) ) ] ) \mathop {\min}\limits_{G}-(\mathbb E_{z\sim p_Z}[\log D(G(z))]) Gmin(EzpZ[logD(G(z))])
接下来,我们比较上述损失函数与 Wasserstein 损失函数。
Wasserstein 损失 (Wasserstein Loss) 是用于 Wasserstein GAN (WGAN) 的一种损失函数。与传统的二元交叉熵损失函数不同,Wasserstein 损失引入了标签 1-1,将判别器的输出从概率值转变为分数 (score),因此,WGAN 的判别器通常也被称为评论家 (critic),并要求判别器是 1-Lipschitz 连续函数。
具体来说,Wasserstein 损失使用标签 y i = 1 y_i=1 yi=1 y i = − 1 y_i=-1 yi=1 代替 y i = 1 y_i=1 yi=1 y i = 0 y_i=0 yi=0,同时还需要移除判别器最后一层的 Sigmoid激活函数,如此一来预测结果 p i p_i pi 就不一定在 [ 0 , 1 ] [0,1] [0,1] 范围内了,它可以是 [ − ∞ , ∞ ] [-∞,∞] [,] 范围内的任何值。Wasserstein 损失的定义如下:
− 1 n ∑ i = 1 n ( y i p i ) -\frac 1 n∑_{i=1}^n(y_ip_i) n1i=1n(yipi)
在训练 WGAN 的判别器 D 时,我们将计算以下损失:判别器对真实图像的预测 p i = D ( x i ) p_i=D(x_i) pi=D(xi) 与标签 y i = 1 y_i=1 yi=1 之间的误差,判别器对生成图像的预测 p i = D ( G ( z i ) ) p_i=D(G(z_i)) pi=D(G(zi)) 与标签 y i = − 1 y_i=-1 yi=1 之间的误差。因此,对于 WGAN 判别器,最小化损失函数的过程可以表示为:
min ⁡ D − ( E x ∼ p X [ D ( x ) ] − E z ∼ p Z [ D ( G ( z ) ) ] ) \mathop {\min}\limits_ D - (\mathbb E_{x\sim p_X}[D(x)] - \mathbb E_{z\sim p_Z}[D(G(z))]) Dmin(ExpX[D(x)]EzpZ[D(G(z))])
换句话说,WGAN 判别器试图最大化其对真实图像的预测和生成图像的预测之间的差异,且真实图像的得分更高。
而对于 WGAN 生成器 G 的训练,我们根据判别器对生成图像的预测 p i = D ( G ( z i ) ) p_i=D(G(z_i)) pi=D(G(zi)) 与标签 y i = 1 y_i=1 yi=1 计算损失。因此,对于 WGAN 生成器,最小化损失函数可以表示为:
min ⁡ G − ( E z ∼ p Z [ D ( G ( z ) ) ] ) \mathop {\min}\limits_ G - (\mathbb E_{z\sim p_Z}[D(G(z))]) Gmin(EzpZ[D(G(z))])
换句话说,WGAN 生成器试图生成被判别器以极高分数判定为真实图像的图像(即,令判别器认为它们是真实的)。

1.2 Lipschitz 约束

由于我们允许判别器输出 [ − ∞ , ∞ ] [-∞,∞] [,] 范围内的任意值,而不是按照 Sigmoid 函数那样将输出限制在 [ 0 , 1 ] [0,1] [0,1] 范围内,因此 Wasserstein 损失可能会非常大。因此,为了使 Wasserstein 损失函数正常工作,需要对判别器进行额外约束,即 1-Lipschitz 连续性约束。判别器是一个将图像转换为预测的函数 D,如果对于任意两个输人图像 x 1 x_1 x1 x 2 x_2 x2,判别器函数 D 满足以下不等式,则该函数为 1-Lipschitz 连续:
∣ D ( x 1 ) − D ( x 2 ) ∣ ∣ x 1 − x 2 ∣ ≤ 1 \frac {|D(x_1) - D(x_2)|}{|x_1 - x_2|} ≤ 1 x1x2D(x1)D(x2)1
其中, ∣ x 1 − x 2 ∣ |x_1 - x_2| x1x2 表示两个图像的平均像素之差的绝对值, ∣ D ( x 1 ) − D ( x 2 ) ∣ |D(x_1) - D(x_2)| D(x1)D(x2) 表示判别器预测之间的绝对值。这意味着判别器的预测变化速率在任何情况下都是有界的(即梯度的绝对值不能大于 1)。可以在下图中的 Lipschitz 连续的一维函数中看到,无论将圆锥放在任何位置,曲线都不会进入圆锥内部。换句话说,曲线上任何一点的上升或下降速度都是有限的。

Lipschitz 连续

1.3 强制 Lipschitz 约束

在原始的 WGAN 论文中,作者通过在每个训练结束后将判别器的权重裁剪到一个较小范围内 [ − 0.01 , 0.01 ] [-0.01, 0.01] [0.01,0.01] 来强制执行 Lipschitz 约束。
由于我们裁剪了判别器的权重,判别器的学习能力大大降低,因此,事实上,权重裁剪并不是一种理想的强制 Lipschitz 约束的方式。一个强大的判别器对于 WGAN 的成功至关重要,因为如果没有准确的梯度,生成器无法学习如何调整其权重以产生更好的样本。
因此,研究人员提出了许多其他方法来强制执行 Lipschitz 约束,并提高 WGAN 学习复杂特征的能力。其中一种方法是带有梯度惩罚 (Gradient Penalty) 的 Wasserstein GAN
通过在判别器的损失函数中包含一个梯度惩罚项来直接强制执行 Lipschitz 约束,如果梯度范数偏离 1 时,该项会惩罚模型,从而使训练过程更加稳定。
接下来,将这个额外的梯度惩罚项加入到判别器损失函数中。

1.4 梯度惩罚损失

下图展示了 WGAN-GP 判别器的训练过程,与原始判别器的训练过程进行比较,我们可以看到关键的改进是将梯度惩罚损失作为整体损失函数的一部分,并与来自真实图像和生成图像的 Wasserstein 损失一起使用。

WGAN-GP

梯度惩罚损失衡量了预测关于输入图像的梯度范数与 1 之间的平方差。模型倾向于找到能够使梯度惩罚项最小化的权重,从而鼓励模型符合 Lipschitz 约束。
在训练过程中,每一处的计算梯度是非常困难的,因此WGAN-GP 只在少数几个点处评估梯度。为了确保平衡的,我们使用一组插值图像,在真实图像与伪造图像之间的随机位置逐像素进行插值 (Interpolation) 以生成一些图像。

插值图像

使用 Keras 计算梯度惩罚项:

    def gradient_penalty(self, batch_size, real_images, fake_images):# 批数据中的每个图像都会得到一个 0~1 之间的随机数字,存储到向量 alpha 中alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)# 计算一组插值图像diff = fake_images - real_imagesinterpolated = real_images + alpha * diffwith tf.GradientTape() as gp_tape:gp_tape.watch(interpolated)# 使用判别器对每个插值图像进行评分pred = self.critic(interpolated, training=True)# 计算插值图像 (y_pred) 的预测对于输入 interpolated_samples) 的梯度grads = gp_tape.gradient(pred, [interpolated])[0]# 计算这个向量的 L2 范数(即欧几里得长度)norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))# 函数返回 L2 范数与 1 之差的平方的均值gp = tf.reduce_mean((norm - 1.0) ** 2)return gp

1.5 训练 WGAN-GP

使用 Wasserstein 损失函数的一个优点是,不再需要担心平衡判别器和生成器的训练。事实上,在使用 Wasserstein 损失时,必须在更新生成器之前将判别器训练到收敛,以确保生成器更新的梯度准确无误。这与标准 GAN 相反,标准 GAN 中重要的是不要让判别器变得过强。
因此,使用 Wasserstein GAN,我们可以简单地在生成器更新之间多次训练判别器,以确保它接近收敛。通常每次生成器更新一次,判别器更新三到五次。
了解了 WGAN-GP 的两个关键概念 (Wasserstein 损失和梯度惩罚项)后,使用 Keras 实现 WGAN-GP

    def train_step(self, real_images):batch_size = tf.shape(real_images)[0]# 对判别器进行三次更新for i in range(self.critic_steps):random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))with tf.GradientTape() as tape:fake_images = self.generator(random_latent_vectors, training=True)fake_predictions = self.critic(fake_images, training=True)real_predictions = self.critic(real_images, training=True)# 计算判别器的 Wasserstein 损失c_wass_loss = tf.reduce_mean(fake_predictions) - tf.reduce_mean(real_predictions)# 计算梯度惩罚项c_gp = self.gradient_penalty(batch_size, real_images, fake_images)# 判别器损失函数是 Wasserstein 损失和梯度惩罚的加权和c_loss = c_wass_loss + c_gp * self.gp_weightc_gradient = tape.gradient(c_loss, self.critic.trainable_variables)# 更新判别器的权重self.c_optimizer.apply_gradients(zip(c_gradient, self.critic.trainable_variables))random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))with tf.GradientTape() as tape:fake_images = self.generator(random_latent_vectors, training=True)fake_predictions = self.critic(fake_images, training=True)# 计算生成器的 Wasserstein 损失g_loss = -tf.reduce_mean(fake_predictions)gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)# 更新生成器的权重self.g_optimizer.apply_gradients(zip(gen_gradient, self.generator.trainable_variables))self.c_loss_metric.update_state(c_loss)self.c_wass_loss_metric.update_state(c_wass_loss)self.c_gp_metric.update_state(c_gp)self.g_loss_metric.update_state(g_loss)return {m.name: m.result() for m in self.metrics}

在训练 WGAN-GP 之前,需要注意的最后一点是判别器不应该使用批量归一化。这是因为批归一化会在同一批图像之间创建相关性,从而使梯度惩罚损失的效果降低。实验证明,即使在判别器中没有批归一化, WGAN-GP 仍然可以输出出色的结果。

2. GAN 与 WGAN-GP 的关键区别

总而言之,标准 GANWGAN-GP 之间存在以下:

  • WGAN-GP 使用 Wasserstein 损失
  • WGAN-GP 使用 1 表示真实图像标签,使用 -1 表示伪造图像的标签
  • 判别器的最后一层没有使用 sigmoid 激活
  • 在判别器的损失函数中包含梯度惩罚项
  • 每训练一次生成器更新权重,需要多次训练判别器
  • 判别器中没有批归一化层

3. WGAN-GP 模型分析

训练 25epoch 后,WGAN-GP 模型的生成器能够生成合理图像:

面部生成结果

该模型已经学习到了面部的重要高级特征,且没有出现模式坍塌的迹象。
如果我们将 WGAN-GP 的输出与变分自编码器 (Variational Autoencoder, VAE) 的输出进行比较,可以看到 WGAN-GP 生成的图像通常更清晰。总的来说,VAE 倾向于产生颜色边界模糊的图像,而 GAN 产生的图像更加清晰合理。GAN 通常比 VAE 更难训练,需要更长的时间才能获得满意的数据质量。

小结

在本节中,我们学习了如何使用 Wasserstein 损失函数以解决经典 GAN 训练过程中的模式坍塌和梯度消失等问题,使得 GAN 的训练更加可预测和可靠。WGAN-GP 通过在损失函数中添加一个令梯度范数指向 1 的项,为训练过程施加 1-Lipschitz 约束。

系列链接

AIGC实战——生成模型简介
AIGC实战——深度学习 (Deep Learning, DL)
AIGC实战——卷积神经网络(Convolutional Neural Network, CNN)
AIGC实战——自编码器(Autoencoder)
AIGC实战——变分自编码器(Variational Autoencoder, VAE)
AIGC实战——使用变分自编码器生成面部图像
AIGC实战——生成对抗网络(Generative Adversarial Network, GAN)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/206249.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入探索C语言中的二叉树:数据结构之旅

引言 在计算机科学领域,数据结构是基础中的基础。在众多数据结构中,二叉树因其在各种操作中的高效性而脱颖而出。二叉树是一种特殊的树形结构,每个节点最多有两个子节点:左子节点和右子节点。这种结构使得搜索、插入、删除等操作…

【React Hooks】useReducer()

useReducer 的三个参数是可选的,默认就是initialState,如果在调用的时候传递第三个参数那么他就会改变为你传递的参数,实际开发不建议这样写。会增加代码的不可读性。 使用方法: 必须将 useReducer 的第一个参数(函数…

MySQL - 并发控制与事务的隔离级别

目录 第1关:并发控制与事务的隔离级别 第2关:读脏 第3关:不可重复读 第4关:幻读 第5关:主动加锁保证可重复读 第6关:可串行化 第1关:并发控制与事务的隔离级别 任务描述 本关任务&#…

linux初级学习

(420条消息) 红帽认证-RHCSA_rhcsa红帽认证_yyyzf的博客-CSDN博客 OS:用户和机器的接口,UI:CMD,GUI shell: 通用格式 命令 选项(调控功能) 参数(操作对象)参数 省略参数对象一般使用当前目录作为参数对…

vue3日常知识点学习归纳

1&#xff0c;父子组件传递&#xff1a; 父组件传递参数 <template><div><!-- 子组件 参数&#xff1a;num 、nums --><child :num"nums.num" :doubleNum"nums.doubleNum" increase"handleIncrease"></child>&l…

JAVA全栈开发 day19_JDBC

一、JDBC 1.JDBC概述 1.1什么是jdbc Java DataBase Connectivity是一种用于执行SQL语句的Java API&#xff0c;它由一组用Java语言编写的类和接口组成。通过这些类和接口&#xff0c;JDBC把SQL语句发送给不同类型的数据库进行处理并接收处理结果。 1.2jdbc的作用 提供java…

【目标检测从零开始】torch搭建yolov3模型

用torch从0简单实现一个的yolov3模型&#xff0c;主要分为Backbone、Neck、Head三部分 目录 Backbone&#xff1a;DarkNet53结构简介代码实现Step1&#xff1a;导入相关库Step2&#xff1a;搭建基本的Conv-BN-LeakyReLUStep3&#xff1a;组成残差连接块Step4&#xff1a;搭建Da…

思维模型 色彩心理效应

本系列文章 主要是 分享 思维模型&#xff0c;涉及各个领域&#xff0c;重在提升认知色彩影响情绪。 1 色彩心理效应的应用 1.1 色彩心理效应在营销中的应用 1 可口可乐公司的“红色”营销 可口可乐公司是全球最著名的饮料品牌之一&#xff0c;其标志性的红色包装已经成为了…

Constraining Async Clock Domain Crossing

Constraining Async Clock Domain Crossing 我们在normal STA中只会去check 同步clock之间的timing,但是design中往往会存在很多CDC paths,这些paths需要被正确约束才能保证design function正确,那么怎么去约束这些CDC paths呢? 以下面的design为例,如下图所示 这里clk…

小红书蒲公英平台开通后,有哪些注意的地方,以及如何进行报价?

今天来给大家聊聊当小红书账号过1000粉后&#xff0c;开通蒲公英需要注意的事项。 蒲公英平台是小红书APP中的一个专为内容创作者设计的平台。它为品牌和创作者提供了一个完整的服务流程&#xff0c;包括内容的创作、推广、互动以及转换等多个方面。 2.蒲公英平台的主要功能 &…

【C语言】vfprintf函数

vfprintf 是 C 语言中的一个函数&#xff0c;它是 fprintf 函数的变体&#xff0c;用于格式化输出到文件中。vfprintf 函数接受一个格式化字符串和一个指向可变参数列表的指针&#xff0c;这个列表通常是通过 va_list 类型来传递的。vfprintf 函数的主要用途是在需要处理不定数…

远传智能水表一般应用于哪些场景?

远传智能水表是一种在水表领域应用广泛的创新技术&#xff0c;它利用物联网和无线通信技术使水表具备了远程监测和数据传输的能力。这种智能水表的应用场景多种多样&#xff0c;可适用于各个领域和环境。那么&#xff0c;远传智能水表一般应用于哪些场景呢&#xff1f; 首先&am…

9.关于Java的程序设计-基于Springboot的家政平台管理系统设计与实现

摘要 随着社会的进步和生活水平的提高&#xff0c;家政服务作为一种重要的生活服务方式逐渐受到人们的关注。本研究基于Spring Boot框架&#xff0c;设计并实现了一种家政平台管理系统&#xff0c;旨在提供一个便捷高效的家政服务管理解决方案。系统涵盖了用户注册登录、家政服…

mybatis数据输出-map类型输出

1、建库建表 create table emp (empNo varchar(10) null,empName varchar(100) null,sal int null,deptno varchar(10) null ); 2、pom.xml <dependencies><dependency><groupId>org.mybatis</groupId><artifactId>mybatis<…

Elasticsearch 8.9 flush刷新缓存中的数据到磁盘源码

一、相关API的handler1、接收HTTP请求的hander2、每一个数据节点(node)执行分片刷新的action是TransportShardFlushAction 二、对indexShard执行刷新请求1、首先获取读锁&#xff0c;再获取刷新锁&#xff0c;如果获取不到根据参数决定是否直接返回还是等待2、在刷新之后transl…

Android Audio实战——音频链路分析(二十五)

在 Android 系统的开发过程当中,音频异常问题通常有如下几类:无声、调节不了声音、爆音、声音卡顿和声音效果异常(忽大忽小,低音缺失等)等。尤其声音效果这部分问题通常从日志上信息量较少,相对难定位根因。想要分析此类问题,便需要对声音传输链路有一定的了解,能够在链…

【论文解读】:大模型免微调的上下文对齐方法

本文通过对alignmenttuning的深入研究揭示了其“表面性质”&#xff0c;即通过监督微调和强化学习调整LLMs的方式可能仅仅影响模型的语言风格&#xff0c;而对模型解码性能的影响相对较小。具体来说&#xff0c;通过分析基础LLMs和alignment-tuned版本在令牌分布上的差异&#…

100多种视频转场素材|专业胶片,抖动,光效电影转场特效PR效果预设

100多种 Premiere Pro 效果预设&#xff0c;包含&#xff1a;“胶片框架”、“胶片烧录”、“彩色LUT”、“相机抖动”、“电影Vignette”和“胶片颗粒”。非常适合制作复古风格的视频&#xff0c;添加独特的色彩。包括视频教程。 来自PR模板网&#xff1a;https://prmuban.com…

git 本地有改动,远程也有改动,且文件是自动生成的配置文件

在改动过的地方 文件是.lock文件&#xff0c;自动生成的。想切到远程的分支&#xff0c;但是远程的分支也有改动过。这时候就要解决冲突&#xff0c;因为这是两个分支&#xff0c;代码都是不一样的&#xff0c;要先把这改动的代码提交在本地或者提交在本分支的远程才可以切到其…

ke13--10章-1数据库JDBC介绍

注册数据库(两种方式),获取连接,通过Connection对象获取Statement对象,使用Statement执行SQL语句。操作ResultSet结果集 ,回收数据库资源. 需要语句: 1Class.forName("DriverName");2Connection conn DriverManager.getConnection(String url, String user, String…