VAE中的“变分”什么

写在前面

        VAE(Variational Autoencoder),中文译为变分自编码器。其中AE(Autoencoder)很好理解。那“变分”指的是什么呢?—其实是“变分推断”。变分推断主要用在VAE的损失函数中,那变分推断是什么,VAE的损失函数又是什么呢?下面我就来说一说!

       可以先看一下 这篇文章,介绍了VAE的代码实现。

一、通俗理解损失函数

        这篇文章已经整体介绍了VAE,这里我详细介绍一下VAE的损失函数:

\mathbf{LOSS=-E_{q(z|x)}\left [ \textit{log}p(x|z) \right ]+KL(q(z|x)||p(z))}

        每个变量的说明下面会有介绍,现在我们只关注VAE的损失函数有由两部分组成,第一部分是一个交叉熵,我们称之为“重构项”,其作用是确保训练时输入和输出间的相似性;第二部分是KL散度,叫做“KL散度项”,它其实是一个正则项,主要解决了两个AE模型的痛点,这也是VAE成功并流行的主要原因:

        1.潜在空间的结构化:AE的潜在空间往往是无规则的,这意味着编码器学到的表征可能杂乱无章,不便于后续操作。VAE通过添加KL散度项来惩罚潜在变量分布与预设先验分布(就是p(z),是一个标准高斯分布)之间的偏差,从而迫使潜在空间呈现出一定的结构,使潜在变量的分布更加合理和连贯。说人话就是:VAE可以输入标准高斯分布的采样数据,生成精美的图像。

        2.潜在空间的连续性:KL散度项要求潜在变量 z 的分布  q(z|x) 尽可能接近预设的先验分布  p(z) ,这个先验分布通常选择为标准正态分布。通过这种方式,潜在空间被组织成一个连续、平滑的多维空间,其中每一维上的值都能够自由变动而不产生剧烈变化。这种设计确保了在潜在空间中的小步长移动会导致解码结果的轻微变化,从而实现了连续性。说人话就是:VAE可以通过微调输入的采样数据,一定程度上修改生成图像的属性。这也是造成“抽卡”的原因之一。

        损失函数的这两项可以简单的这么理解,但是它其实是推导出来的,这就说来话长。感兴趣的小伙伴继续往下看。

二、边际似然

1.边际似然的定义

        VAE 是一种生成模型,生成模型的核心任务是计算在给定潜在变量 z 的情况下生成观测数据 x 的概率。我们希望模型能够生成与真实数据分布相似的新数据,这一目标可以通过边际似然 p(x) 来实现。

        其中z就是Latent;x是训练用的图像;p(x)是边际似然,也就是VAE的损失函数

        p(x)可以很好的衡量模型的生成能力。p(x)直接衡量了模型在生成数据方面的整体能力,因为它考虑了所有潜在的隐变量 z 对观测数据 x 的影响。高的p(x)意味着模型可以很好地解释数据,并且在生成新数据时表现出较强的能力。

        具体来说,如果模型的边际似然高,说明模型在所有可能的隐变量 z 下生成观测数据的概率累加起来后非常高,这意味着模型学到了数据的真实分布。

        边际似然 p(x) 表示给定模型情况下生成观测数据 x 的概率,定义为:

p(x)=\int p(x|z)p(z)dz  (1)

        其中,条件概率 p(x∣z):给定潜在变量 z 的情况下,生成观测数据 x 的概率。先验分布 p(z):潜在变量 z 的分布,反映了我们对 z 的先验知识。

2.边际似然的推导

        使用全概率公式,边际似然可以用全概率公式来定义,具体为:

p(x)=\int p(x,z)dz (2)

        这里 p(x,z)是 x 和 z 的联合分布。根据条件概率的定义,联合分布可以表示为:

p(x,z)=p(x|z)p(z) (3)

        因此,我们可以将边际似然表示为:

p(x)=\int p(x|z)p(z)dz (4)

        我们要做的就是最大化p(x),这里多说一句,最大化p(x)的目标是使得模型生成的总体概率分布 p(x) 更接近于真实数据分布。这样,模型生成的新样本就会与训练数据的分布一致。

        直观理解:假设我们在训练一个模型生成手写数字图片。如果真实的数据集中 80% 是“1”,20% 是“2”,那么一个好的生成模型应该能够生成 80% 的“1”和 20% 的“2”。而不是让p(x)趋近于1.

3.边际似然的挑战

        但是计算边际似然通常是一个复杂且困难的任务,原因包括:

        (1)高维积分:在实际的应用中,潜在变量 z 通常是高维的。例如,如果 z 是 100 维的向量,那么积分就需要在 100 维的空间上进行。这种高维积分是非常复杂的,解析解几乎不可能得到。

        (2)分布形式复杂:在生成模型中,条件分布 p(x∣z)和先验分布 p(z) 可能并不是简单的概率分布。例如,p(x∣z) 可能由一个深度神经网络参数化,计算时需要经过非线性激活函数和复杂的网络结构,这会让这个积分无法直接求解。

        (3)数值计算的困难:计算边际似然时,需要对 z 的所有可能值进行积分,也就是计算出在所有潜在表示 z 上,生成数据 x 的所有可能性。现实中,z 的范围非常大,即使是连续的,也可能取值无穷多个,直接求解所有 z 的可能性几乎是不可能的。

        举个例子,假设我们有一个简单的生成模型,其中:p(z) 是标准正态分布N(0,I)。p(x∣z) 是由一个深度神经网络生成的图像。直接计算边际似然意味着我们需要知道所有 z 的取值如何影响 x。如果 z 是 100 维向量,那么在 R^{100} 空间上对 z 进行积分(或采样)需要极大的计算资源。神经网络的非线性使得每个 p(x∣z) 的计算都很复杂,最终让直接计算积分变得不可行。

        为了解决上面的问题,让模型可以正常训练,我们引入变分推断。

三、变分推断

1.变分推断的定义

        变分推断是一种通过引入近似分布来解决无法直接计算复杂积分的问题的方法。在生成模型中,我们的目标是最大化观测数据的边际似然 p(x):

p(x)=\int p(x|z)p(z)dz (5)

        如前所述,这个积分通常很难直接计算,因此我们引入一个 近似后验分布(也叫变分分布,就是训练时模型的输出 q(z∣x),来代替无法直接求解的真实后验 p(z∣x)。变分推断的目标是让 q(z∣x) 尽可能地接近真实的 p(z∣x)。

\mathbf{p(x)=\int p(x|z)p(z)dz=\int p(z|x)\frac{p(x|z)p(z)}{q(z|x)}dz} (6)

        通过这种重写,我们引入了 q(z∣x) 作为一个权重,这样我们可以在期望的形式下进行优化。我们现在有一个可以计算的表达式:

\mathbf{\mathit{log}p(x)=\mathit{log}\int p(z|x)\frac{p(x|z)p(z)}{q(z|x)}dz} (7)

        尽管重写了表达式,计算 p(x)依然困难,因为积分本身依然难解。因此,我们应用 Jensen 不等式(log是凸函数),将对数操作从积分外移到期望内部(这里的期望是由积分转化来的):

\mathbf{\mathit{log}p(x)=\mathit{log}\int p(z|x)\frac{p(x|z)p(z)}{q(z|x)}dz\geq E_{q(z|x)} \left [ log\frac{p(x|z)p(z)}{q(z|x)} \right ] } (8)

        其中,Eq(z∣x)[⋅]表示在 q(z∣x) 分布下对 z 取期望。这一不等式说明,我们得到了一个对数边际似然的下界,即变分下界 (ELBO)。

2.变分下界ELBO

        式子(8)右边的表达式即为变分下界(Evidence Lower Bound,),通常记作 ELBO,至此我们的目标也变成了最大化ELBO,从而间接地最大化边际似然 p(x)。式子(8)可以写成:

\mathbf{ELBO=E_{q(z|x)} \left [ log\frac{p(x|z)p(z)}{q(z|x)} \right ] } (9)

        式子(9)右边可以展开成:

\mathbf{ELBO=E_{q(z|x)}\left [ \textit{log}p(x|z) \right ]+E_{q(z|x)}\left [ \textit{log}p(z) \right ]-E_{q(z|x)}\left [ \textit{log}q(z|x) \right ]} (10)

        因为KL散度公式:

\mathbf{KL(q(z|x)||p(z))=E_{q(z|x)}[log\frac{q(z|x)}{p(z)}]=E_{q(z|x)}[\textit{log}q(z|x)]-E_{q(z|x)}[\textit{log}p(z)]}(11)

        可以看到,式子(10)右边的第二项和第三项可以用KL散度代替:

\mathbf{-KL(q(z|x)||p(z))=E_{q(z|x)}[\textit{log}p(z)]-E_{q(z|x)}[\textit{log}q(z|x)]}(12)

        最终,ELBO 可以写成如下式子,这也是VAE需要优化的损失函数:

\mathbf{ELBO=E_{q(z|x)}\left [ \textit{log}p(x|z) \right ]-KL(q(z|x)||p(z))} (13)

        ELBO 公式展示了两个部分:

        重构项:表示模型生成数据的能力。

        KL 散度项:作为正则化项,控制 q(z∣x) 和 p(z) 之间的差异。最小化这个项有助于使近似后验 q(z∣x) 尽量接近先验 p(z),从而促进模型的泛化能力。p(z)一般被设置成标准高斯分布。

最大化 ELBO 的意义:

        优化目标:最大化 ELBO 实际上是希望在重构能力和潜在分布的正则化之间取得平衡。通过调整这两个部分,可以确保模型既能够良好地重构输入数据,又能够学习到有意义的潜在空间。

        间接最大化边际似然:由于 ELBO 是边际似然的下界,最大化 ELBO 也会使得边际似然 p(x) 的值增加。

        ELBO 在 VAE 中扮演着至关重要的角色,它将生成模型的目标与优化过程结合起来,使得模型能够在重构能力和潜在空间的正则化之间找到最佳平衡。通过最大化 ELBO,VAE 能够学习到有效的潜在表示,从而生成新样本。

四、代码实现中的公式

        这篇文章介绍了VAE的代码实现,其中的损失函数是ELBO的具体实现,我们来看一下,具体是怎么实现的。

        我们的目标是最大化ELBO,相当于最小化其负值,因此 VAE 的损失函数可以表示为:

\mathbf{LOSS=-E_{q(z|x)}\left [ \textit{log}p(x|z) \right ]+KL(q(z|x)||p(z))}   (14)

1.重构项

        交叉熵的定义为:

H(p,q)=-E_p[log\textbf{q}]   (15)

        如果我们将 p(x∣z) 视为模型生成 x 的概率分布(对应代码中的recon_x,即模型的输出),而将真实数据的分布视为 q(x)(对应代码中的x,即GT),则ELBO的第一项可以写成:

\mathbf{E_{q(z|x)}[\textit{log}p(x|z)]=-H(x,q(z|x))}  (16)

        最大化 ELBO 的第一项(重构项)实际上是最小化交叉熵损失,代码如下:

BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')

2.KL 散度项

        对于高斯分布 q(z|x)=N(\mu ,\sigma ^2)和标准正态分布p(x)=N(0,1),我们可以将 KL散度计算分解为以下几个步骤:

        (1)KL散度 的公式为:

\mathbf{KL\left [ q(z|x)||p(z) \right ]=\int q(z|x)log(\frac{q(z|x)}{p(z)})dz}(17)

        解释一下变量的意义:

        q(z∣x):这是给定输入 x时隐变量 z 的后验分布,通常由编码器生成。

        p(z):这是隐变量 z 的先验分布,通常是标准高斯分布 N(0,1)。

        比率 \mathbf{\frac{q(z|x)}{p(z)}}:这个比率表示后验分布与先验分布的相对关系,反映了后验分布相较于先验分布的“信息量”。

        对数项\mathbf{log(\frac{q(z|x)}{p(z)})}:量化了 q(z∣x) 相较于 p(z) 的信息增益。正值表示后验分布相对于先验分布的增加的信息,而负值则表示信息的损失。

        积分:通过对所有可能的 z进行积分,KL散度 计算了整个后验分布与先验分布之间的差异。

        (2)将q(z|x)=N(\mu ,\sigma ^2)p(z)=N(0,1)带入(17

KL\left [ q(z|x)||p(z) \right ]=\int N(\mu ,\sigma ^2)log(\frac{N(\mu ,\sigma ^2)}{N(0,1)})dz(18)

        (3)高斯分布的公式: 高斯分布的概率密度函数为:

\mathbf{N(z;\mu ,\sigma ^2)=\frac{1}{\sqrt{2\pi \sigma ^2}}exp[-\frac{(z-\mu )^2}{2\sigma ^2}]} (19)

        而标准正态分布为:

\mathbf{N(z;0,1)=\frac{1}{\sqrt{2\pi }}exp(-\frac{z^2}{2})}    (20)

        (4)计算 KL散度: 将这些代入 K散度的公式中,最终可以简化得到:

KL(q(z|x)||p(z))=-\frac{1}{2}(1+log(\sigma ^2)-\mu ^2-\sigma ^2)  (21)

        (5)简化: 进一步简化后,得到:

KL(q(z|x)||p(z))=-0.5(log(\sigma ^2)+1-\mu ^2-\sigma ^2)  (22)

        (6)用对数方差表示: 在实现中,通常使用对数方差 log(\sigma ^2) 来计算,这样可以避免数值稳定性问题,最终得到的 KL散度公式是:

KL(q(z|x)||p(z))=-0.5(1+log(\sigma ^2)-\mu ^2-\sigma ^2)(23)

        KL散度代码实现:在代码实现的时候编码器的输出其实是均值mu和对数方差log_var

KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

        其中log_var 是对数方差,使用对数方差的形式可以保证数值稳定性、避免负值以及计算便利性,这种做法在许多深度学习模型中都得到了广泛应用,尤其是在处理概率分布时。;mu 是均值;\sigma ^2=exp(log\sigma ^2)

五、总结

        1.VAE中的“变分”指的是“变分推断”;

        2.VAE的损失函数值最大化边际似然;

        3.最大化边际似然几乎做不到,所以使用变分推断来简化计算;

        4.使用变分推断后,训练通过最大化ELBO实现;

        5.ELBO有两项:重构项和KL散度项。重构项的作用是确保训练时输入和输出间的相似性,就是传统的损失函数常用的东西;KL散度项是一个正则项,能确保潜在空间的结构化和连续性。

        VAE就介绍到这,关注不迷路(*^__^*) 

  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/883511.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MobileNetV2实现实时口罩检测tensorflow

项目源码获取方式见文章末尾! 回复暗号:13,免费获取600多个深度学习项目资料,快来加入社群一起学习吧。 **《------往期经典推荐------》**项目名称 1.【Informer模型复现项目实战】 2.【卫星图像道路检测DeepLabV3Plus模型】 3.【…

著名AI人工智能的未来应用讲师培训师唐兴通数字经济大数据工业4.0数字化转型AIGC大模型培训讲师

《大数据与人工智能的未来应用》培训课程大纲 一、培训内容简介 本课程旨在帮助学员深度理解大数据与人工智能(AI)如何为未来商业和行业带来革命性变革。课程紧贴前沿科技,从数据采集、分析到AI应用开发,全方位解析大数据和AI如…

51c~目标检测~合集2

我自己的原文哦~ https://blog.51cto.com/whaosoft/12377509 一、总结 这里概述了基于深度学习的目标检测器的最新发展。同时,还提供了目标检测任务的基准数据集和评估指标的简要概述,以及在识别任务中使用的一些高性能基础架构,其还涵盖了…

Docker | images镜像的常用命令总结

命令总结 1. 帮助启动类命令基本命令systemctl status dockerdocker infodocker --help 2. 镜像命令docker images删除镜像出现错误 docker searchdocker pull xxx[:TAG]docker images -adocker images -qdocker system dfdocker rmi -f xxxxxdocker rmi -f $(docker images -q…

Qt 学习第十四天:线程与多线程

1024程序员快乐,如果这博客让你学习到了知识,请给我一个免费的赞❤️ 父子线程演示 一、创建界面文件 LCDnumber 二、创建mythread类,继承QObject 三、在MyThread.h文件做修改,并且加上函数声明 引入头文件,改变继…

实战:大数据冷热分析

实战:大数据冷热分析 冷热分析(Hot and Cold Data Analysis)的目的主要在于优化存储系统的性能和成本。通过识别并区分访问频率和存储需求不同的数据,可以采取适当的存储策略,进而提高系统的效率和用户体验。终极目的…

javaScript整数反转

function _reverse(number) { // 补全代码 return (number ).split().reverse().join(); } number :首先,将数字 number 转换为字符串。在 JavaScript 中,当你将一个数字与一个字符串相加时,JavaScript 会自动将数字转换为字符串…

PyTorch中如何进行向量微分、矩阵微分、计算雅各比行列式

文章目录 摘要Abstract 一、计算雅各比行列式二、向量微分三、矩阵微分总结 摘要 本文介绍了在PyTorch中进行向量微分、矩阵微分以及计算雅各比行列式的方法。通过对自动微分(Autograd)功能的讲解,展示了如何轻松实现复杂的数学运算&#xf…

代码编辑组件

代码编辑组件 文章说明核心代码运行演示源码下载 文章说明 拖了很久,总算是自己写了一个简单的代码编辑组件,虽然还有不少的bug,真的很难写,在写的过程中感觉自己的前端技术根本不够用,好像总是方案不够好;…

Flux 开源替代,他来了——Liberflux

LibreFLUX 是 FLUX.1-schnell 的 Apache 2.0 版本,它提供完整的 T5 上下文长度,使用注意力屏蔽,恢复了无分类器引导,并完全删除了 FLUX 美学微调/DPO 的大部分内容。 这意味着它比基本通量要难看得多,但它有可能更容易…

数据结构与算法汇总整理篇——数组与字符串双指针与滑动窗口的联系学习及框架思考

数组 数组精髓:循环不变量原则 数组是存放在连续内存空间上的相同类型数据的集合,通过索引(下标)访问元素,索引从0开始 随机访问快(O(1)时间复杂度);插入删除慢(需要移动元素);长度固定(部分语言中可动态调整) 其存…

解决电脑突然没有声音

问题描述:电脑突然没有声音了,最近没有怎么动过系统,没有安装或者卸载过什么软件,也没有安装或者卸载过驱动程序,怎么就没有声音了呢? 问题分析:仔细观察,虽然音量按钮那边看不到什…

索引的使用以及使用索引优化sql

索引就是一种快速查询和检索数据的数据结构,mysql中的索引结构有:B树和Hash。 索引的作用就相当于目录的作用,我么只需先去目录里面查找字的位置,然后回家诶翻到那一页就行了,这样查找非常快, 一、索引的使…

[Linux网络编程]06-I/O多路复用策略---select,poll分析解释,优缺点,实现IO多路复用服务器

一.I/O多路复用 I/O多路复用是一种用于提高系统性能的 I/O 处理机制。 它允许一个进程(或线程)同时监视多个文件描述符(可以是套接字、管道、终端设备等),等待这些文件描述符中出现读、写或异常状态。一旦有满足条件的…

ts:类的创建(class)

ts:类的创建(class) 一、主要内容说明二、例子class类的创建1.源码1 (class类的创建)2.源码1的运行效果 三、结语四、定位日期 一、主要内容说明 class创建类里主要有三部分组成,变量的声明,构…

ts:数组的常用方法(filter)

ts:数组的常用方法(filter) 一、主要内容说明二、例子filter方法(过滤)1.源码1 (push方法)2.源码1运行效果 三、结语四、定位日期 一、主要内容说明 ts中数组的filter方法,是筛选数…

【STM32】单片机ADC原理详解及应用编程

本篇文章主要详细讲述单片机的ADC原理和编程应用,希望我的分享对你有所帮助! 目录 一、STM32ADC概述 1、ADC(Analog-to-Digital Converter,模数转换器) 2、STM32工作原理 二、STM32ADC编程实战 (一&am…

C++STL之stack

1.stack的使用 函数说明 接口说明 stack() 构造空的栈 empty() 检测 stack 是否为空 size() 返回 stack 中元素的个数 top() 返回栈顶元素的引用 push() 将元素 val 压入 stack 中 pop() 将 stack 中尾部的元素弹出 2.stack的模拟实现 #include<vector> namespace abc { …

LeetCode 热题 100之普通数组

1.最大子数组和 思路分析&#xff1a;这个问题可以通过动态规划来解决&#xff0c;我们可以使用Kadane’s Algorithm&#xff08;卡登算法&#xff09;来找到具有最大和的连续子数组。 Kadane’s Algorithm 的核心思想是利用一个变量存储当前的累加和 currentSum&#xff0c;并…

MATLAB生物细胞瞬态滞后随机建模定量分析

&#x1f3af;要点 基于随机动态行为受化学主方程控制&#xff0c;定量分析单细胞瞬态效应。确定性常微分方程描述双稳态和滞后现象。通过随机性偏微分方程描述出暂时性滞后会逐渐达到平稳状态&#xff0c;并利用熵方法或截断方法计算平衡收敛速度的估计值。随机定量分析模型使…