XGBoost算法Python代码实现(单棵树类)

### XGBoost单棵树类
class XGBoost_Single_Tree(BinaryDecisionTree):# 结点分裂方法def node_split(self, y):# 中间特征所在列feature = int(np.shape(y)[1]/2)# 左子树为真实值,右子树为预测值y_true, y_pred = y[:, :feature], y[:, feature:]return y_true, y_pred# 信息增益计算方法def gain(self, y, y_pred):# 梯度计算Gradient = np.power((y * self.loss.gradient(y, y_pred)).sum(), 2)# Hessian矩阵计算Hessian = self.loss.hess(y, y_pred).sum()return 0.5 * (Gradient / Hessian)# 树分裂增益计算# 式(12.28)def gain_xgb(self, y, y1, y2):# 结点分裂y_true, y_pred = self.node_split(y)y1, y1_pred = self.node_split(y1)y2, y2_pred = self.node_split(y2)true_gain = self.gain(y1, y1_pred)false_gain = self.gain(y2, y2_pred)gain = self.gain(y_true, y_pred)return true_gain + false_gain - gain# 计算叶子结点最优权重def leaf_weight(self, y):y_true, y_pred = self.node_split(y)# 梯度计算gradient = np.sum(y_true * self.loss.gradient(y_true, y_pred), axis=0)# hessian矩阵计算hessian = np.sum(self.loss.hess(y_true, y_pred), axis=0)# 叶子结点得分leaf_weight =  gradient / hessianreturn leaf_weight# 树拟合方法def fit(self, X, y):self.impurity_calculation = self.gain_xgbself._leaf_value_calculation = self.leaf_weightsuper(XGBoost_Single_Tree, self).fit(X, y)

这段代码定义了一个基于 XGBoost 的单棵决策树类,扩展自 BinaryDecisionTree,用于实现梯度提升树的部分功能。以下是对代码的详细解读,包括其核心函数和算法设计思想。


代码结构分析

1. node_split(self, y)

功能:

  • 将输入数据 y 分为真实值 y true y_{\text{true}} ytrue 和预测值 y pred y_{\text{pred}} ypred,便于计算梯度和 Hessian。

实现:

  • y y y 被假定为一个二维数组,其中前半部分是真实值,后半部分是预测值。
  • 使用切片操作将 y y y 拆分为 y true y_{\text{true}} ytrue y pred y_{\text{pred}} ypred
feature = int(np.shape(y)[1] / 2)
y_true, y_pred = y[:, :feature], y[:, feature:]

2. gain(self, y, y_pred)

功能:

  • 计算节点的损失函数增益(单节点增益),公式为:
    Gain = 1 2 ( ∑ Gradient ) 2 Hessian \text{Gain} = \frac{1}{2} \frac{\left( \sum \text{Gradient} \right)^2}{\text{Hessian}} Gain=21Hessian(Gradient)2

实现:

  • 梯度计算 ( y ∗ gradient ) . sum() (y * \text{gradient}).\text{sum()} (ygradient).sum() 计算所有样本的梯度总和,然后取平方。
  • Hessian 计算 hessian . sum() \text{hessian}.\text{sum()} hessian.sum() 求二阶导数的总和。
  • 增益公式:结合梯度平方和 Hessian 求增益。
Gradient = np.power((y * self.loss.gradient(y, y_pred)).sum(), 2)
Hessian = self.loss.hess(y, y_pred).sum()
return 0.5 * (Gradient / Hessian)

3. gain_xgb(self, y, y1, y2)

功能:

  • 计算节点分裂前后的总增益,公式为:
    Gain Split = Gain Left + Gain Right − Gain Parent \text{Gain}_{\text{Split}} = \text{Gain}_{\text{Left}} + \text{Gain}_{\text{Right}} - \text{Gain}_{\text{Parent}} GainSplit=GainLeft+GainRightGainParent

实现:

  • 分裂节点:调用 node_split y , y 1 , y 2 y, y_1, y_2 y,y1,y2 分别拆分为真实值和预测值。
  • 左右子节点增益:分别计算 Gain Left \text{Gain}_{\text{Left}} GainLeft Gain Right \text{Gain}_{\text{Right}} GainRight
  • 父节点增益:计算分裂前的增益 Gain Parent \text{Gain}_{\text{Parent}} GainParent
  • 总增益:将左右子节点增益之和减去父节点增益。
y_true, y_pred = self.node_split(y)
y1, y1_pred = self.node_split(y1)
y2, y2_pred = self.node_split(y2)
true_gain = self.gain(y1, y1_pred)
false_gain = self.gain(y2, y2_pred)
gain = self.gain(y_true, y_pred)
return true_gain + false_gain - gain

4. leaf_weight(self, y)

功能:

  • 计算叶子节点的最优权重,用于更新叶子节点的预测值,公式为:
    w ∗ = ∑ Gradient ∑ Hessian w^* = \frac{\sum \text{Gradient}}{\sum \text{Hessian}} w=HessianGradient

实现:

  • 梯度计算:计算真实值 y true y_{\text{true}} ytrue 和预测值 y pred y_{\text{pred}} ypred 的梯度和。
  • Hessian 计算:计算真实值和预测值的二阶导数和。
  • 权重计算:直接使用公式 w ∗ = Gradient Hessian w^* = \frac{\text{Gradient}}{\text{Hessian}} w=HessianGradient
gradient = np.sum(y_true * self.loss.gradient(y_true, y_pred), axis=0)
hessian = np.sum(self.loss.hess(y_true, y_pred), axis=0)
leaf_weight = gradient / hessian
return leaf_weight

5. fit(self, X, y)

功能:

  • 训练决策树,设置增益计算方法和叶子节点权重计算方法。

实现:

  • gain_xgb 设置为 impurity 计算方法,用于指导树的分裂。
  • leaf_weight 设置为叶子节点的值计算方法。
  • 调用父类 BinaryDecisionTreefit 方法,完成树的训练。
self.impurity_calculation = self.gain_xgb
self._leaf_value_calculation = self.leaf_weight
super(XGBoost_Single_Tree, self).fit(X, y)

整体流程

  1. 数据拆分

    • 使用 node_split 方法将数据拆分为真实值和预测值部分。
  2. 节点增益计算

    • 使用 gain 方法计算单节点的增益。
    • 使用 gain_xgb 方法计算分裂前后的增益。
  3. 叶子节点权重计算

    • 使用 leaf_weight 方法计算叶子节点的最优预测值。
  4. 树的训练

    • 通过 fit 方法调用父类的逻辑,完成树的构建。

与 XGBoost 的公式对比

代码中的实现与 XGBoost 中的增益公式略有不同:

  1. 增益公式的差异

    • 当前实现的增益公式是简化的,缺少正则化参数 λ \lambda λ 和分裂成本 γ \gamma γ
    • 完整的 XGBoost 增益公式为:
      Gain = 1 2 [ G L 2 H L + λ + G R 2 H R + λ − ( G L + G R ) 2 H L + H R + λ ] − γ \text{Gain} = \frac{1}{2} \left[\frac{G_L^2}{H_L + \lambda} + \frac{G_R^2}{H_R + \lambda} - \frac{(G_L + G_R)^2}{H_L + H_R + \lambda}\right] - \gamma Gain=21[HL+λGL2+HR+λGR2HL+HR+λ(GL+GR)2]γ
    • 当前实现适合快速计算增益,但与 XGBoost 的公式不完全一致。
  2. 正则化的实现

    • 当前实现未引入正则化参数 λ \lambda λ γ \gamma γ,这可能会导致树结构过于复杂或分裂无效。
    • 在实际 XGBoost 应用中,正则化项能够有效控制模型复杂度,提升泛化能力。

总结

这段代码实现了 XGBoost 单棵决策树的基础功能,包括节点增益计算、叶子节点权重更新和树的训练逻辑。尽管实现有所简化,但它为构建完整的 XGBoost 算法提供了一个良好的基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/58753.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python的编程基础分支,循环与函数的应用知识

编程基础是学习任何编程语言的必备知识之一。在Python中,分支、循环和函数是常用的编程概念,它们可以让我们编写出更复杂、更灵活的程序。 分支 分支是根据条件来决定程序执行的不同路径。在Python中,我们使用if语句来实现分支。 if 条件:# …

营业执照OCR识别API接口如何用C#调用

服务器和计算设备的性能不断提升,为 OCR 识别提供了更强大的计算能力支持。更快的 CPU、GPU 以及分布式计算技术的应用,使得营业执照图片的处理速度大幅加快,能够在更短的时间内完成大量营业执照的识别工作。 研发人员不断对 OCR 识别算法进…

qt QLocale详解

1、概述 QLocale是Qt框架中的一个类,用于处理与本地化相关的操作。它能够方便地实现日期、时间、数字和货币的格式化和解析,支持不同的语言、区域设置和字符集。QLocale提供了一种跨平台的方式来获取当前系统的语言设置,并返回该语言的本地化…

微服务架构面试内容整理-Eureka

Spring Cloud Netflix 是一个为构建基于 Spring Cloud 的微服务应用提供的解决方案,利用 Netflix 的开源组件来实现常见的分布式系统功能。以下是 Spring Cloud Netflix 的一些主要组件和特点: 服务注册与发现:Eureka 是一个 RESTful 服务,用于注册和发现微服务。服务实例在…

缓存、注解、分页

一.缓存 作用:应用查询上,内存中的块区域。 缓存查询结果,减少与数据库的交互,从而提高运行效率。 1.SqlSession 缓存 1. 又称为一级缓存,mybatis自动开启。 2. 作用范围:同一…

uniapp vue3 使用echarts-gl 绘画3d图表

我自己翻遍了网上,以及插件市场,其实并没有uniapp 上使用echarts-gl的样例,大多数都是使用插件市场的echarts的插件 开始自己尝试直接用echartsgl 没有成功,后来尝试使用threejs 但是也遇到一些问题,最后我看官网的时…

【言语理解】片段阅读整体概述

1.1 题型分类 片段阅读一般有以下六种: 中心理解题 “这段文字意在说明:” “这段文字意在强调:” “这段文字主要介绍了:” “下列对文意概括最恰当的是:”标题拟定题 “最适合做这段文字标题的是:”下文…

linux搭建大数据环境

前期准备工作 友情提醒提前安装好vmware软件,准备好连接虚拟机的客户端 一. 基础环境 1.配置ip地址 修改ip配置文件 [rootnode1 /]# vim /etc/sysconfig/network-scripts/ifcfg-ens33 TYPE"Ethernet" PROXY_METHOD"none" BROWSER_ONLY"no" # …

什么是 OpenTelemetry?

OpenTelemetry 定义 OpenTelemetry (OTel) 是一个开源可观测性框架,允许开发团队以单一、统一的格式生成、处理和传输遥测数据(telemetry data)。它由云原生计算基金会 (CNCF) 开发,旨在提供标准化协议和工具,用于收集…

ESP32 gptimer通用定时器初始化报错:assert failed: timer_ll_set_clock_prescale

背景:IDF版本V5.1.2 ,配置ESP32 通用定时器,实现100HZ,占空比50% 的PWM波形。 根据乐鑫官方的IDF指导文档设置内部计数器的分辨率,计数器每滴答一次相当于 1 / resolution_hz 秒。 (ESP-IDF编程指导文档&a…

AIGC在游戏设计中的应用及影响

文章目录 一、AIGC的基本概念与背景AIGC的主要应用领域AIGC技术背景 二、AIGC在游戏设计中的应用1. 自动化游戏地图与关卡设计示例:自动生成2D平台游戏关卡 2. 角色与物品生成示例:使用GAN生成虚拟角色 3. 游戏剧情与任务文本生成示例:基于GP…

【NOIP普及组】统计单词数

【NOIP普及组】统计单词数 💐The Begin💐点点关注,收藏不迷路💐 一般的文本编辑器都有查找单词的功能,该功能可以快速定位特定单词在文章中的位置,有的还能统计出特定单词在文章中出现的次数。 现在&#x…

Spring Security(5.x, 6.x ) RBAC访问控制

在 Spring Security 中,基于不同版本实现 RBAC(基于角色的访问控制)功能有一些不同的方式。RBAC 的基本原理是:定义用户、角色和权限的关系,并控制不同用户对资源的访问。 Spring Security 不同版本的实现主要在配置方…

Unity 如何优雅的限定文本长度, 包含对特殊字符,汉字,数字的处理。实际的案例包括 用户昵称

常规限定文本长度 ( 通过 UntiyEngine.UI.Inputfiled 附带的长度限定 ) 痛点1 无法对中文,数字,英文进行识别,同样数量的汉字和同样数量的英文像素长度是不一样的,当我们限定固定长度后,在界面上的排版不够美观 痛点2…

多个服务器共享同一个Redis Cluster集群,并且可以使用Redisson分布式锁

Redisson 是一个高级的 Redis 客户端,它支持多种分布式 Java 对象和服务。其中之一就是分布式锁(RLock),它可以跨多个应用实例在多个服务器上使用同一个 Redis 集群,为这些实例提供锁服务。 当你在不同服务器上运行的…

jmeter常用配置元件介绍总结之函数助手

系列文章目录 1.windows、linux安装jmeter及设置中文显示 2.jmeter常用配置元件介绍总结之安装插件 3.jmeter常用配置元件介绍总结之取样器 jmeter常用配置元件介绍总结之函数助手 1.进入函数助手对话框2.常用函数的使用介绍2.1.RandomFromMultipleVars函数2.2.Random函数2.3.R…

发现了NitroShare的一个bug

NitroShare 是一个跨平台的局域网开源网络文件传输应用程序,它利用广播发现机制在本地网络中找到其他安装了 NitroShare 的设备,从而实现这些设备之间的文件和文件夹发送。 NitroShare 支持 Windows、macOS 和 Linux 操作系统。 NitroShare允许我们为…

【 ElementUI 组件Steps 步骤条使用新手详细教程】

本文介绍如何使用 ElementUI 组件库中的步骤条组件完成分步表单设计。 效果图: 基础用法​ 简单的步骤条。 设置 active 属性,接受一个 Number,表明步骤的 index,从 0 开始。 需要定宽的步骤条时,设置 space 属性即…

互联网技术净土?原生鸿蒙开启全新技术征程

鸿蒙生态与开发者的崭新机会 HarmonyOS NEXT承载着华为对未来操作系统的深刻理解,如今已发展为坚实的数字底座。它不仅在技术层面取得了全面突破,还在中国操作系统市场中站稳了脚跟。 当前,HarmonyOS NEXT的代码行数已超过1.1亿&#xff0c…

[linux驱动开发--API框架]--platform、gpio、pinctrl

1. 结构体定义和实例化 // 这个结构体样式并不固定,按需增减成员,可以参考内核的其他驱动代码 struct leddev_dev{dev_t devid; /* 设备号*/struct cdev cdev; /* cdev*/struct class *class; /* 类*/struct device *d…