初学机器学习:直观解读KL散度的数学概念

初学机器学习:直观解读KL散度的数学概念

转自:初学机器学习:直观解读KL散度的数学概念

译自:https://towardsdatascience.com/light-on-math-machine-learning-intuitive-guide-to-understanding-kl-divergence-2b382ca2b2a8

解读自:https://www.countbayesie.com/blog/2017/5/9/kullback-leibler-divergence-explained

代码:https://github.com/thushv89/exercises_thushv_dot_com/blob/master/kl_divergence.ipynb

作者:Thushan Ganegedara,机器之心编译。

本文修正了一些错误,优化了排版。

基础概念

首先让我们确立一些基本规则。我们将会定义一些我们需要了解的概念。

分布(distribution)

分布可能指代不同的东西,比如数据分布或概率分布。我们这里所涉及的是概率分布。假设你在一张纸上画了两根轴(即 XXXYYY),我可以将一个分布想成是落在这两根轴之间的一条线。其中 XXX 表示你有兴趣获取概率的不同值。YYY 表示观察 XXX 轴上的值时所得到的概率。即 y=p(x)y=p(x)y=p(x)。下图即是某个分布的可视化。

在这里插入图片描述

这是一个连续概率分布。比如,我们可以将 XXX 轴看作是人的身高,YYY 轴是找到对应身高的人的概率。

如果你想得到离散的概率分布,你可以将这条线分成固定长度的片段并以某种方式将这些片段水平化。然后就能根据这条线的每个片段创建边缘互相连接的矩形。这就能得到一个离散概率分布。

事件(event)

对于离散概率分布而言,事件是指观察到 XXX 取某个值(比如 X=1X=1X=1)的情况。我们将事件 X=1X=1X=1 的概率记为 P(X=1)P(X=1)P(X=1)。在连续空间中,你可以将其看作是一个取值范围(比如 0.95<X<1.050.95<X<1.050.95<X<1.05)。注意,事件的定义并不局限于在 XXX 轴上取值。但是我们后面只会考虑这种情况。

回到 KL 散度

从这里开始,我将使用来自这篇博文的示例:https://www.countbayesie.com/blog/2017/5/9/kullback-leibler-divergence-explained。这是一篇很好的 KL 散度介绍文章,但我觉得其中某些复杂的解释可以更详细的阐述。好了,让我们继续吧。

我们想要解决的问题

上述博文中所解决的核心问题是这样的:假设我们是一组正在广袤无垠的太空中进行研究的科学家。我们发现了一些太空蠕虫,这些太空蠕虫的牙齿数量各不相同。现在我们需要将这些信息发回地球。但从太空向地球发送信息的成本很高,所以我们需要用尽量少的数据表达这些信息。我们有个好方法:我们不发送单个数值,而是绘制一张图表,其中 XXX 轴表示所观察到的不同牙齿数量 (0,1,2…)(0,1,2…)(0,1,2)YYY 轴是看到的太空蠕虫具有 xxx 颗牙齿的概率(即具有 xxx 颗牙齿的蠕虫数量/蠕虫总数量)。这样,我们就将观察结果转换成了分布。

发送分布比发送每只蠕虫的信息更高效。但我们还能进一步压缩数据大小。我们可以用一个已知的分布来表示这个分布(比如均匀分布、二项分布、正态分布)。举个例子,假如我们用均匀分布来表示真实分布,我们只需要发送两段数据就能恢复真实数据;均匀概率和蠕虫数量。但我们怎样才能知道哪种分布能更好地解释真实分布呢?这就是 KL 散度的用武之地。

直观解释:KL 散度是一种衡量两个分布(比如两条线)之间的匹配程度的方法。

让我们对示例进行一点修改

为了能够检查数值的正确性,让我们将概率值修改成对人类更友好的值(相比于上述博文中的值)。我们进行如下假设:假设有 100 只蠕虫,各种牙齿数的蠕虫的数量统计结果如下。

牙齿颗数 iii012345678910
蠕虫数2351416151281087
概率 pip_ipi0.020.030.050.140.160.150.120.080.100.080.07

快速做一次完整性检查!确保蠕虫总数为 100,且概率总和为 1.0.

  • 蠕虫总数 = 2+3+5+14+16+15+12+8+10+8+7 = 100
  • 概率总和 = 0.02+0.03+0.05+0.14+0.16+0.15+0.12+0.08+0.1+0.08+0.07 = 1.0

可视化结果为:

在这里插入图片描述

尝试 1:使用均匀分布建模

我们首先使用均匀分布来建模该分布。均匀分布只有一个参数:均匀概率;即给定事件发生的概率。
puniform=1totalevents=111=0.0909p_{uniform}=\frac{1}{total\ events}=\frac{1}{11}=0.0909 puniform=total events1=111=0.0909
均匀分布和我们的真实分布对比:

在这里插入图片描述

先不讨论这个结果,我们再用另一种分布来建模真实分布。

尝试 2:使用二项分布建模

你可能计算过抛硬币正面或背面向上的概率,这就是一种二项分布概率。我们可以将同样的概念延展到我们的问题上。对于有两个可能输出的硬币,我们假设硬币正面向上的概率为 ppp,并且进行了 nnn 次尝试,那么其中成功 kkk 次的概率为:
P(X=k)=(nk)pk(1−p)n−kP(X=k)=\begin{pmatrix} n \\ k \end{pmatrix} p^k(1-p)^{n-k} P(X=k)=(nk)pk(1p)nk
公式解读

这里说明一下二项分布中每一项的含义。第一项是 pkp^kpk。我们想成功 kkk 次,其中单次成功的概率为 ppp;那么成功 kkk 次的概率为 pkp^kpk。另外要记得我们进行了 nnn 次尝试。因此,其中失败的次数为 n−kn-knk,对应失败的概率为 (1−p)(1-p)(1p)。所以成功 k 次的概率即为联合概率 pk(1−p)n−kp^k(1-p)^{n-k}pk(1p)nk。到此还未结束。在 nnn 次尝试中,kkk 次成功会有不同的排列方式。在数量为 nnn 的空间中 kkk 个元素的不同排列数量为
(nk)=n!k!(n−k)!\begin{pmatrix} n \\ k \end{pmatrix} =\frac{n!}{k!(n-k)!} (nk)=k!(nk)!n!
将所有这些项相乘就得到了成功 kkk 次的二项概率。

二项分布的均值和方差

我们还可以定义二项分布的均值和方差:mean=npmean=npmean=npvar=np(1−p)var=np(1-p)var=np(1p)

均值是什么意思?均值是指你进行 nnn 次尝试时的期望(平均)成功次数。如果每次尝试成功的概率为 ppp,那么可以说 nnn 次尝试的成功次数为 npnpnp

方差又是什么意思?它表示真实的成功尝试次数偏离均值的程度。为了理解方差,让我们假设 n=1n=1n=1,那么等式就成了「方差= p(1−p)p(1-p)p(1p)」。那么当 p=0.5p=0.5p=0.5 时(正面和背面向上的概率一样),方差最大;当 p=1p=1p=1p=0p=0p=0 时(只能得到正面或背面中的一种),方差最小。

回来继续建模

现在我们已经理解了二项分布,接下来回到我们之前的问题。首先让我们计算蠕虫的牙齿的期望值:
∑i=011=0×p0+1×p1+…10×p10=5.44\sum_{i=0}^{11}=0\times p_0+1\times p_1+\dots 10\times p_{10}=5.44 i=011=0×p0+1×p1+10×p10=5.44
有了均值,我们可以计算 ppp 的值:
mean=np5.44=10pp=0.544mean=np\\ 5.44=10p\\ p=0.544 mean=np5.44=10pp=0.544
注意,这里的 nnn 是指在蠕虫中观察到的最大牙齿数。你可能会问我们为什么不把蠕虫总数(即 100)或总事件数(即 11)设为 nnn。我们很快就将看到原因。有了这些数据,我们可以按如下方式定义任意牙齿数的概率。

鉴于牙齿数的取值最大为 10,那么看见 kkk 颗牙齿的概率是多少(这里看见一颗牙齿即为一次成功尝试)?

从抛硬币的角度看,这就类似于:

假设我抛 10 次硬币,观察到 kkk 次正面向上的概率是多少?

从形式上讲,我们可以计算所有不同 kkk 值的概率 pkbip_k^{bi}pkbi。其中 kkk 是我们希望观察到的牙齿数量,pkbip_k^{bi}pkbi 是第 k 个牙齿数量位置(即 0 颗牙齿、1 颗牙齿……)的二项概率。所以,计算结果如下:
p0bi=(10!/(0!10!))0.5440(1–0.544)10=0.0004p1bi=(10!/(1!9!))0.5441(1–0.544)9=0.0046p2bi=(10!/(2!8!))0.5442(1–0.544)8=0.0249…p9bi=(10!/(9!1!))0.5449(1–0.544)1=0.0190p10bi=(10!/(10!0!))0.54410(1–0.544)0=0.0023p0^{bi} = (10!/(0!10!)) 0.544⁰ (1–0.544)^{10} = 0.0004\\ p1^{bi} = (10!/(1!9!)) 0.544¹ (1–0.544)⁹ = 0.0046\\ p2^{bi} = (10!/(2!8!)) 0.544² (1–0.544)⁸ = 0.0249\\ …\\ p9^{bi} = (10!/(9!1!)) 0.544⁹ (1–0.544)¹ = 0.0190\\ p10^{bi} = (10!/(10!0!)) 0.544^{10} (1–0.544)⁰ = 0.0023\\ p0bi=(10!/(0!10!))0.5440(1–0.544)10=0.0004p1bi=(10!/(1!9!))0.5441(1–0.544)9=0.0046p2bi=(10!/(2!8!))0.5442(1–0.544)8=0.0249p9bi=(10!/(9!1!))0.5449(1–0.544)1=0.0190p10bi=(10!/(10!0!))0.54410(1–0.544)0=0.0023
我们的真实分布和二项分布的比较如下:

在这里插入图片描述

总结已有情况

现在回头看看我们已经完成的工作。首先,我们理解了我们想要解决的问题。我们的问题是将特定类型的太空蠕虫的牙齿数据统计用尽量小的数据量发回地球。为此,我们想到用某个已知分布来表示真实的蠕虫统计数据,这样我们就可以只发送该分布的参数,而无需发送真实统计数据。我们检查了两种类型的分布,得到了以下结果。

  • 均匀分布——概率为 0.0909
  • 二项分布——n=10n=10n=10p=0.544p=0.544p=0.544kkk 取值在 0 到 10 之间。

让我们在同一个地方可视化这三个分布:

在这里插入图片描述

我们如何定量地确定哪个分布更好?

经过这些计算之后,我们需要一种衡量每个近似分布与真实分布之间匹配程度的方法。这很重要,这样当我们发送信息时,我们才无需担忧「我是否选择对了?」毕竟太空蠕虫关乎我们每个人的生命。

这就是 KL 散度的用武之地。KL 散度在形式上定义如下:
DKL(p∣∣q)=∑i=1Np(xi)log⁡p(xi)q(xi)D_{KL}(p||q)=\sum_{i=1}^Np(x_i)\log\frac{p(x_i)}{q(x_i)} DKL(p∣∣q)=i=1Np(xi)logq(xi)p(xi)
其中 q(x)q(x)q(x) 是近似分布,p(x)p(x)p(x) 是我们想要用 q(x)q(x)q(x) 匹配的真实分布。直观地说,这衡量的是给定任意分布偏离真实分布的程度。如果两个分布完全匹配,那么DKL(p∣∣q)=0D_{KL}(p||q)=0DKL(p∣∣q)=0 ,否则它的取值应该是在 0 到 ∞\infty 之间。KL 散度越小,真实分布与近似分布之间的匹配就越好。

KL 散度的直观解释

让我们看看 KL 散度各个部分的含义。首先看看
logp(xi)q(xi)log\frac{p(x_i)}{q(x_i)} logq(xi)p(xi)
项。如果 q(xi)q(x_i)q(xi) 大于 p(xi)p(x_i)p(xi) 会怎样呢?此时这个项的值为负,因为小于 1 的值的对数为负。另一方面,如果 q(xi)q(x_i)q(xi) 总是小于 p(xi)p(x_i)p(xi),那么该项的值为正。如果 p(xi)=q(xi)p(x_i)=q(x_i)p(xi)=q(xi) 则该项的值为 0。然后,为了使这个值为期望值,你要用 p(xi)p(x_i)p(xi) 来给这个对数项加权。也就是说,p(xi)p(x_i)p(xi) 有更高概率的匹配区域比低 p(xi)p(x_i)p(xi) 概率的匹配区域更加重要。

直观而言,优先正确匹配近似分布中真正高可能性的事件是有实际价值的。从数学上讲,这能让你自动忽略落在真实分布的支集(支集(support)是指分布使用的 XXX 轴的全长度)之外的分布区域。另外,这还能避免计算 log(0)log(0)log(0) 的情况——如果你试图计算落在真实分布的支集之外的任意区域的这个对数项,就可能出现这种情况。

计算 KL 散度

我们计算一下上面两个近似分布与真实分布之间的 KL 散度。首先来看均匀分布:
D(True∣∣Uniform)=0.02log⁡(0.02/0.0909)+⋯+0.07log⁡(0.07/0.0909)=0.136D(True||Uniform)=0.02\log(0.02/0.0909)+\dots+0.07\log(0.07/0.0909)=0.136 D(True∣∣Uniform)=0.02log(0.02/0.0909)++0.07log(0.07/0.0909)=0.136
再看看二项分布:
D(True∣∣Binomila)=0.02log⁡(0.02/0.0004)+⋯+0.07log⁡(0.07/0.0023)=0.427D(True||Binomila)=0.02\log(0.02/0.0004)+\dots+0.07\log(0.07/0.0023)=0.427 D(True∣∣Binomila)=0.02log(0.02/0.0004)++0.07log(0.07/0.0023)=0.427

玩一玩 KL 散度

现在,我们来玩一玩 KL 散度。首先我们会先看看当二元分布的成功概率变化时 KL 散度的变化情况。不幸的是,我们不能使用均匀分布做同样的事,因为 nnn 固定时均匀分布的概率不会变化。

在这里插入图片描述

可以看到,当我们远离我们的选择(红点)时,KL 散度会快速增大。实际上,如果你显示输出我们的选择周围小 Δ\DeltaΔ 数量的 KL 散度值,你会看到我们选择的成功概率的 KL 散度最小。

现在让我们看看 DKL(P∣∣Q)D_{KL}(P||Q)DKL(P∣∣Q)DKL(Q∣∣P)D_{KL}(Q||P)DKL(Q∣∣P) 的行为方式。如下图所示:

在这里插入图片描述

看起来有一个区域中的 DKL(P∣∣Q)D_{KL}(P||Q)DKL(P∣∣Q)DKL(Q∣∣P)D_{KL}(Q||P)DKL(Q∣∣P) 之间有最小的距离。让我们绘出两条线之间的差异(虚线),并且放大我们的概率选择所在的区域。

在这里插入图片描述

有最低差异的区域(但并不是最低差异的区域)。但这仍然是一个很有意思的发现。我不确定出现这种情况的原因是什么。如果有人知道,欢迎讨论。

结论

现在我们有些可靠的结果了。尽管均匀分布看起来很简单且信息不多而二项分布带有更有差别的信息,但实际上均匀分布与真实分布之间的匹配程度比二项分布的匹配程度更高。说老实话,这个结果实际上让我有点惊讶。因为我之前预计二项分布能更好地建模这个真实分布。因此,这个实验也能告诉我们:不要只相信自己的直觉!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532380.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL应用安装_mysql安装和应用

1.下载mysql安装包2.安装mysql&#xff0c;自定义->修改路径3.配置mysql&#xff0c;选择自定义->server模式->500访问量->勾选控制台->设置gbk->设置密码和允许root用户远程登录等等。以管理员权限&#xff0c;在控制台输入&#xff1a;net start MySQL, 启…

mysql 商品规格表_商品规格分析

产品表每次更新商品都会变动的&#xff0c;ID不能用&#xff0c;可是购物车还是用了&#xff0c;这就导致每次保存商品&#xff0c;哪怕什么都没有改动&#xff0c;也会导致用户的购物车失效。~~~其实可以考虑不是每次更新商品就除所有的SKU&#xff0c;毕竟有时什么都没修改呢…

huggingface NLP工具包教程1:Transformers模型

huggingface NLP工具包教程1&#xff1a;Transformers模型 原文&#xff1a;TRANSFORMER MODELS 本课程会通过 Hugging Face 生态系统中的一些工具包&#xff0c;包括 Transformers&#xff0c; Datasets&#xff0c; Tokenizers&#xff0c; Accelerate 和 Hugging Face Hub。…

隐马尔可夫模型HMM推导

隐马尔可夫模型HMM推导 机器学习-白板推导系列(十四)-隐马尔可夫模型HMM&#xff08;Hidden Markov Model&#xff09; 课程笔记 背景介绍 介绍一下频率派和贝叶斯派两大流派发展出的建模方式。 频率派 频率派逐渐发展成了统计机器学习&#xff0c;该流派通常将任务建模为一…

使用randomaccessfile类将一个文本文件中的内容逆序输出_Java 中比较常用的知识点:I/O 总结...

Java中I/O操作主要是指使用Java进行输入&#xff0c;输出操作. Java所有的I/O机制都是基于数据流进行输入输出&#xff0c;这些数据流表示了字符或者字节数据的流动序列。数据流是一串连续不断的数据的集合&#xff0c;就象水管里的水流&#xff0c;在水管的一端一点一点地供水…

huggingface NLP工具包教程2:使用Transformers

huggingface NLP工具包教程2&#xff1a;使用Transformers 引言 Transformer 模型通常非常大&#xff0c;由于有数百万到数百亿个参数&#xff0c;训练和部署这些模型是一项复杂的任务。此外&#xff0c;由于几乎每天都有新模型发布&#xff0c;而且每个模型都有自己的实现&a…

mysql精讲_Mysql 索引精讲

开门见山&#xff0c;直接上图&#xff0c;下面的思维导图即是现在要讲的内容&#xff0c;可以先有个印象&#xff5e;常见索引类型(实现层面)索引种类(应用层面)聚簇索引与非聚簇索引覆盖索引最佳索引使用策略1.常见索引类型(实现层面)首先不谈Mysql怎么实现索引的,先马后炮一…

RT-Smart 官方 ARM 32 平台 musl gcc 工具链下载

前言 RT-Smart 的开发离不开 musl gcc 工具链&#xff0c;用于编译 RT-Smart 内核与用户态应用程序 RT-Smart musl gcc 工具链代码当前未开源&#xff0c;但可以下载到 RT-Thread 官方编译好的最新的 musl gcc 工具链 ARM 32位 平台 比如 RT-Smart 最好用的 ARM32 位 qemu 平…

OpenAI Whisper论文笔记

OpenAI Whisper论文笔记 OpenAI 收集了 68 万小时的有标签的语音数据&#xff0c;通过多任务、多语言的方式训练了一个 seq2seq &#xff08;语音到文本&#xff09;的 Transformer 模型&#xff0c;自动语音识别&#xff08;ASR&#xff09;能力达到商用水准。本文为李沐老师…

【经典简读】知识蒸馏(Knowledge Distillation) 经典之作

【经典简读】知识蒸馏(Knowledge Distillation) 经典之作 转自&#xff1a;【经典简读】知识蒸馏(Knowledge Distillation) 经典之作 作者&#xff1a;潘小小 知识蒸馏是一种模型压缩方法&#xff0c;是一种基于“教师-学生网络思想”的训练方法&#xff0c;由于其简单&#xf…

深度学习三大谜团:集成、知识蒸馏和自蒸馏

深度学习三大谜团&#xff1a;集成、知识蒸馏和自蒸馏 转自&#xff1a;https://mp.weixin.qq.com/s/DdgjJ-j6jHHleGtq8DlNSA 原文&#xff08;英&#xff09;&#xff1a;https://www.microsoft.com/en-us/research/blog/three-mysteries-in-deep-learning-ensemble-knowledge…

在墙上找垂直线_墙上如何快速找水平线

在装修房子的时候&#xff0c;墙面的面积一般都很大&#xff0c;所以在施工的时候要找准水平线很重要&#xff0c;那么一般施工人员是如何在墙上快速找水平线的呢&#xff1f;今天小编就来告诉大家几种找水平线的方法。一、如何快速找水平线1、用一根透明的软管&#xff0c;长度…

Vision Transformer(ViT)PyTorch代码全解析(附图解)

Vision Transformer&#xff08;ViT&#xff09;PyTorch代码全解析 最近CV领域的Vision Transformer将在NLP领域的Transormer结果借鉴过来&#xff0c;屠杀了各大CV榜单。本文将根据最原始的Vision Transformer论文&#xff0c;及其PyTorch实现&#xff0c;将整个ViT的代码做一…

Linux下的ELF文件、链接、加载与库(含大量图文解析及例程)

Linux下的ELF文件、链接、加载与库 链接是将将各种代码和数据片段收集并组合为一个单一文件的过程&#xff0c;这个文件可以被加载到内存并执行。链接可以执行与编译时&#xff0c;也就是在源代码被翻译成机器代码时&#xff1b;也可以执行于加载时&#xff0c;也就是被加载器加…

java 按钮 监听_Button的四种监听方式

Button按钮设置点击的四种监听方式注&#xff1a;加粗放大的都是改变的代码1.使用匿名内部类的形式进行设置使用匿名内部类的形式&#xff0c;直接将需要设置的onClickListener接口对象初始化&#xff0c;内部的onClick方法会在按钮被点击的时候执行第一个活动的java代码&#…

linux查看java虚拟机内存_深入理解java虚拟机(linux与jvm内存关系)

本文转载自美团技术团队发表的同名文章https://tech.meituan.com/linux-jvm-memory.html一, linux与进程内存模型要理解jvm最重要的一点是要知道jvm只是linux的一个进程,把jvm的视野放大,就能很好的理解JVM细分的一些概念下图给出了硬件系统进程三个层面内存之间的关系.从硬件上…

java function void_Java8中你可能不知道的一些地方之函数式接口实战

什么时候可以使用 Lambda&#xff1f;通常 Lambda 表达式是用在函数式接口上使用的。从 Java8 开始引入了函数式接口&#xff0c;其说明比较简单&#xff1a;函数式接口(Functional Interface)就是一个有且仅有一个抽象方法&#xff0c;但是可以有多个非抽象方法的接口。 java8…

java jvm内存地址_JVM--Java内存区域

Java虚拟机在执行Java程序的过程中会把它所管理的内存划分为若干个不同的数据区域&#xff0c;如图&#xff1a;1.程序计数器可以看作是当前线程所执行的字节码的行号指示器&#xff0c;通俗的讲就是用来指示执行哪条指令的。为了线程切换后能恢复到正确的执行位置Java多线程是…

java情人节_情人节写给女朋友Java Swing代码程序

马上又要到情人节了&#xff0c;再不解风情的人也得向女友表示表示。作为一个程序员&#xff0c;示爱的时候自然也要用我们自己的方式。这里给大家上传一段我在今年情人节的时候写给女朋友的一段简单的Java Swing代码&#xff0c;主要定义了一个对话框&#xff0c;让女友选择是…

java web filter链_filter过滤链:Filter链是如何构建的?

在一个Web应用程序中可以注册多个Filter程序&#xff0c;每个Filter程序都可以针对某一个URL进行拦截。如果多个Filter程序都对同一个URL进行拦截&#xff0c;那么这些Filter就会组成一个Filter链(也叫过滤器链)。Filter链用FilterChain对象来表示&#xff0c;FilterChain对象中…