神经网络中的分位数回归和分位数损失

在使用机器学习构建预测模型时,我们不只是想知道“预测值(点预测)”,而是想知道“预测值落在某个范围内的可能性有多大(区间预测)”。例如当需要进行需求预测时,如果只储备最可能的需求预测量,那么缺货的概率非常的大。但是如果库存处于预测的第95个百分位数(需求有95%的可能性小于或等于该值),那么缺货数量会减少到大约20分之1。

获得这些百分位数值的机器学习方法有:

  • scikit-learn:GradientBoostingRegressor(loss='quantile, alpha=alpha)
  • LightGBM: LGBMRegressor(objective='quantile', alpha=alpha)
  • XGBoost: XGBoostRegressor(objective='reg:quantileerror', quantile_alpha=alpha) (version 2.0~)

这种”预测值落在某个范围内的可能性有多大(区间预测)”的方法都被称作分位数回归,上面的这些机器学习的方法是用了一种叫做Quantile Loss的损失。

Quantile loss是用于评估分位数回归模型性能的一种损失函数。在分位数回归中,我们不仅关注预测的中心趋势(如均值),还关注在分布的不同分位数处的预测准确性。Quantile loss允许我们根据所关注的分位数来量化预测的不确定性。

假设我们有一个预测问题,其中我们要预测一个连续型变量的分布,并且我们关注不同的分位数,例如中位数、0.25分位数、0.75分位数等。对于第q分位数,Quantile Loss定义为:

这里:

  • yy 是真实值。
  • yy 是模型的预测值。
  • qq 是目标分位数,取值范围为0,10,1。

这个损失函数的核心思想是,当模型的预测值超过真实值时,损失是预测值与真实值的差值乘以q。当预测值低于真实值时,损失是预测值与真实值的差值乘以1−q。这确保了对于不同的分位数,我们有不同的惩罚。如果我们更关心较小分位数(例如,中位数),我们会设定较小的q,反之亦然。

用Pytorch实现分位数损失

下面是一个使用Pytorch将分位数损失定义为自定义损失函数的示例。

 importtorchdefquantile_loss(y_true, y_pred, quantile):errors=y_true-y_predloss=torch.mean(torch.max((quantile-1) *errors, quantile*errors))returnloss

对于训练来说,跟正常的训练方法一样:

 for epoch in range(num_epochs):for batch_x, batch_y in dataloader:optimizer.zero_grad()outputs = model(batch_x)loss = quantile_loss(outputs, batch_y, quantile)loss.backward()optimizer.step()

让我们看看这个自定义的损失函数是否如预期的那样工作。

Pytorch分位数损失测试

首先,我们尝试为x生成均匀随机分布(-5~5),为y生成与x指数成比例的正态随机分布,看看是否可以从x预测y的分位数点。

 # Generate dummy datanum_samples = 10000shape = (num_samples, 1)torch.manual_seed(0)# x is uniform random from -5 to 5# y is random normal distribution * exp(scaled x)x_tensor = torch.rand(shape) * 10 - 5x_scaled = x_tensor / 5y_tensor = torch.randn(shape) * torch.exp(x_scaled)# Convert values to NumPy array (for graphs)x = x_tensor.numpy()y = y_tensor.numpy()

网络结构很简单,两个中间层64个节点+每层relu。在没有任何正则化或提前停止的情况下使用100次epoch。待预测的四分位数(百分位数)在列中为[0.500,0.700,0.950,0.990,0.995],在行中为批大小[1,4,16,64,256],总共有25个预测。在10,000个训练数据实例(蓝色)中,低于预测输出值(红色)的实例的比率在图中被标记为“实际”值。

低于指定百分位数值的样本百分比通常接近指定值,并且输出分位数预测的是非常直接的。

再考虑一个稍微复杂的例子,其中y=clip(x, - 2,2) + randn。其中clip(x, - 2,2)是剪辑函数(将值限制在指定范围内)。当数字超出给定范围时,该函数将其限制到最近的边界(如果将范围设置为-2到2,并输入-5的输入值,该函数将返回-2;如果输入10,它将返回2),而randn是遵循正态分布的随机数。网络结构和其他设置与前一种情况相同。

与前一种情况一样,低于指定百分位数值的样本百分比通常接近指定值。分位数预测的理想形状总是左上角图中红线的形状。它应该随着指定的百分位数的增加而平行向上移动。当移动到图的右下方时,预测的红线呈现出更线性的形状,这不是一个理想的结果。

让我们用一个更复杂的形状,我们的目标是y=2sin(x) + randn。其他设置与前一种情况相同。

可以看到低于指定百分位数值的样本百分比通常接近指定值。当向5x5图的右下方移动时,分位数预测的形状偏离了正弦形状。在图的右下方,预测值的红线变得更加线性。

如何选择Q

我们看到,如果设置过高的quantile,会得到扁平化的值,那么如何判断使用Quantile Loss得到的结果是否“扁平”,如何“避免扁平呢”?

检测“扁平化”的方法之一是一起计算第50、68和95个百分位值,并检查这些值之间的关系,即使要获得的最终值是99.5百分位值。如果样本分布服从正态分布,以μ为均值,σ为标准差

在μ±σ区间内的概率约为68;在μ±2σ区间内的概率约为95;在μ±3σ区间内的概率约为99.7

如果第68百分位-第50百分位、第95百分位-第50百分位和99.5百分位-第50百分位值的比值明显偏离1:2:3,我们可以确定偏离的百分位值已经“变平”。

避免扁平化”的第一种方法是减少批量大小,如上面的实验所示。较小的批量大小避免了这个问题,并且不太可能产生平坦的预测。但是减少批大小也有缺点,比如收敛不稳定和增加训练时间,所以它只是有时一个容易采用的选择。

第二种方法是在同一批次中收集相似的样本,而不是随机生成批次。这避免了“在批内低于和高于预测值的样本比例与指定的百分位数值之间的平衡”。

最后"扁平化"是无法避免的,我们只能进行缓解,下列符号用于下列方程。

  • P0:第50个百分位值
  • P1:第68个百分位值
  • P2:第95百分位值
  • P3: 99.5百分位值

使用上述变量,可以使用以下流程图获得适当的99.5%百分位数值。

总结

分位数回归是一种强大的统计工具,对于那些关注数据分布中不同区域的问题,以及需要更加灵活建模的情况,都是一种有价值的方法。

本文将介绍了在神经网络种自定义损失实现分位数回归,并且介绍了如何检测和缓解预测结果的"扁平化"问题。Quantile loss在一些应用中很有用,特别是在金融领域的风险管理问题中,因为它提供了一个在不同分位数下评估模型性能的方法。

https://avoid.overfit.cn/post/e64a72a342af4aeda08b249ebca2c214

作者:Shiro Matsumoto

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/587007.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

BDD - Python Behave 用户自定义命令行选项 -D

BDD - Python Behave 用户自定义命令行选项 -D 引言behave -Dbehave -D 应用feature 文件behave.ini 配置文件step 文件执行 引言 日常运行测试用例,有时需要自定义命令行参数,比如不同环境的对应的配置是不一样的,这样就需要传一个环境参数…

【机器学习合集】深度生成模型 ->(个人学习记录笔记)

深度生成模型 深度生成模型基础 1. 监督学习与无监督学习 1.1 监督学习 定义 在真值标签Y的指导下,学习一个映射函数F,使得F(X)Y 判别模型 Discriminative Model,即判别式模型,又称为条件模型,或条件概率模型 生…

mysql哪些情况下不走索引?

mysql哪些情况下不走索引? MySQL是一种常用的关系型数据库,它使用索引来提高查询性能。然而,并非所有的SQL语句都能充分利用索引。在本文中,我们将介绍几个无法使用到索引的MySQL SQL语句。 1. 使用函数:当SQL语句中…

【Linux】chage命令使用

chage命令 chage用来更改linux用户密码到期信息,包括密码修改间隔最短、最长日期、密码失效时间等。 语法 chage [参数] 用户名 chage命令 -Linux手册页 选项及作用 执行令 : chage --help 执行命令结果 参数 -d, --lastday 最近日期 …

【Electron】webview 实现网页内嵌

实现效果: 当在输入框内输入某个网址后并点击button按钮 , 该网址内容就展示到下面 踩到的坑:之前通过web技术实现 iframe 标签内嵌会出现 同源策略,同时尝试过 vue.config.ts 内配置跨域项 那样确实 是实现啦 但不知道如何动态切换 tagert …

Cisco模拟器-交换机端口的隔离

设计要求将某台交换机的端口划分在不同的VLAN。以实现连接在相同VLAN端口上的计算机可以通信,而连接在不同VLAN端口上的计算机无法通信的目的。 通过设计,一方面可以加强计算机网络的安全,另一方面通过隔绝不同VLAN间的广播包也可以提高网络…

GcExcel:DsExcel 7.0 for Java Crack

GcExcel:DsExcel 7.0-高速 Java Excel 电子表格 API 库 Document Solutions for Excel(DsExcel,以前称为 GcExcel)Java 版允许您在 Java 应用程序中以编程方式创建、编辑、导入和导出 Excel 电子表格。几乎可以部署在任何地方。 创建、加载、…

numpy数组04-数组的轴和读取数据

一、数组的轴 在numpy中数组的轴可以理解为方向,使用0,1,2...数字表示。 对于一个一维数组,只有一个0轴,对于2维数组(如shape(2,2)),有0轴和1轴…

java编程SimpleDateFormat详解

java编程SimpleDateFormat详解 大家好,我是免费搭建查券返利机器人赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿!今天,我们将深入研究Java编程中的日期与时间处理工具——SimpleDateForma…

探索 Pinia:简化 Vue 状态管理的新选择(上)

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…

go的json数据类型处理

json对象转slice package mainimport ("encoding/json""fmt""github.com/gogf/gf/container/garray" )func main() {// JSON 字符串jsonStr : ["apple", "banana", "orange"]//方法一:// 解析 JSON 字…

2312d,d的sql构建器

原文 项目 该项目在我工作项目中广泛使用,它允许自动处理联接方式动态构建SQL语句. 还会自动直接按表示数据库行结构序化.它在dconf2022在线演讲中介绍了:建模一切. 刚刚添加了对sqlite的支持.该API还不稳定,但仍非常有用.这是按需构建,所以虽然有个计划外表,但满足了我的需要…

spring 核心技术依赖注入 DI 详细使用教程包含例子

DI Dependency Injection 依赖注入个解:DI是IOC的一种实现方式,bean文件定义数据,通过构造函数或set方式注入到java类中构造函数注入 数据部分:bean constructor-arg节点逻辑部分:构造函数注入<bean id="helloService" class="io.spring.hello.HelloSer…

visual studio + intel Fortran 错误解决

版本&#xff1a;VS2022 intel Fortran 2024.0.2 Package ID: w_oneAPI_2024.0.2.49896 共遇到三个问题。 1.rc.exe not found 2.kernel32.lib 无法打开 3.winres.h 无法打开 我安装时参考的教程&#xff1a;visual studio和intel oneAPI安装与编写fortran程序_visual st…

【赠书第15期】案例学Python(基础篇)

文章目录 前言 1 简介 2 功能列表 3 实现 3.1 学生类 3.2 学生管理系统类 3.3 使用示例 4 推荐图书 5 粉丝福利 前言 当涉及案例学 Python 时&#xff0c;可以选择一个具体的问题或场景&#xff0c;通过编写代码来解决或模拟这个问题。以下是一个例子&#xff0c;通过…

2024年数据管理预测:利用AI更好地利用非结构化数据

在数据存储和非结构化数据管理领域&#xff0c;过去 12 个月发生了很大变化。在不确定的经济环境下&#xff0c;随着成本上升和 IT 预算压力增加&#xff0c;云存储战略受到关注&#xff0c;生成式 AI 正在创造新的数据存储和治理要求&#xff0c;数据迁移越来越复杂&#xff0…

分库分表之Mycat应用学习二

3 Mycat 概念与配置 官网 http://www.mycat.io/ Mycat 概要介绍 https://github.com/MyCATApache/Mycat-Server 入门指南 https://github.com/MyCATApache/Mycat-doc/tree/master/%E5%85%A5%E9%97%A8%E6%8C%87%E5%8D%973.1 Mycat 介绍与核心概念 3.1.1 基本介绍 历史&#x…

【Yii2】数据库查询方法总结

目录 1.查找单个记录&#xff1a; 2.查找多个记录&#xff1a; 3.条件查询&#xff1a; 4.关联查询&#xff1a; 假设User模型有一个名为orders的多对一关联关系。 5.排序和分组&#xff1a; 6.数据操作&#xff1a; 7.事务处理&#xff1a; 8.命令查询&#xff1a; 9…

MongoDB聚合:$out

$out阶段将聚合管道产生的文档写入到指定的集合&#xff0c;从MongoDB4.4开始&#xff0c;支持指定数据库。$out阶段必须放在聚合管道的最后&#xff0c;支持聚合结果任意大小的数据集。 警告&#xff1a; 如果指定的集合已经存在则会被替换。 语法 用法 1&#xff1a; 定数…

骑砍战团MOD开发(29)-module_scenes.py游戏场景

骑砍1战团mod开发-场景制作方法_哔哩哔哩_bilibilihttps://www.bilibili.com/video/BV1Cw411N7G4/ 一.骑砍游戏场景 骑砍战团中进入城堡,乡村,战斗地图都被定义为场景,由module_scenes.py进行管理。 scene(游戏场景) 天空盒(Skyboxes.py) 地形(terrain code) 场景物(scene_…