XGBoost算法详解

XGBoost算法详解

XGBoost(Extreme Gradient Boosting)是一种高效的梯度提升决策树(GBDT)实现,因其高性能和灵活性在机器学习竞赛中广泛使用。本文将详细介绍XGBoost算法的原理,并展示其在实际数据集上的应用。

XGBoost算法原理

XGBoost是一种集成学习方法,通过逐步建立多个决策树,每棵树都在前一棵树的基础上进行改进。XGBoost的基本思想是逐步减少损失函数值,使模型的预测能力不断提高。

算法步骤

  1. 初始化模型:使用常数模型初始化,比如回归问题中可以用目标值的均值初始化模型。
  2. 计算残差:计算当前模型的残差,即预测值与真实值之间的差异。
  3. 拟合残差:用新的决策树拟合残差,并更新模型。
  4. 更新模型:将新决策树的预测结果加到模型中,以减少残差。
  5. 重复步骤2-4:直到达到预设的迭代次数或损失函数值足够小。

公式推理

初始化模型:
F 0 ( x ) = arg ⁡ min ⁡ γ ∑ i = 1 n L ( y i , γ ) F_0(x) = \arg\min_{\gamma} \sum_{i=1}^{n} L(y_i, \gamma) F0(x)=argminγi=1nL(yi,γ)

对于每一次迭代 m = 1 , 2 , … , M m = 1, 2, \ldots, M m=1,2,,M

  1. 计算负梯度(残差): r i m = − [ ∂ L ( y i , F ( x i ) ) ∂ F ( x i ) ] F ( x ) = F m − 1 ( x ) r_{im} = -\left[ \frac{\partial L(y_i, F(x_i))}{\partial F(x_i)} \right]_{F(x) = F_{m-1}(x)} rim=[F(xi)L(yi,F(xi))]F(x)=Fm1(x)
  2. 拟合一个新的决策树来预测残差: h m ( x ) = arg ⁡ min ⁡ h ∑ i = 1 n ( r i m − h ( x i ) ) 2 h_m(x) = \arg\min_{h} \sum_{i=1}^{n} (r_{im} - h(x_i))^2 hm(x)=argminhi=1n(rimh(xi))2
  3. 更新模型: F m ( x ) = F m − 1 ( x ) + ν h m ( x ) F_m(x) = F_{m-1}(x) + \nu h_m(x) Fm(x)=Fm1(x)+νhm(x)
    其中, ν \nu ν是学习率,控制每棵树对最终模型的贡献。

损失函数与正则化

XGBoost的损失函数包含两部分:训练误差和正则化项。训练误差衡量模型预测值与真实值之间的差距,正则化项则用于控制模型复杂度,以避免过拟合。

损失函数形式如下:
L ( F ) = ∑ i = 1 n L ( y i , F ( x i ) ) + ∑ k = 1 K Ω ( f k ) \mathcal{L}(F) = \sum_{i=1}^{n} L(y_i, F(x_i)) + \sum_{k=1}^{K} \Omega(f_k) L(F)=i=1nL(yi,F(xi))+k=1KΩ(fk)
其中, Ω ( f k ) \Omega(f_k) Ω(fk)是第k棵树的正则化项,通常包括叶子节点数和叶子节点权重的平方和:
Ω ( f ) = γ T + 1 2 λ ∑ j = 1 T w j 2 \Omega(f) = \gamma T + \frac{1}{2} \lambda \sum_{j=1}^{T} w_j^2 Ω(f)=γT+21λj=1Twj2

树结构的构建

XGBoost采用启发式算法来构建树结构。在每个节点分裂时,选择能最大程度上减少损失函数的特征和分割点。具体过程如下:

  1. 计算增益:对于每个特征,计算在不同分割点上的增益,增益表示分裂前后损失函数的变化。
  2. 选择分割点:选择增益最大的特征和分割点进行节点分裂。
  3. 递归构建树:对分裂后的每个子节点重复上述过程,直到达到预设的树深度或其他停止条件。

并行和分布式计算

XGBoost通过并行和分布式计算大大提高了训练速度。其核心思想是将特征按列存储,允许在计算增益时并行处理不同特征。此外,XGBoost还支持分布式计算,能够在多台机器上分布式训练模型。

缺失值处理

XGBoost在训练过程中能够自动处理缺失值。在分裂节点时,针对缺失值分别计算增益,选择最佳策略。通常采用两种方法处理缺失值:默认方向法和分布估计法。

学习率与子采样

XGBoost通过学习率和子采样来控制每棵树对最终模型的贡献。学习率 ν \nu ν用于缩小每棵树的预测值,防止模型过拟合。子采样则通过随机选择训练样本和特征,进一步提高模型的泛化能力。

XGBoost算法的特点

  1. 高效性:XGBoost通过并行处理和分布式计算大大提高了训练速度。
  2. 灵活性:XGBoost可以处理回归、分类和排序任务,并且可以使用各种损失函数。
  3. 鲁棒性:XGBoost对数据的噪声和异常值有一定的鲁棒性。
  4. 可解释性:通过特征重要性等方法可以解释XGBoost模型。

XGBoost参数说明

以下是XGBoost常用参数及其详细说明的表格形式:

参数名称描述默认值示例
n_estimators树的棵数,提升迭代的次数100n_estimators=200
learning_rate学习率,控制每棵树对最终模型的贡献0.1learning_rate=0.05
max_depth树的最大深度,控制每棵树的复杂度6max_depth=4
min_child_weight叶子节点最小权重,控制过拟合1min_child_weight=3
subsample样本采样比例,用于控制过拟合1.0subsample=0.8
colsample_bytree每棵树的特征采样比例1.0colsample_bytree=0.8
gamma节点分裂所需的最小损失函数下降值0gamma=0.1
lambdaL2正则化项系数1lambda=2
alphaL1正则化项系数0alpha=0.1
scale_pos_weight正样本的权重比例,用于处理类别不平衡1scale_pos_weight=10
objective要优化的目标函数reg:squarederrorobjective='binary:logistic'
eval_metric评估指标rmseeval_metric='auc'
seed随机数种子,用于结果复现0seed=42
silent是否静默模式,0表示打印运行信息,1表示不打印1silent=0
nthread线程数,控制并行计算所有可用线程nthread=4
max_delta_step每棵树权重估计的最大步长,如果类别极度不平衡,可以设置较高的值0max_delta_step=1
booster要使用的提升类型,可以是gbtreegblineardartgbtreebooster='dart'
tree_method构建树的方法,可以是autoexactapproxhistgpu_histautotree_method='hist'
predictor用于预测的算法类型,可以是cpu_predictorgpu_predictorautopredictor='gpu_predictor'

通过合理调整这些参数,可以优化XGBoost模型在特定任务和数据集上的性能。

XGBoost算法在回归问题中的应用

在本节中,我们将使用合成数据集来展示如何使用XGBoost算法进行回归任务。

导入库

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import xgboost as xgb
from sklearn.metrics import mean_squared_error, r2_score

生成和预处理数据

使用 make_regression 函数生成一个合成的回归数据集:

# 生成合成回归数据集
X, y = make_regression(n_samples=1000, n_features=20, noise=0.1, random_state=42)# 数据集划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

训练XGBoost模型

# 训练XGBoost模型
xgb_regressor = xgb.XGBRegressor(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)
xgb_regressor.fit(X_train, y_train)

预测与评估

# 预测
y_pred = xgb_regressor.predict(X_test)# 评估
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f'Mean Squared Error: {mse:.2f}')
print(f'R^2 Score: {r2:.2f}')

特征重要性

# 特征重要性
feature_importances = xgb_regressor.feature_importances_
plt.barh(range(X.shape[1]), feature_importances, align='center')
plt.yticks(np.arange(X.shape[1]), [f'Feature {i}' for i in range(X.shape[1])])
plt.xlabel('Feature Importance')
plt.ylabel('Feature')
plt.title('Feature Importances in XGBoost')
plt.show()

在这里插入图片描述

XGBoost算法在分类问题中的应用

在本节中,我们将使用 make_classification 函数生成一个合成的分类数据集,来展示如何使用XGBoost算法进行分类任务。

导入库

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

生成和预处理数据

# 生成合成分类数据集
X, y = make_classification(n_samples=1000, n_features=20, n_informative=15, n_redundant=5, random_state=42)# 数据集划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

训练XGBoost模型

# 训练XGBoost模型
xgb_classifier = xgb.XGBClassifier(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)
xgb_classifier.fit(X_train, y_train)

预测与评估

# 预测
y_pred = xgb_classifier.predict(X_test)# 评估
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy:.2f}')# 混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred)
print('Confusion Matrix:')
print(conf_matrix)# 分类报告
class_report = classification_report(y_test, y_pred)
print('Classification Report:')
print(class_report)

结语

本文我们详细介绍了XGBoost算法的原理和特点,并展示了其在回归和分类任务中的应用。首先介绍了XGBoost算法的基本思想和公式,然后展示了如何在合成数据集上使用XGBoost进行回归任务,以及如何在合成分类数据集上使用XGBoost进行分类任务。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/30522.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Flutter 项目设置 Flutter 版本

即便使用了 fvm 设置了版本,AdroidStudio Setting 中如果不修改路径,Editor 依然会编译错误。目前还没看懂如何通过命令、文件来记录AdroidStudio Setting中的设置。 fvm list 来查看 flutter 路径:

怪物猎人物语游戏加载慢、卡加载解决方法一览

怪物猎人物语是《怪物猎人》系列史上首部RPG类型游戏。本作采用动漫式的画风风格,在玩法上完全不同于以往系列作,但本作完整的保持《怪物猎人》系列的世界观,依靠正统的RPG玩法给玩家带来不同以往的游戏体验。因为游戏快要上线了,…

5分钟搭建大模型应用!腾讯将「实用主义」贯彻到底

让企业像搭积木一样构建大模型应用,简单可上手。 在经历了一年多的技术锤炼后,大模型正在迈向真刀真枪抢落地的关键阶段。 对于更多企业而言,如何将看上去酷炫的大模型技术落到实处成了眼下的重要命题。 与此同时,「甲子光年」…

04 Pytorch tensor

一:老版本的 variable 二:新版 tensor 曾经:求导相关 如今:数据相关 –dtype: 张量的数据类型,三大类,共9种。torch.FloatTensor, torch.cuda.FloatTensor –shape: 张量的形状。如:&#x…

智慧校园软件开发:为学校量身定制的技术解决方案

为了满足智慧校园的需求,一套全面的软件解决方案被设计出来,旨在优化学校管理和提升教学质量。首先,通过实施统一的认证门户,结合OAuth2和SSO技术,确保不同用户群体能便捷且安全地访问所需资源。 教务管理系统被构建成…

信创数据库沙龙 | 全国预告

#数据库沙龙 #国产数据库 #信创数据库

虚拟DOM

目录 由状态到UI状态渲染命令式操作DOM声明式操作DOM 效率的取舍虚拟DOMVNodePatch 由状态到UI 状态 状态可以是JavaScript中的任意类型。Object、Array、String、Number、Boolean等都可以作为状态,这些状态可能最终会以段落、表单、链接或按钮等元素呈现在用户界…

课程设计---哈夫曼树的编码与解码(Java详解)

目录 一.设计任务&&要求: 二.方案设计报告: 2.1 哈夫曼树编码&译码的设计原理: 2.3设计目的: 2.3设计的主要过程: 2.4程序方法清单: 三.整体实现源码: 四.运行结果展示&…

javaSE:继承

在谈继承之前,我们先观察下面这个代码: //定义一个猫类 class Cat {public String name;public int age;public float weigth;public void eat(){System.out.println(this.name"正在吃饭");}public void mimi(){System.out.println(this.nam…

YoloV9改进策略:注意力篇|BackBone改进|自研像素和通道并行注意力模块(独家原创)

摘要 本文使用FFA-Net的注意力改进YoloV9,FFA-Net提出了通道注意力和像素注意力相结合的方式,提高Block的表征能力,我把这两种注意力结合起来改进YoloV8的BackBone,取得了非常好的效果,即插即用,简单易懂,非常适合大家入手。 论文翻译:《FFA-Net:用于单图像去雾的特征…

nccl 03 记 回顾:从下载,编译到调试 nccl-test

1, 下载与编译 1.1 源码下载 $ git clone https://github.com/NVIDIA/nccl.git 1.2 编译 1.2.1 一般编译: $ make -j src.build 1.2.2 特定架构gpu 编译 $ make -j src.build NVCC_GENCODE"-gencodearchcompute_80,codesm_80" A10…

探究布局模型:从LayoutLM到LayoutLMv2与LayoutXLM

LAYOUT LM 联合建模文档的layout信息和text信息, 预训练 文档理解模型。 模型架构 使用BERT作为backbone, 加入2-D绝对位置信息,图像信息 ,分别捕获token在文档中的相对位置以及字体、文字方向、颜色等视觉信息。 2D位置嵌入 …

装备制造行业数据分析指标体系

数字化飞速发展的时代,多品种、定制化的产品需求、越来越短的产品生命周期、完善的售后服务、极佳的客户体验和快速的交货速度等,使得装备制造行业的经营环境越来越复杂,企业竞争从拼产品、拼价格迈向拼服务,装备制造企业正处于数…

阿里云 debian10.3 sudo apt-get updat 报错的解决方案

阿里云全新的debian10.3(buster)镜像,却无法正常执行 sudo apt-get update。主要报错信息如下: Err:6 http://mirrors.cloud.aliyuncs.com/debian buster-backports Release404 Not Found [IP: 100.100.2.148 80] Err:3 http://mirrors.cloud.aliyuncs…

无引擎游戏开发(1):EasyX图形库引入 + 跟随鼠标移动的小球

来自bilibili up主的Voidmatrix的视频教程:【从零开始的C游戏开发】 一、图形库引入 EasyX在国内文档最多,而且功能函数齐全,最适合入门。 环境配置:vs2022 (官网下载免费版) 百度搜EasyX官方&#xff0…

后方穿行预警系统技术规范(简化版)

后方穿行预警系统技术规范(简化版) 1 系统概述2 预警区域3 预警目标4 预警条件5 指标需求1 系统概述 RCTA后方穿行预警系统工作在驾驶员有倒车意向的时候。在倒车过程中当驾驶员视线因周围障碍物被遮挡而产生碰撞风险时,系统通过光学信号对驾驶员进行提醒。 2 预警区域 RCT…

前端入门篇(五十二)练习6:transition过渡小动画

所以应该先找到第n个li,找到li再找img,li没有找错,底下又各自只有一个img,解决 ul li:nth-child(1) img { } 描述文字从下往上: 一开始描述也在框框下面,当hover时,translateY(0)&#xff0…

【JS重点18】原型链(面试重点)

一:原型链底层原理 以下面一段代码为例,基于原型对象(Star构造函数的原型对象)的继承使得不同构造函数的原型对象关联在一起(此处是最大的构造函数Object原型对象),并且这种关联的关系是一种链…

CleanShot X for Mac v4.7 屏幕滚动长截图录像工具(保姆级教程,小白轻松上手,简单易学)

Mac分享吧 文章目录 一、下载软件二、部分特有功能效果1、截图软件的普遍常用功能(画框、箭头、加文字等)都具备,不再详细介绍2、ABCD、1234等信息标注(每按一下鼠标,即各是A、B、C、D...等)3、截图更换背…

SQL注入-下篇

HTTP注入 一、Referer注入 概述 当你访问一个网站的时候,你的浏览器需要告诉服务器你是从哪个地方访问服务器的。如直接在浏览器器的URL栏输入网址访问网站是没有referer的,需要在一个打开的网站中,点击链接跳转到另一个页面。 Less-19 判…