【RL】Wasserstein距离-GAN背后的直觉

一、说明

        在本文中,我们将阅读有关Wasserstein GANs的信息。具体来说,我们将关注以下内容:i)什么是瓦瑟斯坦距离?,ii)为什么要使用它?iii) 我们如何使用它来训练 GAN?

二、Wasserstein距离概念

        Wasserstein距离,又称为Earth Mover's Distance (EMD),是衡量两个概率分布之间的差异程度的一种数学方式。它考虑了分布之间的距离和它们之间的“传输成本”。

        简单来说,Wasserstein距离将两个分布看作“堆积在地图上的土堆”,并计算将一个堆移到另一个的最小成本。这个距离度量的优点是它能够处理非均匀分布,并且能够考虑分布的形状和结构。

        Wasserstein距离在机器学习领域中应用非常广泛,特别是在生成模型中用来评估生成器生成的图像与真实图像之间的差异。

图1:学习区分两个高斯时的最优判别器和批评者[1]。

2.1 瓦瑟施泰因距离

        Wasserstein 距离(地球移动器的距离)是给定度量空间上两个概率分布之间的距离度量。直观地说,它可以被视为将一个分布转换为另一个分布所需的最小功,其中功被定义为必须移动的分布的质量和要移动的距离的乘积。在数学上,它被定义为:

方程 1:瓦瑟斯坦 分布P_r和P_g之间的距离。

        在方程1中,Π(P_r,P_g)是x和y上所有联合分布的集合,使得边际分布等于P_r和P_g。 γ(x, y)可以看作是必须从x移动到y才能将P_r转换为P_g的质量量[1]。因此,瓦瑟斯坦距离是最佳运输计划的成本。

2.2 瓦瑟斯坦距离 vs. 詹森-香农分歧

        最初的GAN目标被证明是Jensen-Shannon分歧的最小化[2]。JS背离定义为:

方程 2:P_r 和 P_g 之间的 JS 背离 P_m = (P_r + P_g)/2

        

        与JS相比,Wasserstein距离具有以下优点:

  • Wasserstein 距离是连续的,几乎可以在任何地方微分,这使我们能够训练模型达到最佳状态。
  • 随着鉴别器的变好,JS散度局部饱和,因此梯度变为零并消失。
  • Wasserstein 距离是一个有意义的度量,即当分布彼此靠近时,它收敛到 0,当它们越来越远时发散。
  • 作为目标函数的 Wasserstein 距离比使用 JS 散度更稳定。当使用Wasserstein距离作为目标函数时,模式崩溃问题也得到了缓解。

        从图 1 我们清楚地看到,最佳GAN鉴别器饱和并导致梯度消失,而优化Wasserstein距离的WGAN评论家在整个过程中具有稳定的梯度。

        有关数学证明和更详细的研究,请查看此处的论文!

三、瓦瑟斯坦·GAN

        现在可以清楚地看到,优化 Wasserstein 距离比优化 JS 散度更有意义,还需要注意的是,方程 1 中定义的 Wasserstein 距离非常棘手[3],因为我们不可能计算所有 γ ∈Π(Pr ,Pg) 的下界(最大下界)。然而,从坎托罗维奇-鲁宾斯坦二元性中,我们有,

公式3:1-利普希茨条件下的瓦瑟斯坦距离。

        这里我们有 W(P_r, P_g) 作为所有 1-Lipschitz 函数 f: X → R 的上确界(最低上限)。

        K-利普希茨连续性:给定 2 个度量空间 (X, d_X) 和 (Y, d_Y),变换函数 f: X → Y 是 K-利普希茨连续的,如果

公式3:K-Lipschitz连续性。

        其中d_X和d_Y是各自度量空间中的距离函数。当一个函数是 K-Lipschitz 时,从方程 2 开始,我们最终得到 K ∙ W(P_r, P_g)。

        现在,如果我们有一系列参数化函数 {f_w},其中 w∈W 是 K-Lipschitz 连续的,我们可以有

公式 4

即,w∈W 最大化方程 4 给出瓦瑟斯坦距离乘以一个常数。

四、WGAN评论家

        为此,WGAN引入了一个批评者,而不是我们在GAN中了解到的鉴别器。批评者网络在设计上类似于判别器网络,但通过优化找到将最大化方程 4 的 w* 来预测 Wasserstein 距离。为此,批评家的客观功能如下:

公式5:批评家客观函数。

       在这里,为了在函数f上强制执行Lipschitz连续性,作者诉诸于将权重w限制在一个紧凑的空间内。这是通过将砝码夹紧到一个小范围(论文中的[-1e-2,1e-2][1])来完成的。

鉴别器和批评者之间的区别在于,鉴别器经过训练以正确识别P_r样本和P_g样本,批评家估计P_r和P_g之间的Wasserstein距离。

这是训练批评家的python代码。

for ix in n_critic_steps:opt_critic.zero_grad()real_images = data[0].float().to(device)# * Generate imagesnoise = sample_noise()fake_images = netG(noise)# * though they are name so, they are not logits!real_logits = netCritic(real_images)fake_logits = netCritic(fake_images)# * max E_{x~P_X}[C(x)] - E_{Z~P_Z}[C(g(z))]loss = -(real_logits.mean() - fake_logits.mean())loss.backward(retain_graph=True)opt_critic.step()# * Gradient clipplingfor p in netCritic.parameters():p.data.clamp_(-self.c, self.c)

五、WGAN生成器目标

        当然,发电机的目标是最小化P_r和P_g之间的瓦瑟斯坦距离。生成器试图找到最小化P_g和P_r之间的 Wasserstein 距离的 θ*。为此,生成器的目标函数如下:

        公式 6:生成器目标函数。

        在这里,WGAN生成器和标准生成器之间的主要区别再次在于,WGAN生成器试图最小化P_r和P_g之间的Wasserstein距离,而标准生成器试图用生成的图像欺骗鉴别器。

        以下是训练生成器的 python 代码:

opt_gen.zero_grad()noise = sample_noise()fake_images = netG(noise)# again, these are not logits.
fake_logits = netCritic(fake_images)# * - E_{Z~P_Z}[C(g(z))]
loss = -fake_logits.mean().view(-1)loss.backward()
opt_gen.step()

六、培训结果

fig2:WGAN训练的早期结果[3]。

        图例.2显示了训练WGAN的一些早期结果。请注意,图 2 中的图像是早期结果,一旦确认模型按预期训练,训练就会停止。

七、代码

        Wasserstein GAN的完整实现可以在这里找到[3]。

八、结论

        WGAN提供非常稳定的培训和有意义的培训目标。本文介绍并直观地解释了什么是 Wasserstein 距离,Wasserstein 距离相对于标准 GAN 使用的 Jensen-Shannon 散度的优势,以及如何使用 Wasserstein 距离来训练 WGAN。我们还看到了用于训练 Critic 和生成器的代码片段,以及早期训练模型的大量输出。尽管WGAN比标准GAN具有许多优势,但WGAN论文的作者明确承认,权重裁剪不是执行Lipschitz连续性的最佳方法[1]。为了解决这个问题,他们提出了带有梯度惩罚的Wasserstein GAN[4],我们将在后面的文章中讨论。

        如果您喜欢这个,请查看本系列的下一篇文章,其中讨论了 WGAN-GP!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/29142.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Django】招聘面试管理01 创建项目运行项目

文章目录 前言一、创建项目二、运行项目三、访问后台管理页面四、配置项总结 前言 跟着视频学一学,记录一下。 一、创建项目 照着步骤创建虚拟环境,安装Django等依赖包,创建项目:【Django学习】01 项目创建、结构及命令 > d…

C字符串与C++ string 类:用法万字详解(上)

目录 引言 一、C语言字符串 1.1 创建 C 字符串 1.2 字符串长度 1.3 字符串拼接 1.4 比较字符串 1.5 复制字符串 二、C字符串string类 2.1 解释 2.2 string构造函数 2.2.1 string() 默认构造函数 2.2.2 string(const char* s) 从 C 风格字符串构造 2.2.3 string(co…

pytorch Stream 多流处理

CUD Stream https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#c-language-extensions 中指出在kenel的调用函数中最后一个可选参数表示该核函数处在哪个流之中。 - 参数Dg用于定义整个grid的维度和尺寸,即一个grid有多少个block。为dim3类型。…

无涯教程-Perl - foreach 语句函数

foreach 循环遍历列表值,并将控制变量(var)依次设置为列表的每个元素- foreach - 语法 Perl编程语言中的 foreach 循环的语法是- foreach var (list) { ... } foreach - 流程图 foreach - 示例 #!/usr/local/bin/perllist(2, 20, 30, 40, 50);# foreach loop ex…

【微信小程序创作之路】- 小程序远程数据请求、获取个人信息

【微信小程序创作之路】- 小程序远程数据请求、获取个人信息 第七章 小程序远程数据请求、获取个人信息 文章目录 【微信小程序创作之路】- 小程序远程数据请求、获取个人信息前言一、远程数据请求1.本地环境2.正式域名 二、获取用户个人信息1.展示当前用户的身份信息2.获取用…

Ubuntu安装docker

安装 要是之前安装过,可以进行卸载然后再安装,旧版本的 Docker 的名称为docker、docker.io或 docker-engine。安装新版本之前卸载任何此类旧版本 sudo apt-get remove docker docker-engine docker.io containerd runc使用存储库安装 在新主机上首次安…

kafka-保证数据不重复-生产者开启幂等性和事务的作用?

1. 生产者开启幂等性为什么能去重? 1.1 场景 适用于消息在写入到服务器日志后,由于网络故障,生产者没有及时收到服务端的ACK消息,生产者误以为消息没有持久化到服务端,导致生产者重复发送该消息,造成了消…

econml双机器学习实现连续干预和预测

连续干预 在这个示例中,我们使用LinearDML模型,使用随机森林回归模型来估计因果效应。我们首先模拟数据,然后模型,并使用方法来effect创建不同干预值下的效应(Conditional Average Treatment Effect,CATE&…

vue3—SCSS的安装、配置与使用

SCSS 安装 使用npm安装scss: npm install sass sass-loader --save-dev 配置 配置到全局 🌟附赠代码🌟 css: {preprocessorOptions: {scss: {additionalData:import "./src/Function/Easy_I_Function/Echarts/ToSeeEcharts/utill.…

Spring Boot Admin 环境搭建与基本使用

Spring Boot Admin 环境搭建与基本使用 一、Spring Boot Admin是什么二、提供了那些功能三、 使用Spring Boot Admin3.1搭建Spring Boot Admin服务pom文件yml配置文件启动类启动admin服务效果 3.2 common-apipom文件feignhystrix 3.3服务消费者pom文件yml配置文件启动类control…

Simulation 线性静力分析流程

有限元仿真分析软件有很多,但是分析的流程却是大同小异,今天给大家分享的是Simulation的线性静力分析流程。 1.构思分析方案。 确定研究对象,研究的方法、验证方案等等。听起来比较空洞,实践过程中我建议首先需要把目标和有限元分…

HDFS中的Trash垃圾桶回收机制

Trash垃圾桶回收机制 文件系统垃圾桶背景功能概述Trash Checkpoint Trash功能开启关闭HDFS集群修改core-site.xml删除文件到trash删除文件跳过从trash中恢复文件清空trash 文件系统垃圾桶背景 回收站(垃圾桶)是windows操作系统里的一个系统文件夹&#…

一起学SF框架系列7.1-spring-AOP-基础知识

AOP(Aspect-oriented Programming-面向切面编程)是一种编程模式,是对OOP(Object-oriented Programming-面向对象编程)一种有益补充。在OOP中,万事万物都是独立的对象,对象相互耦合关系是基于业务进行的;但在…

目标识别模型两种部署形态图

目标检测预训练模型基于新数据进行微调(训练)之后,得到一个权重文件。 在日常工业、车载等需求环境下,需要在嵌入式移动端的软件系统中调用该模型文件进行推断测试,软件系统追求性能经常使用C/C进行编码实现&#xff…

第十一次CCF计算机软件能力认证

第一题:打酱油 小明带着 N 元钱去买酱油。 酱油 10 块钱一瓶,商家进行促销,每买 3 瓶送 1 瓶,或者每买 5 瓶送 2 瓶。 请问小明最多可以得到多少瓶酱油。 输入格式 输入的第一行包含一个整数 N,表示小明可用于买酱油的…

【深度学习】【风格迁移】Visual Concept Translator,一般图像到图像的翻译与一次性图像引导,论文

General Image-to-Image Translation with One-Shot Image Guidance 论文:https://arxiv.org/abs/2307.14352 代码:https://github.com/crystalneuro/visual-concept-translator 文章目录 Abstract1. Introduction2. 相关工作2.1 图像到图像转换2.2. Di…

网络防御(2)

1. 什么是防火墙? 2. 状态防火墙工作原理? 3. 防火墙如何处理双通道协议? 一、什么是防火墙? 防火墙是一种网络安全设备或软件,用于保护计算机网络免受未经授权的访问,并管理网络流量。它作为一个安全边界…

Android中级——RemoteView

RemoteView RemoteView的应用NotificationWidgetPendingIntent RemoteViews内部机制模拟RemoteViews RemoteView的应用 Notification 如下开启一个系统的通知栏,点击后跳转到某网页 public class MainActivity extends AppCompatActivity {private static final …

【Linux取经路】进程的奥秘

文章目录 1、什么是进程?1.1 自己写一个进程 2、操作系统如何管理进程?2.1 描述进程-PCB2.2 组织进程2.3 深入理解进程 3、Linux环境下的进程3.1 task_struct3.2 task_struct内容分类3.3 组织进程3.4 查看进程属性 4、结语 1、什么是进程? 在…

软件单元测试

单元测试目的和意义 对于非正式的软件(其特点是功能比较少,后续也不有新特性加入,不用负责维护),我们可以使用debug单步执行,内存修改,检查对应的观测点是否符合要求来进行单元测试&#xff0c…