神经网络中的分位数回归和分位数损失

在使用机器学习构建预测模型时,我们不只是想知道“预测值(点预测)”,而是想知道“预测值落在某个范围内的可能性有多大(区间预测)”。例如当需要进行需求预测时,如果只储备最可能的需求预测量,那么缺货的概率非常的大。但是如果库存处于预测的第95个百分位数(需求有95%的可能性小于或等于该值),那么缺货数量会减少到大约20分之1。

获得这些百分位数值的机器学习方法有:

  • scikit-learn:GradientBoostingRegressor(loss='quantile, alpha=alpha)
  • LightGBM: LGBMRegressor(objective='quantile', alpha=alpha)
  • XGBoost: XGBoostRegressor(objective='reg:quantileerror', quantile_alpha=alpha) (version 2.0~)

这种”预测值落在某个范围内的可能性有多大(区间预测)”的方法都被称作分位数回归,上面的这些机器学习的方法是用了一种叫做Quantile Loss的损失。

Quantile loss是用于评估分位数回归模型性能的一种损失函数。在分位数回归中,我们不仅关注预测的中心趋势(如均值),还关注在分布的不同分位数处的预测准确性。Quantile loss允许我们根据所关注的分位数来量化预测的不确定性。

假设我们有一个预测问题,其中我们要预测一个连续型变量的分布,并且我们关注不同的分位数,例如中位数、0.25分位数、0.75分位数等。对于第q分位数,Quantile Loss定义为:

这里:

  • yy 是真实值。
  • yy 是模型的预测值。
  • qq 是目标分位数,取值范围为0,10,1。

这个损失函数的核心思想是,当模型的预测值超过真实值时,损失是预测值与真实值的差值乘以q。当预测值低于真实值时,损失是预测值与真实值的差值乘以1−q。这确保了对于不同的分位数,我们有不同的惩罚。如果我们更关心较小分位数(例如,中位数),我们会设定较小的q,反之亦然。

用Pytorch实现分位数损失

下面是一个使用Pytorch将分位数损失定义为自定义损失函数的示例。

 importtorchdefquantile_loss(y_true, y_pred, quantile):errors=y_true-y_predloss=torch.mean(torch.max((quantile-1) *errors, quantile*errors))returnloss

对于训练来说,跟正常的训练方法一样:

 for epoch in range(num_epochs):for batch_x, batch_y in dataloader:optimizer.zero_grad()outputs = model(batch_x)loss = quantile_loss(outputs, batch_y, quantile)loss.backward()optimizer.step()

让我们看看这个自定义的损失函数是否如预期的那样工作。

Pytorch分位数损失测试

首先,我们尝试为x生成均匀随机分布(-5~5),为y生成与x指数成比例的正态随机分布,看看是否可以从x预测y的分位数点。

 # Generate dummy datanum_samples = 10000shape = (num_samples, 1)torch.manual_seed(0)# x is uniform random from -5 to 5# y is random normal distribution * exp(scaled x)x_tensor = torch.rand(shape) * 10 - 5x_scaled = x_tensor / 5y_tensor = torch.randn(shape) * torch.exp(x_scaled)# Convert values to NumPy array (for graphs)x = x_tensor.numpy()y = y_tensor.numpy()

网络结构很简单,两个中间层64个节点+每层relu。在没有任何正则化或提前停止的情况下使用100次epoch。待预测的四分位数(百分位数)在列中为[0.500,0.700,0.950,0.990,0.995],在行中为批大小[1,4,16,64,256],总共有25个预测。在10,000个训练数据实例(蓝色)中,低于预测输出值(红色)的实例的比率在图中被标记为“实际”值。

低于指定百分位数值的样本百分比通常接近指定值,并且输出分位数预测的是非常直接的。

再考虑一个稍微复杂的例子,其中y=clip(x, - 2,2) + randn。其中clip(x, - 2,2)是剪辑函数(将值限制在指定范围内)。当数字超出给定范围时,该函数将其限制到最近的边界(如果将范围设置为-2到2,并输入-5的输入值,该函数将返回-2;如果输入10,它将返回2),而randn是遵循正态分布的随机数。网络结构和其他设置与前一种情况相同。

与前一种情况一样,低于指定百分位数值的样本百分比通常接近指定值。分位数预测的理想形状总是左上角图中红线的形状。它应该随着指定的百分位数的增加而平行向上移动。当移动到图的右下方时,预测的红线呈现出更线性的形状,这不是一个理想的结果。

让我们用一个更复杂的形状,我们的目标是y=2sin(x) + randn。其他设置与前一种情况相同。

可以看到低于指定百分位数值的样本百分比通常接近指定值。当向5x5图的右下方移动时,分位数预测的形状偏离了正弦形状。在图的右下方,预测值的红线变得更加线性。

如何选择Q

我们看到,如果设置过高的quantile,会得到扁平化的值,那么如何判断使用Quantile Loss得到的结果是否“扁平”,如何“避免扁平呢”?

检测“扁平化”的方法之一是一起计算第50、68和95个百分位值,并检查这些值之间的关系,即使要获得的最终值是99.5百分位值。如果样本分布服从正态分布,以μ为均值,σ为标准差

在μ±σ区间内的概率约为68;在μ±2σ区间内的概率约为95;在μ±3σ区间内的概率约为99.7

如果第68百分位-第50百分位、第95百分位-第50百分位和99.5百分位-第50百分位值的比值明显偏离1:2:3,我们可以确定偏离的百分位值已经“变平”。

避免扁平化”的第一种方法是减少批量大小,如上面的实验所示。较小的批量大小避免了这个问题,并且不太可能产生平坦的预测。但是减少批大小也有缺点,比如收敛不稳定和增加训练时间,所以它只是有时一个容易采用的选择。

第二种方法是在同一批次中收集相似的样本,而不是随机生成批次。这避免了“在批内低于和高于预测值的样本比例与指定的百分位数值之间的平衡”。

最后"扁平化"是无法避免的,我们只能进行缓解,下列符号用于下列方程。

  • P0:第50个百分位值
  • P1:第68个百分位值
  • P2:第95百分位值
  • P3: 99.5百分位值

使用上述变量,可以使用以下流程图获得适当的99.5%百分位数值。

总结

分位数回归是一种强大的统计工具,对于那些关注数据分布中不同区域的问题,以及需要更加灵活建模的情况,都是一种有价值的方法。

本文将介绍了在神经网络种自定义损失实现分位数回归,并且介绍了如何检测和缓解预测结果的"扁平化"问题。Quantile loss在一些应用中很有用,特别是在金融领域的风险管理问题中,因为它提供了一个在不同分位数下评估模型性能的方法。

https://avoid.overfit.cn/post/e64a72a342af4aeda08b249ebca2c214

作者:Shiro Matsumoto

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/587007.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【机器学习合集】深度生成模型 ->(个人学习记录笔记)

深度生成模型 深度生成模型基础 1. 监督学习与无监督学习 1.1 监督学习 定义 在真值标签Y的指导下,学习一个映射函数F,使得F(X)Y 判别模型 Discriminative Model,即判别式模型,又称为条件模型,或条件概率模型 生…

【Linux】chage命令使用

chage命令 chage用来更改linux用户密码到期信息,包括密码修改间隔最短、最长日期、密码失效时间等。 语法 chage [参数] 用户名 chage命令 -Linux手册页 选项及作用 执行令 : chage --help 执行命令结果 参数 -d, --lastday 最近日期 …

【Electron】webview 实现网页内嵌

实现效果: 当在输入框内输入某个网址后并点击button按钮 , 该网址内容就展示到下面 踩到的坑:之前通过web技术实现 iframe 标签内嵌会出现 同源策略,同时尝试过 vue.config.ts 内配置跨域项 那样确实 是实现啦 但不知道如何动态切换 tagert …

Cisco模拟器-交换机端口的隔离

设计要求将某台交换机的端口划分在不同的VLAN。以实现连接在相同VLAN端口上的计算机可以通信,而连接在不同VLAN端口上的计算机无法通信的目的。 通过设计,一方面可以加强计算机网络的安全,另一方面通过隔绝不同VLAN间的广播包也可以提高网络…

GcExcel:DsExcel 7.0 for Java Crack

GcExcel:DsExcel 7.0-高速 Java Excel 电子表格 API 库 Document Solutions for Excel(DsExcel,以前称为 GcExcel)Java 版允许您在 Java 应用程序中以编程方式创建、编辑、导入和导出 Excel 电子表格。几乎可以部署在任何地方。 创建、加载、…

numpy数组04-数组的轴和读取数据

一、数组的轴 在numpy中数组的轴可以理解为方向,使用0,1,2...数字表示。 对于一个一维数组,只有一个0轴,对于2维数组(如shape(2,2)),有0轴和1轴…

探索 Pinia:简化 Vue 状态管理的新选择(上)

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…

go的json数据类型处理

json对象转slice package mainimport ("encoding/json""fmt""github.com/gogf/gf/container/garray" )func main() {// JSON 字符串jsonStr : ["apple", "banana", "orange"]//方法一:// 解析 JSON 字…

visual studio + intel Fortran 错误解决

版本:VS2022 intel Fortran 2024.0.2 Package ID: w_oneAPI_2024.0.2.49896 共遇到三个问题。 1.rc.exe not found 2.kernel32.lib 无法打开 3.winres.h 无法打开 我安装时参考的教程:visual studio和intel oneAPI安装与编写fortran程序_visual st…

【赠书第15期】案例学Python(基础篇)

文章目录 前言 1 简介 2 功能列表 3 实现 3.1 学生类 3.2 学生管理系统类 3.3 使用示例 4 推荐图书 5 粉丝福利 前言 当涉及案例学 Python 时,可以选择一个具体的问题或场景,通过编写代码来解决或模拟这个问题。以下是一个例子,通过…

2024年数据管理预测:利用AI更好地利用非结构化数据

在数据存储和非结构化数据管理领域,过去 12 个月发生了很大变化。在不确定的经济环境下,随着成本上升和 IT 预算压力增加,云存储战略受到关注,生成式 AI 正在创造新的数据存储和治理要求,数据迁移越来越复杂&#xff0…

分库分表之Mycat应用学习二

3 Mycat 概念与配置 官网 http://www.mycat.io/ Mycat 概要介绍 https://github.com/MyCATApache/Mycat-Server 入门指南 https://github.com/MyCATApache/Mycat-doc/tree/master/%E5%85%A5%E9%97%A8%E6%8C%87%E5%8D%973.1 Mycat 介绍与核心概念 3.1.1 基本介绍 历史&#x…

骑砍战团MOD开发(29)-module_scenes.py游戏场景

骑砍1战团mod开发-场景制作方法_哔哩哔哩_bilibilihttps://www.bilibili.com/video/BV1Cw411N7G4/ 一.骑砍游戏场景 骑砍战团中进入城堡,乡村,战斗地图都被定义为场景,由module_scenes.py进行管理。 scene(游戏场景) 天空盒(Skyboxes.py) 地形(terrain code) 场景物(scene_…

【华为数据之道学习笔记】8-2 数据质量规则

异常数据是不满足数据标准、不符合业务实质的客观存在的数据,如某位员工的国籍信息错误、某位客户的客户名称信息错误等。 数据在底层数据库多数是以二维表格的形式存储,每个数据格存储一个数据值。若想从众多数据中识别出异常数据,就需要通过…

【滑动窗口】C++算法:可见点的最大数目

作者推荐 动态规划 多源路径 字典树 LeetCode2977:转换字符串的最小成本 本题涉及知识点 滑动窗口 LeetCode 1610可见点的最大数目 给你一个点数组 points 和一个表示角度的整数 angle ,你的位置是 location ,其中 location [posx, posy] 且 point…

C#语言发展历程(1-7)

一、类型发展 C#1中是没有泛型的 在C#2中在逐渐推出泛型。C#2还引入了可空类型。 示例:C#泛型(详解)-CSDN博客 1 C#3:引入了匿名类型、和隐式的局部变量(var) 匿名类型:我们主要是使用在LIN…

openGauss学习笔记-179 openGauss 数据库运维-逻辑复制-发布订阅

文章目录 openGauss学习笔记-179 openGauss 数据库运维-逻辑复制-发布订阅179.1 发布179.2 订阅179.3 冲突处理179.4 限制179.5 架构179.6 监控179.7 安全性179.8 配置设置179.9 快速设置 openGauss学习笔记-179 openGauss 数据库运维-逻辑复制-发布订阅 发布和订阅基于逻辑复…

大模型推理部署:LLM 七种推理服务框架总结

自从ChatGPT发布以来,国内外的开源大模型如雨后春笋般成长,但是对于很多企业和个人从头训练预训练模型不太现实,即使微调开源大模型也捉襟见肘,那么直接部署这些开源大模型服务于企业业务将会有很大的前景。 本文将介绍七中主流的…

【eclipse】eclipse开发springboot项目使用入门

下载eclipse Eclipse downloads - Select a mirror | The Eclipse Foundation 安装eclipse 其他一步一步即可 我们是开发java web选择如下 界面修改 Window->Preferences-> 修改eclipse风格主题 Window->Preferences->General->Appearance 修改字体和大小…

基于 CefSharp 实现一个文件小工具

I’m not saying you can’t be financially successful I’m saying have a greater purpose in life well beyond the pursuit of financial success Your soul is screaming for you to answer your true calling You can change today if you redefine what success is to …