构建简单的梯度提升决策树(GBDT)模型:MATLAB 实现详解

梯度提升决策树(Gradient Boosting Decision Trees,GBDT)是一种强大的集成学习方法,广泛用于回归和分类任务。GBDT 的思想是通过串联多个弱学习器(通常是决策树),逐步优化预测残差,从而提高模型的准确性。在本文中,我们将使用 MATLAB 实现一个简化版的 GBDT 模型,并通过可视化工具来评估模型的效果。

数据准备

首先,我们从 Excel 文件中读取数据。特征位于前几列,而目标值(回归目标)位于最后一列。接下来,我们将数据划分为训练集和测试集,方便后续模型的训练与评估。

% 读取数据
data = readtable('data.xlsx');
X = data{:, 1:end-1};  % 前面所有列为特征
y = data{:, end};      % 最后一列为目标值

 

构建 GBDT 模型

在这段代码中,我们使用了 simpleGBDT 函数来实现梯度提升决策树。函数的核心思路是通过逐步拟合前一棵树的残差来建立多棵决策树模型。具体实现步骤如下:

超参数设定

我们首先定义了一些重要的超参数,包括树的数量、学习率、决策树的最大深度以及测试集的比例。

% 调用 GBDT 模型
num_trees = 35000;         % 迭代次数(树的数量)
learning_rate = 0.001;     % 学习率
max_depth = 25;            % 决策树的最大深度
test_ratio = 0.1;          % 测试集比例

核心 GBDT 训练过程

simpleGBDT 函数中,我们首先对数据进行训练集和测试集的划分。初始化时,预测值设为训练集的均值。每一轮迭代中,计算当前的残差,训练一棵新的决策树来拟合这些残差。然后,通过更新预测值来逐步提高模型的准确性。

function [model, metrics] = simpleGBDT(X, y, num_trees, learning_rate, max_depth, test_ratio)% 将数据集分为训练集和测试集n = size(X, 1);split_index = floor((1 - test_ratio) * n);X_train = X(1:split_index, :);y_train = y(1:split_index);X_test = X(split_index+1:end, :);y_test = y(split_index+1:end);% 初始化预测值为训练集的均值F_train = mean(y_train) * ones(size(y_train));F_test = mean(y_train) * ones(size(y_test));% 保存每棵树的结构model.trees = cell(num_trees, 1);% 梯度提升树的训练过程for t = 1:num_trees% 计算残差residual = y_train - F_train;% 训练一棵决策树以拟合残差tree = fitrtree(X_train, residual, 'MaxNumSplits', max_depth);% 记录这棵树model.trees{t} = tree;% 更新训练集预测值F_train = F_train + learning_rate * predict(tree, X_train);% 更新测试集预测值F_test = F_test + learning_rate * predict(tree, X_test);end
模型评价指标

我们使用多个常见的回归评价指标来评估模型的性能,包括平均绝对误差(MAE)、均方误差(MSE)、均方根误差(RMSE)、决定系数(R²)和平均绝对百分误差(MAPE)。

    % 评价指标MAE = mean(abs(y_test - F_test));MSE = mean((y_test - F_test).^2);RMSE = sqrt(MSE);R2 = 1 - sum((y_test - F_test).^2) / sum((y_test - mean(y_test)).^2);MAPE = mean(abs((y_test - F_test) ./ y_test)) * 100;% 输出评价指标metrics = struct('MAE', MAE, 'MSE', MSE, 'RMSE', RMSE, 'R2', R2, 'MAPE', MAPE);fprintf('MAE: %.4f\n', MAE);fprintf('MSE: %.4f\n', MSE);fprintf('RMSE: %.4f\n', RMSE);fprintf('R2: %.4f\n', R2);fprintf('MAPE: %.4f%%\n', MAPE);

可视化

为了更直观地理解模型的表现,我们绘制了两个图像。第一个图是预测值和真实值的对比,显示模型的拟合效果。第二个图是误差的柱状图,显示每个样本的预测误差。

    % 可视化:真实值与预测值对比figure;plot(y_test, '-o');hold on;plot(F_test, '-x');legend('真实值', '预测值');title('真实值与预测值对比');xlabel('样本');ylabel('值');grid on;% 可视化:误差值柱状图figure;bar(y_test - F_test);title('误差值柱状图');xlabel('样本');ylabel('误差');grid on;
end

结果与分析

运行此模型后,MATLAB 将输出各项评价指标以及生成可视化图表。通过这些图表,用户可以直观地了解模型的拟合情况,观察到哪些数据点误差较大,哪些数据点拟合得较好。

示例输出:
MAE: 0.1345
MSE: 0.0456
RMSE: 0.2137
R2: 0.9854
MAPE: 5.23%

如图所示,模型的平均绝对误差(MAE)较小,说明模型整体上预测的误差不大。同时,决定系数(R²)接近 1,表明模型对数据的拟合度非常高。误差柱状图也能够帮助我们识别那些误差较大的数据点。

总结

在这篇文章中,我们展示了如何在 MATLAB 中实现一个简单的 GBDT 模型。通过这种方法,模型能够逐步优化预测残差,从而实现更高的预测精度。尽管此实现相对简化,但它提供了梯度提升树的基本原理和流程。

如果你对机器学习算法有更多兴趣,欢迎继续探索 GBDT 的其他优化版本,如 XGBoost 或 LightGBM。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/58254.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

文案语音图片视频管理分析系统-视频矩阵

文案语音图片视频管理分析系统-视频矩阵 1.产品介绍 产品介绍方案 产品名称: 智驭视频矩阵深度分析系统(SmartVMatrix) 主要功能: 深度学习驱动的视频内容分析多源视频整合与智能分类高效视频检索与编辑实时视频监控与异常预警…

[LeetCode] 39. 组合总和

题目描述: 给你一个 无重复元素 的整数数组 candidates 和一个目标整数 target ,找出 candidates 中可以使数字和为目标数 target 的 所有 不同组合 ,并以列表形式返回。你可以按 任意顺序 返回这些组合。 candidates 中的 同一个 数字可以…

夯实根基之MySql从入门到精通(一)

夯实根基之MySql从入门到精通(一) 引言1. 使用MySQL数据库和表2. MySql 数据类型2.1. 数字类型2.2. 日期和时间类型2.3. 字符串类型2.4. JSON类型2.5. 其他类型 3. MySql运算符3.1. 数学运算符3.2.比较运算符3.3.逻辑运算符3.4. 位运算符3.5.字符串运算符…

openlayers 封装加载本地geojson数据 - vue3

Geojson数据是矢量数据,主要是点、线、面数据集合 Geojson数据获取:DataV.GeoAtlas地理小工具系列 实现代码如下: import {ref,toRaw} from vue; import { Vector as VectorLayer } from ol/layer.js; import { Vector as VectorSource } fr…

OpenCV图像处理方法:腐蚀操作

腐蚀操作 前提 图像数据为二值的(黑/白) 作用 去掉图片中字上的毛刺 显示图片 读取一个图像文件,并在一个窗口中显示它。用户可以查看这个图像,直到按下任意键,然后程序会关闭显示图像的窗口 # cv2是OpenCV库的P…

【运维心得】U盘启动安装Dell服务器踩坑指南

目录 第一坑:没有键盘选择 第二坑:没有修改mount路径 最近碰到一台Dell服务器R720需要重新安装centos操作系统,由于之前已经配置好了Raid,这里就节省了配置磁盘的步骤。 以前都是通过光盘安装的,考虑到R720是支持U盘…

RAGChecker:显著超越RAGAS,一个精细化评估和诊断 RAG 系统的创新框架

RAG应用已经是当下利用大模型能力的典型应用代表,也获得了极大的推广,各种提升RAG性能的技术层出不穷。然而,如何全面、准确地评估 RAG 系统一直是一个挑战。传统评估方法存在诸多局限性:无法有效评估长文本回复、难以区分检索和生成模块的错误来源、与人…

Jmeter自动化实战

一、前言 由于系统业务流程很复杂,在不同的阶段需要不同的数据,且数据无法重复使用,每次造新的数据特别繁琐,故想着能不能使用jmeter一键造数据 二、创建录制模板 可参考:jmeter录制接口 首先创建一个录制模板 因为会有各种请求头,cookies,签名,认证信息等原因,导致手动复制…

JDK的下载

目录 JDK官网 Windows Ubantu 1.安装JDK 2.确定JDK版本 卸载OpenJDK Centos 1.下载JDK 2.安装JDK 3.验证JDK JDK官网 官网网址:Java Downloads | Oracle Windows 双击运⾏exe⽂件, 选择安装⽬录, 直⾄安装完成 Ubantu 1.安装JDK 更新软件包 sudo apt u…

【YOLO 系列】基于YOLO的工业自动化轴承缺陷检测系统【python源码+Pyqt5界面+数据集+训练代码】

前言 轴承作为机械设备中的关键部件,其性能直接影响到设备的稳定性和寿命。轴承缺陷的早期检测对于预防设备故障、减少维护成本和提高生产效率至关重要。然而,传统的轴承缺陷检测方法往往依赖于人工检查,这不仅效率低下,而且容易…

taro微信小程序assets静态图片不被编译成base64

taro 的微信小程序项目,不希望把在样式文件( css 、 less 、 scss )中引入的 assets/images 文件夹下的图片编译成 base64 。 可以在config/index.ts文件中的mini进行配置。 参考:taro小程序打包时静态图片无法关闭base64转换 …

告别局域网限制:宝塔FTP结合内网穿透工具实现远程高效文件传输

文章目录 前言1. Linux安装Cpolar2. 创建FTP公网地址3. 宝塔FTP服务设置4. FTP服务远程连接小结 5. 固定FTP公网地址6. 固定FTP地址连接 前言 本文主要介绍宝塔FTP文件传输服务如何搭配内网穿透工具,实现随时随地远程连接局域网环境搭建的宝塔FTP文件服务并进行文件…

2024 前端面试题!!! html css js相关

常见的块元素、行内元素以及行内块元素,三者有何不同?​​​​​​​ HTML、XML、XHTML它们之间有什么区别?​​​​​​​ DOCTYPE(⽂档类型) 的作⽤ Doctype是HTML5的文档声明,通过它可以告诉浏览器,使用哪一个HTM…

业务逻辑与代码分离:规则引擎如何实现高效管理?

在这个快速变化、高度信息化的时代,软件系统和业务流程的复杂性日益增加。为了应对这种复杂性,越来越多的企业开始采用规则引擎来应对这种复杂性。我们这次结合JVS规则引擎来解析为什么越来越多人使用规则引擎。 规则引擎定义 规则引擎是一种用于管理和…

关键词排名技巧实用指南提升网站流量的有效策略

内容概要 在数字营销的世界中,关键词排名的影响不可小觑。关键词是用户在搜索引擎中输入的词语,通过精确选择和优化这些关键词,网站能够更轻松地被目标用户发现。提升关键词排名的第一步是了解基本概念,包括关键词的分类、重要性…

数据结构与算法——树与二叉树

树与二叉树 1.树的定义与相关概念 树的示例&#xff1a; 树的集合形式定义 Tree(K,R) 元素集合&#xff1a;K{ki|0<i<n,n>0,ki∈ElemType}&#xff08;n为树中结点数&#xff0c;n0则树为空&#xff0c;n>0则为非空树&#xff09; 对于一棵非空树&#xff0c…

51单片机应用开发---定时器(定时1S,LED以1S间隔闪烁)

实现目标 1、掌握定时器的配置流程&#xff1b; 2、掌握定时器初值的计算方法&#xff1b; 3、具体实现&#xff1a;&#xff08;1&#xff09;1mS中断1次&#xff0c;计数1000次中断&#xff0c;实现定时1S功能&#xff1b;&#xff08;2&#xff09;LED1每隔1S状态取反。 …

TCP/IP Attack Lab

网络拓扑&#xff1a; Task 1: SYN Flooding Attack 收到攻击之前&#xff0c;在Victim主机查看网络连接的状态: 在攻击之前使用User1主机(10.9.0.6)访问Victim(10.9.0.5)主机的 Telnet服务: Task 1.1: Launching the Attack Using Python 在Atacker上建立文件attack-1.py…

VictoriaMetrics 中文教程(10)集群版介绍

VictoriaMetrics 中文教程系列文章&#xff1a; VictoriaMetrics 中文教程&#xff08;01&#xff09;简介VictoriaMetrics 中文教程&#xff08;02&#xff09;安装VictoriaMetrics 中文教程&#xff08;03&#xff09;如何配置 Prometheus 使其把数据远程写入 VictoriaMetri…

Vue 3 插件常见用途和场景

Vue 3插件是一个用于增强Vue应用功能的库或模块&#xff0c;其常见用途和场景包括&#xff1a; 常见用途 添加全局方法或属性&#xff1a; 插件可以向Vue实例添加全局方法或属性&#xff0c;使开发者能够在应用的任何部分方便地调用这些方法或属性。 添加全局资源&#xff1a…