XGB-11:随机森林

XGBoost通常用于训练梯度提升决策树和其他梯度提升模型。随机森林使用与梯度提升决策树相同的模型表示和推断,但使用不同的训练算法。可以使用XGBoost来训练独立的随机森林,或者将随机森林作为梯度提升的基模型。这里我们专注于训练独立的随机森林。

XGB从早期开始就有用于训练随机森林的API,而Scikit-Learn在0.82版本之后才有封装。

使用XGBoost API训练独立的随机森林

要启用随机森林训练,必须设置以下参数:

  • booster 应设置为 gbtree,因为正在训练森林。由于这是默认值,通常不需要显式设置此参数。

  • subsample 必须设置为小于 1 的值,以启用对训练样本(行)的随机选择。

  • colsample_by 参数之一必须设置为小于 1 的值,以启用对列的随机选择。通常,colsample_bynode 应设置为小于 1 的值,以在每次树分裂时随机抽样列。

  • num_parallel_tree 应设置为正在训练的森林的大小。

  • num_boost_round 应设置为 1,以防止 XGBoost 提升多个随机森林。请注意,这是train() 的关键字参数,不是参数字典的一部分。

  • 在训练随机森林回归时,应将 eta(别名:learning_rate)设置为 1。

  • random_state 可以用于设置随机数生成器的种子。

其他参数应以类似于梯度提升时设置的方式进行设置。例如,对于回归任务,objective 通常将设置为 reg:squarederror,而对于分类任务,将设置为 binary:logisticlambda 应根据所需的正则化权重进行设置,等等。

如果 num_parallel_treenum_boost_round 都大于 1,则训练将使用随机森林和梯度提升策略的组合。它将执行 num_boost_round 轮,在每一轮中提升 num_parallel_tree 棵树的随机森林。如果未启用提前停止,最终模型将由 num_parallel_tree * num_boost_round 棵树组成。

以下是在 GPU 上使用 xgboost 训练随机森林的示例参数字典:

params = {"colsample_bynode": 0.8,"learning_rate": 1,"max_depth": 5,"num_parallel_tree": 100,"objective": "binary:logistic","subsample": 0.8,"tree_method": "hist","device": "cuda",
}

然后可以按如下方式训练随机森林模型:

bst = train(params, dmatrix, num_boost_round=1)
import xgboost as xgb
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_errordiabetes = load_diabetes()
X = diabetes.data
y = diabetes.target# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# Create a DMatrix for XGBoost
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)# Set parameters for random forest training
params = {"booster": "gbtree","subsample": 0.8,"colsample_bynode": 0.8,"num_parallel_tree": 100,"num_boost_round": 1,"eta": 1,"random_state": 42,"objective": "reg:squarederror",
}# Train the random forest model
model = xgb.train(params, dtrain)# Make predictions on the test set
y_pred = model.predict(dtest)# Evaluate the model
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse}")

基于 Scikit-Learn-Like API 实现随机森林

XGBRFClassifierXGBRFRegressor 是类似于 Scikit-Learn 的类,提供了随机森林的功能。 它们基本上是 XGBClassifierXGBRegressor 的版本,用于训练随机森林而不是梯度提升, 并相应地调整了一些参数的默认值和含义。具体来说:

  • n_estimators 指定要训练的森林的大小;它被转换为 num_parallel_tree,而不是 boosting 轮数的数量
  • learning_rate 默认设置为 1
  • colsample_bynodesubsample 默认设置为 0.8
  • booster 始终为 gbtree

例如,可以使用以下代码训练一个随机森林回归器:

from sklearn.model_selection import KFold# Your code ...kf = KFold(n_splits=2)
for train_index, test_index in kf.split(X, y):xgb_model = xgb.XGBRFRegressor(random_state=42).fit(X[train_index], y[train_index])

注意,与使用 train() 相比,这些类的参数选择较少。特别是,使用此 API 无法将随机森林与梯度提升结合起来。

import xgboost as xgb
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from xgboost import XGBRFRegressor
from sklearn.model_selection import KFolddiabetes = load_diabetes()
X = diabetes.data
y = diabetes.targetkf = KFold(n_splits=2)
for train_index, test_index in kf.split(X, y):xgb_model = xgb.XGBRFRegressor(random_state=42).fit(X[train_index], y[train_index])# Make predictions on the test set
y_pred = xgb_model.predict(X_test)# Evaluate the model
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse}")

注意事项

  • XGBoost 使用二阶逼近来近似目标函数。这可能导致与使用目标函数的精确值的随机森林实现不同的结果
  • 在子采样训练样本时,XGBoost 不执行替换操作。每个训练案例在子采样集中可能出现 0 次或 1 次

参考

  • https://xgboost.readthedocs.io/en/latest/tutorials/rf.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/699910.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解决Mysql的Access denied for user权限不足问题

当用客户端工具连接数据库 以root 用户登录后 无法给相关用户授权数据库等操作: 原因: root%表示 root用户 通过任意其他端访问操作 被拒绝! 授权即可: 登录server端: mysql -uroot -pxxxxx(使用账号密码登录linux mysql服务内部) 然后输入如下sql命令…

Sora:开启视频生成新时代的强大人工智能模型

目录 一、Sora模型的诞生与意义 二、Sora模型的技术特点与创新 三、Sora模型的应用前景与影响 四、面临的挑战与未来发展 1、技术挑战 2、道德和伦理问题 3、计算资源需求 4、未来发展方向 随着信息技术的飞速发展,人工智能(AI)已成为…

vue3中使用vuedraggable实现拖拽el-tree数据进分组

看效果: 可以实现单个拖拽、双击添加、按住ctrl键实现多个添加,或者按住shift键实现范围添加,添加到框中的数据,还能拖拽排序 先安装 vuedraggable 这是他的官网 vue.draggable中文文档 - itxst.com npm i vuedraggable -S 直接…

tomcat通过JAVA_OPTS注入自定义变量 —— 筑梦之路

背景说明 tomcat部署的java应用在k8s集群或容器中,想要给tomcat传自定义变量,应该如何实现? 解决方法 1. 在k8s集群或容器环境中通过env或者configmap方式添加自定义的环境变量 比如: my_key: aaaa 2. tomcat下新增脚本&am…

拓扑空间简介

目录 介绍集合论与映射映射相关定义映射(map)映射的一种分类:一一的和到上的 拓扑空间背景介绍开子集开子集的选择 拓扑拓扑空间常见拓扑拓扑子空间同胚其他重要定义 开覆盖紧致性有限开覆盖紧致性 R R R的紧致性 习题 介绍 这是对梁灿彬的《…

shim error: docker-runc not installed on system

问题描述:shim error: docker-runc not installed on system 解决办法: 方式一: cd /usr/libexec/docker/sudo ln -s docker-runc-current docker-runc 方式二: vi /etc/docker/daemon.json # 添加内容如下: {"…

【软件架构】01-架构的概述

1、定义 软件架构就是软件的顶层结构 RUP(统一过程开发)4 1 视图 1)逻辑视图: 描述系统的功能、组件和它们之间的关系。它主要关注系统的静态结构,包括类、接口、包、模块等,并用于表示系统的组织结构…

全栈笔记_工具篇(nvm免安装版自动配置,无需手动设置环境变量)

将免安装压缩包nvm-noinstall.zip解压到指定目录,如:C:\nvm 修改install.cmd: @echo off set /P NVM_PATH="Enter the absolute path where the nvm-windows zip file is extracted/copied to: " set NVM_HOME=%NVM_PATH% setx NVM_HOME "%NVM_HOME%"fo…

C++入门学习(三十六)函数的声明

程序是自上而下运行的&#xff0c;比如我下面的代码&#xff1a; #include <iostream> #include<string> using namespace std;int main() { int a1; int b2;int sumaddNumbers(a,b); cout<<sum;return 0; }int addNumbers(int a, int b) { int sum …

MFC 配置Halcon

1.新建一个MFC 工程&#xff0c;Halcon 为64位&#xff0c;所以先将工程改为x64 > VC 目录设置包含目录和库目录 包含目录 库目录 c/c ->常规 链接器 ->常规 > 链接器输入 在窗口中添加头文件 #include "HalconCpp.h" #include "Halcon.h"…

简单讲解并梳理微信小程序默认几个文件和文件夹结构及其作用

那么 我们来说一下 小程序整个项目结构 它各个文件 和 整体结构 这是我们新创建的一个小程序项目 我们从上到下 分别来看一下 这些文件和目录结构的作用 首先是 pages 它的作用在于存储整个项目所有的 page页面文件 我们小程序官方 是推荐我们将所有page 界面都放在pages目录…

稀疏计算、彩票假说、MoE、SparseGPT

稀疏计算可能是未来10年内最有潜力的深度学习方向之一&#xff0c;稀疏计算模拟了对人脑的观察&#xff0c;人脑在处理信息的时候只有少数神经元在活动&#xff0c;多数神经元是不工作的。而稀疏计算的基本思想是&#xff1a;在计算过程中&#xff0c;将一些不重要的参数设置为…

一招解决 vue数据格式校验时候 async-validator: [‘XXXX is not a number‘]

在vue中 amt数字需要进行纯数字校验&#xff1a; 格式都没问题&#xff0c;但是输入纯数字也会报错&#xff0c;报错如下&#xff1a; async-validator:[‘amt is not a number’] 网上找了一些&#xff0c;但是均为能奏效&#xff0c;尝试如下&#xff1a; 尝试1&#x…

软件保护技术

本文已收录至《全国计算机等级考试——信息 安全技术》专栏 软件保护 软件保护技术其实是一个很大的概念&#xff0c;技术上分为很多不同的分支&#xff0c;主要包括加密、防篡改、软件水印、软件多样化、反逆向技术、虚拟机、基于网络的保护和基于硬件的保护等。 加密是指对…

基于Python网络爬虫的IT招聘就业岗位可视化分析推荐系统

文章目录 基于Python网络爬虫的IT招聘就业岗位可视化分析推荐系统项目概述招聘岗位数据爬虫分析系统展示用户注册登录系统首页IT招聘数据开发岗-javaIT招聘数据开发岗-PythonIT招聘数据开发岗-Android算法方面运维方面测试方面招聘岗位薪资多维度精准预测招聘岗位分析推荐 结语…

FlinkCDC详解

1、FlinkCDC是什么 1.1 CDC是什么 CDC是Chanage Data Capture&#xff08;数据变更捕获&#xff09;的简称。其核心原理就是监测并捕获数据库的变动&#xff08;例如增删改&#xff09;&#xff0c;将这些变更按照发生顺序捕获&#xff0c;将捕获到的数据&#xff0c;写入数据…

TensorFlow 的特点和应用场景介绍

TensorFlow是一个开源的机器学习框架,最初由Google Brain团队开发并于2015年发布。它被设计用于构建、训练和部署各种机器学习算法和深度神经网络模型。TensorFlow具有以下特点: 强大的计算图:TensorFlow使用计算图来表示复杂的计算任务。计算图是由节点(表示操作)和边(表…

Jenkins中Publish Over SSH插件使用(1)

SSH插件 前言Publish Over SSH插件是jenkins里面必不可少的插件之一&#xff0c;主要的功能有两个把jenkins服务器上的文件&#xff0c;传输到远程nginx&#xff0c; 远程执行shell命令和脚本。 1. SSH插件下载与配置 1.1 下载Publish over SSH插件 系统管理—》管理插件 …

Python Web开发记录 Day1:HTML

名人说&#xff1a;莫道桑榆晚&#xff0c;为霞尚满天。——刘禹锡&#xff08;刘梦得&#xff0c;诗豪&#xff09; 创作者&#xff1a;Code_流苏(CSDN)&#xff08;一个喜欢古诗词和编程的Coder&#x1f60a;&#xff09; 目录 一、HTML1、前端引入和HTML标签①前端引入②浏览…

Linux java查看内存消耗 linux查看java程序内存(转载)

Linux java查看内存消耗 linux查看java程序内存 目录 一、jps命令。 二、ps命令。 三、top命令。 四、free命令。 五、df命令。 查看应用的CPU、内存使用情况&#xff0c;使用jps、ps、top、free、df命令查看。 一、jps命令。 可以列出本机所有java应用程序的进程pid。…