梯度提升决策树(GBDT)

GBDT(Gradient Boosting Decision Tree),全名叫梯度提升决策树,是一种迭代的决策树算法,又叫 MART(Multiple Additive Regression Tree),它通过构造一组弱的学习器(树),并把多颗决策树的结果累加起来作为最终的预测输出。该算法将决策树与集成思想进行了有效的结合。

原理

GBDT的核心思想是将多个弱学习器(通常是决策树)组合成一个强大的预测模型。具体而言,GBDT的定义如下:

  • 初始化:首先GBDT使用一个常数(通常是目标变量的平均值,也可以是其他合适的初始值。初始预测值代表了模型对整体数据的初始估计。)作为初始预测值。这个初始预测值代表了我们对目标变量的初始猜测。
  • 迭代训练:GBDT是一个迭代算法,通过多轮迭代来逐步改进模型。在每一轮迭代中,GBDT都会训练一棵新的决策树,目标是减少前一轮模型的残差(或误差)。残差是实际观测值与当前模型预测值之间的差异,新的树将学习如何纠正这些残差。
  • 1)计算残差:在每轮迭代开始时,计算当前模型对训练数据的预测值与实际观测值之间的残差。这个残差代表了前一轮模型未能正确预测的部分。
  • 2):训练新的决策树:使用计算得到的残差作为新的目标变量,训练一棵新的决策树。这棵树将尝试纠正前一轮模型的错误,以减少残差。
  • 3):更新模型:将新训练的决策树与之前的模型进行组合。具体地,将新树的预测结果与之前模型的预测结果相加,得到更新后的模型。
  • 集成:最终,GBDT将所有决策树的预测结果相加,得到最终的集成预测结果。这个过程使得模型能够捕捉数据中的复杂关系,从而提高了预测精度。

每次都以当前预测为基准,下一个弱分类器去拟合误差函数对预测值的残差(预测值与真实值之间的误差)。
GBDT的弱分类器使用的是树模型。
在这里插入图片描述
如图是一个非常简单的帮助理解的示例,我们用GBDT去预测年龄:

  • 第一个弱分类器(第一棵树)预测一个年龄(如20岁),计算发现误差有10岁;
  • 第二棵树预测拟合残差,预测值6,计算发现差距还有4岁;
  • 第三棵树继续预测拟合残差,预测值3,发现差距只有1岁了;
  • 第四课树用1岁拟合剩下的残差,完成。
    最终,四棵树的结论加起来,得到30岁这个标注答案(实际工程实现里,GBDT是计算负梯度,用负梯度近似残差)。

GBDT的优势

1)高精度预测能力
GBDT以其强大的集成学习能力而闻名,能够处理复杂的非线性关系和高维数据。它通常能够在分类和回归任务中取得比单一决策树或线性模型更高的精度。
2)对各种类型数据的适应性
GBDT对不同类型的数据(数值型、类别型、文本等)具有很好的适应性,不需要对数据进行特别的预处理。这使得它在实际应用中更易于使用。

  • 处理混合数据类型
    在现实世界的数据挖掘任务中,常常会遇到混合数据类型的情况。例如,在房价预测问题中,特征既包括数值型(如房屋面积和卧室数量),还包括类别型(如房屋位置和建筑类型)和文本型(如房屋描述)数据。GBDT能够直接处理这些混合数据,无需将其转换成统一的格式。这简化了数据预处理的步骤,节省了建模时间。
  • 不需要特征缩放
    与某些机器学习算法(如支持向量机和神经网络)不同,GBDT不需要对特征进行缩放或归一化。这意味着特征的尺度差异不会影响模型的性能。在一些算法中,特征的尺度不一致可能导致模型无法正确学习,需要进行繁琐的特征缩放操作。而GBDT能够直接处理原始特征,减轻了数据预处理的负担。

3)在数据不平衡情况下的优势

  • 加权损失函数
    GBDT使用的损失函数允许对不同类别的样本赋予不同权重。这意味着模型可以更关注少数类别,从而提高对不平衡数据的处理能力。
  • 逐步纠正错误
    GBDT的迭代训练方式使其能够逐步纠正前一轮模型的错误。在处理不平衡数据时,模型通常会在多轮迭代中重点关注难以分类的少数类别样本。通过逐步纠正错误,模型逐渐提高了对少数类别的分类能力,从而改善了预测结果。

4)鲁棒性与泛化能力
GBDT在处理噪声数据和复杂问题时表现出色。其鲁棒性使得它能够有效应对数据中的异常值或噪声,不容易受到局部干扰而产生较大的预测误差。
5)特征重要性评估
GBDT可以提供有关特征重要性的信息,帮助用户理解模型的决策过程。通过分析每个特征对模型预测的贡献程度,用户可以识别出哪些特征对于问题的解决最为关键。这对于特征选择、模型解释和问题理解非常有帮助。
6)高效处理大规模数据
尽管GBDT通常是串行训练的,每棵树依赖于前一棵树的结果,但它可以高效处理大规模数据。这得益于GBDT的并行化实现和轻量级的决策树结构。此外,GBDT在处理大规模数据时可以通过特征抽样和数据抽样来加速训练过程,而不会牺牲太多预测性能。

关键参数与调优

参数解释

n_estimators:迭代次数,即最终模型中弱学习器的数量。
learning_rate(学习率):每次迭代时,新决策树对预测结果的贡献权重。
max_depth:决策树的最大深度,控制着树的复杂度。
min_samples_split:节点分裂所需的最小样本数。
subsample:用于训练每棵树的样本采样比例,小于1时可实现随机梯度提升。
loss:即我们GBDT算法中的损失函数。分类模型和回归模型的损失函数是不一样的。1)对于分类模型,有对数似然损失函数"deviance"和指数损失函数"exponential"两者输入选择。默认是对数似然损失函数"deviance"。2)对于回归模型,有均方差"ls", 绝对损失"lad", Huber损失"huber"和分位数损失“quantile”。默认是均方差"ls"。一般来说,如果数据的噪音点不多,用默认的均方差"ls"比较好。如果是噪音点较多,则推荐用抗噪音的损失函数"huber"。而如果我们需要对训练集进行分段预测的时候,则采用“quantile”。
subsample:用于训练每个弱学习器的样本比例。减小该参数可以降低方差,但也可能增加偏差。

调优策略

学习率与迭代次数的平衡:较低的学习率通常需要更多的迭代次数来达到较好的性能,但能减少过拟合的风险。
树的深度与样本采样:合理限制树的深度和采用子采样可以提高模型的泛化能力。
早停机制:在验证集上监控性能,一旦性能不再显著提升,则提前终止训练。

为了解决GBDT的效率问题,LightGBM和XGBoost等先进框架被提出,它们通过优化算法结构(如直方图近似)、并行计算等方式显著提高了训练速度。

python实现

from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import GradientBoostingRegressorGradientBoostingClassifier(*, loss='deviance', learning_rate=0.1, n_estimators=100, subsample=1.0, criterion='friedman_mse', min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0,max_depth=3, min_impurity_decrease=0.0, min_impurity_split=None, init=None, random_state=None,max_features=None, verbose=0, max_leaf_nodes=None, warm_start=False, presort='deprecated', validation_fraction=0.1, n_iter_no_change=None, tol=0.0001, ccp_alpha=0.0)GradientBoostingRegressor(*, loss='ls', learning_rate=0.1, n_estimators=100, subsample=1.0, criterion='friedman_mse', min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_depth=3, min_impurity_decrease=0.0, min_impurity_split=None, init=None, random_state=None, max_features=None, alpha=0.9, verbose=0, max_leaf_nodes=None, warm_start=False, presort='deprecated', validation_fraction=0.1, n_iter_no_change=None, tol=0.0001, ccp_alpha=0.0) 

回归实现

# 导入必要的库
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error# 加载波士顿房价数据集
boston = load_boston()
X, y = boston.data, boston.target# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 初始化GBDT回归器
gbdt_reg = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)# 训练模型
gbdt_reg.fit(X_train, y_train)# 预测
y_pred = gbdt_reg.predict(X_test)# 计算并打印均方误差
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse:.2f}")

GBDT正则化

针对GBDT正则化,我们通过子采样比例方法和定义步长v方法来防止过拟合。

  • 子采样比例:通过不放回抽样的子采样比例(subsample),取值为(0,1]。如果取值为1,则全部样本都使用。如果取值小于1,利用部分样本去做GBDT的决策树拟合。选择小于1的比例可以减少方差,防止过拟合,但是会增加样本拟合的偏差。因此取值不能太低,推荐在[0.5, 0.8]之间。
  • 定义步长v:针对弱学习器的迭代,我们定义步长v,取值为(0,1]。对于同样的训练集学习效果,较小的v意味着我们需要更多的弱学习器的迭代次数。通常我们用步长和迭代最大次数一起来决定算法的拟合效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/25333.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Valgo,类型安全,表达能⼒强的go验证器

valgo 是一个为 Go 语言设计的类型安全、表达性强且可扩展的验证库。该库的特点包括: github.com/cohesivestack/valgo 类型安全:利用 Go 语言的泛型特性(从 Go 1.18 版本开始支持),确保验证逻辑的类型安全。表达性&a…

关于Stream.toList()方法使用小记

对照示例 public static void main(String[] args) {final List<String> list new ArrayList<>();list.add("aa");list.add("bb");list.add("cc");list.remove("cc");System.out.println(list);}结果&#xff1a; Stre…

【谣传】不能完全取代HR

https://arxiv.org/pdf/2405.18113 这份研究论文提出了 MockLLM&#xff0c;一个利用大型语言模型&#xff08;LLM&#xff09;角色扮演能力来促进招聘场景中人和职位匹配的框架。它通过模拟面试过程来生成额外的匹配证据&#xff0c;从而提高匹配的准确性。 主要问题和挑战&am…

使用python绘制季节图

使用python绘制季节图 季节图效果代码 季节图 季节图&#xff08;Seasonal Plot&#xff09;是一种数据可视化图表&#xff0c;用于展示时间序列数据的季节性变化。它通过将每个时间段&#xff08;如每个月、每个季度&#xff09;的数据绘制在同一张图表上&#xff0c;使得不同…

移动安全赋能化工能源行业智慧转型

随着我国能源化工企业的不断发展&#xff0c;化工厂中经常存在火灾爆炸的危险&#xff0c;特别是生产场所&#xff0c;约有80%以上生产场所区域存在爆炸性物质。而目前我国化工危险场所移动通信设备的普及率高&#xff0c;但是对移动通信设备的安全防护却有所忽视&#xff0c;包…

关系数据库标准查询语言-SQL-SQL语言概述

一、SQL(Structured Query Language)语言 1、是高度非过程化的语言 2、关系数据库管理系统(RDBMS)都支持SQL标准 3、具有定义、查询、更新、控制四大功能 4、数据库对象由数据库&#xff08;Database&#xff09;、基本表&#xff08;Table&#xff09;、视图&#xff08;V…

string经典题目(C++)

文章目录 前言一、最长回文子串1.题目解析2.算法原理3.代码编写 二、字符串相乘1.题目解析2.算法原理3.代码编写 总结 前言 一、最长回文子串 1.题目解析 给你一个字符串 s&#xff0c;找到 s 中最长的回文子串。 示例 1&#xff1a; 输入&#xff1a;s “babad” 输出&am…

自动化测试-Selenium-元素定位

一.元素定位 因为使用selenium进行自动化测试&#xff0c;元素定位是必不可少的&#xff0c;所以这篇文章用于自动化测试中的selenium中的元素定位法。 1.根据id属性进行定位&#xff08;id是唯一的&#xff09; id定位要求比较高&#xff0c;要求这个元素的id必须是固定且唯…

Java的自动装箱和自动拆箱

自动装箱和拆箱在Java开发中的应用与注意事项 在Java开发中&#xff0c;自动装箱&#xff08;Autoboxing&#xff09;和自动拆箱&#xff08;Unboxing&#xff09;是指基本数据类型与其对应的包装类之间的自动转换。这些特性可以使代码更加简洁和易读&#xff0c;但在实际项目…

CANoe-Trace窗口无法解析SOME/IP报文、Demo License激活方式改变

1、Trace窗口无法解析SOME/IP报文 在文章《如何让CANoe或Wireshark自动解析应用层协议》中,我们通过设置指定端口号为SOME/IP报文的方式,可以让CANoe中的Trace窗口对此端口号的报文当成是SOME/IP报文进行解析。 Trace窗口就可以根据传输层端口号对payload数据按照SOME/IP协议…

linuxDNS域名解析

文章目录 DNS 是域名系统的简称正向解析反向解析主从服务器解析bond网卡 DNS 是域名系统的简称 域名和IP地址之间的映射关系 互联网中&#xff0c;IP地址是通信的唯一标识&#xff0c;逻辑地址 访问网站 域名解析的目的就是为了实现&#xff0c;访问域名就等于访问IP地址 …

JS(JavaScript)的引用方式介绍与代码演示

天行健&#xff0c;君子以自强不息&#xff1b;地势坤&#xff0c;君子以厚德载物。 每个人都有惰性&#xff0c;但不断学习是好好生活的根本&#xff0c;共勉&#xff01; 文章均为学习整理笔记&#xff0c;分享记录为主&#xff0c;如有错误请指正&#xff0c;共同学习进步。…

SpringBoot+Vue体育馆管理系统(前后端分离)

技术栈 JavaSpringBootMavenMySQLMyBatisVueShiroElement-UI 角色对应功能 学生管理员 功能截图

Linux安装MySQL教程【带图文命令巨详细】

巨详细Linux安装MySQL 1、查看是否有自带数据库或残留数据库信息1.1检查残留mysql1.2检查并删除残留mysql依赖1.3检查是否自带mariadb库 2、下载所需MySQL版本&#xff0c;上传至系统指定位置2.1创建目录2.2下载MySQL压缩包 3、安装MySQL3.1创建目录3.2解压mysql压缩包3.3安装解…

DBeaver无法连接Clickhouse,连接失败

DBeaver默认下载的是0.2.6版本的驱动&#xff0c;但是一直连接失败&#xff1a; 报错提示 解决办法 点击上图中的Open Driver Configuration点击库 - 重置为默认状态在弹出的窗口中修改驱动版本号为0.2.4或者其他版本&#xff08;我没有试用过其他版本&#xff09;&#xff0…

vscode软件上安装 Fitten Code插件及使用

一. 简介 前面几篇文章学习了 Pycharm开发工具上安装 Fitten Code插件&#xff0c;以及 Fitten Code插件的使用。 Fitten Code插件是是一款由非十大模型驱动的 AI 编程助手&#xff0c;它可以自动生成代码&#xff0c;提升开发效率&#xff0c;帮您调试 Bug&#xff0c;节省…

FPGA通过移位相加实现无符号乘法器(参数化,封装成IP可直接调用)

目录 1.前言2.原理3.移位无符号乘法器实现&#xff0c;并参数化 微信公众号获取更多FPGA相关源码&#xff1a; 1.前言 在硬件设计中&#xff0c;乘法器是非常重要的一个器件&#xff0c;乘法器的种类繁多&#xff0c;常见的有并行乘法器、移位相加乘法器和查找表乘法器。 并…

Java——简单图书管理系统

前言&#xff1a; 一、图书管理系统是什么样的&#xff1f;二、准备工作分析有哪些对象&#xff1f;画UML图 三、实现三大模块用户模块书架模块管理操作模块管理员操作有这些普通用户操作有这些 四、Test测试类五、拓展 哈喽&#xff0c;大家好&#xff0c;我是无敌小恐龙。 写…

Spark作业运行异常慢的问题定位和分析思路

一直很慢 &#x1f422; 运行中状态、卡住了&#xff0c;可以从以下两种方式入手&#xff1a; 如果 Spark UI 上&#xff0c;有正在运行的 Job/Stage/Task&#xff0c;看 Executor 相关信息就好。 第一步&#xff0c;如果发现卡住了&#xff0c;直接找到对应的 Executor 页面&a…

模糊控制器实现对某个对象追踪输入

MATLAB是一个十分便捷的软件&#xff0c;里面提供了许多集成的组件&#xff0c;本文利用simulink实现模糊控制器实现对某个对象追踪输入。 这里的对象根据自己的需求可以修改&#xff0c;那么搭建一个闭环控制系统并不是难事儿&#xff0c;主要是对于模糊控制器参数的设置&…