【Pytorch神经网络理论篇】 23 对抗神经网络:概述流程 + WGAN模型 + WGAN-gp模型 + 条件GAN + WGAN-div + W散度

同学你好!本文章于2021年末编写,获得广泛的好评!

故在2022年末对本系列进行填充与更新,欢迎大家订阅最新的专栏,获取基于Pytorch1.10版本的理论代码(2023版)实现,

Pytorch深度学习·理论篇(2023版)目录地址为:

CSDN独家 | 全网首发 | Pytorch深度学习·理论篇(2023版)目录本专栏将通过系统的深度学习实例,从可解释性的角度对深度学习的原理进行讲解与分析,通过将深度学习知识与Pytorch的高效结合,帮助各位新入门的读者理解深度学习各个模板之间的关系,这些均是在Pytorch上实现的,可以有效的结合当前各位研究生的研究方向,设计人工智能的各个领域,是经过一年时间打磨的精品专栏!https://v9999.blog.csdn.net/article/details/127587345欢迎大家订阅(2023版)理论篇

以下为2021版原文~~~~

 

1 对抗神经简介

1.1 对抗神经网络的基本组成

1.1.1 基本构成

对抗神经网络(即生成式对抗网络,GAN)一般由两个模型组成:

  • 生成器模型(generator):用于合成与真实样本相差无几的模拟样本。
  • 判别器模型(discriminator):用于判断某个样本是来自真实世界还是模拟生成的。

1.1.2 不同模型的在GAN中的主要作用

生成器模型的目的是,让判别器模型将合成样本当成直实样本。
判别器模的目的是,将合成样本与真实样本区分开。

1.1.3 独立任务

若将两个模型放在一起同步训练,那么生成器模型生成的模拟样本会更加真实,判别器模型对样本的判断会更加精准。

  • 生成器模型可以当成生成式模型,用来独立处理生成式任务;
  • 判别器模型可以当成分类器模型,用来独立处理分类任务。

1.2 对抗神经网络的工作流程

1.2.1生成器模型

生成器模型的输入是一个随机编码向量,输出是一个复杂样本(如图片)。从训练数据中产生相同分布的样本。对于输入样本x,类别标签y,在生成器模型中估计其联合概率分布,生成与输入样本x更为相似的样本。

1.2.2 判别器模型

根据条件概率分布区分真假样本。它的输入是一个复杂样本,输出是一个概率。这个概率用来判定输入样本是真实样本还是生成器输出的模拟样本。

1.2.3 工作流程简介

生成器模型与判别器模型都采用监督学习方式进行训练。二者的训练目标相反,存在对抗关系。将二者结合后,将形成如下图所示的网络结构。

对抗神经网络结构对抗神经网络的训练方法各种各样,但其原理都是一样的,即在迭代练的优化过程中进行两个网络的优化。有的方法会在一个优化步骤中对两个网络进行优化、有的会对两个网络采取不同的优化步骤。
经过大量的迭代训练会使生成器模型尽可能模拟出“以假乱真”的样本,而判别模型会有更精确的鉴别真伪数据的能力,从而使整个对抗神经网络最终达到所谓的纳什均衡,即判别器模型对于生成器模型输出数据的鉴别结果为50%直、50%假。

1.3 对抗神经网络的功能

监督学习神经网络都属于判别器模型,自编码神经网络中,编码器部分就属于一个生成器模型

1.3.1 生成器模型的特性

  • 在应用数学和工程方面,能够有效地表征高维数据分布。
  • 在强化学习方面,作为一种技术手段有效表征强化学习模型中的状态。
  • 在半览督学习方面,能够在数据缺失的情况下训练模型、并给出相应的输出。

1.3.2 举例

在视频中,通过场景预测下一帧的场景,而判别器模型的输出是维度很低的判别结果和期望输出的某个预测值,无法训练出单输入多输出的模型。

1.4 Gan模型难以训练的原因

GAN中最终达到对抗的纳什均衡只是一个理想状态,而现实情况是,随着训练次数的增多,判别器D的效果渐好,从而总是可以将生成器G的输出与真实样本区分开。

1.4.1 现象剖析

因为生成器G是从低维空间向高维空间(复杂的样本空间)的映射,其生成的样本分布空间Pg难以充满整个真实样本的分布空间Pr,即两个分布完全没有重叠的部分,或者重叠的部分可忽略,这就使得判别器D可以将其分开。

1.4.2 生成样本与真实样本重叠部分可忽略的原因

在二维平面中,随机取两条曲线,两条曲线上的点可以代表二者的分布。要想让判别器无法分辨它们,需要两个分布融合在一起,也就是它们之间需要存在重叠的线段,然而这样的概率为0。即使它们很可能会存在交叉点,但是相比于两条曲线而言,交叉点比曲线低一个维度,也就是它只是一个点,代表不了分布情况,因此可将其忽略。

1.4.2 原因分析

假设先将判别器D训练得足够好,固定判别器D后再来训练生成器G,通过实验会发现G的loss值无法收敛到最小值,而是无限地接近一个特定值。这个值可以理解为模拟样本分布Pg与原始样本分布Pr两个样本分布之间的距离。对于loss值恒定(即表明生成器G的梯度为0)的情况,生成器G无法通过训练来优化自己。

在原始GAN的训练中判别器训练得太好,生成器梯度就会逍失,生成器的lossS值降不下去;

在原始GAN的训练中判别器训练得不好,生成器梯度不准,抖动较大。

只有判别器训练到中间状态,才是最好的,但是这个尺度很难把握,甚至在同一轮训练的不同阶段这个状态出现的时段都不一样,这是一个完全不可控的情况。

2 WGAN模型

WGAN的名字源于Wasserstein GAN,Vasserstein是指Wasserstein距离,又称Earth-Mover(EM)推土机距离。

2.1 WGAN模型的原理

WGAN的原理是将生成的模拟样本分布Pg与原始样本分布Pr组合起来,并作为所有可能的联合分布的集合,并计算出二者的距离和距离的期望值。

2.1.1 WGAN原理的优势

可以通过训练模型的方式,让网络沿着其该网络所有可能的联合分布期望值的下界方向进行优化,即将两个分布的集合拉到一起。此时,原来的判别器就不再具有判别真伪的功能,而获得计算两个分布集合距离的功能。因此,将其称为评论器会更加合适。最后一层的Sigmoid函数也需要去掉(不需要将值域控制在0~1)。

2.2 WGAN模型的实现

使用神经网络来计算Wasserstein距离,可以让神经网络直接拟合下式:

f(x)可以理解成神经网络的计算,让判别器实现将f(x1)与f(x2)的距离变换成x1-x2的绝对值乘以k(k≥0)。k代表函数f(x)的Lipschitz常数,这样两个分布集合的距离就可以表示成D(real)-D(G(x))的绝对值乘以k了。这个k可以理解成梯度,即在神经网络f(x)中乘以的梯度绝对值小于k。

将上式中的k忽略,经过整理后,可以得到二者分布的距离公式:

现在要做的就是将L当成目标来计算loss值。

判别器D的任务是区分它们,因为希望二者距离变大,所以loss值需要取反得到:

通过判别器D的losss值也可以看出生成器G的生成质量,即loss值越小,代表距离越近,生成的质量越高。

生成器G用来将希望模拟样本分布Pg越来越接近原始样本分布Pr,所以需要训练让距离L最小化。因为生成器G与第一项无关,所以G的loss值口可简化为:

2.4 WGAN的缺点

若原始WGAN的Lipschitz限制的施加方式不对,那么使用梯度截断方式太过生硬。每当更新完一次判别器的参数之后,就应检查判别器中所有参数的绝对值有没有超过阈值,有的话就把这些参数截断回[-0.01,0.01]范围内。

Lipschitz限制本意是当输入的样本稍微变化后,判别器给出的分数不能产生太过剧烈的变化。通过在训练过程中保证判别器的所有参数有界,可保证判别器不能对两个略微不同的样本给出天差地别的分数值,从而间接实现了Lipschitz限制。

这种期望与判别器本身的目的相矛盾。判别器中希望loss值尽可能大,这样才能拉大真假样本间的区别,但是这种情况会导致在判别器中,通过loss值算出来的梯度会沿着loss值越来越大的方向变化,然而经过梯度截断后每一个网络参数又被独立地限制了取值范圃(如[-0.01,0.01])。这种结果会使得所有的参数要么取最大值(如0.01),要么取最小值(如-0.01)。判别器没能充分利用自身的模型能力,经过它回传给生成器的梯度也会跟着变差。

如果判别器是一个多层网络,那么梯度截断还会导致梯度消失或者梯度“爆炸”问题。截断阀值设置得稍微低一点,那么每经过一层网络,梯度就会变小一点,多层之后就会呈指数衰减趋势。

反之截断阔值设置得稍大,每经过一层网络,梯度变大一点,则多层之后就会呈指数爆炸趋势。在实际应用中,很难做到设合适,让生或器获得恰到好处的回传梯度。

2.3 WGAN模型总结

WGAN引入了Wasserstein距离,由于它相对KL散度与JS散度具有优越的平滑特性,因此理论上可以解决梯度消失问题。再利用一个参数数值范围受限的判别器神经网络实现将Wasserstein距离数学变换写成可求解的形式的最大化,可近似得到Wasserstein距离。

在此近似最优判别器下,优化生成器使得Wasserstein距离缩小,这能有效拉近生成分布与真实分布。WGAN既解决了训练不稳定的问题,又提供了一个可靠的训练进程指标,而且该指标确实与生成样本的质量高度相关。

在实际训练过程中,WGAN直接使用截断(clipping)的方式来防止梯度过大或过小。但这个方式太过生硬,在实际应用中仍会出现问题,所以后来产生了其升级版WGAN-gp。

3 WGAN-gp模型(更容易训练的GAN模型)

WGAN-gp又称为具有梯度惩罚的WGAN,是WGAN的升级版,一般可以用来全面代替WGAN。

3.1 WGAN-gp介绍

WGAN-gp中的gp是梯度惩罚(gradient penalty)的意思,是替换weight clipping的一种方法。通过直接设置一个额外的梯度惩罚项来实现判别器的梯度不超过k。其表达公式为:

其中,MSE为平方差公式;X_inter为整个联合分布空间中的x取样,即梯度惩罚项gradent _penaltys为求整个联合分布空间中x对应D的梯度与k的平方差。

3.2 WGAN-gp的原理与实现

3.3 Tip

  • 因为要对每个样本独立地施加梯度惩罚,所以在判别器的模型架构中不能使用BN算法,因为它会引入同一个批次中不同样本的相互依赖关系。
  • 如果需要的话,那么可以选择其他归一化办法,如Layer Normalization、Weight Normalization、Instance Normalization等,这些方法不会引入样本之间的依赖。

4 条件GAN

条件GAN的作用是可以让GAN的生成器模型按照指定的类别生成模拟样本。

4.1 条件GAN的实现

条件GAN在GAN的生成器和判别器基础上各进行了一处改动;在它们的输入部分加入了一个标签向量(one_hot类型)。

4.2 条件GAN的原理

GAN的原理与条件变分自编码神经网络的原理一样。这种做法可以理解为给GAN增加一个条件,让网络学习图片分布时加入标签因素,这样可以按照标签的数值来生成指定的图片。

5 带有散度的GAN——WGAN-div

WGAN-div模型在WGAN-gp的基础上,从理论层面进行了二次深化。在WGAN-gp中,将判别器的梯度作为惩罚项加入判别器的loss值中。
在计算判别器梯度时,为了让X_inter从整个联合分布空间的x中取样,在真假样本之间采取随机取样的方式,保证采样区间属于真假样本的过渡区域。然而,这种方案更像是一种经验方案,没有更完备的理论支撑(使用个体采样代替整体分布,而没能从整体分布层面直接解决问题)。

3.1 WGAN-div模型的使用思路

WGAN-div模型与WGAN-gp相比,有截然不同的使用思路:不从梯度惩罚的角度去考虑,而通过两个样本间的分布距离来实现。
在WGAN-diⅳ模型中,引入了W散度用于度量真假样本分布之间的距离,并证明了中的W距离不是散度。这意味着WGAN-gp在训练判别器的时候,并非总会拉大两个分布间的距离,从而在理论上证明了WGAN-gp存在的缺陷一—会有训练失效的情况。
WGAN-div模型从理论层面对WGAN进行了补充。利用WGAN-div模型的理论所实现的loss值不再需要采样过程,并且所达到的训练效果也比WGAN-gp更胜一筹。

3.2 了解W散度

 3.3 WGAN-div的损失函数

3.4 W散度与W距离间的关系

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/469326.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[haoi2011]防线修建

动态加点维护凸包。 论STL的熟练运用。 #include<cstdio> #include<algorithm> #include<cstring> #include<vector> #include<iostream> #include<string> #include<map> #include<set> #include<cstdlib> #include<…

【Pytorch神经网络实战案例】15 WGAN-gp模型生成Fashon-MNST模拟数据

1 WGAN-gp模型生成模拟数据案例说明 使用WGAN-gp模型模拟Fashion-MNIST数据的生成&#xff0c;会使用到WGAN-gp模型、深度卷积GAN(DeepConvolutional GAN&#xff0c;DCGAN)模型、实例归一化技术。 1.1 DCGAN中的全卷积 WGAN-gp模型侧重于GAN模型的训练部分&#xff0c;而DCG…

Android启动过程深入解析

转载自&#xff1a;http://blog.jobbole.com/67931/ 当按下Android设备电源键时究竟发生了什么&#xff1f;Android的启动过程是怎么样的&#xff1f;什么是Linux内核&#xff1f;桌面系统linux内核与Android系统linux内核有什么区别&#xff1f;什么是引导装载程序&#xff1…

android 解析网络数据(JSON)

解析json数据&#xff0c;获取你需要的信息 首先在manifest中添加允许访问网络的权限信息 <uses-permission android:name"android.permission.INTERNET"/> Main package com.chuanxidemo.shaoxin.demo08;import android.os.Bundle; import android.support.an…

【Pytorch神经网络实战案例】16 条件WGAN模型生成可控Fashon-MNST模拟数据

1 条件GAN前置知识 条件GAN也可以使GAN所生成的数据可控&#xff0c;使模型变得实用&#xff0c; 1.1 实验描述 搭建条件GAN模型&#xff0c;实现向模型中输入标签&#xff0c;并使其生成与标签类别对应的模拟数据的功能&#xff0c;基于WGAN-gp模型改造实现带有条件的wGAN-…

Android bootchart(二)

这篇文章讲一下MTK8127开机启动的时间 MTK8127发布版本开机时间大约在&#xff12;&#xff10;秒左右&#xff0c;如果发现开机时间变长&#xff0c;大部分是因为加上了客户订制的东西&#xff0c;代码累赘太多了。 &#xff11;、下面看一下&#xff2d;&#xff34;&#…

Android Camera框架

总体介绍 Android Camera 框架从整体上看是一个 client/service 的架构, 有两个进程: client 进程,可以看成是 AP 端,主要包括 JAVA 代码与一些 native c/c++代码; service 进 程,属于服务端,是 native c/c++代码,主要负责和 linux kernel 中的 camera driver 交互,搜集 li…

【Pytorch神经网络实战案例】17 带W散度的WGAN-div模型生成Fashon-MNST模拟数据

1 WGAN-div 简介 W散度的损失函数GAN-dv模型使用了W散度来替换W距离的计算方式&#xff0c;将原有的真假样本采样操作换为基于分布层面的计算。 2 代码实现 在WGAN-gp的基础上稍加改动来实现&#xff0c;重写损失函数的实现。 2.1 代码实战&#xff1a;引入模块并载入样本-…

runtime如何实现weak属性

首先了解weak是一种非拥有关系,属性所值对象销毁时,属性值会情况(nil). Runtime对注册的类会进行布局,对于weak对象会放入hash表中,用weak指向的内存地址作为key,当对象引用计数器为0时会dealloc,假如weak指向的对象内存地址为a,那么就会以a为键,在这个weak表中搜索,找到以a为键…

计算机地址分配

1、计算机寻址 在linux下可以通过查看proc/ioport来看他们的地址段范围 weiqifa@weiqifa-Inspiron-3847:~/weiqifa/new_tm100/tm100$ cat /proc/ioports 0000-0cf7 : PCI Bus 0000:000000-001f : dma10020-0021 : pic10040-0043 : timer

【Pytorch神经网络理论篇】 24 神经网络中散度的应用:F散度+f-GAN的实现+互信息神经估计+GAN模型训练技巧

同学你好&#xff01;本文章于2021年末编写&#xff0c;获得广泛的好评&#xff01; 故在2022年末对本系列进行填充与更新&#xff0c;欢迎大家订阅最新的专栏&#xff0c;获取基于Pytorch1.10版本的理论代码(2023版)实现&#xff0c; Pytorch深度学习理论篇(2023版)目录地址…

angularjs控制器之间的数据共享与通信

1、可以写一个service服务&#xff0c;从而达到数据和代码的共享; var appangular.module(app,[]);app.service(ObjectService, [ObjectService]); function ObjectService() {var list {};return {get: function(id){return list[id];},set: function(id, v){list[id] v;}} …

【Pytorch神经网络实战案例】18 最大化深度互信信息模型DIM实现搜索最相关与最不相关的图片

图片搜索器分为图片的特征提取和匹配两部分&#xff0c;其中图片的特征提取是关键。将使用一种基于无监督模型的提取特征的方法实现特征提取&#xff0c;即最大化深度互信息&#xff08;DeepInfoMax&#xff0c;DIM&#xff09;方法。 1 最大深度互信信息模型DIM简介 在DIM模型…

linux tar 解压命令总结

把常用的tar解压命令总结下&#xff0c;当作备忘&#xff1a; tar -c: 建立压缩档案 -x&#xff1a;解压 -t&#xff1a;查看内容 -r&#xff1a;向压缩归档文件末尾追加文件 -u&#xff1a;更新原压缩包中的文件 这五个是独立的命令&#xff0c;压缩解压都要用到其中一…

经典c语言题

1. 用预处理指令#define 声明一个常数&#xff0c;用以表明1年中有多少秒&#xff08;忽略闰年问题&#xff09; #define SECONDS_PER_YEAR (60 * 60 * 24 * 365)UL 2. 写一个“标准”宏MIN&#xff0c;这个宏输入两个参数并返回较小的一个。 #define MIN(A,B) ((A) < (B) ?…

【Pytorch神经网络实战案例】19 神经网络实现估计互信息的功能

1 案例说明&#xff08;实现MINE正方法的功能&#xff09; 定义两组具有不同分布的模拟数据&#xff0c;使用神经网络的MINE的方法计算两个数据分布之间的互信息 2 代码编写 2.1 代码实战&#xff1a;准备样本数据 import torch import torch.nn as nn import torch.nn.fun…

当前进程(Linux Devices Driver)

尽管内核模块不像应用程序一样顺序执行,内核做的大部分动作是代表一个特定进程的,内核代码可以引用当前进程,通过存取全局项current,它在asm/current.h中定义,它产生一个指针指向结构task_struct,在linux/sched.h定义,current指针指向当前在运行的进程,在一个系统调用执行…

爬虫实战学习笔记_6 网络请求request模块:基本请求方式+设置请求头+获取cookies+模拟登陆+会话请求+验证请求+上传文件+超时异常

1 requests requests是Python中实现HTTP请求的一种方式&#xff0c;requests是第三方模块&#xff0c;该模块在实现HTTP请求时要比urlib、urllib3模块简化很多&#xff0c;操作更加人性化。 2 基本请求方式 由于requests模块为第三方模块&#xff0c;所以在使用requests模块时…

201521123044 《Java程序设计》第01周学习总结

1.本章学习总结 你对于本章知识的学习总结 1.了解了Java的发展史。 2.学习了什么是JVM,区分JRE与JDK,下载JDK。 3.从C语言的.c 到C的 .cpp再到Java的.java&#xff0c;每种语言编译程序各有不同&#xff0c;却有相似之处。 2. 书面作业 **Q1.为什么java程序可以跨平台运行&…

爬虫实战学习笔记_7 【实战】模拟下载页面视频(模板)

import requests # 导入requests模块 import re # 导入re模块 import os # 导入系统os模块# 实现发送网络请求&#xff0c;返回响应结果 def send_request(url,headers):response requests.get(urlurl,headersheaders) # 发送网络请求if response.st…