gumbel-softmax如何实现离散分布可微+torch代码+原理+证明

文章目录

    • 背景
    • 方法通俗理解
    • 什么是重参数化
    • gumbel-softmax
    • 为什么是gumbel
    • torch实现
    • 思考

在这里插入图片描述

背景

这里举一个简单的情况,当前我们有p1, p2, p3三个概率,我们需要得到最优的一个即max(p1, p2, p3),例如当前p3 = max(p1, p2, p3),那么理想输出应当为[0, 0, 1],然后应用于下游的优化目标,这种场景在搜索等场景经常出现。
如果暴力的进行clip或者mask操作转化为独热向量的话会导致在梯度反向传播的时候无法更新上游网络。因为p1和p2对应的梯度一定为0。

方法通俗理解

针对上述情况,采用重参数化的思路可以解决。
即然每次前向传播理想情况下是0-1独热向量向量,但同时能保证[p1, p2, p3]这个分布能被根据概率被更新。于是采用了一种重参数化的方法,即从每次都从一个分布中采样一个u,这个u属于一个均匀分布,从这个均匀分布通过转换变成[p1, p2, p3]这个分布。这样就能即保证梯度可以反向传播,同时根据每次采样来实现按照[p1, p2, p3]这个分布更新,而不是每次只能更新最大的一个。
而这种方法就是重参数化。

什么是重参数化

Reparameterization,重参数化,这是一个方法论,是一种技巧。
我们首先可以抽象出来它的数学表达形式:
L θ = E z ~ p θ ′ ( z ) ( f θ ( z ) ) \begin{equation} L_{\theta} = E_{z~p_{\theta'}(z)}(f_{\theta}(z)) \end{equation} Lθ=Ezpθ(z)(fθ(z))
注意:在有些时候 θ ′ ∈ θ \theta' \in \theta θθ或者 θ ′ = θ \theta' = \theta θ=θ
如何理解:这里我们的优化目标是 L θ L_{\theta} Lθ,其中 f θ ( ) f_{\theta}() fθ()一般是我们的模型,而计算 z z z是从分布 p θ ′ ( z ) p_{\theta'}(z) pθ(z)中采样得到的。但是问题是我们不能把一个分布输入到 f θ ( ) f_{\theta}() fθ()中去,只能从选择一个特定的 z z z,但是这样就没法更新 θ ′ \theta' θ
综上,重参数化就是从给定分布中采样得到一个 z z z,同时保证了梯度可以更新 θ ′ \theta' θ,这种保证采样分布和给定分布无损转换的采样策略叫做重参数化。(个人理解,欢迎大佬指正)

由于我们现在解决的是gumbel-softmax问题,所以只关注当 p θ ′ ( z ) p_{\theta'}(z) pθ(z)是离散的情况下,此时:
L θ = E z ~ p θ ′ ( z ) ( f θ ( z ) ) = ∑ p θ ′ ( z ) ( f θ ( z ) ) \begin{equation} L_{\theta} = E_{z~p_{\theta'}(z)}(f_{\theta}(z)) = \sum p_{\theta'}(z)(f_{\theta}(z)) \end{equation} Lθ=Ezpθ(z)(fθ(z))=pθ(z)(fθ(z))
这也就是gumbel-softmax要解决的数学形式。

gumbel-softmax

gumbel-softmax给出的采样方案,叫做gumbel max:
从原来的 a r g m a x i ( [ p 1 , p 2 , . . . ] ) argmax_i([p1, p2, ...]) argmaxi([p1,p2,...]) a r g m a x i ( l o g ( p i ) − l o g ( − l o g ( ϵ i ) ) ) , ϵ i ∈ U [ 0 , 1 ] argmax_i(log(p_i)-log(-log(\epsilon_i))), \epsilon_i \in U[0, 1] argmaxi(log(pi)log(log(ϵi))),ϵiU[0,1]

也就是先算出各个概率的对数 l o g ( p i ) log(p_i) log(pi),然后从均匀分布 U U U中采样随机数 ϵ i \epsilon_i ϵi,把 − l o g ( − l o g ϵ i ) −log(−log\epsilon_i) log(logϵi)加到 l o g ( p i ) log(p_i) log(pi),然后再进行后续操作。
这里可以理解为通过 ϵ \epsilon ϵ的采样将随机性增加。有的人会疑问,为什么格式变得这么复杂,各种算log,这是为什么?这个就涉及到下一节了,具体原因就是来保证数学的变换正确性,即我增加了随机性,但是保证分布的期望仍然是和原始[p1, p2, p3]是一致的,这个证明在下一节,是有比较严谨的数学证明的。

但是这里还有一个问题,就是argmax或者说onehot操作仍然会丢失梯度,所以采用带超参 τ \tau τ的softmax,来进行平滑:
s o f t m a x ( ( l o g ( p i ) − l o g ( − l o g ( ϵ i ) ) ) / τ ) \begin{equation} softmax((log(p_i)-log(-log(\epsilon_i)))/\tau) \end{equation} softmax((log(pi)log(log(ϵi)))/τ)
其中 τ \tau τ也被称为退火参数,用来调整平滑的程度: τ \tau τ越小,越接近onhot向量。

这里也解释清楚了所谓gumbel-softmax是通过gumbel max实现重参数化,通过带退火参数的softmax实现梯度反向传递。

为什么是gumbel

这就涉及到一个gumbel max的证明了。
目标是证明针对 l o g ( p i ) − l o g ( − l o g ( ϵ i ) ) log(p_i)-log(-log(\epsilon_i)) log(pi)log(log(ϵi)),当 a r g m a x i ( l o g ( p i ) − l o g ( − l o g ( ϵ i ) ) ) = 1 argmax_i(log(p_i)-log(-log(\epsilon_i))) = 1 argmaxi(log(pi)log(log(ϵi)))=1时,其概率为 p 1 p_1 p1

假设:
l o g ( p 1 ) − l o g ( − l o g ( ϵ 1 ) ) log(p_1)-log(-log(\epsilon_1)) log(p1)log(log(ϵ1)) 最大

则:
l o g ( p 1 ) − l o g ( − l o g ( ϵ 1 ) ) > l o g ( p 2 ) − l o g ( − l o g ( ϵ 2 ) ) log(p_1)-log(-log(\epsilon_1)) > log(p_2)-log(-log(\epsilon_2)) log(p1)log(log(ϵ1))>log(p2)log(log(ϵ2))
l o g ( p 1 ) − l o g ( − l o g ( ϵ 1 ) ) > l o g ( p 3 ) − l o g ( − l o g ( ϵ 3 ) ) log(p_1)-log(-log(\epsilon_1)) > log(p_3)-log(-log(\epsilon_3)) log(p1)log(log(ϵ1))>log(p3)log(log(ϵ3))

l o g ( p 1 ) − l o g ( − l o g ( ϵ 1 ) ) > l o g ( p 2 ) − l o g ( − l o g ( ϵ 2 ) ) log(p_1)-log(-log(\epsilon_1)) > log(p_2)-log(-log(\epsilon_2)) log(p1)log(log(ϵ1))>log(p2)log(log(ϵ2)) ->
ϵ 1 p 2 / p 1 > ϵ 2 \epsilon_1^{p_2/p_1} > \epsilon_2 ϵ1p2/p1>ϵ2
所以: p 1 p1 p1 大于 p 2 p_2 p2 的概率是 ϵ 1 p 2 / p 1 \epsilon_1^{p_2/p_1} ϵ1p2/p1

同理:
p 1 p1 p1 大于 p 3 p_3 p3 的概率是 ϵ 1 p 3 / p 1 \epsilon_1^{p_3/p_1} ϵ1p3/p1

所以 l o g ( p 1 ) − l o g ( − l o g ( ϵ 1 ) ) log(p_1)-log(-log(\epsilon_1)) log(p1)log(log(ϵ1)) 最大的概率是:
ϵ 1 p 2 / p 1 \epsilon_1^{p_2/p_1} ϵ1p2/p1 * ϵ 1 p 3 / p 1 \epsilon_1^{p_3/p_1} ϵ1p3/p1 * … = ϵ 1 ( 1 − p 1 ) / p 1 \epsilon_1^{(1-p_1)/p_1} ϵ1(1p1)/p1
E ( ϵ 1 ( 1 − p 1 ) / p 1 ) E(\epsilon_1^{(1-p_1)/p_1}) E(ϵ1(1p1)/p1) = ∫ 0 1 ϵ 1 ( 1 − p 1 ) / p 1 d ϵ \int_{0}^{1}\epsilon_1^{(1-p_1)/p_1} d\epsilon 01ϵ1(1p1)/p1dϵ = ∫ 0 1 ϵ 1 ( 1 / p 1 ) − 1 d ϵ 1 \int_{0}^{1}\epsilon_1^{(1/p_1)-1} d\epsilon_1 01ϵ1(1/p1)1dϵ1 = ( p 1 ( ϵ 1 1 / p 1 ) ) ∣ 0 1 (p_1(\epsilon_1^{1/p_1}))|^1_0 (p1(ϵ11/p1))01 = p 1 p_1 p1

证明假设成立

torch实现

def sample_gumbel(shape, eps=1e-20):U = torch.rand(shape)U = U.cuda()return -torch.log(-torch.log(U + eps) + eps)def gumbel_softmax_sample(logits, temperature=0.5):y = torch.log(logits) + sample_gumbel(logits.size())return F.softmax(y / temperature, dim=-1)def gumbel_softmax(logits, temperature=1, hard=False):"""input: [B, n_class]return: [B, n_class] an one-hot vector"""y = gumbel_softmax_sample(logits, temperature)if not hard:return yshape = y.size()_, ind = y.max(dim=-1)y_hard = torch.zeros_like(y).view(-1, shape[-1])y_hard.scatter_(1, ind.view(-1, 1), 1)y_hard = y_hard.view(*shape)# Set gradients w.r.t. y_hard gradients w.r.t. yy_hard = (y_hard - y).detach() + yreturn y_hard

思考

为什么gumbel-softmax和softmax的输出是不一样的?
为什么argmax(gumbel-softmax) 和 argmax(softmax)的结果也不一定一样?这是正常的吗?

大家共勉~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/732688.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【从部署服务器到安装autodock vina】

注意:服务器 linux系统选用ubuntu 登录系统,如果没有图形化见面可以先安装图形化界面 可以参考该视频 --> linux安装图形化界面 非阿里云ubuntu 依次执行以下命令 sudo apt-get update sudo apt-get install gnome sudo reboot阿里云ubuntu 需多执…

分布式解决方案

目录 1. 分布式ID1-1. 传统方案1-2. 分布式ID特点1-3. 实现方案1-4. 开源组件 1. 分布式ID 1-1. 传统方案 时间戳UUID 1-2. 分布式ID特点 全局唯一高并发高可用 1-3. 实现方案 方案总结: 号段模式 有两台服务器,给第一台服务器分配0-100&#xff0…

前端手册-实现挂坠灯笼效果

Unity3D特效百例案例项目实战源码Android-Unity实战问题汇总游戏脚本-辅助自动化Android控件全解手册再战Android系列Scratch编程案例软考全系列Unity3D学习专栏蓝桥系列ChatGPT和AIGC 👉关于作者 专注于Android/Unity和各种游戏开发技巧,以及各种资源分…

c#触发事件

Demo1 触发事件 <Window x:Class"WPFExample.MainWindow"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/winfx/2006/xaml"Title"WPF Example" Height"600" Wi…

如何在Linux上为PyCharm创建和配置Desktop Entry

在Linux操作系统中&#xff0c;.desktop 文件是一种桌面条目文件&#xff0c;用于在图形用户界面中添加程序快捷方式。本文将指导您如何为PyCharm IDE创建和配置一个 .desktop 文件&#xff0c;从而能够通过应用程序菜单或桌面图标快速启动PyCharm。 步骤 1: 确定PyCharm安装路…

鸿蒙应用开发学习:使用视频播放(Video)组件播放视频和音频文件

一、前言 播放音视频是手机的重要功能之一&#xff0c;近期我学习了在鸿蒙系统应用开发中实现音视频的播放功能&#xff0c;应用中使用到了视频播放(Video)组件&#xff0c;ohos.file.picker&#xff08;选择器&#xff09;。特撰此文分享一下我的学习经历。 二、参考资料 本…

【设计】基于web的会员管理系统

1、引言 设计结课作业,课程设计无处下手&#xff0c;网页要求的总数量太多&#xff1f;没有合适的模板&#xff1f;数据库&#xff0c;java&#xff0c;python&#xff0c;vue&#xff0c;html作业复杂工程量过大&#xff1f;毕设毫无头绪等等一系列问题。你想要解决的问题&am…

Elasticsearch 单节点部署教程,以及踩坑记录

1、简介 Elasticsearch 作为分布式搜索引擎&#xff0c;在生产环境中使用集群部署&#xff0c;对于学习者而言我们只需要掌握如何使用即可&#xff0c;后续更高级的集群部署配置将在以后博客中更新。 Elasticsearch 更新迭代速度非常快&#xff0c;并且不同版本有着很大区别&am…

外包干了30天,技术明显退步。。

&#x1f345; 视频学习&#xff1a;文末有免费的配套视频可观看 &#x1f345; 点击文末小卡片&#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 这次来聊一个大家可能也比较关心的问题&#xff0c;那就是就业城市选择的问题。而谈到这个问题&a…

scrapy的基本使用介绍

创建项目 ### 1. 创建虚拟环境 conda create -n spiderScrapy python3.9 ### 2. 安装scrapy pip install scrapy2.8.0 -i https://pypi.tuna.tsinghua.edu.cn/simple### 3. 生成一个框架 scrapy startproject my_spider### 4. 生成项目 scrapy genspider baidu https://www.b…

基于springboot+vue实现高校学生党员发展管理系统项目【项目源码+论文说明】

基于springboot实现高校学生党员发展管理系统演示 摘要 随着高校学生规模的不断扩大&#xff0c;高校内的党员统计及发展管理工作面临较大的压力&#xff0c;高校信息化建设的不断优化发展也进一步促进了系统平台的应用&#xff0c;借助系统平台可以实现更加高效便捷的党员信息…

吴恩达机器学习-可选实验:使用ScikitLearn进行线性回归(Linear Regression using Scikit-Learn)

文章目录 实验一目标工具梯度下降加载数据集缩放/规范化训练数据创建并拟合回归模型查看参数作出预测绘制结果 恭喜 实验二目标工具线性回归&#xff0c;闭式解加载数据集创建并拟合模型查看参数作出预测 第二个例子恭喜 有一个开源的、商业上可用的机器学习工具包&#xff0c;…

2024蓝桥杯每日一题(双指针)

一、第一题&#xff1a;牛的学术圈 解题思路&#xff1a;双指针贪心 仔细思考可以知道&#xff0c;写一篇综述最多在原来的H指数的基础上1&#xff0c;所以基本方法可以是先求出原始的H指数&#xff0c;然后分类讨论怎么样提升H指数。 【Python程序代码】 n,l map(int,…

GO: 快速升级Go版本

由于底层依赖升级了&#xff0c;那我们也要跟着升&#xff0c;go老版本已经不足满足需求了&#xff0c;必须要将版本升级到1.22.0以上 查看当前Go版本 命令查看go版本 go version[rootlocalhost local]# go version go version go1.21.4 linux/amd64 [rootlocalhost local]# …

一篇文章带你了解Python数据分析

目录 一、什么是数据分析&#xff1f; 二、为什么学习数据分析&#xff1f; 三、数据分析实现流程 一、什么是数据分析&#xff1f; 是把隐藏在一些看似杂乱无章的数据背后的信息提炼出来&#xff0c;总结出所研究对象的内在规律。 使得数据的价值最大化 指定促销活动的方…

【网络原理】使用Java基于UDP实现简单客户端与服务器通信

目录 &#x1f384;API介绍&#x1f338;DatagramSocket&#x1f338;DatagramPacket&#x1f338;InetSocketAddress &#x1f333;回显客户端与服务器&#x1f338;建立回显服务器&#x1f338;回显客户端 ⭕总结 我们用Java实现UDP数据报套接字编程&#xff0c;需要借用以下…

yolo模型中神经节点Mul与Sigmoid 和 Conv、Concat、Add、Resize、Reshape、Transpose、Split

yolo模型中神经节点Mul与Sigmoid 和 Conv、Concat、Add、Resize、Reshape、Transpose、Split 在YOLO&#xff08;You Only Look Once&#xff09;模型中&#xff0c;具体作用和用途的解释&#xff1a;

Claude 3 Sonnet 模型现已在亚马逊云科技的 Amazon Bedrock 正式可用!

今天&#xff0c;我们宣布一个激动人心的里程碑&#xff1a;Anthropic 的 Claude 3 Sonnet 模型现已在亚马逊云科技的 Amazon Bedrock 正式可用。 下一代 Claude (Claude 3) 的三个模型 Claude 3 Opus、Claude 3 Sonnet 和 Claude 3 Haiku 将陆续登陆 Amazon Bedrock。Amazon …

二叉树遍历(前中后序的递归/非递归遍历、层序遍历)

二叉树的遍历 1. 二叉树的前序、中序、后序遍历 前、中、后序遍历又叫深度优先遍历 注&#xff1a;严格来说&#xff0c;深度优先遍历是先访问当前节点再继续递归访问&#xff0c;因此&#xff0c;只有前序遍历是严格意义上的深度优先遍历 首先需要知道下面几点&#xff1a; …

【排序】详解插入排序

一、思想 插入排序是通过构建有序序列&#xff0c;对于未排序数据&#xff0c;在已排序序列中从后向前扫描&#xff0c;找到相应位置并插入。具体步骤如下&#xff0c;将数组下标为0的元素视为已经排序的部分&#xff0c;从1开始遍历数组&#xff0c;在遍历的过程中当前元素从…