XGB-26:model

切片树模型|Slice tree model

当XGBoost中的 booster 参数设置为 gbtreedart 时,算法构建了一个由多棵树组成的树模型。这个树模型可以被切片成多个子模型,每个子模型包含原始模型中一部分树。这个切片过程允许创建更小、更专业的模型,专注于原始模型性能或行为的特定方面。

import numpy as np
import pandas as pdimport xgboost as xgb
from sklearn.datasets import make_classificationnum_classes = 3X, y = make_classification(n_samples=1000, n_informative=5, n_classes=num_classes)dtrain = xgb.DMatrix(data=X, label=y)num_parallel_tree = 4
num_boost_round = 16# total number of built trees is num_parallel_tree * num_classes * num_boost_round# We build a boosted random forest for classification here.
booster = xgb.train({'num_parallel_tree': 4, 'subsample': 0.5, 'num_class': 3},num_boost_round=num_boost_round, dtrain=dtrain)# This is the sliced model, containing [3, 7) forests
# step is also supported with some limitations like negative step is invalid.
sliced: xgb.Booster = booster[3:7]# Access individual tree layer
trees = [_ for _ in booster]
assert len(trees) == num_boost_round
sliced_trees = [_ for _ in sliced]
assert len(sliced_trees) == 4print(trees)
print(sliced_trees)
[<xgboost.core.Booster at 0x7f04a6e20820>,
<xgboost.core.Booster at 0x7f04a6d89250>,
<xgboost.core.Booster at 0x7f04a6fac550>,
<xgboost.core.Booster at 0x7f048b1eaee0>,
<xgboost.core.Booster at 0x7f048b1ea520>,
<xgboost.core.Booster at 0x7f048b1ea850>,
<xgboost.core.Booster at 0x7f048b1ea700>,
<xgboost.core.Booster at 0x7f048b1eabe0>,
<xgboost.core.Booster at 0x7f04a6e3ddf0>,
<xgboost.core.Booster at 0x7f04a6e3d850>,
<xgboost.core.Booster at 0x7f048b1e72e0>,
<xgboost.core.Booster at 0x7f048b1e7970>,
<xgboost.core.Booster at 0x7f048b1e75b0>,
<xgboost.core.Booster at 0x7f048b1e79d0>,
<xgboost.core.Booster at 0x7f048b1e7f10>,
<xgboost.core.Booster at 0x7f048b1e74c0>][<xgboost.core.Booster at 0x7f04a6b82f40>,
<xgboost.core.Booster at 0x7f04736dfeb0>,
<xgboost.core.Booster at 0x7f04736dfbe0>,
<xgboost.core.Booster at 0x7f04736dfd00>]

切片模型是所选树的副本,这意味着在切片过程中模型本身是不可变的。这个特性是早停回调中 save_best 选项的基础。

import numpy as np
import pandas as pdimport xgboost as xgb
from sklearn.datasets import make_classificationfrom scipy.special import logit
from sklearn.datasets import load_svmlight_filedef individual_tree() -> None:"""Get prediction from each individual tree and combine them together."""X_train, y_train = load_svmlight_file(train)X_test, y_test = load_svmlight_file(test)Xy_train = xgb.QuantileDMatrix(X_train, y_train)n_rounds = 4# Specify the base score, otherwise xgboost will estimate one from the training data.base_score = 0.5params = {"max_depth": 2,"eta": 1,"objective": "reg:logistic","tree_method": "hist","base_score": base_score,}booster = xgb.train(params, Xy_train, num_boost_round=n_rounds)# Use logit to inverse the base score back to raw leaf value (margin)scores = np.full((X_test.shape[0],), logit(base_score))for i in range(n_rounds):# - Use output_margin to get raw leaf values# - Use iteration_range to get prediction for only one tree# - Use previous prediction as base marign for the modelXy_test = xgb.DMatrix(X_test, base_margin=scores)if i == n_rounds - 1:# last round, get the transformed predictionscores = booster.predict(Xy_test, iteration_range=(i, i + 1), output_margin=False)else:# get raw leaf value for accumulationscores = booster.predict(Xy_test, iteration_range=(i, i + 1), output_margin=True)full = booster.predict(xgb.DMatrix(X_test), output_margin=False)np.testing.assert_allclose(scores, full)def model_slices() -> None:"""Inference with each individual tree using model slices."""X_train, y_train = load_svmlight_file(train)X_test, y_test = load_svmlight_file(test)Xy_train = xgb.QuantileDMatrix(X_train, y_train)n_rounds = 4# Specify the base score, otherwise xgboost will estimate one from the training data.base_score = 0.5params = {"max_depth": 2,"eta": 1,"objective": "reg:logistic","tree_method": "hist","base_score": base_score,}booster = xgb.train(params, Xy_train, num_boost_round=n_rounds)trees = [booster[t] for t in range(n_rounds)]# Use logit to inverse the base score back to raw leaf value (margin)scores = np.full((X_test.shape[0],), logit(base_score))for i, t in enumerate(trees):# Feed previous scores into base margin.Xy_test = xgb.DMatrix(X_test, base_margin=scores)if i == n_rounds - 1:# last round, get the transformed predictionscores = t.predict(Xy_test, output_margin=False)else:# get raw leaf value for accumulationscores = t.predict(Xy_test, output_margin=True)full = booster.predict(xgb.DMatrix(X_test), output_margin=False)np.testing.assert_allclose(scores, full)individual_tree()
model_slices()

两个函数演示如何使用XGBoost库来获取每棵单独树的预测,并将它们结合起来得到最终的预测结果。

  1. individual_tree 函数:
    • 使用 scipy.special.logit 函数来反转基础分数(base score)回到原始叶值(margin
    • 加载训练和测试数据集。
    • 定义XGBoost参数,并使用 xgb.train 方法来训练模型。
    • 使用 booster.predict 方法来获取每棵树的预测。对于最后一轮,获取转换后的预测;对于其他轮,获取原始叶值以便累加。
    • 使用 np.testing.assert_allclose 方法来验证单独树预测的结合结果与使用整个模型得到的预测结果是否相近。
  2. model_slices 函数:
    • individual_tree 函数类似,也是加载训练和测试数据集,定义XGBoost参数,并使用 xgb.train 方法来训练模型。
    • 使用 booster[t] 来获取每棵树,并将其存储在一个列表中。
    • 使用每棵树的 predict 方法来获取预测。对于最后一轮,获取转换后的预测;对于其他轮,获取原始叶值以便累加。
    • 同样使用 np.testing.assert_allclose 方法来验证单独树预测的结合结果与使用整个模型得到的预测结果是否相近。

参考

  • https://xgboost.readthedocs.io/en/latest/python/model.html
  • https://xgboost.readthedocs.io/en/latest/python/examples/individual_trees.html#sphx-glr-python-examples-individual-trees-py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/822964.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

每日OJ题_01背包④_力扣1049. 最后一块石头的重量 II

目录 力扣1049. 最后一块石头的重量 II 问题解析 解析代码 滚动数组优化代码 力扣1049. 最后一块石头的重量 II 1049. 最后一块石头的重量 II 难度 中等 有一堆石头&#xff0c;用整数数组 stones 表示。其中 stones[i] 表示第 i 块石头的重量。 每一回合&#xff0c;从…

数字乳腺癌组织病理学图像分类的Vision Transformer及其变体

Vision Transformer作为一种基于自注意力机制的高效图像分类工具被提出。近年来出现了基于Poolingbased Vision Transformer (PiT)、卷积视觉变压器(CvT)、CrossFormer、CrossViT、NesT、MaxViT和分离式视觉变压器(SepViT)等新模型。 它们被用于BreakHis和IDC数据集上的图像分…

软件项目管理 - PERT 图

文章目录 1 概述1.1 PERT 图1.2 基础概念 2 相关计算2.1 最早时刻2.2 最迟时刻2.3 关键路径2.4 松弛时间 1 概述 1.1 PERT 图 PERT&#xff1a;Program Evaluation and Review Technique&#xff08;项目评估与评审技术&#xff09; PERT 图是一个有向图&#xff0c;图中的箭…

【C++造神计划】数学运算

​ 数学运算库 //二者选一 #include <cmath> #include <math.h>// #include <math.h> #include <cmath> #include <stdio.h>int main() {float res;res sqrt(2);res abs(-5.3);res sin(0.5*M_PI);res asin(res);res cos(0.5*M_PI);res a…

通过Dockerfile 创建 kali-novnc

创建Dockerfile # 使用官方Kali镜像作为基础镜像 FROM kalilinux/kali-rolling# 设置工作目录 WORKDIR /app# 将当前目录下的所有文件复制到工作目录中 COPY ./run.sh .# 安装项目依赖 RUN apt update -y RUN apt upgrade -y RUN apt install dbus-x11 xfce4 tightvncserver …

【c++】stack和queue使用 stack和queue模拟实现

主页&#xff1a;醋溜马桶圈-CSDN博客 专栏&#xff1a;c_醋溜马桶圈的博客-CSDN博客 gitee&#xff1a;mnxcc (mnxcc) - Gitee.com 目录 1. stack的介绍和使用 1.1 stack的介绍 1.2 stack的使用 1.3 stack的模拟实现 2. queue的介绍和使用 2.1 queue的介绍 2.2 queue的…

Unity中使用NewtonJson序列化继承类时报错解决方法参考

在Unity中使用NewtonJson时&#xff0c;如果要序列化的类继承了别的类&#xff0c;可能报如下错误&#xff1a; JsonSerializationException: Self referencing loop detected for property...... 解决的方法是新建一个JsonSerializerSettings对象&#xff0c;并设置对象的Ref…

Golang面试题五(GC)

目录 1.Golang GC版本 2.常见的垃圾回收算法有以下几种 3.怎么找到程序中无用的对象 引用计数法 根搜索法 GC roots对象 4.java与go的GC对比 5.三色标记法 1.Golang GC版本 Go 1.3版本&#xff1a;普通标记清除法&#xff0c;整体过程需要启动STW&#xff0c;效率极低。…

SpringBoot之JWT令牌校验

SpringBoot之JWT令牌校验 本文根据黑马b站springboot3vue3课程 JWT &#xff08;JSON Web Token&#xff09;是一种开放标准&#xff08;RFC 7519&#xff09;&#xff0c;用于在不同实体之间安全地传输信息。它由三个部分组成&#xff1a;头部&#xff08;Header&#xff09;…

如何实现音乐音频合并?分享3种简单的合并技巧!音频合并的方法

音乐合并&#xff0c;作为一种音乐创作与编辑的手法&#xff0c;已经逐渐在音乐制作领域占据了一席之地。音乐合并不仅是对音乐元素的重新组合&#xff0c;更是对音乐内涵的深化和拓展。它可以将不同的音乐风格和元素巧妙地融合在一起&#xff0c;创造出全新的听觉体验。 一&a…

DonkeyDocker-v1-0渗透思路

MY_BLOG https://xyaxxya.github.io/2024/04/13/DonkeyDocker-v1-0%E6%B8%97%E9%80%8F%E6%80%9D%E8%B7%AF/ date: 2024-04-13 19:15:10 tags: 内网渗透Dockerfile categories: 内网渗透vulnhub 靶机下载地址 https://www.vulnhub.com/entry/donkeydocker-1,189/ 靶机IP&a…

芯片设计围炉札记

文章目录 语言Verilog 和 VHDL 区别 芯片验证 语言 System Verilog的概念以及与verilog的对比 IC 设计软件分析 Verilog 和 VHDL 区别 Verilog HDL 和 VHDL 的区别如下&#xff1a; 语法结构&#xff1a;Verilog的语法结构类似于C语言&#xff0c;而VHDL的语法结构则更接近…

苍穹外卖学习记录(一)

1.JWT令牌认证 JSON Web Token (JWT)是一个开放标准(RFC 7519)&#xff0c;它定义了一种紧凑的、自包含的方式&#xff0c;用于作为JSON对象在各方之间安全地传输信息。该信息可以被验证和信任&#xff0c;因为它是数字签名的。 JWT是目前最常用的一种令牌规范&#xff0c;它最…

QT-编译报库错误(LF/CRLF)

QT-安装后环境问题记录 版本和环境问题 版本和环境 QT5.15.2 Windows10 QT Creator 问题 在QT夸端开发的项目中 &#xff0c;使用QTCreator打开项目pro文件&#xff0c;编译报出很多系统库 及本地文件中的一些问题&#xff0c;具体如图&#xff1a; 后续&#xff0c;我以为…

L1-059 敲笨钟

原题链接&#xff1a;https://pintia.cn/problem-sets/994805046380707840/exam/problems/1111914599412858880?type7&page0 目录 1. 题目描述 2. 思路分析 3. 代码实现 1. 题目描述 微博上有个自称“大笨钟V”的家伙&#xff0c;每天敲钟催促码农们爱惜身体早点睡觉…

数字人结合动捕设备化身虚拟主持人如何玩转大型活动?

由十五运会和残特奥会广州赛区执委会、广州市政府新闻办、广州市科学技术局联合举办的“科技赋能 畅想全运”故事会上&#xff0c;数字人“小运”结合动捕设备化身虚拟主持人惊喜亮相&#xff0c;与真人主持人趣味互动&#xff0c;并向大众介绍了其后续将在大运会上给运动员、工…

photoshop2022增效工具ICOFormat.8bi(PS ico插件)

先退出关闭ps 1、下载插件压缩包&#xff0c;解压出ICOFormat.8bi文件&#xff0c;有两个版本ICOFormat64.8bi对应32位版、ICOFormat64.8bi对应64位版本。 2、把解压后的ICOFormat64.8bi文件覆盖到Photoshop安装目录: C:\Program Files\Adobe\Adobe Photoshop 2022\Required…

【机器学习算法介绍】(6)随机森林

随机森林&#xff08;Random Forest&#xff09;是一种集成学习方法&#xff0c;主要用于分类和回归任务。它通过构建多个决策树&#xff08;Decision Trees&#xff09;并汇总它们的预测结果来提高整体模型的性能。随机森林的核心思想在于“集体智慧”——单个模型&#xff08…

redis的主从复制(docker方式快速入门和实战)

目录 一、主从复制简介 二、配置主从服务器 2.1使用配置文件的形式来主从复制 2.2使用纯代码的方式来进行主从复制&#xff1b; 2.3脱离主服务器 三、一些注意事项 一、主从复制简介 主从复制&#xff0c;是指将一台Redis服务器的数据&#xff0c;复制到其他的Redis服务器…

APEX开发过程中需要注意的小细节5.5

oracle保留小数点后两位的函数 在日常开发中经常用到百分比做数据对比&#xff0c;但是有可能得到的数据是一个多位小数&#xff0c;结果如下所示&#xff1a; 如果想截取部分小数如保留小数点后两位可以怎么做呢&#xff1f; 在Oracle中&#xff0c;可以使用ROUND函数来四舍…