WGAN - 瓦萨斯坦生成对抗网络

1. 背景与问题

生成对抗网络(Generative Adversarial Networks, GANs)是由Ian Goodfellow等人于2014年提出的一种深度学习模型。它包括两个主要部分:生成器(Generator)和判别器(Discriminator),两者通过对抗训练的方式,彼此不断改进,生成器的目标是生成尽可能“真实”的数据,而判别器的目标是区分生成的数据和真实数据

虽然传统GAN在多个领域取得了巨大成功,但它们也存在一些显著的问题,尤其是训练不稳定性和模式崩溃(Mode Collapse)。为了克服这些问题,Wasserstein Generative Adversarial Network(WGAN)应运而生,提出了一种新的损失函数,基于Wasserstein距离来衡量生成数据和真实数据之间的差异,从而提高训练的稳定性和生成效果。

推荐阅读:DenseNet-密集连接卷积网络

2. 传统GAN的局限性

在传统的GAN中,生成器和判别器之间的对抗过程是通过最小化生成器的损失函数来实现的。GAN的损失函数通常使用交叉熵来衡量生成数据与真实数据的差异,公式如下:

  • 生成器的损失:

    在这里插入图片描述

  • 判别器的损失:

在这里插入图片描述

问题:

  • 梯度消失:如果判别器过强,它会变得非常接近0或1,导致生成器的梯度几乎消失,训练陷入停滞。
  • 模式崩溃(Mode Collapse):生成器可能只生成非常有限的几种样本,无法覆盖真实数据的所有模式。
  • 训练不稳定:在某些情况下,生成器和判别器之间的博弈可能导致不收敛,难以调节超参数。
    在这里插入图片描述

3. WGAN简介

WGAN的提出旨在通过引入Wasserstein距离来解决传统GAN中的上述问题。Wasserstein距离是一种度量两个分布之间距离的方法,它可以有效地避免传统GAN中存在的梯度消失问题,并且提供更加稳定的训练过程。

WGAN的核心思想是在判别器中不使用标准的sigmoid激活函数,而是采用线性输出,并用Wasserstein距离来作为损失函数。Wasserstein距离的引入,使得生成器和判别器的训练变得更加平滑,且训练过程更为稳定。

4. WGAN的理论基础:Wasserstein距离

Wasserstein距离,也称为地球搬运人距离(Earth Mover’s Distance, EMD),是用于度量两个概率分布之间差异的一种方法。在生成对抗网络中,Wasserstein距离可以用来衡量生成数据分布和真实数据分布之间的距离。

Wasserstein距离的定义

给定两个分布PP和QQ,Wasserstein距离可以定义为:

W(P,Q)=inf⁡γ∈Π(P,Q)E(x,y)∼γ[∥x−y∥]W(P, Q) = \inf_{\gamma \in \Pi(P,Q)} \mathbb{E}_{(x,y) \sim \gamma} [ |x - y| ]

其中,Π(P,Q)\Pi(P,Q)表示所有可能的联合分布γ\gamma,其边缘分布分别是PP和QQ,而∥x−y∥|x - y|是样本之间的距离。

在WGAN中,Wasserstein距离的引入使得训练更加稳定,且相比于交叉熵损失函数,它能够提供更加有效的梯度信息。

证明Wasserstein距离的优势

WGAN的一个关键优势是,它避免了传统GAN中出现的梯度消失问题。具体来说,WGAN中的判别器(称为批量判别器)并不输出概率值,而是输出一个实数值,因此在优化过程中能够提供更加稳定的梯度信号。

5. WGAN的架构与优化

网络架构

WGAN的架构与传统GAN基本相同,主要包括两个网络:生成器和判别器。区别在于,WGAN中的判别器不再是一个概率分类器,而是一个逼近Wasserstein距离的网络。

生成器(Generator)

生成器的目标是生成能够尽可能接近真实数据的样本。它通过一个隐空间向量zz生成样本,输出与真实数据分布相似的样本。

判别器(Discriminator)

判别器的任务是区分真实数据和生成数据的差异,但它并不输出概率值,而是输出一个实数值,表示样本的Wasserstein距离

WGAN的损失函数

WGAN中的损失函数非常简单。生成器的目标是最小化Wasserstein距离,而判别器的目标是最大化Wasserstein距离。WGAN的损失函数如下:

  • 生成器的损失:

    LG=−Ez∼pz(z)[D(G(z))]\mathcal{L}G = - \mathbb{E}{z \sim p_z(z)} [D(G(z))]

  • 判别器的损失:

    LD=Ex∼pdata(x)[D(x)]−Ez∼pz(z)[D(G(z))]\mathcal{L}D = \mathbb{E}{x \sim p_{data}(x)} [D(x)] - \mathbb{E}_{z \sim p_z(z)} [D(G(z))]

判别器的权重剪切

为了确保Wasserstein距离的有效性,WGAN要求判别器的参数满足1-Lipschitz条件。为此,WGAN采用了权重剪切(weight clipping)的方法,即在每次训练判别器时,都将其权重限制在一个小的范围内。例如,假设权重剪切的最大值为cc,则每次更新判别器时都会将其权重强制限制在区间[−c,c][-c, c]内。

# 伪代码:判别器权重剪切
for p in discriminator.parameters():p.data.clamp_(-c, c)

这种操作是WGAN的关键所在,它确保了判别器的权重满足Lipschitz连续性,从而使得Wasserstein距离能够有效地度量生成数据和真实数据之间的差异。

6. WGAN的训练技巧

判别器与生成器的训练

WGAN的训练过程与传统GAN类似,但有以下几点不同:

  • 判别器训练:在每次更新判别器时,WGAN要求进行多个步骤的训练。一般来说,判别器的训练次数会比生成器的训练次数多。这是因为判别器需要更好地逼近真实数据和生成数据之间的Wasserstein距离。

    for i in range(n_critic):D.zero_grad()real_data = get_real_data()fake_data = generator(z)loss_d = discriminator_loss(real_data, fake_data)loss_d.backward()optimizer_d.step()clip_weights(discriminator)
    
  • 生成器训练:生成器的更新则是根据判别器的输出进行的。通过反向传播,生成器可以最小化其生成数据与真实数据之间的Wasserstein距离。

    G.zero_grad()
    fake_data = generator(z)
    loss_g = generator_loss(fake_data)
    loss_g.backward()
    optimizer_g.step()
    

权重剪切的局限性

虽然权重剪切可以保证Lipschitz条件,但它也有一定的局限性。过度的权重剪切可能导致判别器的能力受限,进而影响生成效果。因此,研究

人员提出了**梯度惩罚(Gradient Penalty)**作为改进方法,这将在后续部分讨论。

7. WGAN改进:WGAN-GP (Gradient Penalty)

WGAN-GP的动机

WGAN的一个问题在于权重剪切可能导致网络不稳定或训练过慢。为了解决这个问题,提出了WGAN-GP(Wasserstein GAN with Gradient Penalty)方法,它引入了梯度惩罚来代替权重剪切,从而保持Wasserstein距离的有效性。

WGAN-GP损失函数

WGAN-GP的损失函数相比WGAN有所改进,加入了梯度惩罚项,具体如下:

  • 判别器损失: LD=Ex∼pdata(x)[D(x)]−Ez∼pz(z)[D(G(z))]+λEx∼px[(∥∇xD(x)∥2−1)2]\mathcal{L}D = \mathbb{E}{x \sim p_{data}(x)} [D(x)] - \mathbb{E}{z \sim p_z(z)} [D(G(z))] + \lambda \mathbb{E}{\hat{x} \sim p_{\hat{x}}} \left[ (|\nabla_{\hat{x}} D(\hat{x})|_2 - 1)^2 \right]

其中,x^\hat{x}是从真实数据和生成数据之间的插值中采样得到的,λ\lambda是梯度惩罚项的系数。

训练过程

WGAN-GP的训练过程与WGAN相似,只是判别器的更新方式有所不同。具体来说,我们需要计算梯度惩罚,并将其加到判别器的损失函数中:

# 计算梯度惩罚
def compute_gradient_penalty(D, real_data, fake_data):alpha = torch.rand(real_data.size(0), 1, 1, 1).to(real_data.device)interpolated = alpha * real_data + (1 - alpha) * fake_datainterpolated.requires_grad_(True)d_interpolated = D(interpolated)grad_outputs = torch.ones_like(d_interpolated)gradients = torch.autograd.grad(outputs=d_interpolated, inputs=interpolated, grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True)[0]gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()return gradient_penalty

优势与效果

WGAN-GP的引入梯度惩罚后,训练过程显著更加稳定,避免了WGAN中因权重剪切带来的不稳定性和训练速度较慢的问题。WGAN-GP已成为生成对抗网络中常用的变体之一。

8. WGAN应用案例

WGAN和WGAN-GP已被广泛应用于图像生成、文本生成、音乐生成等多个领域。以下是一些实际的应用案例:

  1. 图像生成:WGAN常用于高分辨率图像的生成,尤其是在超分辨率图像生成、图片到图片的转换等任务中表现优异。
  2. 文本生成:WGAN也可以用于自然语言处理领域,通过生成器生成自然语言文本,判别器判断文本的质量。
  3. 数据增强:WGAN被用作数据增强技术,通过生成更多的训练数据来提高模型的泛化能力。

9. WGAN与传统GAN对比

优点

  • 训练稳定性:WGAN通过引入Wasserstein距离,使得训练过程更加稳定,避免了梯度消失和模式崩溃的问题。
  • 优化效果:WGAN优化过程中生成器和判别器之间的博弈更加平衡,从而生成质量更高的样本。

缺点

  • 计算成本:WGAN的计算成本较传统GAN更高,尤其是在判别器训练阶段,计算Wasserstein距离和梯度惩罚需要更多的计算资源。
  • 收敛速度:尽管WGAN的训练稳定性较强,但它的收敛速度可能比其他类型的GAN稍慢。

10. 总结与展望

WGAN为生成对抗网络的训练提供了一种新的优化策略,通过引入Wasserstein距离来替代传统的交叉熵损失函数,显著提高了训练的稳定性和生成质量。尽管WGAN在许多方面具有优势,但仍存在一些计算成本和收敛速度上的挑战。

未来,随着硬件的进步和算法的优化,WGAN及其变种(如WGAN-GP)有望在更广泛的应用中得到进一步的推广与发展。同时,研究人员也在不断探索新的方法来优化WGAN的训练过程,进一步提升其在生成任务中的表现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/68696.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qt 5.14.2 学习记录 —— 십칠 窗口和菜单

文章目录 1、Qt窗口2、菜单栏设置快捷键添加子菜单添加分割线和菜单图标 3、工具栏 QToolBar4、状态栏 QStatusBar5、浮动窗口 QDockWidget 1、Qt窗口 QWidget,即控件,是窗口的一部分。在界面中创建控件组成界面时,Qt自动生成了窗口&#xf…

SpringCloud系列教程:微服务的未来(十四)网关登录校验、自定义过滤器GlobalFilter、GatawayFilter

前言 在微服务架构中,API 网关扮演着至关重要的角色,负责路由请求、执行安全验证、流量控制等任务。Spring Cloud Gateway 作为一个强大的网关解决方案,提供了灵活的方式来实现这些功能。 本篇博客将重点介绍如何在 Spring Cloud Gateway 中…

Redis源码-redisObject

解释 redis中,所有的数据类型最终都转换成了redisObject,该结构体的定义,在文件server.h中。 参数说明 参数名说明unsigned type:4对象对应的数据类型unsigned encoding:4对象的编码方式unsigned lru:LRU_BITSLRU算法清空对象&#xff0c…

为什么相关性不是因果关系?人工智能中的因果推理探秘

目录 一、背景 (一)聚焦当下人工智能 (二)基于关联框架的人工智能 (三)基于因果框架的人工智能 二、因果推理的基本理论 (一)因果推理基本范式:因果模型&#xff0…

兼职全职招聘系统架构与功能分析

2015工作至今,10年资深全栈工程师,CTO,擅长带团队、攻克各种技术难题、研发各类软件产品,我的代码态度:代码虐我千百遍,我待代码如初恋,我的工作态度:极致,责任&#xff…

js重要知识点

目录 一、冒泡排序的计算方法 二、数组forEach方法 三、Number(null)和Number(undefined) 四、es6中的set 一、冒泡排序的计算方法 冒泡排序的重点:两次循环,外层循环是总共要进行的躺数,为数组总长度-1,内层循环则是每个元素在每一次循环中需要比较的次数&#xff…

Chrome 132 版本新特性

Chrome 132 版本新特性 一、Chrome 132 版本浏览器更新 1. 在 iOS 上使用 Google Lens 搜索 在 Chrome 132 版本中,开始在所有平台上推出这一功能。 1.1. 更新版本: Chrome 126 在 ChromeOS、Linux、Mac、Windows 上:在 1% 的稳定版用户…

2024微短剧行业生态洞察报告汇总PDF洞察(附原数据表)

原文链接: https://tecdat.cn/?p39072 本报告合集洞察从多个维度全面解读微短剧行业。在行业发展层面,市场规模与用户规模双增长,创造大量高收入就业岗位并带动产业链升级。内容创作上,精品化、品牌化趋势凸显,题材走…

基于GRU实现股价多变量时间序列预测(PyTorch版)

前言 系列专栏:【深度学习:算法项目实战】✨︎ 涉及医疗健康、财经金融、商业零售、食品饮料、运动健身、交通运输、环境科学、社交媒体以及文本和图像处理等诸多领域,讨论了各种复杂的深度神经网络思想,如卷积神经网络、循环神经网络、生成对抗网络、门控循环单元、长短期记…

Python基于Django的社区爱心养老管理系统设计与实现【附源码】

博主介绍:✌Java老徐、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇&…

基于OpenCV和Python的人脸识别系统_django

开发语言:Python框架:djangoPython版本:python3.7.7数据库:mysql 5.7数据库工具:Navicat11开发软件:PyCharm 系统展示 管理员登录 管理员功能界面 用户管理 公告信息管理 操作日志管理 用户登录界面 用户…

吴恩达深度学习——神经网络编程的基础知识

文章内容来自BV11H4y1F7uH,仅为个人学习所用。 文章目录 二分分类一些符号说明 逻辑斯蒂回归传统的线性回归函数 y ^ w T x b \hat{y}w^T\boldsymbol{x}b y^​wTxbSigmoid激活函数逻辑斯蒂回归损失函数损失函数成本函数与损失函数的关系 梯度下降法计算图逻辑斯蒂…

调试Hadoop源代码

个人博客地址:调试Hadoop源代码 | 一张假钞的真实世界 Hadoop版本 Hadoop 2.7.3 调试模式下启动Hadoop NameNode 在${HADOOP_HOME}/etc/hadoop/hadoop-env.sh中设置NameNode启动的JVM参数,如下: export HADOOP_NAMENODE_OPTS"-Xdeb…

通过Ukey或者OTP动态口令实现windows安全登录

通过 安当SLA(System Login Agent)实现Windows安全登录认证,是一种基于双因素认证(2FA)的解决方案,旨在提升 Windows 系统的登录安全性。以下是详细的实现方法和步骤: 1. 安当SLA的核心功能 安…

从前端视角看设计模式之结构型模式篇

上篇我们介绍了 设计模式之创建型模式篇,接下来介绍设计模式之结构型模式篇 适配器模式 适配器模式旨在解决接口不兼容的问题,它通过创建一个适配器类,将源对象的接口转换成目标接口,从而使得不兼容的接口能够协同工作。简单来说…

彻底讲清楚 单体架构、集群架构、分布式架构及扩展架构

目录 什么是系统架构 单体架构 介绍 示例图 优点 缺点 集群架构 介绍 示意图 优点 缺点 分布式架构 示意图 优点 缺点 生态扩展 介绍 示意图 优点 缺点 扩展:分布式服务解析 纵切拆服务 全链路追踪能力 循环依赖 全链路日志(En…

编辑器Vim基本模式和指令 --【Linux基础开发工具】

文章目录 一、编辑器Vim 键盘布局二、Linux编辑器-vim使用三、vim的基本概念正常/普通/命令模式(Normal mode)插入模式(Insert mode)末行模式(last line mode) 四、vim的基本操作五、vim正常模式命令集插入模式从插入模式切换为命令模式移动光标删除文字复制替换撤销上一次操作…

ChatGPT被曝存在爬虫漏洞,OpenAI未公开承认

OpenAI的ChatGPT爬虫似乎能够对任意网站发起分布式拒绝服务(DDoS)攻击,而OpenAI尚未承认这一漏洞。 本月,德国安全研究员Benjamin Flesch通过微软的GitHub分享了一篇文章,解释了如何通过向ChatGPT API发送单个HTTP请求…

成就与远见:2024年技术与思维的升华

个人主页:chian-ocean 前言: 2025年1月17日,2024年博客之星年度评选——创作影响力评审的入围名单公布。我很荣幸能够跻身Top 300,虽然与顶尖博主仍有一定差距,但这也为我提供了更加明确的发展方向与指引。展望崭新的2025年&…

【前端】CSS学习笔记(2)

目录 CSS3新特性圆角阴影动画keyframes 创建动画animation 执行动画timing-function 时间函数direction 播放方向过渡动画(transition) 媒体查询设置meta标签媒体查询语法 雪碧图字体图标 CSS3新特性 圆角 使用CSS3border-radius属性,你可以…